termtap/internal/proxy/handlers_test.go
Hayden Hargreaves 002773e77f test: AI generated all of these tests
Just for the MVP of course. Need to validate the idea.
2026-04-23 19:47:04 -07:00

331 lines
9.4 KiB
Go

package proxy
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"termtap.dev/internal/model"
)
type failingResponseWriter struct {
header http.Header
code int
}
func (w *failingResponseWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *failingResponseWriter) WriteHeader(statusCode int) {
w.code = statusCode
}
func (w *failingResponseWriter) Write(_ []byte) (int, error) {
return 0, io.ErrClosedPipe
}
type hijackFailWriter struct {
header http.Header
code int
}
func (w *hijackFailWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *hijackFailWriter) WriteHeader(statusCode int) {
w.code = statusCode
}
func (w *hijackFailWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (w *hijackFailWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, fmt.Errorf("hijack failed")
}
type dummyConn struct{}
func (dummyConn) Read(_ []byte) (int, error) { return 0, io.EOF }
func (dummyConn) Write(p []byte) (int, error) { return len(p), nil }
func (dummyConn) Close() error { return nil }
func (dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (dummyConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (dummyConn) SetDeadline(_ time.Time) error { return nil }
func (dummyConn) SetReadDeadline(_ time.Time) error { return nil }
func (dummyConn) SetWriteDeadline(_ time.Time) error { return nil }
type writeConnectFailWriter struct {
header http.Header
code int
}
func (w *writeConnectFailWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *writeConnectFailWriter) WriteHeader(statusCode int) {
w.code = statusCode
}
func (w *writeConnectFailWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (w *writeConnectFailWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
rw := bufio.NewReadWriter(
bufio.NewReader(strings.NewReader("")),
bufio.NewWriter(errWriter{}),
)
return dummyConn{}, rw, nil
}
func TestProxyHandler_NonConnectRejectsNonAbsoluteURL(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
h := proxyHandler(ch, nil, ps)
req := httptest.NewRequest(http.MethodGet, "/not-proxy-form", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", w.Code, http.StatusBadRequest)
}
events := drainEvents(t, ch, 1, time.Second)
if !hasEventType(events, model.EventTypeWarn) {
t.Fatalf("expected %s event, got %#v", model.EventTypeWarn, events)
}
}
func TestProxyHandler_NonConnectSuccess(t *testing.T) {
t.Parallel()
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Upstream", "yes")
w.WriteHeader(http.StatusAccepted)
_, _ = w.Write([]byte("pong"))
}))
t.Cleanup(upstream.Close)
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
h := proxyHandler(ch, nil, ps)
req := httptest.NewRequest(http.MethodGet, upstream.URL+"/ping", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
if w.Code != http.StatusAccepted {
t.Fatalf("status = %d, want %d", w.Code, http.StatusAccepted)
}
if got, want := w.Body.String(), "pong"; got != want {
t.Fatalf("body = %q, want %q", got, want)
}
if got := w.Header().Get("X-Upstream"); got != "yes" {
t.Fatalf("X-Upstream header = %q, want yes", got)
}
events := drainEvents(t, ch, 2, time.Second)
if !hasEventType(events, model.EventTypeRequestStarted) {
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
}
if !hasEventType(events, model.EventTypeRequestFinished) {
t.Fatalf("expected %s event", model.EventTypeRequestFinished)
}
}
func TestProxyHandler_NonConnectUpstreamError(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
h := proxyHandler(ch, nil, ps)
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1:1/fail", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
if w.Code != http.StatusBadGateway {
t.Fatalf("status = %d, want %d", w.Code, http.StatusBadGateway)
}
events := drainEvents(t, ch, 2, time.Second)
if !hasEventType(events, model.EventTypeRequestStarted) {
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
}
if !hasEventType(events, model.EventTypeRequestFailed) {
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
}
}
func TestProxyHandler_NonConnectWriteFailureEmitsFailedEvent(t *testing.T) {
t.Parallel()
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = io.Copy(w, bytes.NewBufferString("response-body"))
}))
t.Cleanup(upstream.Close)
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
h := proxyHandler(ch, nil, ps)
req := httptest.NewRequest(http.MethodGet, upstream.URL+"/copy-fail", nil)
w := &failingResponseWriter{}
h.ServeHTTP(w, req)
events := drainEvents(t, ch, 2, time.Second)
if !hasEventType(events, model.EventTypeRequestStarted) {
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
}
if !hasEventType(events, model.EventTypeRequestFailed) {
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
}
}
func TestHandleConnect_EarlyFailures(t *testing.T) {
t.Parallel()
transport := &mockTransport{fn: func(req *http.Request) (*http.Response, error) {
return nil, context.Canceled
}}
tests := []struct {
name string
writer http.ResponseWriter
req *http.Request
ca *CertificateAuthority
wantCode int
wantBody string
wantStat int
}{
{
name: "nil CA returns 502 and failed event",
writer: httptest.NewRecorder(),
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
ca: nil,
wantCode: http.StatusBadGateway,
wantBody: "certificate authority is unavailable",
wantStat: http.StatusBadGateway,
},
{
name: "nil CA with host without port still fails predictably",
writer: httptest.NewRecorder(),
req: &http.Request{Method: http.MethodConnect, Host: "example.com"},
ca: nil,
wantCode: http.StatusBadGateway,
wantBody: "certificate authority is unavailable",
wantStat: http.StatusBadGateway,
},
{
name: "cert mint failure returns 502 and failed event",
writer: httptest.NewRecorder(),
req: &http.Request{Method: http.MethodConnect, Host: ""},
ca: &CertificateAuthority{},
wantCode: http.StatusBadGateway,
wantBody: "failed to mint interception certificate",
wantStat: http.StatusBadGateway,
},
{
name: "non-hijacker writer returns 500 and failed event",
writer: httptest.NewRecorder(),
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
ca: newTestCA(t),
wantCode: http.StatusInternalServerError,
wantBody: "hijack is unavailable",
wantStat: http.StatusInternalServerError,
},
{
name: "hijack failure returns 500 and failed event",
writer: &hijackFailWriter{},
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
ca: newTestCA(t),
wantCode: http.StatusInternalServerError,
wantBody: "CONNECT hijack failed",
wantStat: http.StatusInternalServerError,
},
{
name: "write connect established failure emits failed event",
writer: &writeConnectFailWriter{},
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
ca: newTestCA(t),
wantCode: 0,
wantBody: "CONNECT setup failed",
wantStat: http.StatusBadGateway,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
handleConnect(tt.writer, tt.req, ch, transport, tt.ca, ps)
events := drainEvents(t, ch, 2, time.Second)
if events[0].Type != model.EventTypeRequestStarted {
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
}
if events[1].Type != model.EventTypeRequestFailed {
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
}
if events[1].Request.Method != http.MethodConnect {
t.Fatalf("failed event request method = %q, want CONNECT", events[1].Request.Method)
}
if events[1].Request.Status != tt.wantStat {
t.Fatalf("failed event request status = %d, want %d", events[1].Request.Status, tt.wantStat)
}
if !events[1].Request.Failed || events[1].Request.Pending {
t.Fatalf("failed event request flags = pending:%v failed:%v, want pending:false failed:true", events[1].Request.Pending, events[1].Request.Failed)
}
if !strings.Contains(events[1].Body, tt.wantBody) {
t.Fatalf("failed event body %q does not contain %q", events[1].Body, tt.wantBody)
}
if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok {
if recorder.Code != tt.wantCode {
t.Fatalf("status = %d, want %d", recorder.Code, tt.wantCode)
}
}
if w, ok := tt.writer.(*hijackFailWriter); ok {
if w.code != tt.wantCode {
t.Fatalf("status = %d, want %d", w.code, tt.wantCode)
}
}
})
}
}
// TODO: Add full TLS MITM loop integration test.