ResumeLens/internal/api/middleware_test.go
2026-04-07 13:14:52 -07:00

503 lines
14 KiB
Go

package api
import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"time"
)
func resetRateLimiter() {
rateLimiter.mu.Lock()
rateLimiter.timestamps = make(map[string][]time.Time)
rateLimiter.mu.Unlock()
}
func withFixedNow(t *testing.T, now time.Time) {
t.Helper()
prev := timeNow
timeNow = func() time.Time { return now }
t.Cleanup(func() {
timeNow = prev
})
}
func newRateLimitedRequest(ip string) *http.Request {
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
if ip != "" {
req.RemoteAddr = ip + ":12345"
}
return req
}
func TestRateLimit_5_1_1_AllowsTenRequestsPerHour(t *testing.T) {
resetRateLimiter()
withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC))
var handlerCalls int32
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
atomic.AddInt32(&handlerCalls, 1)
w.WriteHeader(http.StatusOK)
})
h := RateLimit(next)
for i := 0; i < 10; i++ {
rr := httptest.NewRecorder()
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
if rr.Code != http.StatusOK {
t.Fatalf("request %d: expected 200, got %d", i+1, rr.Code)
}
}
if got := atomic.LoadInt32(&handlerCalls); got != 10 {
t.Fatalf("expected 10 handler calls, got %d", got)
}
}
func TestRateLimit_5_1_2_EleventhRequestBlocked(t *testing.T) {
resetRateLimiter()
withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC))
var handlerCalls int32
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
atomic.AddInt32(&handlerCalls, 1)
w.WriteHeader(http.StatusOK)
})
h := RateLimit(next)
for i := 0; i < 10; i++ {
rr := httptest.NewRecorder()
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
}
rr := httptest.NewRecorder()
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
if rr.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429 on 11th request, got %d", rr.Code)
}
if got := atomic.LoadInt32(&handlerCalls); got != 10 {
t.Fatalf("expected handler to be called only 10 times, got %d", got)
}
}
func TestRateLimit_5_1_3_DifferentIPsUnaffected(t *testing.T) {
resetRateLimiter()
withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC))
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
h := RateLimit(next)
for i := 0; i < 10; i++ {
rr := httptest.NewRecorder()
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
}
blocked := httptest.NewRecorder()
h.ServeHTTP(blocked, newRateLimitedRequest("127.0.0.1"))
if blocked.Code != http.StatusTooManyRequests {
t.Fatalf("expected IP A to be blocked, got %d", blocked.Code)
}
allowed := httptest.NewRecorder()
h.ServeHTTP(allowed, newRateLimitedRequest("192.168.1.100"))
if allowed.Code != http.StatusOK {
t.Fatalf("expected IP B to be allowed, got %d", allowed.Code)
}
}
func TestRateLimit_5_2_1_RequestsOlderThanOneHourDontCount(t *testing.T) {
resetRateLimiter()
now := time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC)
withFixedNow(t, now)
old := make([]time.Time, 10)
for i := range old {
old[i] = now.Add(-61 * time.Minute)
}
rateLimiter.mu.Lock()
rateLimiter.timestamps["127.0.0.1"] = old
rateLimiter.mu.Unlock()
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
h := RateLimit(next)
for i := 0; i < 10; i++ {
rr := httptest.NewRecorder()
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
if rr.Code != http.StatusOK {
t.Fatalf("request %d expected 200, got %d", i+1, rr.Code)
}
}
}
func TestRateLimit_5_2_2_RollingWindowAllowsAfterExpiry(t *testing.T) {
resetRateLimiter()
base := time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC)
withFixedNow(t, base.Add(61*time.Minute))
recent := make([]time.Time, 10)
for i := range recent {
recent[i] = base.Add(time.Duration(i) * 3 * time.Minute)
}
rateLimiter.mu.Lock()
rateLimiter.timestamps["127.0.0.1"] = recent
rateLimiter.mu.Unlock()
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
h := RateLimit(next)
rr := httptest.NewRecorder()
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
if rr.Code != http.StatusOK {
t.Fatalf("expected request after rolling expiry to pass, got %d", rr.Code)
}
}
func TestRateLimit_5_2_3_ConcurrentRequestsThreadSafety(t *testing.T) {
resetRateLimiter()
withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC))
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
h := RateLimit(next)
var wg sync.WaitGroup
results := make(chan int, 20)
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
defer wg.Done()
rr := httptest.NewRecorder()
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
results <- rr.Code
}()
}
wg.Wait()
close(results)
okCount := 0
tooManyCount := 0
for code := range results {
switch code {
case http.StatusOK:
okCount++
case http.StatusTooManyRequests:
tooManyCount++
default:
t.Fatalf("unexpected status code: %d", code)
}
}
if okCount != 10 || tooManyCount != 10 {
t.Fatalf("expected exactly 10 allowed and 10 blocked, got %d allowed and %d blocked", okCount, tooManyCount)
}
}
func TestGetClientIP_5_3_1_XForwardedForSingleIP(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
req.Header.Set("X-Forwarded-For", "203.0.113.45")
if ip := getClientIP(req); ip != "203.0.113.45" {
t.Fatalf("expected first XFF IP, got %q", ip)
}
}
func TestGetClientIP_5_3_2_XForwardedForMultipleIPs(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
req.Header.Set("X-Forwarded-For", "203.0.113.45, 198.51.100.67")
if ip := getClientIP(req); ip != "203.0.113.45" {
t.Fatalf("expected first XFF IP, got %q", ip)
}
}
func TestGetClientIP_5_3_3_XRealIP(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
req.Header.Set("X-Real-IP", "192.0.2.123")
if ip := getClientIP(req); ip != "192.0.2.123" {
t.Fatalf("expected X-Real-IP, got %q", ip)
}
}
func TestGetClientIP_5_3_4_RemoteAddrFallback(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
req.RemoteAddr = "127.0.0.1:54321"
if ip := getClientIP(req); ip != "127.0.0.1" {
t.Fatalf("expected remote addr IP without port, got %q", ip)
}
}
func TestGetClientIP_5_3_5_XForwardedForWhitespaceHandling(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
req.Header.Set("X-Forwarded-For", " 203.0.113.45 , 198.51.100.67")
if ip := getClientIP(req); ip != "203.0.113.45" {
t.Fatalf("expected trimmed first XFF IP, got %q", ip)
}
}
func TestRateLimit_5_4_1_ErrorResponseFormat(t *testing.T) {
resetRateLimiter()
withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC))
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
h := RateLimit(next)
for i := 0; i < 10; i++ {
rr := httptest.NewRecorder()
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
}
rr := httptest.NewRecorder()
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
if rr.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429, got %d", rr.Code)
}
if ct := rr.Header().Get("Content-Type"); ct != "application/json" {
t.Fatalf("expected application/json content type, got %q", ct)
}
var body map[string]string
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
t.Fatalf("expected valid JSON body, got parse error: %v", err)
}
expected := "Rate limit exceeded. You can make up to 10 requests per hour. Please try again later."
if body["error"] != expected {
t.Fatalf("expected exact error message %q, got %q", expected, body["error"])
}
}
func TestRateLimit_5_4_2_NoHandlerCallWhenRateLimited(t *testing.T) {
resetRateLimiter()
withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC))
var handlerCalls int32
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
atomic.AddInt32(&handlerCalls, 1)
w.WriteHeader(http.StatusOK)
})
h := RateLimit(next)
for i := 0; i < 10; i++ {
rr := httptest.NewRecorder()
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
}
blocked := httptest.NewRecorder()
h.ServeHTTP(blocked, newRateLimitedRequest("127.0.0.1"))
if blocked.Code != http.StatusTooManyRequests {
t.Fatalf("expected 429, got %d", blocked.Code)
}
if got := atomic.LoadInt32(&handlerCalls); got != 10 {
t.Fatalf("expected handler to stay at 10 calls after block, got %d", got)
}
}
func TestCORS_6_1_1_AllowedOriginViteDevServer(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
h := CORS(next)
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
req.Header.Set("Origin", "http://localhost:5173")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:5173" {
t.Fatalf("expected allowed origin header, got %q", got)
}
}
func TestCORS_6_1_2_AllowedOriginDockerNginx(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
h := CORS(next)
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
req.Header.Set("Origin", "http://localhost")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost" {
t.Fatalf("expected allowed origin header, got %q", got)
}
}
func TestCORS_6_1_3_AllowedOriginDockerNginxWithPort(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
h := CORS(next)
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
req.Header.Set("Origin", "http://localhost:80")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:80" {
t.Fatalf("expected allowed origin header, got %q", got)
}
}
func TestCORS_6_1_4_DisallowedOrigin(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
h := CORS(next)
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
req.Header.Set("Origin", "http://malicious-site.com")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "" {
t.Fatalf("expected no CORS allow-origin for disallowed origin, got %q", got)
}
}
func TestCORS_6_1_5_MissingOriginHeader(t *testing.T) {
called := false
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
h := CORS(next)
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if !called {
t.Fatalf("expected request to continue when Origin is missing")
}
if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "" {
t.Fatalf("expected no CORS headers when Origin missing, got %q", got)
}
}
func TestCORS_6_2_1_AllowedMethodsHeader(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
h := CORS(next)
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
req.Header.Set("Origin", "http://localhost:5173")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if got := rr.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, OPTIONS" {
t.Fatalf("expected allow-methods header, got %q", got)
}
}
func TestCORS_6_2_2_AllowedHeadersHeader(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
h := CORS(next)
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
req.Header.Set("Origin", "http://localhost:5173")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if got := rr.Header().Get("Access-Control-Allow-Headers"); got != "Content-Type" {
t.Fatalf("expected allow-headers header, got %q", got)
}
}
func TestCORS_6_2_3_CredentialsAllowed(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
h := CORS(next)
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
req.Header.Set("Origin", "http://localhost:5173")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if got := rr.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
t.Fatalf("expected allow-credentials header true, got %q", got)
}
}
func TestCORS_6_3_1_OPTIONSPreflight(t *testing.T) {
called := false
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
})
h := CORS(next)
req := httptest.NewRequest(http.MethodOptions, "/api/analyze", nil)
req.Header.Set("Origin", "http://localhost:5173")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
if rr.Code != http.StatusNoContent {
t.Fatalf("expected 204 for preflight, got %d", rr.Code)
}
if called {
t.Fatalf("expected preflight to short-circuit before next handler")
}
if rr.Body.Len() != 0 {
t.Fatalf("expected empty body for preflight, got %q", rr.Body.String())
}
if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:5173" {
t.Fatalf("expected CORS headers on preflight, got origin %q", got)
}
}
func TestCORS_6_3_2_POSTAfterPreflight(t *testing.T) {
var calls int32
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
atomic.AddInt32(&calls, 1)
w.WriteHeader(http.StatusOK)
})
h := CORS(next)
preflight := httptest.NewRequest(http.MethodOptions, "/api/analyze", nil)
preflight.Header.Set("Origin", "http://localhost:5173")
preflightResp := httptest.NewRecorder()
h.ServeHTTP(preflightResp, preflight)
if preflightResp.Code != http.StatusNoContent {
t.Fatalf("expected preflight 204, got %d", preflightResp.Code)
}
post := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
post.Header.Set("Origin", "http://localhost:5173")
postResp := httptest.NewRecorder()
h.ServeHTTP(postResp, post)
if postResp.Code != http.StatusOK {
t.Fatalf("expected POST to be processed, got %d", postResp.Code)
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected next handler called exactly once for POST, got %d", got)
}
if got := postResp.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:5173" {
t.Fatalf("expected CORS headers on POST, got origin %q", got)
}
}