package api import ( "encoding/json" "net/http" "net/http/httptest" "sync" "sync/atomic" "testing" "time" ) func resetRateLimiter() { rateLimiter.mu.Lock() rateLimiter.timestamps = make(map[string][]time.Time) rateLimiter.mu.Unlock() } func withFixedNow(t *testing.T, now time.Time) { t.Helper() prev := timeNow timeNow = func() time.Time { return now } t.Cleanup(func() { timeNow = prev }) } func newRateLimitedRequest(ip string) *http.Request { req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) if ip != "" { req.RemoteAddr = ip + ":12345" } return req } func TestRateLimit_5_1_1_AllowsTenRequestsPerHour(t *testing.T) { resetRateLimiter() withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC)) var handlerCalls int32 next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { atomic.AddInt32(&handlerCalls, 1) w.WriteHeader(http.StatusOK) }) h := RateLimit(next) for i := 0; i < 10; i++ { rr := httptest.NewRecorder() h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1")) if rr.Code != http.StatusOK { t.Fatalf("request %d: expected 200, got %d", i+1, rr.Code) } } if got := atomic.LoadInt32(&handlerCalls); got != 10 { t.Fatalf("expected 10 handler calls, got %d", got) } } func TestRateLimit_5_1_2_EleventhRequestBlocked(t *testing.T) { resetRateLimiter() withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC)) var handlerCalls int32 next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { atomic.AddInt32(&handlerCalls, 1) w.WriteHeader(http.StatusOK) }) h := RateLimit(next) for i := 0; i < 10; i++ { rr := httptest.NewRecorder() h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1")) } rr := httptest.NewRecorder() h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1")) if rr.Code != http.StatusTooManyRequests { t.Fatalf("expected 429 on 11th request, got %d", rr.Code) } if got := atomic.LoadInt32(&handlerCalls); got != 10 { t.Fatalf("expected handler to be called only 10 times, got %d", got) } } func TestRateLimit_5_1_3_DifferentIPsUnaffected(t *testing.T) { resetRateLimiter() withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC)) next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) h := RateLimit(next) for i := 0; i < 10; i++ { rr := httptest.NewRecorder() h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1")) } blocked := httptest.NewRecorder() h.ServeHTTP(blocked, newRateLimitedRequest("127.0.0.1")) if blocked.Code != http.StatusTooManyRequests { t.Fatalf("expected IP A to be blocked, got %d", blocked.Code) } allowed := httptest.NewRecorder() h.ServeHTTP(allowed, newRateLimitedRequest("192.168.1.100")) if allowed.Code != http.StatusOK { t.Fatalf("expected IP B to be allowed, got %d", allowed.Code) } } func TestRateLimit_5_2_1_RequestsOlderThanOneHourDontCount(t *testing.T) { resetRateLimiter() now := time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC) withFixedNow(t, now) old := make([]time.Time, 10) for i := range old { old[i] = now.Add(-61 * time.Minute) } rateLimiter.mu.Lock() rateLimiter.timestamps["127.0.0.1"] = old rateLimiter.mu.Unlock() next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) h := RateLimit(next) for i := 0; i < 10; i++ { rr := httptest.NewRecorder() h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1")) if rr.Code != http.StatusOK { t.Fatalf("request %d expected 200, got %d", i+1, rr.Code) } } } func TestRateLimit_5_2_2_RollingWindowAllowsAfterExpiry(t *testing.T) { resetRateLimiter() base := time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC) withFixedNow(t, base.Add(61*time.Minute)) recent := make([]time.Time, 10) for i := range recent { recent[i] = base.Add(time.Duration(i) * 3 * time.Minute) } rateLimiter.mu.Lock() rateLimiter.timestamps["127.0.0.1"] = recent rateLimiter.mu.Unlock() next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) h := RateLimit(next) rr := httptest.NewRecorder() h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1")) if rr.Code != http.StatusOK { t.Fatalf("expected request after rolling expiry to pass, got %d", rr.Code) } } func TestRateLimit_5_2_3_ConcurrentRequestsThreadSafety(t *testing.T) { resetRateLimiter() withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC)) next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) h := RateLimit(next) var wg sync.WaitGroup results := make(chan int, 20) for i := 0; i < 20; i++ { wg.Add(1) go func() { defer wg.Done() rr := httptest.NewRecorder() h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1")) results <- rr.Code }() } wg.Wait() close(results) okCount := 0 tooManyCount := 0 for code := range results { switch code { case http.StatusOK: okCount++ case http.StatusTooManyRequests: tooManyCount++ default: t.Fatalf("unexpected status code: %d", code) } } if okCount != 10 || tooManyCount != 10 { t.Fatalf("expected exactly 10 allowed and 10 blocked, got %d allowed and %d blocked", okCount, tooManyCount) } } func TestGetClientIP_5_3_1_XForwardedForSingleIP(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) req.Header.Set("X-Forwarded-For", "203.0.113.45") if ip := getClientIP(req); ip != "203.0.113.45" { t.Fatalf("expected first XFF IP, got %q", ip) } } func TestGetClientIP_5_3_2_XForwardedForMultipleIPs(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) req.Header.Set("X-Forwarded-For", "203.0.113.45, 198.51.100.67") if ip := getClientIP(req); ip != "203.0.113.45" { t.Fatalf("expected first XFF IP, got %q", ip) } } func TestGetClientIP_5_3_3_XRealIP(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) req.Header.Set("X-Real-IP", "192.0.2.123") if ip := getClientIP(req); ip != "192.0.2.123" { t.Fatalf("expected X-Real-IP, got %q", ip) } } func TestGetClientIP_5_3_4_RemoteAddrFallback(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) req.RemoteAddr = "127.0.0.1:54321" if ip := getClientIP(req); ip != "127.0.0.1" { t.Fatalf("expected remote addr IP without port, got %q", ip) } } func TestGetClientIP_5_3_5_XForwardedForWhitespaceHandling(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) req.Header.Set("X-Forwarded-For", " 203.0.113.45 , 198.51.100.67") if ip := getClientIP(req); ip != "203.0.113.45" { t.Fatalf("expected trimmed first XFF IP, got %q", ip) } } func TestRateLimit_5_4_1_ErrorResponseFormat(t *testing.T) { resetRateLimiter() withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC)) next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) h := RateLimit(next) for i := 0; i < 10; i++ { rr := httptest.NewRecorder() h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1")) } rr := httptest.NewRecorder() h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1")) if rr.Code != http.StatusTooManyRequests { t.Fatalf("expected 429, got %d", rr.Code) } if ct := rr.Header().Get("Content-Type"); ct != "application/json" { t.Fatalf("expected application/json content type, got %q", ct) } var body map[string]string if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil { t.Fatalf("expected valid JSON body, got parse error: %v", err) } expected := "Rate limit exceeded. You can make up to 10 requests per hour. Please try again later." if body["error"] != expected { t.Fatalf("expected exact error message %q, got %q", expected, body["error"]) } } func TestRateLimit_5_4_2_NoHandlerCallWhenRateLimited(t *testing.T) { resetRateLimiter() withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC)) var handlerCalls int32 next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { atomic.AddInt32(&handlerCalls, 1) w.WriteHeader(http.StatusOK) }) h := RateLimit(next) for i := 0; i < 10; i++ { rr := httptest.NewRecorder() h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1")) } blocked := httptest.NewRecorder() h.ServeHTTP(blocked, newRateLimitedRequest("127.0.0.1")) if blocked.Code != http.StatusTooManyRequests { t.Fatalf("expected 429, got %d", blocked.Code) } if got := atomic.LoadInt32(&handlerCalls); got != 10 { t.Fatalf("expected handler to stay at 10 calls after block, got %d", got) } } func TestCORS_6_1_1_AllowedOriginViteDevServer(t *testing.T) { next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) h := CORS(next) req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) req.Header.Set("Origin", "http://localhost:5173") rr := httptest.NewRecorder() h.ServeHTTP(rr, req) if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:5173" { t.Fatalf("expected allowed origin header, got %q", got) } } func TestCORS_6_1_2_AllowedOriginDockerNginx(t *testing.T) { next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) h := CORS(next) req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) req.Header.Set("Origin", "http://localhost") rr := httptest.NewRecorder() h.ServeHTTP(rr, req) if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost" { t.Fatalf("expected allowed origin header, got %q", got) } } func TestCORS_6_1_3_AllowedOriginDockerNginxWithPort(t *testing.T) { next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) h := CORS(next) req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) req.Header.Set("Origin", "http://localhost:80") rr := httptest.NewRecorder() h.ServeHTTP(rr, req) if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:80" { t.Fatalf("expected allowed origin header, got %q", got) } } func TestCORS_6_1_4_DisallowedOrigin(t *testing.T) { next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) h := CORS(next) req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) req.Header.Set("Origin", "http://malicious-site.com") rr := httptest.NewRecorder() h.ServeHTTP(rr, req) if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "" { t.Fatalf("expected no CORS allow-origin for disallowed origin, got %q", got) } } func TestCORS_6_1_5_MissingOriginHeader(t *testing.T) { called := false next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { called = true w.WriteHeader(http.StatusOK) }) h := CORS(next) req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) rr := httptest.NewRecorder() h.ServeHTTP(rr, req) if !called { t.Fatalf("expected request to continue when Origin is missing") } if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "" { t.Fatalf("expected no CORS headers when Origin missing, got %q", got) } } func TestCORS_6_2_1_AllowedMethodsHeader(t *testing.T) { next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) h := CORS(next) req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) req.Header.Set("Origin", "http://localhost:5173") rr := httptest.NewRecorder() h.ServeHTTP(rr, req) if got := rr.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, OPTIONS" { t.Fatalf("expected allow-methods header, got %q", got) } } func TestCORS_6_2_2_AllowedHeadersHeader(t *testing.T) { next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) h := CORS(next) req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) req.Header.Set("Origin", "http://localhost:5173") rr := httptest.NewRecorder() h.ServeHTTP(rr, req) if got := rr.Header().Get("Access-Control-Allow-Headers"); got != "Content-Type" { t.Fatalf("expected allow-headers header, got %q", got) } } func TestCORS_6_2_3_CredentialsAllowed(t *testing.T) { next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) }) h := CORS(next) req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) req.Header.Set("Origin", "http://localhost:5173") rr := httptest.NewRecorder() h.ServeHTTP(rr, req) if got := rr.Header().Get("Access-Control-Allow-Credentials"); got != "true" { t.Fatalf("expected allow-credentials header true, got %q", got) } } func TestCORS_6_3_1_OPTIONSPreflight(t *testing.T) { called := false next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { called = true w.WriteHeader(http.StatusOK) }) h := CORS(next) req := httptest.NewRequest(http.MethodOptions, "/api/analyze", nil) req.Header.Set("Origin", "http://localhost:5173") rr := httptest.NewRecorder() h.ServeHTTP(rr, req) if rr.Code != http.StatusNoContent { t.Fatalf("expected 204 for preflight, got %d", rr.Code) } if called { t.Fatalf("expected preflight to short-circuit before next handler") } if rr.Body.Len() != 0 { t.Fatalf("expected empty body for preflight, got %q", rr.Body.String()) } if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:5173" { t.Fatalf("expected CORS headers on preflight, got origin %q", got) } } func TestCORS_6_3_2_POSTAfterPreflight(t *testing.T) { var calls int32 next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { atomic.AddInt32(&calls, 1) w.WriteHeader(http.StatusOK) }) h := CORS(next) preflight := httptest.NewRequest(http.MethodOptions, "/api/analyze", nil) preflight.Header.Set("Origin", "http://localhost:5173") preflightResp := httptest.NewRecorder() h.ServeHTTP(preflightResp, preflight) if preflightResp.Code != http.StatusNoContent { t.Fatalf("expected preflight 204, got %d", preflightResp.Code) } post := httptest.NewRequest(http.MethodPost, "/api/analyze", nil) post.Header.Set("Origin", "http://localhost:5173") postResp := httptest.NewRecorder() h.ServeHTTP(postResp, post) if postResp.Code != http.StatusOK { t.Fatalf("expected POST to be processed, got %d", postResp.Code) } if got := atomic.LoadInt32(&calls); got != 1 { t.Fatalf("expected next handler called exactly once for POST, got %d", got) } if got := postResp.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:5173" { t.Fatalf("expected CORS headers on POST, got origin %q", got) } }