120 lines
3.0 KiB
Go
120 lines
3.0 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
)
|
|
|
|
func TestServer_7_3_1_ConfiguredForPort3000AndStartupLog(t *testing.T) {
|
|
if serverAddr != ":3000" {
|
|
t.Fatalf("expected serverAddr :3000, got %q", serverAddr)
|
|
}
|
|
if serverStartupMessage != "Server listening on :3000" {
|
|
t.Fatalf("unexpected startup log message: %q", serverStartupMessage)
|
|
}
|
|
}
|
|
|
|
func TestServer_7_3_2_HandlesRequestsAfterStartup(t *testing.T) {
|
|
r := setupRouter()
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
|
|
if rr.Code == http.StatusNotFound {
|
|
t.Fatalf("expected mounted route to be reachable, got 404")
|
|
}
|
|
}
|
|
|
|
func TestServer_7_3_3_LoggerMiddlewareActive(t *testing.T) {
|
|
var loggerInvoked int32
|
|
prevLogger := middleware.DefaultLogger
|
|
middleware.DefaultLogger = func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
atomic.AddInt32(&loggerInvoked, 1)
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
t.Cleanup(func() {
|
|
middleware.DefaultLogger = prevLogger
|
|
})
|
|
|
|
r := setupRouter()
|
|
|
|
var buf bytes.Buffer
|
|
prevWriter := log.Writer()
|
|
log.SetOutput(&buf)
|
|
t.Cleanup(func() {
|
|
log.SetOutput(prevWriter)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/api/analyze", nil)
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
|
|
out := buf.String()
|
|
if atomic.LoadInt32(&loggerInvoked) == 0 {
|
|
t.Fatalf("expected logger middleware to be invoked")
|
|
}
|
|
if out != "" && !strings.Contains(out, "/api/analyze") {
|
|
t.Fatalf("unexpected logger output content: %q", out)
|
|
}
|
|
}
|
|
|
|
func TestServer_7_2_3_MiddlewareChainOrderBehavior(t *testing.T) {
|
|
order := make([]string, 0, 5)
|
|
mark := func(name string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
order = append(order, name)
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
r := chi.NewRouter()
|
|
r.Use(mark("logger"))
|
|
r.Use(mark("recoverer"))
|
|
r.Use(mark("cors"))
|
|
r.Use(mark("ratelimit"))
|
|
r.Get("/probe", func(w http.ResponseWriter, _ *http.Request) {
|
|
order = append(order, "handler")
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/probe", nil)
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", rr.Code)
|
|
}
|
|
joined := strings.Join(order, ",")
|
|
if joined != "logger,recoverer,cors,ratelimit,handler" {
|
|
t.Fatalf("expected middleware chain order, got %q", joined)
|
|
}
|
|
}
|
|
|
|
func TestServer_7_2_4_PanicRecovery(t *testing.T) {
|
|
r := setupRouter().(*chi.Mux)
|
|
r.Get("/panic", func(http.ResponseWriter, *http.Request) {
|
|
panic("boom")
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
|
|
rr := httptest.NewRecorder()
|
|
r.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusInternalServerError {
|
|
t.Fatalf("expected 500 from recoverer, got %d", rr.Code)
|
|
}
|
|
}
|