503 lines
14 KiB
Go
503 lines
14 KiB
Go
package api
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func resetRateLimiter() {
|
|
rateLimiter.mu.Lock()
|
|
rateLimiter.timestamps = make(map[string][]time.Time)
|
|
rateLimiter.mu.Unlock()
|
|
}
|
|
|
|
func withFixedNow(t *testing.T, now time.Time) {
|
|
t.Helper()
|
|
prev := timeNow
|
|
timeNow = func() time.Time { return now }
|
|
t.Cleanup(func() {
|
|
timeNow = prev
|
|
})
|
|
}
|
|
|
|
func newRateLimitedRequest(ip string) *http.Request {
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
if ip != "" {
|
|
req.RemoteAddr = ip + ":12345"
|
|
}
|
|
return req
|
|
}
|
|
|
|
func TestRateLimit_5_1_1_AllowsTenRequestsPerHour(t *testing.T) {
|
|
resetRateLimiter()
|
|
withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC))
|
|
|
|
var handlerCalls int32
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
atomic.AddInt32(&handlerCalls, 1)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := RateLimit(next)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("request %d: expected 200, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
|
|
if got := atomic.LoadInt32(&handlerCalls); got != 10 {
|
|
t.Fatalf("expected 10 handler calls, got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestRateLimit_5_1_2_EleventhRequestBlocked(t *testing.T) {
|
|
resetRateLimiter()
|
|
withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC))
|
|
|
|
var handlerCalls int32
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
atomic.AddInt32(&handlerCalls, 1)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := RateLimit(next)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
|
|
}
|
|
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
|
|
|
|
if rr.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("expected 429 on 11th request, got %d", rr.Code)
|
|
}
|
|
if got := atomic.LoadInt32(&handlerCalls); got != 10 {
|
|
t.Fatalf("expected handler to be called only 10 times, got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestRateLimit_5_1_3_DifferentIPsUnaffected(t *testing.T) {
|
|
resetRateLimiter()
|
|
withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC))
|
|
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := RateLimit(next)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
|
|
}
|
|
|
|
blocked := httptest.NewRecorder()
|
|
h.ServeHTTP(blocked, newRateLimitedRequest("127.0.0.1"))
|
|
if blocked.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("expected IP A to be blocked, got %d", blocked.Code)
|
|
}
|
|
|
|
allowed := httptest.NewRecorder()
|
|
h.ServeHTTP(allowed, newRateLimitedRequest("192.168.1.100"))
|
|
if allowed.Code != http.StatusOK {
|
|
t.Fatalf("expected IP B to be allowed, got %d", allowed.Code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimit_5_2_1_RequestsOlderThanOneHourDontCount(t *testing.T) {
|
|
resetRateLimiter()
|
|
now := time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC)
|
|
withFixedNow(t, now)
|
|
|
|
old := make([]time.Time, 10)
|
|
for i := range old {
|
|
old[i] = now.Add(-61 * time.Minute)
|
|
}
|
|
rateLimiter.mu.Lock()
|
|
rateLimiter.timestamps["127.0.0.1"] = old
|
|
rateLimiter.mu.Unlock()
|
|
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := RateLimit(next)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("request %d expected 200, got %d", i+1, rr.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRateLimit_5_2_2_RollingWindowAllowsAfterExpiry(t *testing.T) {
|
|
resetRateLimiter()
|
|
base := time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC)
|
|
withFixedNow(t, base.Add(61*time.Minute))
|
|
|
|
recent := make([]time.Time, 10)
|
|
for i := range recent {
|
|
recent[i] = base.Add(time.Duration(i) * 3 * time.Minute)
|
|
}
|
|
rateLimiter.mu.Lock()
|
|
rateLimiter.timestamps["127.0.0.1"] = recent
|
|
rateLimiter.mu.Unlock()
|
|
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := RateLimit(next)
|
|
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("expected request after rolling expiry to pass, got %d", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimit_5_2_3_ConcurrentRequestsThreadSafety(t *testing.T) {
|
|
resetRateLimiter()
|
|
withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC))
|
|
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := RateLimit(next)
|
|
|
|
var wg sync.WaitGroup
|
|
results := make(chan int, 20)
|
|
for i := 0; i < 20; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
|
|
results <- rr.Code
|
|
}()
|
|
}
|
|
wg.Wait()
|
|
close(results)
|
|
|
|
okCount := 0
|
|
tooManyCount := 0
|
|
for code := range results {
|
|
switch code {
|
|
case http.StatusOK:
|
|
okCount++
|
|
case http.StatusTooManyRequests:
|
|
tooManyCount++
|
|
default:
|
|
t.Fatalf("unexpected status code: %d", code)
|
|
}
|
|
}
|
|
|
|
if okCount != 10 || tooManyCount != 10 {
|
|
t.Fatalf("expected exactly 10 allowed and 10 blocked, got %d allowed and %d blocked", okCount, tooManyCount)
|
|
}
|
|
}
|
|
|
|
func TestGetClientIP_5_3_1_XForwardedForSingleIP(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
req.Header.Set("X-Forwarded-For", "203.0.113.45")
|
|
|
|
if ip := getClientIP(req); ip != "203.0.113.45" {
|
|
t.Fatalf("expected first XFF IP, got %q", ip)
|
|
}
|
|
}
|
|
|
|
func TestGetClientIP_5_3_2_XForwardedForMultipleIPs(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
req.Header.Set("X-Forwarded-For", "203.0.113.45, 198.51.100.67")
|
|
|
|
if ip := getClientIP(req); ip != "203.0.113.45" {
|
|
t.Fatalf("expected first XFF IP, got %q", ip)
|
|
}
|
|
}
|
|
|
|
func TestGetClientIP_5_3_3_XRealIP(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
req.Header.Set("X-Real-IP", "192.0.2.123")
|
|
|
|
if ip := getClientIP(req); ip != "192.0.2.123" {
|
|
t.Fatalf("expected X-Real-IP, got %q", ip)
|
|
}
|
|
}
|
|
|
|
func TestGetClientIP_5_3_4_RemoteAddrFallback(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
req.RemoteAddr = "127.0.0.1:54321"
|
|
|
|
if ip := getClientIP(req); ip != "127.0.0.1" {
|
|
t.Fatalf("expected remote addr IP without port, got %q", ip)
|
|
}
|
|
}
|
|
|
|
func TestGetClientIP_5_3_5_XForwardedForWhitespaceHandling(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
req.Header.Set("X-Forwarded-For", " 203.0.113.45 , 198.51.100.67")
|
|
|
|
if ip := getClientIP(req); ip != "203.0.113.45" {
|
|
t.Fatalf("expected trimmed first XFF IP, got %q", ip)
|
|
}
|
|
}
|
|
|
|
func TestRateLimit_5_4_1_ErrorResponseFormat(t *testing.T) {
|
|
resetRateLimiter()
|
|
withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC))
|
|
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := RateLimit(next)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
|
|
}
|
|
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
|
|
|
|
if rr.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("expected 429, got %d", rr.Code)
|
|
}
|
|
if ct := rr.Header().Get("Content-Type"); ct != "application/json" {
|
|
t.Fatalf("expected application/json content type, got %q", ct)
|
|
}
|
|
|
|
var body map[string]string
|
|
if err := json.Unmarshal(rr.Body.Bytes(), &body); err != nil {
|
|
t.Fatalf("expected valid JSON body, got parse error: %v", err)
|
|
}
|
|
expected := "Rate limit exceeded. You can make up to 10 requests per hour. Please try again later."
|
|
if body["error"] != expected {
|
|
t.Fatalf("expected exact error message %q, got %q", expected, body["error"])
|
|
}
|
|
}
|
|
|
|
func TestRateLimit_5_4_2_NoHandlerCallWhenRateLimited(t *testing.T) {
|
|
resetRateLimiter()
|
|
withFixedNow(t, time.Date(2026, 4, 7, 10, 0, 0, 0, time.UTC))
|
|
|
|
var handlerCalls int32
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
atomic.AddInt32(&handlerCalls, 1)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := RateLimit(next)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, newRateLimitedRequest("127.0.0.1"))
|
|
}
|
|
|
|
blocked := httptest.NewRecorder()
|
|
h.ServeHTTP(blocked, newRateLimitedRequest("127.0.0.1"))
|
|
|
|
if blocked.Code != http.StatusTooManyRequests {
|
|
t.Fatalf("expected 429, got %d", blocked.Code)
|
|
}
|
|
if got := atomic.LoadInt32(&handlerCalls); got != 10 {
|
|
t.Fatalf("expected handler to stay at 10 calls after block, got %d", got)
|
|
}
|
|
}
|
|
|
|
func TestCORS_6_1_1_AllowedOriginViteDevServer(t *testing.T) {
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := CORS(next)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
req.Header.Set("Origin", "http://localhost:5173")
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, req)
|
|
|
|
if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:5173" {
|
|
t.Fatalf("expected allowed origin header, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestCORS_6_1_2_AllowedOriginDockerNginx(t *testing.T) {
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := CORS(next)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
req.Header.Set("Origin", "http://localhost")
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, req)
|
|
|
|
if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost" {
|
|
t.Fatalf("expected allowed origin header, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestCORS_6_1_3_AllowedOriginDockerNginxWithPort(t *testing.T) {
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := CORS(next)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
req.Header.Set("Origin", "http://localhost:80")
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, req)
|
|
|
|
if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:80" {
|
|
t.Fatalf("expected allowed origin header, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestCORS_6_1_4_DisallowedOrigin(t *testing.T) {
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := CORS(next)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
req.Header.Set("Origin", "http://malicious-site.com")
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, req)
|
|
|
|
if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "" {
|
|
t.Fatalf("expected no CORS allow-origin for disallowed origin, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestCORS_6_1_5_MissingOriginHeader(t *testing.T) {
|
|
called := false
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
called = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := CORS(next)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, req)
|
|
|
|
if !called {
|
|
t.Fatalf("expected request to continue when Origin is missing")
|
|
}
|
|
if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "" {
|
|
t.Fatalf("expected no CORS headers when Origin missing, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestCORS_6_2_1_AllowedMethodsHeader(t *testing.T) {
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := CORS(next)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
req.Header.Set("Origin", "http://localhost:5173")
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, req)
|
|
|
|
if got := rr.Header().Get("Access-Control-Allow-Methods"); got != "GET, POST, OPTIONS" {
|
|
t.Fatalf("expected allow-methods header, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestCORS_6_2_2_AllowedHeadersHeader(t *testing.T) {
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := CORS(next)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
req.Header.Set("Origin", "http://localhost:5173")
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, req)
|
|
|
|
if got := rr.Header().Get("Access-Control-Allow-Headers"); got != "Content-Type" {
|
|
t.Fatalf("expected allow-headers header, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestCORS_6_2_3_CredentialsAllowed(t *testing.T) {
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := CORS(next)
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
req.Header.Set("Origin", "http://localhost:5173")
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, req)
|
|
|
|
if got := rr.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
|
|
t.Fatalf("expected allow-credentials header true, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestCORS_6_3_1_OPTIONSPreflight(t *testing.T) {
|
|
called := false
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
called = true
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := CORS(next)
|
|
|
|
req := httptest.NewRequest(http.MethodOptions, "/api/analyze", nil)
|
|
req.Header.Set("Origin", "http://localhost:5173")
|
|
rr := httptest.NewRecorder()
|
|
h.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusNoContent {
|
|
t.Fatalf("expected 204 for preflight, got %d", rr.Code)
|
|
}
|
|
if called {
|
|
t.Fatalf("expected preflight to short-circuit before next handler")
|
|
}
|
|
if rr.Body.Len() != 0 {
|
|
t.Fatalf("expected empty body for preflight, got %q", rr.Body.String())
|
|
}
|
|
if got := rr.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:5173" {
|
|
t.Fatalf("expected CORS headers on preflight, got origin %q", got)
|
|
}
|
|
}
|
|
|
|
func TestCORS_6_3_2_POSTAfterPreflight(t *testing.T) {
|
|
var calls int32
|
|
next := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
atomic.AddInt32(&calls, 1)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
h := CORS(next)
|
|
|
|
preflight := httptest.NewRequest(http.MethodOptions, "/api/analyze", nil)
|
|
preflight.Header.Set("Origin", "http://localhost:5173")
|
|
preflightResp := httptest.NewRecorder()
|
|
h.ServeHTTP(preflightResp, preflight)
|
|
if preflightResp.Code != http.StatusNoContent {
|
|
t.Fatalf("expected preflight 204, got %d", preflightResp.Code)
|
|
}
|
|
|
|
post := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
post.Header.Set("Origin", "http://localhost:5173")
|
|
postResp := httptest.NewRecorder()
|
|
h.ServeHTTP(postResp, post)
|
|
|
|
if postResp.Code != http.StatusOK {
|
|
t.Fatalf("expected POST to be processed, got %d", postResp.Code)
|
|
}
|
|
if got := atomic.LoadInt32(&calls); got != 1 {
|
|
t.Fatalf("expected next handler called exactly once for POST, got %d", got)
|
|
}
|
|
if got := postResp.Header().Get("Access-Control-Allow-Origin"); got != "http://localhost:5173" {
|
|
t.Fatalf("expected CORS headers on POST, got origin %q", got)
|
|
}
|
|
}
|