termtap/internal/proxy/requests_test.go
Hayden Hargreaves 002773e77f test: AI generated all of these tests
Just for the MVP of course. Need to validate the idea.
2026-04-23 19:47:04 -07:00

276 lines
7.9 KiB
Go

package proxy
import (
"errors"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/google/uuid"
"termtap.dev/internal/model"
)
type mockTransport struct {
fn func(*http.Request) (*http.Response, error)
}
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return m.fn(req)
}
func TestRoundTripCapturedRequest_Success(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
reqURL, err := url.Parse("http://example.com/path?q=1")
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
req := &http.Request{
Method: http.MethodPost,
URL: reqURL,
Host: "example.com",
Header: http.Header{
"Connection": {"X-Hop"},
"X-Hop": {"drop"},
"Authorization": {"Bearer token"},
"Content-Type": {"text/plain"},
},
Body: io.NopCloser(strings.NewReader("req\nbody")),
}
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
if got := outReq.Header.Get("Connection"); got != "" {
t.Fatalf("Connection header should be stripped, got %q", got)
}
if got := outReq.Header.Get("X-Hop"); got != "" {
t.Fatalf("header listed in Connection should be stripped, got %q", got)
}
data, readErr := io.ReadAll(outReq.Body)
if readErr != nil {
t.Fatalf("ReadAll(outReq.Body) error = %v", readErr)
}
if got, want := string(data), "req\nbody"; got != want {
t.Fatalf("request body = %q, want %q", got, want)
}
return &http.Response{
StatusCode: http.StatusCreated,
Header: http.Header{
"Set-Cookie": {"session=top-secret"},
"Connection": {"close"},
"Content-Type": {"text/plain"},
},
Body: io.NopCloser(strings.NewReader("resp\nbody")),
}, nil
}}
resp, captured, responsePreview, err := roundTripCapturedRequest(req, transport, ch, "", false)
if err != nil {
t.Fatalf("roundTripCapturedRequest() error = %v", err)
}
if resp == nil {
t.Fatal("roundTripCapturedRequest() returned nil response")
}
if responsePreview == nil {
t.Fatal("roundTripCapturedRequest() returned nil response preview")
}
if _, readErr := io.ReadAll(resp.Body); readErr != nil {
t.Fatalf("ReadAll(resp.Body) error = %v", readErr)
}
_ = resp.Body.Close()
if got, want := string(captured.RequestData), `req\nbody`; got != want {
t.Fatalf("captured.RequestData = %q, want %q", got, want)
}
if got, want := string(responsePreview.Preview()), `resp\nbody`; got != want {
t.Fatalf("responsePreview = %q, want %q", got, want)
}
if got := captured.RequestHeaders.Get("Authorization"); got != "[REDACTED]" {
t.Fatalf("Authorization header = %q, want [REDACTED]", got)
}
if got := captured.ResponseHeaders.Get("Set-Cookie"); got != "[REDACTED]" {
t.Fatalf("Set-Cookie header = %q, want [REDACTED]", got)
}
if got := captured.ResponseHeaders.Get("Connection"); got != "" {
t.Fatalf("Connection should be stripped from response headers, got %q", got)
}
events := drainEvents(t, ch, 1, time.Second)
if events[0].Type != model.EventTypeRequestStarted {
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
}
}
func TestRoundTripCapturedRequest_ErrorPath(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
reqURL, err := url.Parse("http://example.com/fail")
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
req := &http.Request{
Method: http.MethodPost,
URL: reqURL,
Host: "example.com",
Header: http.Header{"Content-Type": {"text/plain"}},
Body: io.NopCloser(strings.NewReader("boom\nbody")),
}
wantErr := errors.New("upstream failed")
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
_, _ = io.ReadAll(outReq.Body)
return nil, wantErr
}}
resp, captured, responsePreview, gotErr := roundTripCapturedRequest(req, transport, ch, "", false)
if !errors.Is(gotErr, wantErr) {
t.Fatalf("error = %v, want %v", gotErr, wantErr)
}
if resp != nil {
t.Fatalf("response = %#v, want nil", resp)
}
if responsePreview != nil {
t.Fatalf("responsePreview = %#v, want nil", responsePreview)
}
if got, want := string(captured.RequestData), `boom\nbody`; got != want {
t.Fatalf("captured.RequestData = %q, want %q", got, want)
}
events := drainEvents(t, ch, 1, time.Second)
if events[0].Type != model.EventTypeRequestStarted {
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
}
}
func TestRoundTripCapturedRequest_InterceptedTLSDefaults(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
u, err := url.Parse("/secure?p=1")
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
req := &http.Request{
Method: http.MethodGet,
URL: u,
Header: http.Header{},
}
const defaultHost = "api.example.com:443"
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
if got := outReq.URL.Scheme; got != "https" {
t.Fatalf("URL.Scheme = %q, want https", got)
}
if got := outReq.URL.Host; got != defaultHost {
t.Fatalf("URL.Host = %q, want %q", got, defaultHost)
}
if got := outReq.Host; got != defaultHost {
t.Fatalf("Host = %q, want %q", got, defaultHost)
}
return &http.Response{
StatusCode: http.StatusNoContent,
Header: http.Header{"Content-Type": {"text/plain"}},
Body: io.NopCloser(strings.NewReader("")),
}, nil
}}
_, _, _, gotErr := roundTripCapturedRequest(req, transport, ch, defaultHost, true)
if gotErr != nil {
t.Fatalf("roundTripCapturedRequest() error = %v", gotErr)
}
}
func TestNewConnectRequest(t *testing.T) {
t.Parallel()
now := time.Now().Add(-time.Second)
req := &http.Request{Method: http.MethodConnect, Host: "example.com:443"}
got := newConnectRequest(req, now)
if got.Method != http.MethodConnect {
t.Fatalf("Method = %q, want CONNECT", got.Method)
}
if got.Host != req.Host || got.URL != req.Host || got.RawURL != req.Host {
t.Fatalf("connect request host/url/raw mismatch: %#v", got)
}
if !got.Pending || got.Failed {
t.Fatalf("Pending/Failed = (%v,%v), want (true,false)", got.Pending, got.Failed)
}
if got.Status != -1 {
t.Fatalf("Status = %d, want -1", got.Status)
}
if got.StartTime != now {
t.Fatalf("StartTime = %v, want %v", got.StartTime, now)
}
if got.ID == uuid.Nil {
t.Fatal("ID must be non-zero UUID")
}
}
func TestStartFinishFailRequestEvents(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 4)
req := model.Request{
ID: uuid.New(),
Method: http.MethodGet,
RawURL: "http://example.com/a",
StartTime: time.Now().Add(-3 * time.Millisecond),
Pending: true,
}
startRequest(ch, req)
events := drainEvents(t, ch, 1, time.Second)
startEv := events[0]
if startEv.Type != model.EventTypeRequestStarted {
t.Fatalf("start event type = %s, want %s", startEv.Type, model.EventTypeRequestStarted)
}
if startEv.Request.Pending != true {
t.Fatalf("start request pending = %v, want true", startEv.Request.Pending)
}
finishRequest(ch, req, http.StatusOK)
events = drainEvents(t, ch, 1, time.Second)
finishEv := events[0]
if finishEv.Type != model.EventTypeRequestFinished {
t.Fatalf("finish event type = %s, want %s", finishEv.Type, model.EventTypeRequestFinished)
}
if finishEv.Request.Pending {
t.Fatal("finished request should not be pending")
}
if finishEv.Request.Failed {
t.Fatal("finished request should not be failed")
}
if finishEv.Request.Status != http.StatusOK {
t.Fatalf("finished status = %d, want %d", finishEv.Request.Status, http.StatusOK)
}
failRequest(ch, req, http.StatusBadGateway, "upstream error")
events = drainEvents(t, ch, 1, time.Second)
failEv := events[0]
if failEv.Type != model.EventTypeRequestFailed {
t.Fatalf("fail event type = %s, want %s", failEv.Type, model.EventTypeRequestFailed)
}
if failEv.Request.Pending {
t.Fatal("failed request should not be pending")
}
if !failEv.Request.Failed {
t.Fatal("failed request should be marked failed")
}
if failEv.Request.Status != http.StatusBadGateway {
t.Fatalf("failed status = %d, want %d", failEv.Request.Status, http.StatusBadGateway)
}
}