From 002773e77faff160d03972b494d934d03a06d85d Mon Sep 17 00:00:00 2001 From: Hayden Hargreaves Date: Thu, 23 Apr 2026 19:47:04 -0700 Subject: [PATCH] test: AI generated all of these tests Just for the MVP of course. Need to validate the idea. --- TEST_COVERAGE_SUMMARY.md | 45 +++ cmd/tap/main_test.go | 49 +++ internal/app/integration_session_test.go | 109 ++++++ internal/app/process.go | 15 +- internal/app/process_test.go | 223 +++++++++++ internal/app/proxy_test.go | 183 +++++++++ internal/app/session_test.go | 247 ++++++++++++ internal/app/test_helpers_test.go | 111 ++++++ internal/cli/run.go | 73 ++-- internal/cli/run_test.go | 356 +++++++++++++++++ internal/process/runner_test.go | 219 +++++++++++ internal/process/signal_unix_test.go | 147 +++++++ internal/proxy/certs_error_test.go | 242 ++++++++++++ internal/proxy/certs_lifecycle_test.go | 45 +++ internal/proxy/certs_test.go | 266 +++++++++++++ internal/proxy/certs_trust_test.go | 29 ++ internal/proxy/handlers_test.go | 330 ++++++++++++++++ internal/proxy/headers_test.go | 157 ++++++++ internal/proxy/integration_http_test.go | 129 +++++++ internal/proxy/integration_https_mitm_test.go | 301 +++++++++++++++ internal/proxy/preview_test.go | 142 +++++++ internal/proxy/requests_test.go | 275 +++++++++++++ internal/proxy/secure_utils_test.go | 138 +++++++ internal/proxy/server_test.go | 235 ++++++++++++ internal/proxy/test_helpers_test.go | 77 ++++ internal/proxy/utils_test.go | 256 +++++++++++++ internal/tui/model_update_test.go | 361 ++++++++++++++++++ internal/tui/view_split_panes_style_test.go | 245 ++++++++++++ 28 files changed, 4976 insertions(+), 29 deletions(-) create mode 100644 TEST_COVERAGE_SUMMARY.md create mode 100644 cmd/tap/main_test.go create mode 100644 internal/app/integration_session_test.go create mode 100644 internal/app/process_test.go create mode 100644 internal/app/proxy_test.go create mode 100644 internal/app/session_test.go create mode 100644 internal/app/test_helpers_test.go create mode 100644 internal/cli/run_test.go create mode 100644 internal/process/runner_test.go create mode 100644 internal/process/signal_unix_test.go create mode 100644 internal/proxy/certs_error_test.go create mode 100644 internal/proxy/certs_lifecycle_test.go create mode 100644 internal/proxy/certs_test.go create mode 100644 internal/proxy/certs_trust_test.go create mode 100644 internal/proxy/handlers_test.go create mode 100644 internal/proxy/headers_test.go create mode 100644 internal/proxy/integration_http_test.go create mode 100644 internal/proxy/integration_https_mitm_test.go create mode 100644 internal/proxy/preview_test.go create mode 100644 internal/proxy/requests_test.go create mode 100644 internal/proxy/secure_utils_test.go create mode 100644 internal/proxy/server_test.go create mode 100644 internal/proxy/test_helpers_test.go create mode 100644 internal/proxy/utils_test.go create mode 100644 internal/tui/model_update_test.go create mode 100644 internal/tui/view_split_panes_style_test.go diff --git a/TEST_COVERAGE_SUMMARY.md b/TEST_COVERAGE_SUMMARY.md new file mode 100644 index 0000000..5005455 --- /dev/null +++ b/TEST_COVERAGE_SUMMARY.md @@ -0,0 +1,45 @@ +# termtap Test Coverage Summary + +Generated from: + +```bash +go test -coverprofile=/tmp/termtap.cover ./... +go tool cover -func=/tmp/termtap.cover +``` + +## Package coverage snapshot + +| Package | Coverage | +|---|---:| +| `cmd/tap` | 100.0% | +| `internal/app` | 98.1% | +| `internal/cli` | 93.4% | +| `internal/process` | 95.8% | +| `internal/proxy` | 90.0% | +| `internal/tui` | 96.2% | +| `examples/echo` | 0.0% (example app; intentionally not covered) | +| `internal/model` | no test files (pure data structs) | + +Total statements in module: **77.9%**. + +## Notable lower-coverage targets (production code) + +- `internal/proxy/handlers.go:handleConnect` — 75.9% +- `internal/proxy/certs.go:writeFileAtomically` — 76.2% +- `internal/proxy/certs.go:load` — 90.5% +- `internal/proxy/certs.go:create` — 80.8% +- `internal/proxy/certs.go:IsTrustedBySystem` — 76.9% +- `internal/cli/run.go:runCert` — 87.5% + +## Interpretation + +- Core runtime paths (`internal/app`, `internal/process`, `internal/tui`) are high confidence. +- Proxy package has broad behavior coverage including HTTP and HTTPS MITM integration flow, and now clears 90% package coverage. +- CLI command routing and fatal/stdout/stderr seams are covered, including `Run` success/error branches. +- TUI pane rendering coverage now includes error/PID branch behavior. + +## Next optional improvements + +1. Add deeper CONNECT tunnel write/flush/read failure permutations inside `handleConnect` loop. +2. Add additional deterministic edge cases for `IsTrustedBySystem` non-unknown-authority verify failures. +3. Optionally add non-unix signal file coverage in CI matrix (currently unix-focused). diff --git a/cmd/tap/main_test.go b/cmd/tap/main_test.go new file mode 100644 index 0000000..29b7e75 --- /dev/null +++ b/cmd/tap/main_test.go @@ -0,0 +1,49 @@ +package main + +import ( + "bytes" + "io" + "os" + "strings" + "testing" + "time" +) + +func TestMain_SmokeInvokesCLIRun(t *testing.T) { + origArgs := os.Args + t.Cleanup(func() { os.Args = origArgs }) + os.Args = []string{"tap", "invalid"} + + origStderr := os.Stderr + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("stderr pipe error: %v", err) + } + t.Cleanup(func() { + os.Stderr = origStderr + _ = r.Close() + }) + os.Stderr = w + + outCh := make(chan string, 1) + go func() { + var buf bytes.Buffer + _, _ = io.Copy(&buf, r) + outCh <- buf.String() + }() + + main() + + _ = w.Close() + os.Stderr = origStderr + var got string + select { + case got = <-outCh: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for stderr capture") + } + + if !strings.Contains(got, "usage:") { + t.Fatalf("stderr missing usage output, got: %q", got) + } +} diff --git a/internal/app/integration_session_test.go b/internal/app/integration_session_test.go new file mode 100644 index 0000000..7be5baf --- /dev/null +++ b/internal/app/integration_session_test.go @@ -0,0 +1,109 @@ +package app + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "termtap.dev/internal/model" +) + +// NOTE: Run with -race; this validates cross-component concurrency. + +func TestSessionIntegration_LifecycleAndRequestEvents(t *testing.T) { + addr := freeTCPAddr(t) + s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr) + if err != nil { + t.Fatalf("StartSession() error = %v", err) + } + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + t.Cleanup(upstream.Close) + + startupEvents := collectUntilTypes(t, s.Events, []model.EventType{ + model.EventTypeProxyStarting, + model.EventTypeProcessStarting, + model.EventTypeProcessStarted, + }, 3*time.Second) + + proxyURL, err := url.Parse(s.proxy.Url) + if err != nil { + t.Fatalf("url.Parse(proxy) error = %v", err) + } + client := &http.Client{ + Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)}, + Timeout: 3 * time.Second, + } + + resp, err := client.Get(upstream.URL + "/session") + if err != nil { + t.Fatalf("proxy request error = %v", err) + } + _ = resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + + requestEvents := collectUntilTypes(t, s.Events, []model.EventType{ + model.EventTypeRequestStarted, + model.EventTypeRequestFinished, + }, 3*time.Second) + + s.Stop() + select { + case <-s.proc.Done: + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for process stop") + } + + shutdownEvents := collectUntilTypes(t, s.Events, []model.EventType{ + model.EventTypeProxyStopped, + model.EventTypeProcessSignaled, + model.EventTypeProcessExited, + }, 4*time.Second) + + if !isBefore(startupEvents, model.EventTypeProcessStarting, model.EventTypeProcessStarted) { + t.Fatalf("expected %s before %s in startup events: %#v", model.EventTypeProcessStarting, model.EventTypeProcessStarted, startupEvents) + } + if !isBefore(requestEvents, model.EventTypeRequestStarted, model.EventTypeRequestFinished) { + t.Fatalf("expected %s before %s in request events: %#v", model.EventTypeRequestStarted, model.EventTypeRequestFinished, requestEvents) + } + if !isBefore(shutdownEvents, model.EventTypeProcessSignaled, model.EventTypeProcessExited) { + t.Fatalf("expected %s before %s in shutdown events: %#v", model.EventTypeProcessSignaled, model.EventTypeProcessExited, shutdownEvents) + } +} + +func TestSessionIntegration_RestartProcessEmitsLifecycleEvents(t *testing.T) { + addr := freeTCPAddr(t) + s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr) + if err != nil { + t.Fatalf("StartSession() error = %v", err) + } + t.Cleanup(func() { s.Stop() }) + + _ = collectUntilTypes(t, s.Events, []model.EventType{ + model.EventTypeProcessStarted, + model.EventTypeProxyStarting, + }, 3*time.Second) + + if err := s.RestartProcess(); err != nil { + t.Fatalf("RestartProcess() error = %v", err) + } + + events := collectUntilTypes(t, s.Events, []model.EventType{ + model.EventTypeProcessRestarting, + model.EventTypeProcessSignaled, + model.EventTypeProcessExited, + model.EventTypeProcessStarting, + model.EventTypeProcessStarted, + }, 4*time.Second) + + if !isBefore(events, model.EventTypeProcessRestarting, model.EventTypeProcessStarted) { + t.Fatalf("expected restarting before process started, got %#v", events) + } +} diff --git a/internal/app/process.go b/internal/app/process.go index 3040293..c9ceee5 100644 --- a/internal/app/process.go +++ b/internal/app/process.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "os/exec" + "sync" "syscall" "time" @@ -11,6 +12,10 @@ import ( "termtap.dev/internal/process" ) +var killEscalationDelay = 1500 * time.Millisecond +var scheduleKillEscalation = time.AfterFunc +var killEscalationMu sync.RWMutex + func StartProcess(cmd model.Command, addr string, ch chan<- model.Event) (*model.Process, error) { ch <- model.Event{ Time: time.Now().Local(), @@ -44,12 +49,16 @@ func StopProcess(proc *model.Process, ch chan<- model.Event, sig syscall.Signal) _ = process.SignalProcess(proc.Exec, sig) - go func() { - time.Sleep(1500 * time.Millisecond) + killEscalationMu.RLock() + delay := killEscalationDelay + scheduler := scheduleKillEscalation + killEscalationMu.RUnlock() + + scheduler(delay, func() { if process.ProcessAlive(proc.Exec) { _ = process.SignalProcess(proc.Exec, syscall.SIGKILL) } - }() + }) } func waitForProcessExit(proc *model.Process, ch chan<- model.Event) { diff --git a/internal/app/process_test.go b/internal/app/process_test.go new file mode 100644 index 0000000..c58b646 --- /dev/null +++ b/internal/app/process_test.go @@ -0,0 +1,223 @@ +package app + +import ( + "os/exec" + "syscall" + "testing" + "time" + + "termtap.dev/internal/model" +) + +func TestStartProcess(t *testing.T) { + t.Parallel() + + t.Run("starts process and marks running", func(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 32) + proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 0.2"}}, "127.0.0.1:8080", ch) + if err != nil { + t.Fatalf("StartProcess() error = %v", err) + } + t.Cleanup(func() { + StopProcess(proc, ch, syscall.SIGTERM) + select { + case <-proc.Done: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting process done in cleanup") + } + }) + + if proc == nil || proc.Exec == nil { + t.Fatal("StartProcess() returned nil process/exec") + } + + events := drainEvents(t, ch, 2, time.Second) + if !hasType(events, model.EventTypeProcessStarting) { + t.Fatalf("missing %s event", model.EventTypeProcessStarting) + } + if !hasType(events, model.EventTypeProcessStarted) { + t.Fatalf("missing %s event", model.EventTypeProcessStarted) + } + }) + + t.Run("returns error when exec start fails", func(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 8) + proc, err := StartProcess(model.Command{Name: "definitely-not-a-real-command"}, "127.0.0.1:8080", ch) + if err == nil { + if proc != nil && proc.Exec != nil && proc.Exec.Process != nil { + _ = proc.Exec.Process.Kill() + } + t.Fatal("StartProcess() error = nil, want non-nil") + } + + events := drainEvents(t, ch, 1, time.Second) + if !hasType(events, model.EventTypeProcessStarting) { + t.Fatalf("missing %s event", model.EventTypeProcessStarting) + } + }) +} + +func TestStopProcess(t *testing.T) { + t.Parallel() + + t.Run("nil guards", func(t *testing.T) { + t.Parallel() + StopProcess(nil, make(chan model.Event, 1), syscall.SIGTERM) + StopProcess(&model.Process{}, make(chan model.Event, 1), syscall.SIGTERM) + StopProcess(&model.Process{Exec: &exec.Cmd{}}, make(chan model.Event, 1), syscall.SIGTERM) + }) + + t.Run("emits signaled event", func(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 32) + proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, "127.0.0.1:8080", ch) + if err != nil { + t.Fatalf("StartProcess() error = %v", err) + } + + StopProcess(proc, ch, syscall.SIGTERM) + if _, ok := waitForEventType(t, ch, model.EventTypeProcessSignaled, 2*time.Second); !ok { + t.Fatalf("did not receive %s event", model.EventTypeProcessSignaled) + } + + select { + case <-proc.Done: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for process to exit after signal") + } + }) + + t.Run("schedules deterministic kill escalation hook", func(t *testing.T) { + t.Parallel() + + origDelay := killEscalationDelay + origScheduler := scheduleKillEscalation + defer func() { + killEscalationMu.Lock() + killEscalationDelay = origDelay + scheduleKillEscalation = origScheduler + killEscalationMu.Unlock() + }() + + ch := make(chan model.Event, 32) + proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, "127.0.0.1:8080", ch) + if err != nil { + t.Fatalf("StartProcess() error = %v", err) + } + + killEscalationMu.Lock() + killEscalationDelay = 25 * time.Millisecond + scheduled := make(chan time.Duration, 1) + scheduleKillEscalation = func(d time.Duration, fn func()) *time.Timer { + scheduled <- d + go fn() + return nil + } + killEscalationMu.Unlock() + + StopProcess(proc, ch, syscall.SIGTERM) + + select { + case d := <-scheduled: + if d != killEscalationDelay { + t.Fatalf("scheduled delay = %v, want %v", d, killEscalationDelay) + } + case <-time.After(time.Second): + t.Fatal("kill escalation was not scheduled") + } + + select { + case <-proc.Done: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for process to exit after deterministic escalation") + } + }) +} + +func TestWaitForProcessExit(t *testing.T) { + t.Parallel() + + t.Run("nil guards are no-op", func(t *testing.T) { + t.Parallel() + waitForProcessExit(nil, make(chan model.Event, 1)) + waitForProcessExit(&model.Process{}, make(chan model.Event, 1)) + }) + + t.Run("normal exit emits process exited and closes done", func(t *testing.T) { + t.Parallel() + + cmd := exec.Command("sh", "-c", "exit 0") + if err := cmd.Start(); err != nil { + t.Fatalf("cmd.Start() error = %v", err) + } + + proc := &model.Process{Exec: cmd, Running: true, Done: make(chan struct{})} + ch := make(chan model.Event, 8) + + waitForProcessExit(proc, ch) + + if _, ok := waitForEventType(t, ch, model.EventTypeProcessExited, time.Second); !ok { + t.Fatalf("did not receive %s event", model.EventTypeProcessExited) + } + select { + case <-proc.Done: + case <-time.After(time.Second): + t.Fatal("Done channel was not closed") + } + }) + + t.Run("exit error path carries exit code", func(t *testing.T) { + t.Parallel() + + cmd := exec.Command("sh", "-c", "exit 7") + if err := cmd.Start(); err != nil { + t.Fatalf("cmd.Start() error = %v", err) + } + + proc := &model.Process{Exec: cmd, Running: true, Done: make(chan struct{})} + ch := make(chan model.Event, 8) + + waitForProcessExit(proc, ch) + + events := drainEvents(t, ch, 2, time.Second) + found := false + for _, ev := range events { + if ev.Type == model.EventTypeProcessExited && ev.ExitCode == 7 { + found = true + break + } + } + if !found { + t.Fatalf("expected ProcessExited with exit code 7, got %#v", events) + } + select { + case <-proc.Done: + case <-time.After(time.Second): + t.Fatal("Done channel was not closed") + } + }) + + t.Run("unexpected wait failure emits fatal", func(t *testing.T) { + t.Parallel() + + proc := &model.Process{Exec: &exec.Cmd{}, Done: make(chan struct{})} + ch := make(chan model.Event, 8) + + waitForProcessExit(proc, ch) + + events := drainEvents(t, ch, 1, time.Second) + if events[0].Type != model.EventTypeFatal { + t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeFatal) + } + select { + case <-proc.Done: + case <-time.After(time.Second): + t.Fatal("Done channel was not closed") + } + }) +} diff --git a/internal/app/proxy_test.go b/internal/app/proxy_test.go new file mode 100644 index 0000000..78e6b59 --- /dev/null +++ b/internal/app/proxy_test.go @@ -0,0 +1,183 @@ +package app + +import ( + "context" + "errors" + "net" + "net/http" + "testing" + "time" + + "termtap.dev/internal/model" +) + +type staticErrListener struct { + err error +} + +func (l *staticErrListener) Accept() (net.Conn, error) { return nil, l.err } +func (l *staticErrListener) Close() error { return nil } +func (l *staticErrListener) Addr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080} +} + +func TestStartProxy_NilGuards(t *testing.T) { + t.Parallel() + + StartProxy(nil, make(chan model.Event, 1)) + StartProxy(&model.ProxyServer{}, make(chan model.Event, 1)) + StartProxy(&model.ProxyServer{Server: &http.Server{}}, make(chan model.Event, 1)) +} + +func TestStartProxy_EmitsStartingAndWarnWhenUntrustedCA(t *testing.T) { + t.Parallel() + + listenErr := errors.New("accept failed") + var ln net.Listener = &staticErrListener{err: listenErr} + + ch := make(chan model.Event, 8) + ps := &model.ProxyServer{ + Listener: &ln, + Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})}, + CAReady: true, + CATrusted: false, + CACreated: true, + CACertPath: "/tmp/test-ca.pem", + } + + StartProxy(ps, ch) + + events := drainEvents(t, ch, 3, time.Second) + if !hasType(events, model.EventTypeProxyStarting) { + t.Fatalf("missing %s event", model.EventTypeProxyStarting) + } + if !hasType(events, model.EventTypeWarn) { + t.Fatalf("missing %s event", model.EventTypeWarn) + } + if !containsBody(events, "generated HTTPS interception CA") { + t.Fatalf("expected generated-CA warning body, got events: %#v", events) + } + if !hasType(events, model.EventTypeFatal) { + t.Fatalf("missing %s event", model.EventTypeFatal) + } +} + +func TestStartProxy_WarnBodyForExistingUntrustedCA(t *testing.T) { + t.Parallel() + + listenErr := errors.New("accept failed") + var ln net.Listener = &staticErrListener{err: listenErr} + + ch := make(chan model.Event, 8) + ps := &model.ProxyServer{ + Listener: &ln, + Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})}, + CAReady: true, + CATrusted: false, + CACreated: false, + CACertPath: "/tmp/test-ca.pem", + } + + StartProxy(ps, ch) + + events := drainEvents(t, ch, 3, time.Second) + if !containsBody(events, "HTTPS interception CA available at") { + t.Fatalf("expected existing-CA warning body, got events: %#v", events) + } +} + +func TestStartProxy_NoCAWarningWhenNotReady(t *testing.T) { + t.Parallel() + + listenErr := errors.New("accept failed") + var ln net.Listener = &staticErrListener{err: listenErr} + + ch := make(chan model.Event, 8) + ps := &model.ProxyServer{ + Listener: &ln, + Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})}, + CAReady: false, + CATrusted: false, + } + + StartProxy(ps, ch) + + events := drainEvents(t, ch, 2, time.Second) + if !hasType(events, model.EventTypeProxyStarting) { + t.Fatalf("missing %s event", model.EventTypeProxyStarting) + } + if hasType(events, model.EventTypeWarn) { + t.Fatalf("unexpected %s event when CA is not ready: %#v", model.EventTypeWarn, events) + } + if !hasType(events, model.EventTypeFatal) { + t.Fatalf("missing %s event", model.EventTypeFatal) + } +} + +func TestStartProxy_NoCAWarningWhenTrusted(t *testing.T) { + t.Parallel() + + listenErr := errors.New("accept failed") + var ln net.Listener = &staticErrListener{err: listenErr} + + ch := make(chan model.Event, 8) + ps := &model.ProxyServer{ + Listener: &ln, + Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})}, + CAReady: true, + CATrusted: true, + } + + StartProxy(ps, ch) + + events := drainEvents(t, ch, 2, time.Second) + if !hasType(events, model.EventTypeProxyStarting) { + t.Fatalf("missing %s event", model.EventTypeProxyStarting) + } + if hasType(events, model.EventTypeWarn) { + t.Fatalf("unexpected %s event when CA is already trusted: %#v", model.EventTypeWarn, events) + } + if !hasType(events, model.EventTypeFatal) { + t.Fatalf("missing %s event", model.EventTypeFatal) + } +} + +func TestStartProxy_SwallowsErrServerClosed(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen error = %v", err) + } + t.Cleanup(func() { _ = ln.Close() }) + + ch := make(chan model.Event, 8) + server := &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})} + ps := &model.ProxyServer{Listener: &ln, Server: server} + + done := make(chan struct{}) + go func() { + StartProxy(ps, ch) + close(done) + }() + + shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := server.Shutdown(shutdownCtx); err != nil { + t.Fatalf("shutdown error = %v", err) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("timeout waiting for StartProxy to return") + } + + events := drainEvents(t, ch, 1, time.Second) + if !hasType(events, model.EventTypeProxyStarting) { + t.Fatalf("missing %s event", model.EventTypeProxyStarting) + } + if hasType(events, model.EventTypeFatal) { + t.Fatalf("unexpected %s event", model.EventTypeFatal) + } +} diff --git a/internal/app/session_test.go b/internal/app/session_test.go new file mode 100644 index 0000000..1469ec8 --- /dev/null +++ b/internal/app/session_test.go @@ -0,0 +1,247 @@ +package app + +import ( + "errors" + "net" + "syscall" + "testing" + "time" + + "termtap.dev/internal/model" +) + +func TestStartSession(t *testing.T) { + t.Run("happy path creates proxy and process", func(t *testing.T) { + addr := freeTCPAddr(t) + s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr) + if err != nil { + t.Fatalf("StartSession() error = %v", err) + } + if s == nil { + t.Fatal("StartSession() returned nil session") + } + if s.proxy == nil { + t.Fatal("session.proxy is nil") + } + if s.proc == nil { + t.Fatal("session.proc is nil") + } + + s.Stop() + if s.proc != nil && s.proc.Done != nil { + select { + case <-s.proc.Done: + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for process stop") + } + } + }) + + t.Run("error when proxy creation fails", func(t *testing.T) { + t.Setenv("XDG_CONFIG_HOME", "") + t.Setenv("HOME", "") + + s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "true"}}, "127.0.0.1:0") + if err == nil { + if s != nil { + s.Stop() + } + t.Fatal("StartSession() error = nil, want non-nil") + } + if s != nil { + t.Fatalf("session = %#v, want nil", s) + } + }) + + t.Run("destroys proxy when process startup fails", func(t *testing.T) { + addr := freeTCPAddr(t) + + s, err := StartSession(model.Command{Name: "definitely-not-a-real-command"}, addr) + if err == nil { + if s != nil { + s.Stop() + } + t.Fatal("StartSession() error = nil, want non-nil") + } + + deadline := time.After(3 * time.Second) + ticker := time.NewTicker(25 * time.Millisecond) + defer ticker.Stop() + for { + ln, listenErr := net.Listen("tcp", addr) + if listenErr == nil { + _ = ln.Close() + break + } + select { + case <-deadline: + t.Fatalf("address %s did not become reusable, got err: %v", addr, listenErr) + case <-ticker.C: + } + } + }) + + t.Run("stop during restart stops new process and returns ErrSessionStopped", func(t *testing.T) { + s := &Session{ + ch: make(chan model.Event, 64), + cmd: model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, + addr: "127.0.0.1:0", + proc: &model.Process{Done: make(chan struct{})}, + } + close(s.proc.Done) + + s.restartMu.Lock() + errCh := make(chan error, 1) + stopDone := make(chan struct{}) + + go func() { + errCh <- s.RestartProcess() + }() + go func() { + s.Stop() + close(stopDone) + }() + + s.restartMu.Unlock() + + select { + case err := <-errCh: + if !errors.Is(err, ErrSessionStopped) { + t.Fatalf("error = %v, want %v", err, ErrSessionStopped) + } + case <-time.After(6 * time.Second): + t.Fatal("timeout waiting for RestartProcess result") + } + + select { + case <-stopDone: + case <-time.After(time.Second): + t.Fatal("timeout waiting for Stop completion") + } + }) +} + +func TestSessionStop_Idempotent(t *testing.T) { + addr := freeTCPAddr(t) + s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr) + if err != nil { + t.Fatalf("StartSession() error = %v", err) + } + + s.Stop() + s.Stop() + + if s.proc != nil && s.proc.Done != nil { + select { + case <-s.proc.Done: + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for process to stop") + } + } +} + +func TestRestartProcess(t *testing.T) { + t.Run("nil session returns error", func(t *testing.T) { + var s *Session + err := s.RestartProcess() + if err == nil { + t.Fatal("RestartProcess() error = nil, want non-nil") + } + }) + + t.Run("stopped session returns ErrSessionStopped", func(t *testing.T) { + s := &Session{stopped: true} + err := s.RestartProcess() + if !errors.Is(err, ErrSessionStopped) { + t.Fatalf("error = %v, want %v", err, ErrSessionStopped) + } + }) + + t.Run("concurrent restart returns ErrRestartInProgress", func(t *testing.T) { + s := &Session{restarting: true} + err := s.RestartProcess() + if !errors.Is(err, ErrRestartInProgress) { + t.Fatalf("error = %v, want %v", err, ErrRestartInProgress) + } + }) + + t.Run("timeout waiting for stop returns error", func(t *testing.T) { + s := &Session{ + ch: make(chan model.Event, 8), + cmd: model.Command{Name: "sh", Args: []string{"-c", "true"}}, + addr: "127.0.0.1:0", + proc: &model.Process{Done: make(chan struct{})}, + } + + err := s.RestartProcess() + if err == nil { + t.Fatal("RestartProcess() error = nil, want non-nil") + } + }) + + t.Run("successful restart swaps process", func(t *testing.T) { + addr := freeTCPAddr(t) + oldProc := &model.Process{Done: make(chan struct{})} + close(oldProc.Done) + + s := &Session{ + ch: make(chan model.Event, 32), + cmd: model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, + addr: addr, + proc: oldProc, + } + + err := s.RestartProcess() + if err != nil { + t.Fatalf("RestartProcess() error = %v", err) + } + if _, ok := waitForEventType(t, s.ch, model.EventTypeProcessRestarting, time.Second); !ok { + t.Fatalf("did not receive %s event", model.EventTypeProcessRestarting) + } + if s.proc == nil { + t.Fatal("session process is nil after restart") + } + if s.proc == oldProc { + t.Fatal("session process pointer was not swapped") + } + + StopProcess(s.proc, s.ch, syscall.SIGTERM) + select { + case <-s.proc.Done: + case <-time.After(4 * time.Second): + t.Fatal("timeout waiting for restarted process to stop") + } + }) +} + +func TestWaitForProcessStop(t *testing.T) { + t.Parallel() + + t.Run("nil process and nil done are true", func(t *testing.T) { + t.Parallel() + if !waitForProcessStop(nil, time.Millisecond) { + t.Fatal("waitForProcessStop(nil) = false, want true") + } + if !waitForProcessStop(&model.Process{}, time.Millisecond) { + t.Fatal("waitForProcessStop(process with nil done) = false, want true") + } + }) + + t.Run("done channel closes before timeout", func(t *testing.T) { + t.Parallel() + proc := &model.Process{Done: make(chan struct{})} + close(proc.Done) + + if !waitForProcessStop(proc, time.Second) { + t.Fatal("waitForProcessStop() = false, want true") + } + }) + + t.Run("timeout returns false", func(t *testing.T) { + t.Parallel() + proc := &model.Process{Done: make(chan struct{})} + if waitForProcessStop(proc, 20*time.Millisecond) { + t.Fatal("waitForProcessStop() = true, want false") + } + }) +} diff --git a/internal/app/test_helpers_test.go b/internal/app/test_helpers_test.go new file mode 100644 index 0000000..5008901 --- /dev/null +++ b/internal/app/test_helpers_test.go @@ -0,0 +1,111 @@ +package app + +import ( + "net" + "strings" + "testing" + "time" + + "termtap.dev/internal/model" +) + +func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event { + t.Helper() + + out := make([]model.Event, 0, n) + deadline := time.After(timeout) + for len(out) < n { + select { + case ev := <-ch: + out = append(out, ev) + case <-deadline: + t.Fatalf("timeout waiting for %d events, got %d", n, len(out)) + } + } + + return out +} + +func hasType(events []model.Event, typ model.EventType) bool { + for _, ev := range events { + if ev.Type == typ { + return true + } + } + return false +} + +func containsBody(events []model.Event, part string) bool { + for _, ev := range events { + if strings.Contains(ev.Body, part) { + return true + } + } + return false +} + +func waitForEventType(t *testing.T, ch <-chan model.Event, typ model.EventType, timeout time.Duration) (model.Event, bool) { + t.Helper() + + deadline := time.After(timeout) + for { + select { + case ev := <-ch: + if ev.Type == typ { + return ev, true + } + case <-deadline: + return model.Event{}, false + } + } +} + +func collectUntilTypes(t *testing.T, ch <-chan model.Event, required []model.EventType, timeout time.Duration) []model.Event { + t.Helper() + + need := make(map[model.EventType]bool, len(required)) + for _, typ := range required { + need[typ] = true + } + + events := make([]model.Event, 0, len(required)+8) + deadline := time.After(timeout) + for len(need) > 0 { + select { + case ev := <-ch: + events = append(events, ev) + delete(need, ev.Type) + case <-deadline: + t.Fatalf("timeout waiting for required events: remaining=%v, collected=%#v", need, events) + } + } + + return events +} + +func isBefore(events []model.Event, first, second model.EventType) bool { + firstIdx := -1 + secondIdx := -1 + for i, ev := range events { + if ev.Type == first && firstIdx == -1 { + firstIdx = i + } + if ev.Type == second && secondIdx == -1 { + secondIdx = i + } + } + + return firstIdx >= 0 && secondIdx >= 0 && firstIdx < secondIdx +} + +func freeTCPAddr(t *testing.T) string { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen error = %v", err) + } + addr := ln.Addr().String() + _ = ln.Close() + return addr +} diff --git a/internal/cli/run.go b/internal/cli/run.go index 0f6fcbd..8b94294 100644 --- a/internal/cli/run.go +++ b/internal/cli/run.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "io" "log" "os" "runtime" @@ -16,6 +17,23 @@ import ( // This should be configurable at some point, just in case they build on 8080 const proxy_addr = "127.0.0.1:8080" +var fatalExit = log.Fatalln +var stdoutWriter io.Writer = stdioRef{isErr: false} +var stderrWriter io.Writer = stdioRef{isErr: true} +var startSessionFn = app.StartSession +var runTUIFn = tui.Run + +type stdioRef struct { + isErr bool +} + +func (w stdioRef) Write(p []byte) (int, error) { + if w.isErr { + return os.Stderr.Write(p) + } + return os.Stdout.Write(p) +} + func Run(args []string) { if len(args) >= 2 && args[1] == "cert" { runCert() @@ -28,9 +46,10 @@ func Run(args []string) { return } - session, err := app.StartSession(cmd, proxy_addr) + session, err := startSessionFn(cmd, proxy_addr) if err != nil { - log.Fatalln(err) + fatalExit(err) + return } defer session.Stop() @@ -38,8 +57,9 @@ func Run(args []string) { Restart: session.RestartProcess, } - if err := tui.Run(session.Events, controls); err != nil { - log.Fatalln(err) + if err := runTUIFn(session.Events, controls); err != nil { + fatalExit(err) + return } } @@ -67,49 +87,50 @@ usage: tap run -- [args...] ` - fmt.Fprintln(os.Stderr, helpText) + fmt.Fprintln(stderrWriter, helpText) } func runCert() { ca, err := proxy.EnsureCertificateAuthority() if err != nil { - log.Fatalln(err) + fatalExit(err) + return } certPath := ca.CertPath() quotedCertPath := strconv.Quote(certPath) - fmt.Printf("Certificate path: %s\n", certPath) + fmt.Fprintf(stdoutWriter, "Certificate path: %s\n", certPath) if ca.WasCreated() { - fmt.Println("Created a new local HTTPS interception CA.") + fmt.Fprintln(stdoutWriter, "Created a new local HTTPS interception CA.") } else { - fmt.Println("Using existing local HTTPS interception CA.") + fmt.Fprintln(stdoutWriter, "Using existing local HTTPS interception CA.") } trusted, err := ca.IsTrustedBySystem() if err != nil { - fmt.Printf("System trust check failed: %v\n", err) + fmt.Fprintf(stdoutWriter, "System trust check failed: %v\n", err) } else if trusted { - fmt.Println("System trust store: trusted") + fmt.Fprintln(stdoutWriter, "System trust store: trusted") } else { - fmt.Println("System trust store: not trusted") + fmt.Fprintln(stdoutWriter, "System trust store: not trusted") } if runtime.GOOS != "linux" { - fmt.Println("Install this certificate into your OS or client trust store to inspect HTTPS traffic.") + fmt.Fprintln(stdoutWriter, "Install this certificate into your OS or client trust store to inspect HTTPS traffic.") return } - fmt.Println() - fmt.Println("Trust instructions (Linux):") - fmt.Println("Debian/Ubuntu:") - fmt.Printf(" sudo cp %s /usr/local/share/ca-certificates/termtap.crt\n", quotedCertPath) - fmt.Println(" sudo update-ca-certificates") - fmt.Println("Fedora/RHEL/CentOS:") - fmt.Printf(" sudo cp %s /etc/pki/ca-trust/source/anchors/termtap.crt\n", quotedCertPath) - fmt.Println(" sudo update-ca-trust") - fmt.Println("Arch:") - fmt.Printf(" sudo trust anchor %s\n", quotedCertPath) - fmt.Println() - fmt.Println("Quick curl test:") - fmt.Printf(" curl --proxy http://%s --cacert %s https://example.com\n", proxy_addr, quotedCertPath) + fmt.Fprintln(stdoutWriter) + fmt.Fprintln(stdoutWriter, "Trust instructions (Linux):") + fmt.Fprintln(stdoutWriter, "Debian/Ubuntu:") + fmt.Fprintf(stdoutWriter, " sudo cp %s /usr/local/share/ca-certificates/termtap.crt\n", quotedCertPath) + fmt.Fprintln(stdoutWriter, " sudo update-ca-certificates") + fmt.Fprintln(stdoutWriter, "Fedora/RHEL/CentOS:") + fmt.Fprintf(stdoutWriter, " sudo cp %s /etc/pki/ca-trust/source/anchors/termtap.crt\n", quotedCertPath) + fmt.Fprintln(stdoutWriter, " sudo update-ca-trust") + fmt.Fprintln(stdoutWriter, "Arch:") + fmt.Fprintf(stdoutWriter, " sudo trust anchor %s\n", quotedCertPath) + fmt.Fprintln(stdoutWriter) + fmt.Fprintln(stdoutWriter, "Quick curl test:") + fmt.Fprintf(stdoutWriter, " curl --proxy http://%s --cacert %s https://example.com\n", proxy_addr, quotedCertPath) } diff --git a/internal/cli/run_test.go b/internal/cli/run_test.go new file mode 100644 index 0000000..7e49745 --- /dev/null +++ b/internal/cli/run_test.go @@ -0,0 +1,356 @@ +package cli + +import ( + "bytes" + "errors" + "io" + "os" + "runtime" + "strings" + "testing" + "time" + + "termtap.dev/internal/app" + "termtap.dev/internal/model" + "termtap.dev/internal/tui" +) + +func TestParseCommand(t *testing.T) { + tests := []struct { + name string + args []string + ok bool + nameWant string + argsWant []string + }{ + {name: "too few args", args: []string{"tap"}, ok: false}, + {name: "missing run token", args: []string{"tap", "oops", "--", "echo"}, ok: false}, + {name: "missing separator", args: []string{"tap", "run", "echo"}, ok: false}, + {name: "single command", args: []string{"tap", "run", "--", "echo"}, ok: true, nameWant: "echo", argsWant: []string{}}, + {name: "command with args", args: []string{"tap", "run", "--", "curl", "-s", "https://example.com"}, ok: true, nameWant: "curl", argsWant: []string{"-s", "https://example.com"}}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + cmd, ok := parseCommand(tt.args) + if ok != tt.ok { + t.Fatalf("ok = %v, want %v", ok, tt.ok) + } + if !tt.ok { + return + } + if cmd.Name != tt.nameWant { + t.Fatalf("cmd.Name = %q, want %q", cmd.Name, tt.nameWant) + } + if strings.Join(cmd.Args, "|") != strings.Join(tt.argsWant, "|") { + t.Fatalf("cmd.Args = %#v, want %#v", cmd.Args, tt.argsWant) + } + }) + } +} + +func TestDisplayHelpWritesToStderr(t *testing.T) { + _, stderr := captureOutput(t, func() { + displayHelp() + }) + + if !strings.Contains(stderr, "tap cert") || !strings.Contains(stderr, "tap run --") { + t.Fatalf("stderr missing usage text: %q", stderr) + } +} + +func TestRun_InvalidCommandShowsHelp(t *testing.T) { + _, stderr := captureOutput(t, func() { + Run([]string{"tap", "wat"}) + }) + + if !strings.Contains(stderr, "usage:") { + t.Fatalf("stderr missing usage output: %q", stderr) + } +} + +func TestRun_RoutesCertCommand(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", configRoot) + + stdout, _ := captureOutput(t, func() { + Run([]string{"tap", "cert"}) + }) + + if !strings.Contains(stdout, "Certificate path:") { + t.Fatalf("stdout missing certificate path output: %q", stdout) + } +} + +func TestRun_StartSessionFailureCallsFatalExit(t *testing.T) { + restore := installRunSeams(t) + defer restore() + + startSessionFn = func(model.Command, string) (*app.Session, error) { + return nil, errors.New("boom") + } + called := installFatalSpy(t) + + Run([]string{"tap", "run", "--", "definitely-not-a-real-command"}) + if !*called { + t.Fatal("expected fatalExit to be called on StartSession failure") + } +} + +func TestRun_TUIFailureCallsFatalExit(t *testing.T) { + restore := installRunSeams(t) + defer restore() + + startSessionFn = func(model.Command, string) (*app.Session, error) { + return &app.Session{Events: make(chan model.Event)}, nil + } + runTUIFn = func(<-chan model.Event, tui.Controls) error { + return errors.New("tui failed") + } + called := installFatalSpy(t) + + Run([]string{"tap", "run", "--", "echo"}) + if !*called { + t.Fatal("expected fatalExit to be called on tui failure") + } +} + +func TestRun_SuccessPathDoesNotCallFatal(t *testing.T) { + restore := installRunSeams(t) + defer restore() + + startSessionFn = func(model.Command, string) (*app.Session, error) { + return &app.Session{Events: make(chan model.Event)}, nil + } + runTUIFn = func(<-chan model.Event, tui.Controls) error { + return nil + } + called := installFatalSpy(t) + + Run([]string{"tap", "run", "--", "echo"}) + if *called { + t.Fatal("fatalExit should not be called on success path") + } +} + +func TestRunCert_EnsureCAFailureCallsFatalExit(t *testing.T) { + t.Setenv("XDG_CONFIG_HOME", "") + t.Setenv("HOME", "") + called := installFatalSpy(t) + + runCert() + if !*called { + t.Fatal("expected fatalExit to be called when EnsureCertificateAuthority fails") + } +} + +func TestRunCertOutputContract(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", configRoot) + + stdout, _ := captureOutput(t, func() { + runCert() + }) + + if !strings.Contains(stdout, "Certificate path:") { + t.Fatalf("stdout missing certificate path line: %q", stdout) + } + if !strings.Contains(stdout, "local HTTPS interception CA") { + t.Fatalf("stdout missing CA create/existing line: %q", stdout) + } + if !strings.Contains(stdout, "System trust store:") && !strings.Contains(stdout, "System trust check failed:") { + t.Fatalf("stdout missing trust check line: %q", stdout) + } + + if runtime.GOOS == "linux" { + if !strings.Contains(stdout, "Trust instructions (Linux):") { + t.Fatalf("stdout missing linux trust instructions: %q", stdout) + } + } +} + +func TestRunCert_CreatedThenExistingMessage(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", configRoot) + + firstOut, _ := captureOutput(t, func() { + runCert() + }) + if !strings.Contains(firstOut, "Created a new local HTTPS interception CA.") { + t.Fatalf("first run should indicate created CA, got: %q", firstOut) + } + + secondOut, _ := captureOutput(t, func() { + runCert() + }) + if !strings.Contains(secondOut, "Using existing local HTTPS interception CA.") { + t.Fatalf("second run should indicate existing CA, got: %q", secondOut) + } +} + +func captureOutput(t *testing.T, fn func()) (stdout string, stderr string) { + t.Helper() + + origStdoutWriter := stdoutWriter + origStderrWriter := stderrWriter + t.Cleanup(func() { + stdoutWriter = origStdoutWriter + stderrWriter = origStderrWriter + }) + + origStdout := os.Stdout + origStderr := os.Stderr + + outR, outW, err := os.Pipe() + if err != nil { + t.Fatalf("stdout pipe error: %v", err) + } + errR, errW, err := os.Pipe() + if err != nil { + _ = outR.Close() + _ = outW.Close() + t.Fatalf("stderr pipe error: %v", err) + } + + os.Stdout = outW + os.Stderr = errW + stdoutWriter = outW + stderrWriter = errW + + outCh := make(chan string, 1) + errCh := make(chan string, 1) + go func() { + var buf bytes.Buffer + _, _ = io.Copy(&buf, outR) + outCh <- buf.String() + }() + go func() { + var buf bytes.Buffer + _, _ = io.Copy(&buf, errR) + errCh <- buf.String() + }() + + fn() + + _ = outW.Close() + _ = errW.Close() + stdoutWriter = origStdoutWriter + stderrWriter = origStderrWriter + os.Stdout = origStdout + os.Stderr = origStderr + + select { + case stdout = <-outCh: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for stdout capture") + } + select { + case stderr = <-errCh: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for stderr capture") + } + _ = outR.Close() + _ = errR.Close() + + return stdout, stderr +} + +func TestDisplayHelp_UsesInjectedStderrWriter(t *testing.T) { + var buf bytes.Buffer + orig := stderrWriter + t.Cleanup(func() { stderrWriter = orig }) + stderrWriter = &buf + + displayHelp() + + if got := buf.String(); !strings.Contains(got, "usage:") { + t.Fatalf("help output missing usage, got: %q", got) + } +} + +func TestRunCert_UsesInjectedStdoutWriter(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", configRoot) + + var buf bytes.Buffer + orig := stdoutWriter + t.Cleanup(func() { stdoutWriter = orig }) + stdoutWriter = &buf + + runCert() + + if got := buf.String(); !strings.Contains(got, "Certificate path:") { + t.Fatalf("cert output missing path line, got: %q", got) + } +} + +func installRunSeams(t *testing.T) func() { + t.Helper() + + origStartSession := startSessionFn + origRunTUI := runTUIFn + return func() { + startSessionFn = origStartSession + runTUIFn = origRunTUI + } +} + +func installFatalSpy(t *testing.T) *bool { + t.Helper() + + origFatal := fatalExit + called := false + fatalExit = func(v ...any) { + called = true + } + t.Cleanup(func() { + fatalExit = origFatal + }) + + return &called +} + +func TestStdioRefWrite(t *testing.T) { + t.Run("writes to stdout", func(t *testing.T) { + assertStdioRefWrite(t, false, "hello") + }) + + t.Run("writes to stderr", func(t *testing.T) { + assertStdioRefWrite(t, true, "boom") + }) +} + +func assertStdioRefWrite(t *testing.T, isErr bool, payload string) { + t.Helper() + + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("pipe error: %v", err) + } + defer func() { _ = r.Close() }() + + if isErr { + orig := os.Stderr + os.Stderr = w + defer func() { os.Stderr = orig }() + } else { + orig := os.Stdout + os.Stdout = w + defer func() { os.Stdout = orig }() + } + + if _, err := (stdioRef{isErr: isErr}).Write([]byte(payload)); err != nil { + _ = w.Close() + t.Fatalf("stdioRef write error: %v", err) + } + _ = w.Close() + + data, err := io.ReadAll(r) + if err != nil { + t.Fatalf("ReadAll(pipe) error: %v", err) + } + if got := string(data); got != payload { + t.Fatalf("pipe write = %q, want %q", got, payload) + } +} diff --git a/internal/process/runner_test.go b/internal/process/runner_test.go new file mode 100644 index 0000000..cd72165 --- /dev/null +++ b/internal/process/runner_test.go @@ -0,0 +1,219 @@ +package process + +import ( + "os" + "os/exec" + "reflect" + "strings" + "testing" + "time" + + "termtap.dev/internal/model" +) + +func TestCommandString(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cmd model.Command + want string + }{ + { + name: "empty args", + cmd: model.Command{Name: "go", Args: []string{}}, + want: "go ", + }, + { + name: "multiple args", + cmd: model.Command{Name: "curl", Args: []string{"-s", "https://example.com"}}, + want: "curl -s https://example.com", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := CommandString(tt.cmd); got != tt.want { + t.Fatalf("CommandString() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestInjectEnv(t *testing.T) { + t.Parallel() + + cmd := exec.Command("sh", "-c", "true") + injectEnv(cmd, "127.0.0.1:8080") + + mustContain := []string{ + "HTTP_PROXY=http://127.0.0.1:8080", + "http_proxy=http://127.0.0.1:8080", + "HTTPS_PROXY=http://127.0.0.1:8080", + "https_proxy=http://127.0.0.1:8080", + "NO_PROXY=", + "no_proxy=", + } + + for _, kv := range mustContain { + if !containsEnvEntry(cmd.Env, kv) { + t.Fatalf("injectEnv() missing env entry %q", kv) + } + } +} + +func TestReadPipe(t *testing.T) { + t.Parallel() + + t.Run("stdout lines emit stdout events", func(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 4) + input := strings.NewReader("line1\nline2\n") + + readPipe(input, model.EventTypeProcessStdout, ch) + + events := drainEvents(t, ch, 2, time.Second) + if events[0].Type != model.EventTypeProcessStdout || events[0].Body != "line1" { + t.Fatalf("event[0] = %#v, want stdout line1", events[0]) + } + if events[1].Type != model.EventTypeProcessStdout || events[1].Body != "line2" { + t.Fatalf("event[1] = %#v, want stdout line2", events[1]) + } + }) + + t.Run("stderr lines emit stderr events", func(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 4) + input := strings.NewReader("err1\nerr2\n") + + readPipe(input, model.EventTypeProcessStderr, ch) + + events := drainEvents(t, ch, 2, time.Second) + if events[0].Type != model.EventTypeProcessStderr || events[0].Body != "err1" { + t.Fatalf("event[0] = %#v, want stderr err1", events[0]) + } + if events[1].Type != model.EventTypeProcessStderr || events[1].Body != "err2" { + t.Fatalf("event[1] = %#v, want stderr err2", events[1]) + } + }) +} + +func TestUpdateStatus(t *testing.T) { + t.Parallel() + + t.Run("nil process is no-op", func(t *testing.T) { + t.Parallel() + UpdateStatus(nil, true, make(chan model.Event, 1)) + }) + + t.Run("state unchanged is no-op", func(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 1) + proc := &model.Process{Running: true} + UpdateStatus(proc, true, ch) + + select { + case ev := <-ch: + t.Fatalf("unexpected event for unchanged state: %#v", ev) + default: + } + }) + + t.Run("emits started and stopped events with pid", func(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 2) + proc := &model.Process{ + Exec: &exec.Cmd{Process: &os.Process{Pid: 4321}}, + } + + UpdateStatus(proc, true, ch) + events := drainEvents(t, ch, 1, time.Second) + started := events[0] + if started.Type != model.EventTypeProcessStarted { + t.Fatalf("started type = %s, want %s", started.Type, model.EventTypeProcessStarted) + } + if started.PID != 4321 { + t.Fatalf("started PID = %d, want %d", started.PID, 4321) + } + if !proc.Running { + t.Fatal("proc.Running = false, want true") + } + + UpdateStatus(proc, false, ch) + events = drainEvents(t, ch, 1, time.Second) + stopped := events[0] + if stopped.Type != model.EventTypeProcessExited { + t.Fatalf("stopped type = %s, want %s", stopped.Type, model.EventTypeProcessExited) + } + if stopped.PID != 4321 { + t.Fatalf("stopped PID = %d, want %d", stopped.PID, 4321) + } + if proc.Running { + t.Fatal("proc.Running = true, want false") + } + }) +} + +func TestNewProcess(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 8) + cmd := model.Command{Name: "sh", Args: []string{"-c", "printf test"}} + + proc := NewProcess(cmd, "127.0.0.1:8080", ch) + if proc == nil { + t.Fatal("NewProcess() returned nil") + } + if !reflect.DeepEqual(proc.Command, cmd) { + t.Fatalf("process command = %#v, want %#v", proc.Command, cmd) + } + if proc.Exec == nil { + t.Fatal("process Exec is nil") + } + if proc.Running { + t.Fatal("new process should not be running") + } + if proc.Done == nil { + t.Fatal("Done channel is nil") + } + + if got, want := proc.Exec.Args[0], "sh"; got != want { + t.Fatalf("Exec.Args[0] = %q, want %q", got, want) + } + if !containsEnvEntry(proc.Exec.Env, "HTTP_PROXY=http://127.0.0.1:8080") { + t.Fatal("process env missing injected HTTP_PROXY") + } +} + +func containsEnvEntry(env []string, want string) bool { + for _, entry := range env { + if entry == want { + return true + } + } + return false +} + +func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event { + t.Helper() + + events := make([]model.Event, 0, n) + deadline := time.After(timeout) + for len(events) < n { + select { + case ev := <-ch: + events = append(events, ev) + case <-deadline: + t.Fatalf("timeout waiting for %d events; got %d", n, len(events)) + } + } + + return events +} diff --git a/internal/process/signal_unix_test.go b/internal/process/signal_unix_test.go new file mode 100644 index 0000000..aebb236 --- /dev/null +++ b/internal/process/signal_unix_test.go @@ -0,0 +1,147 @@ +//go:build unix + +package process + +import ( + "errors" + "os" + "os/exec" + "syscall" + "testing" + "time" +) + +type customSignal struct{} + +func (customSignal) String() string { return "custom" } +func (customSignal) Signal() {} + +// NOTE: Run these tests with -race in CI for signal/process safety. + +func TestConfigureProcessForSignals(t *testing.T) { + t.Parallel() + + cmd := exec.Command("sh", "-c", "sleep 0.1") + configureProcessForSignals(cmd) + + if cmd.SysProcAttr == nil { + t.Fatal("SysProcAttr is nil") + } + if !cmd.SysProcAttr.Setpgid { + t.Fatal("Setpgid = false, want true") + } +} + +func TestSignalProcess_NilSafe(t *testing.T) { + t.Parallel() + + if err := SignalProcess(nil, syscall.SIGTERM); err != nil { + t.Fatalf("SignalProcess(nil) error = %v, want nil", err) + } + + cmd := &exec.Cmd{} + if err := SignalProcess(cmd, syscall.SIGTERM); err != nil { + t.Fatalf("SignalProcess(cmd without process) error = %v, want nil", err) + } + + cmd.Process = &os.Process{Pid: 0} + if err := SignalProcess(cmd, syscall.SIGTERM); err != nil { + t.Fatalf("SignalProcess(pid<=0) error = %v, want nil", err) + } +} + +func TestSignalProcess_ESRCHIsTreatedAsSuccess(t *testing.T) { + t.Parallel() + + cmd := &exec.Cmd{Process: &os.Process{Pid: 999999}} + if err := SignalProcess(cmd, syscall.SIGTERM); err != nil { + t.Fatalf("SignalProcess() error = %v, want nil when process group not found", err) + } +} + +func TestSignalProcess_UsesFallbackOnKillError(t *testing.T) { + t.Parallel() + + cmd := exec.Command("sh", "-c", "sleep 5") + configureProcessForSignals(cmd) + if err := cmd.Start(); err != nil { + t.Fatalf("start command error = %v", err) + } + t.Cleanup(func() { + _ = cmd.Process.Kill() + _, _ = cmd.Process.Wait() + }) + + // Invalid signal causes syscall.Kill to fail with EINVAL, then fallback to cmd.Process.Signal. + err := SignalProcess(cmd, syscall.Signal(9999)) + if err == nil { + t.Fatal("SignalProcess() error = nil, want non-nil for invalid signal") + } + if !(errors.Is(err, syscall.EINVAL) || errors.Is(err, os.ErrProcessDone)) { + // OS/process timing can vary; ensure we at least failed predictably. + t.Fatalf("SignalProcess() unexpected error: %v", err) + } +} + +func TestSignalProcess_NonSyscallSignalUsesProcessSignal(t *testing.T) { + t.Parallel() + + cmd := exec.Command("sh", "-c", "sleep 1") + configureProcessForSignals(cmd) + if err := cmd.Start(); err != nil { + t.Fatalf("start command error = %v", err) + } + t.Cleanup(func() { + _ = cmd.Process.Kill() + _, _ = cmd.Process.Wait() + }) + + err := SignalProcess(cmd, customSignal{}) + if err == nil { + t.Fatal("SignalProcess(custom signal) error = nil, want non-nil") + } + if errors.Is(err, os.ErrProcessDone) { + return + } + if msg := err.Error(); msg == "" { + t.Fatalf("unexpected empty error for custom signal: %v", err) + } +} + +func TestProcessAlive(t *testing.T) { + t.Parallel() + + if ProcessAlive(nil) { + t.Fatal("ProcessAlive(nil) = true, want false") + } + if ProcessAlive(&exec.Cmd{}) { + t.Fatal("ProcessAlive(cmd without process) = true, want false") + } + + cmd := exec.Command("sh", "-c", "sleep 0.2") + configureProcessForSignals(cmd) + if err := cmd.Start(); err != nil { + t.Fatalf("start command error = %v", err) + } + + if !ProcessAlive(cmd) { + _ = cmd.Process.Kill() + _, _ = cmd.Process.Wait() + t.Fatal("ProcessAlive(running) = false, want true") + } + + _ = cmd.Process.Kill() + _, _ = cmd.Process.Wait() + + deadline := time.After(time.Second) + for { + if !ProcessAlive(cmd) { + return + } + select { + case <-deadline: + t.Fatal("ProcessAlive(exited) stayed true") + default: + } + } +} diff --git a/internal/proxy/certs_error_test.go b/internal/proxy/certs_error_test.go new file mode 100644 index 0000000..cd1833c --- /dev/null +++ b/internal/proxy/certs_error_test.go @@ -0,0 +1,242 @@ +package proxy + +import ( + "crypto/tls" + "encoding/pem" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestLoadOrCreateCertificateAuthority_RecreatesWhenKeyMissing(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", configRoot) + + baseDir := filepath.Join(configRoot, caDirName) + if err := os.MkdirAll(baseDir, 0o700); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + + certPath := filepath.Join(baseDir, caCertName) + if err := os.WriteFile(certPath, []byte("stale-cert"), 0o600); err != nil { + t.Fatalf("WriteFile(cert) error = %v", err) + } + + ca, err := loadOrCreateCertificateAuthority() + if err != nil { + t.Fatalf("loadOrCreateCertificateAuthority() error = %v", err) + } + if !ca.WasCreated() { + t.Fatal("WasCreated = false, want true when key is missing") + } + if _, err := os.Stat(filepath.Join(baseDir, caKeyName)); err != nil { + t.Fatalf("expected key file to be created, stat error = %v", err) + } +} + +func TestLoadOrCreateCertificateAuthority_LoadErrorOnCorruptFiles(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", configRoot) + + baseDir := filepath.Join(configRoot, caDirName) + if err := os.MkdirAll(baseDir, 0o700); err != nil { + t.Fatalf("MkdirAll() error = %v", err) + } + + certPath := filepath.Join(baseDir, caCertName) + keyPath := filepath.Join(baseDir, caKeyName) + if err := os.WriteFile(certPath, []byte("not-a-pem"), 0o600); err != nil { + t.Fatalf("WriteFile(cert) error = %v", err) + } + if err := os.WriteFile(keyPath, []byte("not-a-pem"), 0o600); err != nil { + t.Fatalf("WriteFile(key) error = %v", err) + } + + _, err := loadOrCreateCertificateAuthority() + if err == nil { + t.Fatal("loadOrCreateCertificateAuthority() error = nil, want non-nil") + } +} + +func TestCertificateAuthorityLoad_ErrorPaths(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + certBytes []byte + keyBytes []byte + wantPart string + }{ + { + name: "invalid cert pem", + certBytes: []byte("bad-cert"), + keyBytes: []byte("bad-key"), + wantPart: "decode ca cert pem", + }, + { + name: "parse cert fails", + certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("bogus")}), + keyBytes: []byte("bad-key"), + wantPart: "parse ca cert", + }, + { + name: "invalid key pem", + certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}), + keyBytes: []byte("bad-key"), + wantPart: "decode ca key pem", + }, + { + name: "parse key fails", + certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}), + keyBytes: pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: []byte("bogus")}), + wantPart: "parse ca key", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + ca := &CertificateAuthority{ + certPath: filepath.Join(dir, caCertName), + keyPath: filepath.Join(dir, caKeyName), + leafCert: make(map[string]*tls.Certificate), + } + + if err := os.WriteFile(ca.certPath, tt.certBytes, 0o600); err != nil { + t.Fatalf("write cert file error = %v", err) + } + if err := os.WriteFile(ca.keyPath, tt.keyBytes, 0o600); err != nil { + t.Fatalf("write key file error = %v", err) + } + + err := ca.load() + if err == nil { + t.Fatal("load() error = nil, want non-nil") + } + if !strings.Contains(err.Error(), tt.wantPart) { + t.Fatalf("load() error = %q, want contains %q", err.Error(), tt.wantPart) + } + }) + } +} + +func TestCertificateAuthorityCreate_ErrorWhenWritePathInvalid(t *testing.T) { + t.Parallel() + + ca := &CertificateAuthority{ + certPath: filepath.Join("/nope", "missing", "ca-cert.pem"), + keyPath: filepath.Join("/nope", "missing", "ca-key.pem"), + leafCert: make(map[string]*tls.Certificate), + } + + err := ca.create() + if err == nil { + t.Fatal("create() error = nil, want non-nil") + } +} + +func TestCertificateAuthorityCreate_WriteErrorPaths(t *testing.T) { + t.Parallel() + + t.Run("write ca cert wraps error", func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + badCertPath := filepath.Join(dir, "cert-as-dir") + if err := os.MkdirAll(badCertPath, 0o700); err != nil { + t.Fatalf("MkdirAll(cert dir) error = %v", err) + } + + ca := &CertificateAuthority{ + certPath: badCertPath, + keyPath: filepath.Join(dir, "ca-key.pem"), + leafCert: make(map[string]*tls.Certificate), + } + + err := ca.create() + if err == nil { + t.Fatal("create() error = nil, want non-nil") + } + if !strings.Contains(err.Error(), "write ca cert") { + t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca cert") + } + }) + + t.Run("write ca key wraps error", func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + badKeyPath := filepath.Join(dir, "key-as-dir") + if err := os.MkdirAll(badKeyPath, 0o700); err != nil { + t.Fatalf("MkdirAll(key dir) error = %v", err) + } + + ca := &CertificateAuthority{ + certPath: filepath.Join(dir, "ca-cert.pem"), + keyPath: badKeyPath, + leafCert: make(map[string]*tls.Certificate), + } + + err := ca.create() + if err == nil { + t.Fatal("create() error = nil, want non-nil") + } + if !strings.Contains(err.Error(), "write ca key") { + t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca key") + } + }) +} + +func TestCertificateAuthorityLoad_ReadErrorPaths(t *testing.T) { + t.Parallel() + + t.Run("read cert failure", func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + ca := &CertificateAuthority{ + certPath: filepath.Join(dir, "missing-cert.pem"), + keyPath: filepath.Join(dir, "missing-key.pem"), + leafCert: make(map[string]*tls.Certificate), + } + + err := ca.load() + if err == nil { + t.Fatal("load() error = nil, want non-nil") + } + if !strings.Contains(err.Error(), "read ca cert") { + t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca cert") + } + }) + + t.Run("read key failure", func(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + goodCA := newTestCA(t) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: goodCA.cert.Raw}) + + certPath := filepath.Join(dir, "cert.pem") + if err := os.WriteFile(certPath, certPEM, 0o600); err != nil { + t.Fatalf("WriteFile(cert) error = %v", err) + } + + ca := &CertificateAuthority{ + certPath: certPath, + keyPath: filepath.Join(dir, "missing-key.pem"), + leafCert: make(map[string]*tls.Certificate), + } + + err := ca.load() + if err == nil { + t.Fatal("load() error = nil, want non-nil") + } + if !strings.Contains(err.Error(), "read ca key") { + t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca key") + } + }) +} diff --git a/internal/proxy/certs_lifecycle_test.go b/internal/proxy/certs_lifecycle_test.go new file mode 100644 index 0000000..c21d90e --- /dev/null +++ b/internal/proxy/certs_lifecycle_test.go @@ -0,0 +1,45 @@ +package proxy + +import ( + "os" + "testing" +) + +func TestLoadOrCreateCertificateAuthority_CreateThenLoad(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", configRoot) + + ca1, err := loadOrCreateCertificateAuthority() + if err != nil { + t.Fatalf("first loadOrCreateCertificateAuthority() error = %v", err) + } + if ca1 == nil { + t.Fatal("first CA is nil") + } + if !ca1.WasCreated() { + t.Fatal("first CA should report WasCreated=true") + } + if ca1.CertPath() == "" { + t.Fatal("first CA CertPath is empty") + } + if _, err := os.Stat(ca1.CertPath()); err != nil { + t.Fatalf("first CA cert file missing: %v", err) + } + + ca2, err := loadOrCreateCertificateAuthority() + if err != nil { + t.Fatalf("second loadOrCreateCertificateAuthority() error = %v", err) + } + if ca2 == nil { + t.Fatal("second CA is nil") + } + if ca2.WasCreated() { + t.Fatal("second CA should report WasCreated=false (loaded existing)") + } + if ca2.CertPath() != ca1.CertPath() { + t.Fatalf("cert path mismatch: first=%q second=%q", ca1.CertPath(), ca2.CertPath()) + } + if ca2.cert == nil || ca2.key == nil { + t.Fatal("loaded CA should include cert and key") + } +} diff --git a/internal/proxy/certs_test.go b/internal/proxy/certs_test.go new file mode 100644 index 0000000..02bb650 --- /dev/null +++ b/internal/proxy/certs_test.go @@ -0,0 +1,266 @@ +package proxy + +import ( + "errors" + "math/big" + "os" + "path/filepath" + "testing" +) + +func TestNormalizeCertHost(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in string + want string + }{ + {name: "host and port", in: "example.com:443", want: "example.com"}, + {name: "plain host", in: "example.com", want: "example.com"}, + {name: "whitespace trims", in: " example.com:8443 ", want: "example.com"}, + {name: "empty", in: " ", want: ""}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := normalizeCertHost(tt.in); got != tt.want { + t.Fatalf("normalizeCertHost(%q) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestRandSerialNumber(t *testing.T) { + t.Parallel() + + serial, err := randSerialNumber() + if err != nil { + t.Fatalf("randSerialNumber() error = %v", err) + } + if serial == nil { + t.Fatal("serial is nil") + } + if serial.Sign() < 0 { + t.Fatalf("serial must be non-negative, got %v", serial) + } + + limit := new(big.Int).Lsh(big.NewInt(1), 128) + if serial.Cmp(limit) >= 0 { + t.Fatalf("serial must be < 2^128, got %v", serial) + } +} + +func TestWriteFileAtomically(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "cert.pem") + + if err := writeFileAtomically(path, []byte("first"), 0o600); err != nil { + t.Fatalf("first writeFileAtomically() error = %v", err) + } + if err := writeFileAtomically(path, []byte("second"), 0o600); err != nil { + t.Fatalf("second writeFileAtomically() error = %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile() error = %v", err) + } + if got, want := string(data), "second"; got != want { + t.Fatalf("file contents = %q, want %q", got, want) + } + + info, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat() error = %v", err) + } + if got := info.Mode().Perm(); got != 0o600 { + t.Fatalf("file permissions = %#o, want %#o", got, 0o600) + } +} + +func TestCertificateAuthority_Basics(t *testing.T) { + t.Parallel() + + var nilCA *CertificateAuthority + if got := nilCA.CertPath(); got != "" { + t.Fatalf("nil CertPath() = %q, want empty", got) + } + if got := nilCA.WasCreated(); got { + t.Fatalf("nil WasCreated() = %v, want false", got) + } + + ca := newTestCA(t) + ca.certPath = "/tmp/test-ca.pem" + ca.wasCreated = true + if got, want := ca.CertPath(), "/tmp/test-ca.pem"; got != want { + t.Fatalf("CertPath() = %q, want %q", got, want) + } + if !ca.WasCreated() { + t.Fatal("WasCreated() = false, want true") + } +} + +func TestCertificateForHost(t *testing.T) { + t.Parallel() + + ca := newTestCA(t) + + t.Run("empty host returns error", func(t *testing.T) { + t.Parallel() + cert, err := ca.CertificateForHost(" ") + if err == nil { + t.Fatal("CertificateForHost() error = nil, want non-nil") + } + if cert != nil { + t.Fatalf("cert = %#v, want nil", cert) + } + }) + + t.Run("cache hit returns same pointer", func(t *testing.T) { + t.Parallel() + + c1, err := ca.CertificateForHost("example.com:443") + if err != nil { + t.Fatalf("first CertificateForHost() error = %v", err) + } + c2, err := ca.CertificateForHost("example.com") + if err != nil { + t.Fatalf("second CertificateForHost() error = %v", err) + } + + if c1 != c2 { + t.Fatal("expected same certificate pointer from cache") + } + }) + + t.Run("ip and dns SAN selection", func(t *testing.T) { + t.Parallel() + + ipCert, err := ca.CertificateForHost("127.0.0.1:443") + if err != nil { + t.Fatalf("ip CertificateForHost() error = %v", err) + } + if ipCert.Leaf == nil { + t.Fatal("ip cert leaf is nil") + } + if len(ipCert.Leaf.IPAddresses) == 0 { + t.Fatal("ip cert should contain IP SAN") + } + if len(ipCert.Leaf.DNSNames) != 0 { + t.Fatalf("ip cert DNSNames = %v, want empty", ipCert.Leaf.DNSNames) + } + + dnsCert, err := ca.CertificateForHost("service.local") + if err != nil { + t.Fatalf("dns CertificateForHost() error = %v", err) + } + if dnsCert.Leaf == nil { + t.Fatal("dns cert leaf is nil") + } + if len(dnsCert.Leaf.DNSNames) == 0 { + t.Fatal("dns cert should contain DNS SAN") + } + }) + + t.Run("evicts oldest entry over maxLeafCerts", func(t *testing.T) { + t.Parallel() + + ca2 := newTestCA(t) + for i := 0; i < maxLeafCerts+1; i++ { + host := filepath.Base(filepath.Join("h", big.NewInt(int64(i)).String()+".example")) + if _, err := ca2.CertificateForHost(host); err != nil { + t.Fatalf("CertificateForHost(%q) error = %v", host, err) + } + } + + if len(ca2.leafOrder) != maxLeafCerts { + t.Fatalf("leafOrder len = %d, want %d", len(ca2.leafOrder), maxLeafCerts) + } + if _, ok := ca2.leafCert["0.example"]; ok { + t.Fatal("expected oldest cert to be evicted") + } + }) +} + +func TestIsTrustedBySystem(t *testing.T) { + t.Parallel() + + var nilCA *CertificateAuthority + _, err := nilCA.IsTrustedBySystem() + if err == nil { + t.Fatal("nil IsTrustedBySystem() error = nil, want non-nil") + } + + ca := &CertificateAuthority{} + _, err = ca.IsTrustedBySystem() + if err == nil { + t.Fatal("missing-cert IsTrustedBySystem() error = nil, want non-nil") + } + + t.Run("untrusted generated CA returns false without error", func(t *testing.T) { + t.Parallel() + ca := newTestCA(t) + + trusted, err := ca.IsTrustedBySystem() + if err != nil { + t.Fatalf("IsTrustedBySystem() error = %v, want nil for unknown authority", err) + } + if trusted { + t.Fatal("trusted = true, want false for generated test CA") + } + }) +} + +func TestEnsureCertificateAuthority(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", configRoot) + + ca, err := EnsureCertificateAuthority() + if err != nil { + t.Fatalf("EnsureCertificateAuthority() error = %v", err) + } + if ca == nil { + t.Fatal("EnsureCertificateAuthority() returned nil CA") + } + if ca.CertPath() == "" { + t.Fatal("EnsureCertificateAuthority() returned empty cert path") + } + if _, statErr := os.Stat(ca.CertPath()); statErr != nil { + t.Fatalf("expected cert on disk, stat error = %v", statErr) + } +} + +func TestWriteFileAtomically_ErrorPath(t *testing.T) { + t.Parallel() + + err := writeFileAtomically(filepath.Join("/nope", "bad", "path.pem"), []byte("x"), 0o600) + if err == nil { + t.Fatal("writeFileAtomically() error = nil, want non-nil") + } + if errors.Is(err, os.ErrNotExist) { + return + } + // Accept platform-dependent fs errors as long as function fails. +} + +func TestWriteFileAtomically_RenameErrorWhenTargetIsDirectory(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + targetDir := filepath.Join(dir, "target-as-dir") + if err := os.MkdirAll(targetDir, 0o700); err != nil { + t.Fatalf("MkdirAll(targetDir) error = %v", err) + } + + err := writeFileAtomically(targetDir, []byte("x"), 0o600) + if err == nil { + t.Fatal("writeFileAtomically() error = nil, want non-nil") + } +} + +// TODO: Add deterministic tests for loadOrCreateCertificateAuthority trust-store interactions. diff --git a/internal/proxy/certs_trust_test.go b/internal/proxy/certs_trust_test.go new file mode 100644 index 0000000..7dd540e --- /dev/null +++ b/internal/proxy/certs_trust_test.go @@ -0,0 +1,29 @@ +package proxy + +import ( + "encoding/pem" + "os" + "path/filepath" + "testing" +) + +func TestIsTrustedBySystem_TrustedViaCertEnv(t *testing.T) { + ca := newTestCA(t) + + rootFile := filepath.Join(t.TempDir(), "roots.pem") + pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.cert.Raw}) + if err := os.WriteFile(rootFile, pemBytes, 0o600); err != nil { + t.Fatalf("WriteFile(root cert) error = %v", err) + } + + t.Setenv("SSL_CERT_FILE", rootFile) + t.Setenv("SSL_CERT_DIR", t.TempDir()) + + trusted, err := ca.IsTrustedBySystem() + if err != nil { + t.Fatalf("IsTrustedBySystem() error = %v", err) + } + if !trusted { + t.Fatal("trusted = false, want true when CA is in configured cert file") + } +} diff --git a/internal/proxy/handlers_test.go b/internal/proxy/handlers_test.go new file mode 100644 index 0000000..9737678 --- /dev/null +++ b/internal/proxy/handlers_test.go @@ -0,0 +1,330 @@ +package proxy + +import ( + "bufio" + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "termtap.dev/internal/model" +) + +type failingResponseWriter struct { + header http.Header + code int +} + +func (w *failingResponseWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *failingResponseWriter) WriteHeader(statusCode int) { + w.code = statusCode +} + +func (w *failingResponseWriter) Write(_ []byte) (int, error) { + return 0, io.ErrClosedPipe +} + +type hijackFailWriter struct { + header http.Header + code int +} + +func (w *hijackFailWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *hijackFailWriter) WriteHeader(statusCode int) { + w.code = statusCode +} + +func (w *hijackFailWriter) Write(p []byte) (int, error) { + return len(p), nil +} + +func (w *hijackFailWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, fmt.Errorf("hijack failed") +} + +type dummyConn struct{} + +func (dummyConn) Read(_ []byte) (int, error) { return 0, io.EOF } +func (dummyConn) Write(p []byte) (int, error) { return len(p), nil } +func (dummyConn) Close() error { return nil } +func (dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{} } +func (dummyConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } +func (dummyConn) SetDeadline(_ time.Time) error { return nil } +func (dummyConn) SetReadDeadline(_ time.Time) error { return nil } +func (dummyConn) SetWriteDeadline(_ time.Time) error { return nil } + +type writeConnectFailWriter struct { + header http.Header + code int +} + +func (w *writeConnectFailWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *writeConnectFailWriter) WriteHeader(statusCode int) { + w.code = statusCode +} + +func (w *writeConnectFailWriter) Write(p []byte) (int, error) { + return len(p), nil +} + +func (w *writeConnectFailWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + rw := bufio.NewReadWriter( + bufio.NewReader(strings.NewReader("")), + bufio.NewWriter(errWriter{}), + ) + return dummyConn{}, rw, nil +} + +func TestProxyHandler_NonConnectRejectsNonAbsoluteURL(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 8) + ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})} + h := proxyHandler(ch, nil, ps) + + req := httptest.NewRequest(http.MethodGet, "/not-proxy-form", nil) + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", w.Code, http.StatusBadRequest) + } + + events := drainEvents(t, ch, 1, time.Second) + if !hasEventType(events, model.EventTypeWarn) { + t.Fatalf("expected %s event, got %#v", model.EventTypeWarn, events) + } +} + +func TestProxyHandler_NonConnectSuccess(t *testing.T) { + t.Parallel() + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Upstream", "yes") + w.WriteHeader(http.StatusAccepted) + _, _ = w.Write([]byte("pong")) + })) + t.Cleanup(upstream.Close) + + ch := make(chan model.Event, 8) + ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})} + h := proxyHandler(ch, nil, ps) + + req := httptest.NewRequest(http.MethodGet, upstream.URL+"/ping", nil) + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + if w.Code != http.StatusAccepted { + t.Fatalf("status = %d, want %d", w.Code, http.StatusAccepted) + } + if got, want := w.Body.String(), "pong"; got != want { + t.Fatalf("body = %q, want %q", got, want) + } + if got := w.Header().Get("X-Upstream"); got != "yes" { + t.Fatalf("X-Upstream header = %q, want yes", got) + } + + events := drainEvents(t, ch, 2, time.Second) + if !hasEventType(events, model.EventTypeRequestStarted) { + t.Fatalf("expected %s event", model.EventTypeRequestStarted) + } + if !hasEventType(events, model.EventTypeRequestFinished) { + t.Fatalf("expected %s event", model.EventTypeRequestFinished) + } +} + +func TestProxyHandler_NonConnectUpstreamError(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 8) + ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})} + h := proxyHandler(ch, nil, ps) + + req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1:1/fail", nil) + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadGateway { + t.Fatalf("status = %d, want %d", w.Code, http.StatusBadGateway) + } + + events := drainEvents(t, ch, 2, time.Second) + if !hasEventType(events, model.EventTypeRequestStarted) { + t.Fatalf("expected %s event", model.EventTypeRequestStarted) + } + if !hasEventType(events, model.EventTypeRequestFailed) { + t.Fatalf("expected %s event", model.EventTypeRequestFailed) + } +} + +func TestProxyHandler_NonConnectWriteFailureEmitsFailedEvent(t *testing.T) { + t.Parallel() + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = io.Copy(w, bytes.NewBufferString("response-body")) + })) + t.Cleanup(upstream.Close) + + ch := make(chan model.Event, 8) + ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})} + h := proxyHandler(ch, nil, ps) + + req := httptest.NewRequest(http.MethodGet, upstream.URL+"/copy-fail", nil) + w := &failingResponseWriter{} + + h.ServeHTTP(w, req) + + events := drainEvents(t, ch, 2, time.Second) + if !hasEventType(events, model.EventTypeRequestStarted) { + t.Fatalf("expected %s event", model.EventTypeRequestStarted) + } + if !hasEventType(events, model.EventTypeRequestFailed) { + t.Fatalf("expected %s event", model.EventTypeRequestFailed) + } +} + +func TestHandleConnect_EarlyFailures(t *testing.T) { + t.Parallel() + + transport := &mockTransport{fn: func(req *http.Request) (*http.Response, error) { + return nil, context.Canceled + }} + + tests := []struct { + name string + writer http.ResponseWriter + req *http.Request + ca *CertificateAuthority + wantCode int + wantBody string + wantStat int + }{ + { + name: "nil CA returns 502 and failed event", + writer: httptest.NewRecorder(), + req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil), + ca: nil, + wantCode: http.StatusBadGateway, + wantBody: "certificate authority is unavailable", + wantStat: http.StatusBadGateway, + }, + { + name: "nil CA with host without port still fails predictably", + writer: httptest.NewRecorder(), + req: &http.Request{Method: http.MethodConnect, Host: "example.com"}, + ca: nil, + wantCode: http.StatusBadGateway, + wantBody: "certificate authority is unavailable", + wantStat: http.StatusBadGateway, + }, + { + name: "cert mint failure returns 502 and failed event", + writer: httptest.NewRecorder(), + req: &http.Request{Method: http.MethodConnect, Host: ""}, + ca: &CertificateAuthority{}, + wantCode: http.StatusBadGateway, + wantBody: "failed to mint interception certificate", + wantStat: http.StatusBadGateway, + }, + { + name: "non-hijacker writer returns 500 and failed event", + writer: httptest.NewRecorder(), + req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil), + ca: newTestCA(t), + wantCode: http.StatusInternalServerError, + wantBody: "hijack is unavailable", + wantStat: http.StatusInternalServerError, + }, + { + name: "hijack failure returns 500 and failed event", + writer: &hijackFailWriter{}, + req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil), + ca: newTestCA(t), + wantCode: http.StatusInternalServerError, + wantBody: "CONNECT hijack failed", + wantStat: http.StatusInternalServerError, + }, + { + name: "write connect established failure emits failed event", + writer: &writeConnectFailWriter{}, + req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil), + ca: newTestCA(t), + wantCode: 0, + wantBody: "CONNECT setup failed", + wantStat: http.StatusBadGateway, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 8) + ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})} + + handleConnect(tt.writer, tt.req, ch, transport, tt.ca, ps) + + events := drainEvents(t, ch, 2, time.Second) + if events[0].Type != model.EventTypeRequestStarted { + t.Fatalf("expected %s event", model.EventTypeRequestStarted) + } + if events[1].Type != model.EventTypeRequestFailed { + t.Fatalf("expected %s event", model.EventTypeRequestFailed) + } + if events[1].Request.Method != http.MethodConnect { + t.Fatalf("failed event request method = %q, want CONNECT", events[1].Request.Method) + } + if events[1].Request.Status != tt.wantStat { + t.Fatalf("failed event request status = %d, want %d", events[1].Request.Status, tt.wantStat) + } + if !events[1].Request.Failed || events[1].Request.Pending { + t.Fatalf("failed event request flags = pending:%v failed:%v, want pending:false failed:true", events[1].Request.Pending, events[1].Request.Failed) + } + if !strings.Contains(events[1].Body, tt.wantBody) { + t.Fatalf("failed event body %q does not contain %q", events[1].Body, tt.wantBody) + } + + if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok { + if recorder.Code != tt.wantCode { + t.Fatalf("status = %d, want %d", recorder.Code, tt.wantCode) + } + } + if w, ok := tt.writer.(*hijackFailWriter); ok { + if w.code != tt.wantCode { + t.Fatalf("status = %d, want %d", w.code, tt.wantCode) + } + } + }) + } +} + +// TODO: Add full TLS MITM loop integration test. diff --git a/internal/proxy/headers_test.go b/internal/proxy/headers_test.go new file mode 100644 index 0000000..fd7a85f --- /dev/null +++ b/internal/proxy/headers_test.go @@ -0,0 +1,157 @@ +package proxy + +import ( + "net/http" + "reflect" + "testing" +) + +func TestStripHopByHopHeaders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in http.Header + want http.Header + }{ + { + name: "removes static hop-by-hop headers", + in: http.Header{ + "Connection": {"keep-alive"}, + "Keep-Alive": {"timeout=5"}, + "Proxy-Connection": {"keep-alive"}, + "Transfer-Encoding": {"chunked"}, + "Upgrade": {"websocket"}, + "X-Custom": {"ok"}, + }, + want: http.Header{ + "X-Custom": {"ok"}, + }, + }, + { + name: "removes headers listed in connection with spaces and commas", + in: http.Header{ + "Connection": {" keep-alive, X-Foo ,X-Bar"}, + "Keep-Alive": {"timeout=5"}, + "X-Foo": {"foo"}, + "X-Bar": {"bar"}, + "Content-Type": {"application/json"}, + "Content-Length": {"12"}, + }, + want: http.Header{ + "Content-Type": {"application/json"}, + "Content-Length": {"12"}, + }, + }, + { + name: "keeps unrelated headers when no connection header", + in: http.Header{ + "Accept": {"*/*"}, + "Authorization": {"Bearer abc"}, + }, + want: http.Header{ + "Accept": {"*/*"}, + "Authorization": {"Bearer abc"}, + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + stripHopByHopHeaders(tt.in) + + if !reflect.DeepEqual(tt.in, tt.want) { + t.Fatalf("stripHopByHopHeaders() = %#v, want %#v", tt.in, tt.want) + } + }) + } +} + +func TestStripHopByHopHeaders_NilHeader(t *testing.T) { + t.Parallel() + + stripHopByHopHeaders(nil) +} + +func TestRedactHeaders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input http.Header + want http.Header + }{ + { + name: "redacts sensitive canonical headers", + input: http.Header{ + "Authorization": {"Bearer token123"}, + "Cookie": {"session=abc"}, + "Proxy-Authorization": {"Basic abc"}, + "Set-Cookie": {"a=1"}, + "X-Api-Key": {"secret"}, + }, + want: http.Header{ + "Authorization": {"[REDACTED]"}, + "Cookie": {"[REDACTED]"}, + "Proxy-Authorization": {"[REDACTED]"}, + "Set-Cookie": {"[REDACTED]"}, + "X-Api-Key": {"[REDACTED]"}, + }, + }, + { + name: "leaves non-sensitive headers untouched", + input: http.Header{ + "Content-Type": {"application/json"}, + "X-Trace-ID": {"trace-1"}, + }, + want: http.Header{ + "Content-Type": {"application/json"}, + "X-Trace-ID": {"trace-1"}, + }, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := redactHeaders(tt.input) + if !reflect.DeepEqual(got, tt.want) { + t.Fatalf("redactHeaders() = %#v, want %#v", got, tt.want) + } + + got.Set("Content-Type", "modified") + if reflect.DeepEqual(tt.input, got) { + t.Fatal("redactHeaders() appears to mutate input or return aliased map") + } + }) + } +} + +func TestCopyHeaders(t *testing.T) { + t.Parallel() + + src := http.Header{ + "X-Multi": {"a", "b", "c"}, + "Content-Type": {"application/json"}, + } + dest := http.Header{ + "Existing": {"keep"}, + } + + copyHeaders(src, dest) + + want := http.Header{ + "Existing": {"keep"}, + "X-Multi": {"a", "b", "c"}, + "Content-Type": {"application/json"}, + } + + if !reflect.DeepEqual(dest, want) { + t.Fatalf("copyHeaders() dest = %#v, want %#v", dest, want) + } +} diff --git a/internal/proxy/integration_http_test.go b/internal/proxy/integration_http_test.go new file mode 100644 index 0000000..6ea5a31 --- /dev/null +++ b/internal/proxy/integration_http_test.go @@ -0,0 +1,129 @@ +package proxy + +import ( + "io" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "termtap.dev/internal/model" +) + +func TestHTTPProxyE2E_WithClientProxyConfig(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", configRoot) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("pong")) + })) + t.Cleanup(upstream.Close) + + ch := make(chan model.Event, 64) + ps, err := NewProxyServer("127.0.0.1:0", ch) + if err != nil { + t.Fatalf("NewProxyServer() error = %v", err) + } + + serveDone := make(chan error, 1) + go func() { + serveDone <- ps.Server.Serve(*ps.Listener) + }() + + t.Cleanup(func() { + Destroy(ps, ch) + select { + case <-serveDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for proxy server to stop") + } + }) + + proxyURL, err := url.Parse(ps.Url) + if err != nil { + t.Fatalf("url.Parse(proxy) error = %v", err) + } + client := &http.Client{ + Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)}, + Timeout: 3 * time.Second, + } + + resp, err := client.Get(upstream.URL + "/ping") + if err != nil { + t.Fatalf("client.Get() error = %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadAll(response) error = %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK) + } + if got, want := string(body), "pong"; got != want { + t.Fatalf("body = %q, want %q", got, want) + } + + events := drainEvents(t, ch, 2, 2*time.Second) + if !hasEventType(events, model.EventTypeRequestStarted) { + t.Fatalf("missing %s event", model.EventTypeRequestStarted) + } + if !hasEventType(events, model.EventTypeRequestFinished) { + t.Fatalf("missing %s event", model.EventTypeRequestFinished) + } +} + +func TestHTTPProxyE2E_UpstreamFailureEmitsRequestFailed(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", configRoot) + + ch := make(chan model.Event, 64) + ps, err := NewProxyServer("127.0.0.1:0", ch) + if err != nil { + t.Fatalf("NewProxyServer() error = %v", err) + } + + serveDone := make(chan error, 1) + go func() { + serveDone <- ps.Server.Serve(*ps.Listener) + }() + + t.Cleanup(func() { + Destroy(ps, ch) + select { + case <-serveDone: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for proxy server to stop") + } + }) + + proxyURL, err := url.Parse(ps.Url) + if err != nil { + t.Fatalf("url.Parse(proxy) error = %v", err) + } + client := &http.Client{ + Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)}, + Timeout: 3 * time.Second, + } + + resp, reqErr := client.Get("http://127.0.0.1:1/unreachable") + if reqErr != nil { + t.Fatalf("client.Get() error = %v; proxy should reply with mapped HTTP status", reqErr) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadGateway { + t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusBadGateway) + } + + events := drainEvents(t, ch, 2, 3*time.Second) + if !hasEventType(events, model.EventTypeRequestStarted) { + t.Fatalf("missing %s event", model.EventTypeRequestStarted) + } + if !hasEventType(events, model.EventTypeRequestFailed) { + t.Fatalf("missing %s event", model.EventTypeRequestFailed) + } +} diff --git a/internal/proxy/integration_https_mitm_test.go b/internal/proxy/integration_https_mitm_test.go new file mode 100644 index 0000000..6774446 --- /dev/null +++ b/internal/proxy/integration_https_mitm_test.go @@ -0,0 +1,301 @@ +package proxy + +import ( + "bufio" + "crypto/tls" + "crypto/x509" + "errors" + "io" + "net" + "net/http" + "strings" + "testing" + "time" + + "termtap.dev/internal/model" +) + +type hijackPipeWriter struct { + conn net.Conn + + header http.Header + code int +} + +func (w *hijackPipeWriter) Header() http.Header { + if w.header == nil { + w.header = make(http.Header) + } + return w.header +} + +func (w *hijackPipeWriter) WriteHeader(statusCode int) { + w.code = statusCode +} + +func (w *hijackPipeWriter) Write(p []byte) (int, error) { + return len(p), nil +} + +func (w *hijackPipeWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return w.conn, nil, nil +} + +func startMITMConnect(t *testing.T, transport http.RoundTripper) (net.Conn, *tls.Conn, chan model.Event, chan struct{}) { + t.Helper() + + ca := newTestCA(t) + clientConn, serverConn := net.Pipe() + + writer := &hijackPipeWriter{conn: serverConn} + req, err := http.NewRequest(http.MethodConnect, "http://example.com:443", nil) + if err != nil { + t.Fatalf("NewRequest(CONNECT) error = %v", err) + } + req.Host = "example.com:443" + + ch := make(chan model.Event, 16) + ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})} + + handleDone := make(chan struct{}) + go func() { + handleConnect(writer, req, ch, transport, ca, ps) + close(handleDone) + }() + + reader := bufio.NewReader(clientConn) + connectResp, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect}) + if err != nil { + _ = clientConn.Close() + _ = serverConn.Close() + t.Fatalf("ReadResponse(CONNECT established) error = %v", err) + } + if connectResp.StatusCode != http.StatusOK { + _ = clientConn.Close() + _ = serverConn.Close() + t.Fatalf("CONNECT established status = %d, want %d", connectResp.StatusCode, http.StatusOK) + } + + pool := x509.NewCertPool() + pool.AddCert(ca.cert) + tlsClient := tls.Client(clientConn, &tls.Config{ + ServerName: "example.com", + RootCAs: pool, + MinVersion: tls.VersionTLS12, + }) + + if err := tlsClient.Handshake(); err != nil { + _ = tlsClient.Close() + _ = serverConn.Close() + t.Fatalf("tls handshake error = %v", err) + } + + t.Cleanup(func() { + _ = tlsClient.Close() + _ = serverConn.Close() + }) + + return clientConn, tlsClient, ch, handleDone +} + +func TestHTTPSE2E_MITMHandleConnectFlow(t *testing.T) { + transport := &mockTransport{fn: func(r *http.Request) (*http.Response, error) { + if r.URL.Scheme != "https" { + t.Fatalf("inner request scheme = %q, want https", r.URL.Scheme) + } + if r.URL.Host == "" { + t.Fatal("inner request host should be populated") + } + + return &http.Response{ + StatusCode: http.StatusCreated, + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{"Content-Type": {"text/plain"}}, + Body: io.NopCloser(strings.NewReader("mitm-ok")), + }, nil + }} + clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport) + + // 3) Send decrypted inner request over tunnel and read MITM response + innerReq, err := http.NewRequest(http.MethodGet, "https://example.com/inside", nil) + if err != nil { + t.Fatalf("NewRequest(inner) error = %v", err) + } + innerReq.Host = "example.com" + innerReq.Close = true + if err := innerReq.Write(tlsClient); err != nil { + t.Fatalf("innerReq.Write() error = %v", err) + } + + innerResp, err := http.ReadResponse(bufio.NewReader(tlsClient), innerReq) + if err != nil { + t.Fatalf("ReadResponse(inner) error = %v", err) + } + defer innerResp.Body.Close() + + body, err := io.ReadAll(innerResp.Body) + if err != nil { + t.Fatalf("ReadAll(inner body) error = %v", err) + } + if innerResp.StatusCode != http.StatusCreated { + t.Fatalf("inner status = %d, want %d", innerResp.StatusCode, http.StatusCreated) + } + if got, want := string(body), "mitm-ok"; got != want { + t.Fatalf("inner body = %q, want %q", got, want) + } + + _ = tlsClient.Close() + _ = clientConn.Close() + + select { + case <-handleDone: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for handleConnect to return") + } + + events := drainEvents(t, ch, 4, 2*time.Second) + if events[0].Type != model.EventTypeRequestStarted { + t.Fatalf("event[0] = %s, want %s", events[0].Type, model.EventTypeRequestStarted) + } + if events[1].Type != model.EventTypeRequestStarted { + t.Fatalf("event[1] = %s, want %s", events[1].Type, model.EventTypeRequestStarted) + } + if events[2].Type != model.EventTypeRequestFinished { + t.Fatalf("event[2] = %s, want %s", events[2].Type, model.EventTypeRequestFinished) + } + if events[3].Type != model.EventTypeRequestFinished { + t.Fatalf("event[3] = %s, want %s", events[3].Type, model.EventTypeRequestFinished) + } +} + +func TestHTTPSE2E_MITMUpstreamErrorReturnsHTTPErrorInsideTunnel(t *testing.T) { + transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) { + return nil, errors.New("upstream exploded") + }} + + clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport) + + innerReq, err := http.NewRequest(http.MethodGet, "https://example.com/fail", nil) + if err != nil { + t.Fatalf("NewRequest(inner) error = %v", err) + } + innerReq.Host = "example.com" + if err := innerReq.Write(tlsClient); err != nil { + t.Fatalf("innerReq.Write() error = %v", err) + } + + innerResp, err := http.ReadResponse(bufio.NewReader(tlsClient), innerReq) + if err != nil { + t.Fatalf("ReadResponse(inner) error = %v", err) + } + defer innerResp.Body.Close() + + if innerResp.StatusCode != http.StatusBadGateway { + t.Fatalf("inner status = %d, want %d", innerResp.StatusCode, http.StatusBadGateway) + } + + _ = tlsClient.Close() + _ = clientConn.Close() + + select { + case <-handleDone: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for handleConnect to return") + } + + events := drainEvents(t, ch, 4, 2*time.Second) + if events[0].Type != model.EventTypeRequestStarted || + events[1].Type != model.EventTypeRequestStarted || + events[2].Type != model.EventTypeRequestFailed || + events[3].Type != model.EventTypeRequestFailed { + t.Fatalf("unexpected event sequence: %#v", events) + } +} + +func TestHTTPSE2E_MITMHandshakeFailureEmitsConnectFailed(t *testing.T) { + ca := newTestCA(t) + + clientConn, serverConn := net.Pipe() + t.Cleanup(func() { + _ = clientConn.Close() + _ = serverConn.Close() + }) + + writer := &hijackPipeWriter{conn: serverConn} + req, err := http.NewRequest(http.MethodConnect, "http://example.com:443", nil) + if err != nil { + t.Fatalf("NewRequest(CONNECT) error = %v", err) + } + req.Host = "example.com:443" + + ch := make(chan model.Event, 16) + ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})} + transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) { + return nil, errors.New("should not be called") + }} + + handleDone := make(chan struct{}) + go func() { + handleConnect(writer, req, ch, transport, ca, ps) + close(handleDone) + }() + + reader := bufio.NewReader(clientConn) + connectResp, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect}) + if err != nil { + t.Fatalf("ReadResponse(CONNECT established) error = %v", err) + } + if connectResp.StatusCode != http.StatusOK { + t.Fatalf("CONNECT established status = %d, want %d", connectResp.StatusCode, http.StatusOK) + } + + // Write plaintext (not TLS handshake) to force handshake failure. + if _, err := clientConn.Write([]byte("not tls handshake")); err != nil { + t.Fatalf("client write error = %v", err) + } + _ = clientConn.Close() + + select { + case <-handleDone: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for handleConnect to return on handshake failure") + } + + events := drainEvents(t, ch, 2, 2*time.Second) + if events[0].Type != model.EventTypeRequestStarted || events[1].Type != model.EventTypeRequestFailed { + t.Fatalf("unexpected event sequence: %#v", events) + } + if !strings.Contains(events[1].Body, "TLS handshake with client failed") { + t.Fatalf("failed event body = %q, want handshake failure details", events[1].Body) + } +} + +func TestHTTPSE2E_MITMDecryptedReadFailureEmitsConnectFailed(t *testing.T) { + transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) { + return nil, errors.New("should not be called") + }} + + clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport) + + // Send malformed decrypted payload so ReadRequest fails (non-EOF branch). + if _, err := tlsClient.Write([]byte("BAD REQUEST\r\n\r\n")); err != nil { + t.Fatalf("tlsClient.Write() error = %v", err) + } + _ = tlsClient.Close() + _ = clientConn.Close() + + select { + case <-handleDone: + case <-time.After(3 * time.Second): + t.Fatal("timeout waiting for handleConnect to return on decrypted read failure") + } + + events := drainEvents(t, ch, 2, 2*time.Second) + if events[0].Type != model.EventTypeRequestStarted || events[1].Type != model.EventTypeRequestFailed { + t.Fatalf("unexpected event sequence: %#v", events) + } + if !strings.Contains(events[1].Body, "failed to read decrypted HTTPS request") { + t.Fatalf("failed event body = %q, want read failure details", events[1].Body) + } +} diff --git a/internal/proxy/preview_test.go b/internal/proxy/preview_test.go new file mode 100644 index 0000000..2d79bea --- /dev/null +++ b/internal/proxy/preview_test.go @@ -0,0 +1,142 @@ +package proxy + +import ( + "bufio" + "io" + "net" + "strings" + "testing" +) + +func TestNewBodyPreview(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + contentType string + wantEnabled bool + }{ + {name: "text content enabled", contentType: "text/plain", wantEnabled: true}, + {name: "json content enabled", contentType: "application/json", wantEnabled: true}, + {name: "binary content disabled", contentType: "application/octet-stream", wantEnabled: false}, + {name: "empty content type disabled", contentType: "", wantEnabled: false}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + p := newBodyPreview(tt.contentType) + if p.enabled != tt.wantEnabled { + t.Fatalf("newBodyPreview(%q).enabled = %v, want %v", tt.contentType, p.enabled, tt.wantEnabled) + } + }) + } +} + +func TestBodyPreviewWriteAndPreview(t *testing.T) { + t.Parallel() + + t.Run("nil receiver is safe", func(t *testing.T) { + t.Parallel() + var p *bodyPreview + p.Write([]byte("abc")) + }) + + t.Run("disabled preview ignores data", func(t *testing.T) { + t.Parallel() + p := &bodyPreview{enabled: false} + p.Write([]byte("abc")) + if got := string(p.Preview()); got != "" { + t.Fatalf("Preview() = %q, want empty", got) + } + }) + + t.Run("empty write does nothing", func(t *testing.T) { + t.Parallel() + p := &bodyPreview{enabled: true} + p.Write(nil) + if got := string(p.Preview()); got != "" { + t.Fatalf("Preview() = %q, want empty", got) + } + }) + + t.Run("escapes newlines", func(t *testing.T) { + t.Parallel() + p := &bodyPreview{enabled: true} + p.Write([]byte("a\nb")) + if got, want := string(p.Preview()), `a\nb`; got != want { + t.Fatalf("Preview() = %q, want %q", got, want) + } + }) + + t.Run("truncates at max preview bytes and appends ellipsis", func(t *testing.T) { + t.Parallel() + p := &bodyPreview{enabled: true} + p.Write([]byte(strings.Repeat("a", maxPreviewBytes))) + p.Write([]byte("b")) + + got := string(p.Preview()) + if !strings.HasSuffix(got, "...") { + t.Fatalf("Preview() must end with ellipsis when truncated: %q", got[len(got)-10:]) + } + if len(got) != maxPreviewBytes+3 { + t.Fatalf("len(Preview()) = %d, want %d", len(got), maxPreviewBytes+3) + } + }) +} + +func TestWrapBufferedConn(t *testing.T) { + t.Parallel() + + client, server := net.Pipe() + t.Cleanup(func() { + _ = client.Close() + _ = server.Close() + }) + + t.Run("returns original conn when readWriter nil", func(t *testing.T) { + t.Parallel() + got := wrapBufferedConn(client, nil) + if got != client { + t.Fatal("wrapBufferedConn should return original conn when readWriter is nil") + } + }) + + t.Run("read uses buffered readWriter", func(t *testing.T) { + t.Parallel() + rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("xyz")), bufio.NewWriter(io.Discard)) + got := wrapBufferedConn(client, rw) + + buf := make([]byte, 3) + n, err := got.Read(buf) + if err != nil { + t.Fatalf("Read() error = %v", err) + } + if n != 3 || string(buf) != "xyz" { + t.Fatalf("Read() = (%d, %q), want (3, %q)", n, string(buf), "xyz") + } + }) +} + +func TestPreviewReadCloserRead(t *testing.T) { + t.Parallel() + + preview := newBodyPreview("text/plain") + rc := &previewReadCloser{ + ReadCloser: io.NopCloser(strings.NewReader("hello\nworld")), + preview: preview, + } + + data, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("ReadAll() error = %v", err) + } + if got, want := string(data), "hello\nworld"; got != want { + t.Fatalf("read content = %q, want %q", got, want) + } + if got, want := string(preview.Preview()), `hello\nworld`; got != want { + t.Fatalf("preview = %q, want %q", got, want) + } +} diff --git a/internal/proxy/requests_test.go b/internal/proxy/requests_test.go new file mode 100644 index 0000000..8b7c21b --- /dev/null +++ b/internal/proxy/requests_test.go @@ -0,0 +1,275 @@ +package proxy + +import ( + "errors" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "termtap.dev/internal/model" +) + +type mockTransport struct { + fn func(*http.Request) (*http.Response, error) +} + +func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return m.fn(req) +} + +func TestRoundTripCapturedRequest_Success(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 8) + + reqURL, err := url.Parse("http://example.com/path?q=1") + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + req := &http.Request{ + Method: http.MethodPost, + URL: reqURL, + Host: "example.com", + Header: http.Header{ + "Connection": {"X-Hop"}, + "X-Hop": {"drop"}, + "Authorization": {"Bearer token"}, + "Content-Type": {"text/plain"}, + }, + Body: io.NopCloser(strings.NewReader("req\nbody")), + } + + transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) { + if got := outReq.Header.Get("Connection"); got != "" { + t.Fatalf("Connection header should be stripped, got %q", got) + } + if got := outReq.Header.Get("X-Hop"); got != "" { + t.Fatalf("header listed in Connection should be stripped, got %q", got) + } + + data, readErr := io.ReadAll(outReq.Body) + if readErr != nil { + t.Fatalf("ReadAll(outReq.Body) error = %v", readErr) + } + if got, want := string(data), "req\nbody"; got != want { + t.Fatalf("request body = %q, want %q", got, want) + } + + return &http.Response{ + StatusCode: http.StatusCreated, + Header: http.Header{ + "Set-Cookie": {"session=top-secret"}, + "Connection": {"close"}, + "Content-Type": {"text/plain"}, + }, + Body: io.NopCloser(strings.NewReader("resp\nbody")), + }, nil + }} + + resp, captured, responsePreview, err := roundTripCapturedRequest(req, transport, ch, "", false) + if err != nil { + t.Fatalf("roundTripCapturedRequest() error = %v", err) + } + if resp == nil { + t.Fatal("roundTripCapturedRequest() returned nil response") + } + if responsePreview == nil { + t.Fatal("roundTripCapturedRequest() returned nil response preview") + } + + if _, readErr := io.ReadAll(resp.Body); readErr != nil { + t.Fatalf("ReadAll(resp.Body) error = %v", readErr) + } + _ = resp.Body.Close() + + if got, want := string(captured.RequestData), `req\nbody`; got != want { + t.Fatalf("captured.RequestData = %q, want %q", got, want) + } + if got, want := string(responsePreview.Preview()), `resp\nbody`; got != want { + t.Fatalf("responsePreview = %q, want %q", got, want) + } + if got := captured.RequestHeaders.Get("Authorization"); got != "[REDACTED]" { + t.Fatalf("Authorization header = %q, want [REDACTED]", got) + } + if got := captured.ResponseHeaders.Get("Set-Cookie"); got != "[REDACTED]" { + t.Fatalf("Set-Cookie header = %q, want [REDACTED]", got) + } + if got := captured.ResponseHeaders.Get("Connection"); got != "" { + t.Fatalf("Connection should be stripped from response headers, got %q", got) + } + + events := drainEvents(t, ch, 1, time.Second) + if events[0].Type != model.EventTypeRequestStarted { + t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted) + } +} + +func TestRoundTripCapturedRequest_ErrorPath(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 8) + + reqURL, err := url.Parse("http://example.com/fail") + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + req := &http.Request{ + Method: http.MethodPost, + URL: reqURL, + Host: "example.com", + Header: http.Header{"Content-Type": {"text/plain"}}, + Body: io.NopCloser(strings.NewReader("boom\nbody")), + } + + wantErr := errors.New("upstream failed") + transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) { + _, _ = io.ReadAll(outReq.Body) + return nil, wantErr + }} + + resp, captured, responsePreview, gotErr := roundTripCapturedRequest(req, transport, ch, "", false) + if !errors.Is(gotErr, wantErr) { + t.Fatalf("error = %v, want %v", gotErr, wantErr) + } + if resp != nil { + t.Fatalf("response = %#v, want nil", resp) + } + if responsePreview != nil { + t.Fatalf("responsePreview = %#v, want nil", responsePreview) + } + if got, want := string(captured.RequestData), `boom\nbody`; got != want { + t.Fatalf("captured.RequestData = %q, want %q", got, want) + } + + events := drainEvents(t, ch, 1, time.Second) + if events[0].Type != model.EventTypeRequestStarted { + t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted) + } +} + +func TestRoundTripCapturedRequest_InterceptedTLSDefaults(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 8) + + u, err := url.Parse("/secure?p=1") + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + req := &http.Request{ + Method: http.MethodGet, + URL: u, + Header: http.Header{}, + } + + const defaultHost = "api.example.com:443" + transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) { + if got := outReq.URL.Scheme; got != "https" { + t.Fatalf("URL.Scheme = %q, want https", got) + } + if got := outReq.URL.Host; got != defaultHost { + t.Fatalf("URL.Host = %q, want %q", got, defaultHost) + } + if got := outReq.Host; got != defaultHost { + t.Fatalf("Host = %q, want %q", got, defaultHost) + } + + return &http.Response{ + StatusCode: http.StatusNoContent, + Header: http.Header{"Content-Type": {"text/plain"}}, + Body: io.NopCloser(strings.NewReader("")), + }, nil + }} + + _, _, _, gotErr := roundTripCapturedRequest(req, transport, ch, defaultHost, true) + if gotErr != nil { + t.Fatalf("roundTripCapturedRequest() error = %v", gotErr) + } +} + +func TestNewConnectRequest(t *testing.T) { + t.Parallel() + + now := time.Now().Add(-time.Second) + req := &http.Request{Method: http.MethodConnect, Host: "example.com:443"} + + got := newConnectRequest(req, now) + + if got.Method != http.MethodConnect { + t.Fatalf("Method = %q, want CONNECT", got.Method) + } + if got.Host != req.Host || got.URL != req.Host || got.RawURL != req.Host { + t.Fatalf("connect request host/url/raw mismatch: %#v", got) + } + if !got.Pending || got.Failed { + t.Fatalf("Pending/Failed = (%v,%v), want (true,false)", got.Pending, got.Failed) + } + if got.Status != -1 { + t.Fatalf("Status = %d, want -1", got.Status) + } + if got.StartTime != now { + t.Fatalf("StartTime = %v, want %v", got.StartTime, now) + } + if got.ID == uuid.Nil { + t.Fatal("ID must be non-zero UUID") + } +} + +func TestStartFinishFailRequestEvents(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 4) + req := model.Request{ + ID: uuid.New(), + Method: http.MethodGet, + RawURL: "http://example.com/a", + StartTime: time.Now().Add(-3 * time.Millisecond), + Pending: true, + } + + startRequest(ch, req) + events := drainEvents(t, ch, 1, time.Second) + startEv := events[0] + if startEv.Type != model.EventTypeRequestStarted { + t.Fatalf("start event type = %s, want %s", startEv.Type, model.EventTypeRequestStarted) + } + if startEv.Request.Pending != true { + t.Fatalf("start request pending = %v, want true", startEv.Request.Pending) + } + + finishRequest(ch, req, http.StatusOK) + events = drainEvents(t, ch, 1, time.Second) + finishEv := events[0] + if finishEv.Type != model.EventTypeRequestFinished { + t.Fatalf("finish event type = %s, want %s", finishEv.Type, model.EventTypeRequestFinished) + } + if finishEv.Request.Pending { + t.Fatal("finished request should not be pending") + } + if finishEv.Request.Failed { + t.Fatal("finished request should not be failed") + } + if finishEv.Request.Status != http.StatusOK { + t.Fatalf("finished status = %d, want %d", finishEv.Request.Status, http.StatusOK) + } + + failRequest(ch, req, http.StatusBadGateway, "upstream error") + events = drainEvents(t, ch, 1, time.Second) + failEv := events[0] + if failEv.Type != model.EventTypeRequestFailed { + t.Fatalf("fail event type = %s, want %s", failEv.Type, model.EventTypeRequestFailed) + } + if failEv.Request.Pending { + t.Fatal("failed request should not be pending") + } + if !failEv.Request.Failed { + t.Fatal("failed request should be marked failed") + } + if failEv.Request.Status != http.StatusBadGateway { + t.Fatalf("failed status = %d, want %d", failEv.Request.Status, http.StatusBadGateway) + } +} diff --git a/internal/proxy/secure_utils_test.go b/internal/proxy/secure_utils_test.go new file mode 100644 index 0000000..87d19e4 --- /dev/null +++ b/internal/proxy/secure_utils_test.go @@ -0,0 +1,138 @@ +package proxy + +import ( + "bufio" + "bytes" + "io" + "net" + "strings" + "testing" + "time" +) + +type captureConn struct { + bytes.Buffer +} + +func (c *captureConn) Read(_ []byte) (int, error) { return 0, io.EOF } +func (c *captureConn) Close() error { return nil } +func (c *captureConn) LocalAddr() net.Addr { return &net.TCPAddr{} } +func (c *captureConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } +func (c *captureConn) SetDeadline(_ time.Time) error { return nil } +func (c *captureConn) SetReadDeadline(_ time.Time) error { return nil } +func (c *captureConn) SetWriteDeadline(_ time.Time) error { return nil } + +type failWriteConn struct{} + +func (failWriteConn) Read(_ []byte) (int, error) { return 0, io.EOF } +func (failWriteConn) Write(_ []byte) (int, error) { return 0, io.ErrClosedPipe } +func (failWriteConn) Close() error { return nil } +func (failWriteConn) LocalAddr() net.Addr { return &net.TCPAddr{} } +func (failWriteConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } +func (failWriteConn) SetDeadline(_ time.Time) error { return nil } +func (failWriteConn) SetReadDeadline(_ time.Time) error { return nil } +func (failWriteConn) SetWriteDeadline(_ time.Time) error { return nil } + +type trackingBody struct { + data *bytes.Reader + readN int + closed bool +} + +func (b *trackingBody) Read(p []byte) (int, error) { + n, err := b.data.Read(p) + b.readN += n + return n, err +} + +func (b *trackingBody) Close() error { + b.closed = true + return nil +} + +func TestWriteConnectEstablished(t *testing.T) { + t.Parallel() + + t.Run("writes directly to raw conn", func(t *testing.T) { + t.Parallel() + + conn := &captureConn{} + if err := writeConnectEstablished(conn, nil); err != nil { + t.Fatalf("writeConnectEstablished() error = %v", err) + } + + if got, want := conn.String(), "HTTP/1.1 200 Connection Established\r\n\r\n"; got != want { + t.Fatalf("raw write = %q, want %q", got, want) + } + }) + + t.Run("writes and flushes with buffered readWriter", func(t *testing.T) { + t.Parallel() + + conn := &captureConn{} + rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bufio.NewWriter(conn)) + if err := writeConnectEstablished(conn, rw); err != nil { + t.Fatalf("writeConnectEstablished() error = %v", err) + } + + if got, want := conn.String(), "HTTP/1.1 200 Connection Established\r\n\r\n"; got != want { + t.Fatalf("buffered write = %q, want %q", got, want) + } + }) + + t.Run("returns flush error", func(t *testing.T) { + t.Parallel() + + rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bufio.NewWriter(errWriter{})) + err := writeConnectEstablished(&captureConn{}, rw) + if err == nil { + t.Fatal("writeConnectEstablished() error = nil, want non-nil") + } + }) + + t.Run("returns buffered write error when writer already failed", func(t *testing.T) { + t.Parallel() + + bw := bufio.NewWriter(errWriter{}) + _ = bw.Flush() // set sticky error to force WriteString error path + rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bw) + + err := writeConnectEstablished(&captureConn{}, rw) + if err == nil { + t.Fatal("writeConnectEstablished() error = nil, want non-nil") + } + }) + + t.Run("returns raw conn write error", func(t *testing.T) { + t.Parallel() + err := writeConnectEstablished(failWriteConn{}, nil) + if err == nil { + t.Fatal("writeConnectEstablished() error = nil, want non-nil") + } + }) +} + +func TestDiscardAndCloseBody(t *testing.T) { + t.Parallel() + + t.Run("nil body is safe", func(t *testing.T) { + t.Parallel() + discardAndCloseBody(nil) + }) + + t.Run("closes body and discards at most limit", func(t *testing.T) { + t.Parallel() + + payload := bytes.Repeat([]byte("x"), maxDiscardBodyBytes+128) + body := &trackingBody{data: bytes.NewReader(payload)} + + discardAndCloseBody(body) + + if !body.closed { + t.Fatal("body was not closed") + } + if body.readN != maxDiscardBodyBytes { + t.Fatalf("bytes read = %d, want %d", body.readN, maxDiscardBodyBytes) + } + }) +} diff --git a/internal/proxy/server_test.go b/internal/proxy/server_test.go new file mode 100644 index 0000000..710e79c --- /dev/null +++ b/internal/proxy/server_test.go @@ -0,0 +1,235 @@ +package proxy + +import ( + "context" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "termtap.dev/internal/model" +) + +// NOTE: Run these tests with -race; they cover concurrent state. + +type closeTrackingConn struct { + closed bool + mu sync.Mutex +} + +func (c *closeTrackingConn) Read(_ []byte) (int, error) { return 0, net.ErrClosed } +func (c *closeTrackingConn) Write(b []byte) (int, error) { return len(b), nil } +func (c *closeTrackingConn) Close() error { c.mu.Lock(); c.closed = true; c.mu.Unlock(); return nil } +func (c *closeTrackingConn) LocalAddr() net.Addr { return &net.TCPAddr{} } +func (c *closeTrackingConn) RemoteAddr() net.Addr { return &net.TCPAddr{} } +func (c *closeTrackingConn) SetDeadline(_ time.Time) error { return nil } +func (c *closeTrackingConn) SetReadDeadline(_ time.Time) error { return nil } +func (c *closeTrackingConn) SetWriteDeadline(_ time.Time) error { return nil } + +func (c *closeTrackingConn) isClosed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +func TestNewProxyServer_Success(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", configRoot) + + ch := make(chan model.Event, 16) + ps, err := NewProxyServer("127.0.0.1:0", ch) + if err != nil { + t.Fatalf("NewProxyServer() error = %v", err) + } + t.Cleanup(func() { Destroy(ps, ch) }) + + if ps.Listener == nil || *ps.Listener == nil { + t.Fatal("Listener is nil") + } + if ps.Server == nil { + t.Fatal("Server is nil") + } + if ps.Server.Handler == nil { + t.Fatal("Server.Handler is nil") + } + if !strings.HasPrefix(ps.Url, "http://") { + t.Fatalf("Url = %q, want http:// prefix", ps.Url) + } + if !ps.CAReady { + t.Fatal("CAReady = false, want true") + } + if ps.CACertPath == "" { + t.Fatal("CACertPath should not be empty") + } + if got, want := filepath.Dir(ps.CACertPath), filepath.Join(configRoot, caDirName); got != want { + t.Fatalf("CACertPath dir = %q, want %q", got, want) + } + if _, statErr := os.Stat(ps.CACertPath); statErr != nil { + t.Fatalf("expected CA cert on disk, stat error = %v", statErr) + } + if ps.Conns == nil { + t.Fatal("Conns map is nil") + } +} + +func TestNewProxyServer_CACreatedFlagAcrossRuns(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", configRoot) + + ch := make(chan model.Event, 16) + ps1, err := NewProxyServer("127.0.0.1:0", ch) + if err != nil { + t.Fatalf("first NewProxyServer() error = %v", err) + } + Destroy(ps1, ch) + + ps2, err := NewProxyServer("127.0.0.1:0", ch) + if err != nil { + t.Fatalf("second NewProxyServer() error = %v", err) + } + t.Cleanup(func() { Destroy(ps2, ch) }) + + if !ps1.CACreated { + t.Fatal("first NewProxyServer should report CACreated=true") + } + if ps2.CACreated { + t.Fatal("second NewProxyServer should report CACreated=false when CA already exists") + } +} + +func TestNewProxyServer_ErrorWhenCASetupFails(t *testing.T) { + t.Setenv("XDG_CONFIG_HOME", "") + t.Setenv("HOME", "") + + ch := make(chan model.Event, 8) + ps, err := NewProxyServer("127.0.0.1:0", ch) + if err == nil { + Destroy(ps, ch) + t.Fatal("NewProxyServer() error = nil, want non-nil") + } + if ps != nil { + t.Fatalf("proxy server = %#v, want nil", ps) + } +} + +func TestNewProxyServer_ErrorWhenListenFails(t *testing.T) { + configRoot := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", configRoot) + + occupied, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen() error = %v", err) + } + t.Cleanup(func() { _ = occupied.Close() }) + + addr := occupied.Addr().String() + ch := make(chan model.Event, 8) + ps, gotErr := NewProxyServer(addr, ch) + if gotErr == nil { + Destroy(ps, ch) + t.Fatal("NewProxyServer() error = nil, want non-nil") + } + if ps != nil { + t.Fatalf("proxy server = %#v, want nil", ps) + } +} + +func TestDestroy(t *testing.T) { + t.Parallel() + + t.Run("nil-safe", func(t *testing.T) { + t.Parallel() + Destroy(nil, make(chan model.Event, 1)) + }) + + t.Run("emits ProxyStopped when server exists", func(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 2) + ps := &model.ProxyServer{Server: &http.Server{}, Conns: make(map[net.Conn]struct{})} + + Destroy(ps, ch) + + select { + case ev := <-ch: + if ev.Type != model.EventTypeProxyStopped { + t.Fatalf("event type = %s, want %s", ev.Type, model.EventTypeProxyStopped) + } + case <-time.After(time.Second): + t.Fatal("timeout waiting for ProxyStopped event") + } + }) +} + +func TestConnectionTrackingHelpers(t *testing.T) { + t.Parallel() + + t.Run("track and untrack are nil-safe", func(t *testing.T) { + t.Parallel() + trackConnection(nil, nil) + untrackConnection(nil, nil) + closeTrackedConnections(nil) + }) + + t.Run("track/untrack mutate connection set", func(t *testing.T) { + t.Parallel() + + ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})} + c := &closeTrackingConn{} + + trackConnection(ps, c) + if _, ok := ps.Conns[c]; !ok { + t.Fatal("connection not tracked") + } + + untrackConnection(ps, c) + if _, ok := ps.Conns[c]; ok { + t.Fatal("connection still tracked after untrack") + } + }) + + t.Run("closeTrackedConnections closes all tracked conns", func(t *testing.T) { + t.Parallel() + + ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})} + c1 := &closeTrackingConn{} + c2 := &closeTrackingConn{} + ps.Conns[c1] = struct{}{} + ps.Conns[c2] = struct{}{} + + closeTrackedConnections(ps) + + if !c1.isClosed() || !c2.isClosed() { + t.Fatalf("expected all tracked conns closed, got c1=%v c2=%v", c1.isClosed(), c2.isClosed()) + } + }) +} + +func TestDestroy_ShutdownsServerContext(t *testing.T) { + t.Parallel() + + // Basic smoke test: ensure a real server can be shut down via Destroy. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen error = %v", err) + } + + srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("ok")) })} + go func() { + _ = srv.Serve(ln) + }() + + ch := make(chan model.Event, 2) + ps := &model.ProxyServer{Server: srv, Conns: make(map[net.Conn]struct{})} + Destroy(ps, ch) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil && err != http.ErrServerClosed { + t.Fatalf("server should be closed after Destroy, got shutdown error %v", err) + } +} diff --git a/internal/proxy/test_helpers_test.go b/internal/proxy/test_helpers_test.go new file mode 100644 index 0000000..ae67864 --- /dev/null +++ b/internal/proxy/test_helpers_test.go @@ -0,0 +1,77 @@ +package proxy + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "testing" + "time" + + "termtap.dev/internal/model" +) + +func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event { + t.Helper() + + events := make([]model.Event, 0, n) + deadline := time.After(timeout) + for len(events) < n { + select { + case ev := <-ch: + events = append(events, ev) + case <-deadline: + t.Fatalf("timeout waiting for %d events, got %d", n, len(events)) + } + } + + return events +} + +func hasEventType(events []model.Event, typ model.EventType) bool { + for _, ev := range events { + if ev.Type == typ { + return true + } + } + return false +} + +func newTestCA(t *testing.T) *CertificateAuthority { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey() error = %v", err) + } + + now := time.Now() + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test-ca"}, + NotBefore: now.Add(-time.Minute), + NotAfter: now.Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, + IsCA: true, + BasicConstraintsValid: true, + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + if err != nil { + t.Fatalf("CreateCertificate() error = %v", err) + } + + cert, err := x509.ParseCertificate(der) + if err != nil { + t.Fatalf("ParseCertificate() error = %v", err) + } + + return &CertificateAuthority{ + cert: cert, + key: key, + leafCert: make(map[string]*tls.Certificate), + } +} diff --git a/internal/proxy/utils_test.go b/internal/proxy/utils_test.go new file mode 100644 index 0000000..88d2959 --- /dev/null +++ b/internal/proxy/utils_test.go @@ -0,0 +1,256 @@ +package proxy + +import ( + "bufio" + "context" + "errors" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "testing" + + "github.com/google/uuid" +) + +type timeoutErr struct{} + +func (timeoutErr) Error() string { return "timeout" } +func (timeoutErr) Timeout() bool { return true } +func (timeoutErr) Temporary() bool { return true } + +type errWriter struct{} + +func (errWriter) Write(_ []byte) (int, error) { + return 0, errors.New("write failed") +} + +func TestCanDisplayContent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + contentType string + want bool + }{ + {name: "empty content type", contentType: "", want: false}, + {name: "text type", contentType: "text/plain", want: true}, + {name: "json type", contentType: "application/json", want: true}, + {name: "xml suffix", contentType: "application/problem+xml", want: true}, + {name: "unknown binary", contentType: "application/octet-stream", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := canDisplayContent(tt.contentType); got != tt.want { + t.Fatalf("canDisplayContent(%q) = %v, want %v", tt.contentType, got, tt.want) + } + }) + } +} + +func TestFormatHeaders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + headers http.Header + want string + }{ + { + name: "empty returns none", + headers: http.Header{}, + want: "", + }, + { + name: "sorts keys stably", + headers: http.Header{ + "B-Key": {"b1"}, + "A-Key": {"a1", "a2"}, + }, + want: `A-Key="a1,a2", B-Key="b1"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := formatHeaders(tt.headers); got != tt.want { + t.Fatalf("formatHeaders() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestGetEndOfUUID(t *testing.T) { + t.Parallel() + + id := uuid.MustParse("123e4567-e89b-12d3-a456-426614174000") + if got, want := getEndOfUUID(id), "426614174000"; got != want { + t.Fatalf("getEndOfUUID() = %q, want %q", got, want) + } +} + +func TestStatusFromUpstreamError(t *testing.T) { + t.Parallel() + + newReq := func(ctx context.Context) *http.Request { + reqURL, err := url.Parse("http://example.com") + if err != nil { + t.Fatalf("url parse failed: %v", err) + } + return (&http.Request{Method: http.MethodGet, URL: reqURL}).WithContext(ctx) + } + + tests := []struct { + name string + req *http.Request + resp *http.Response + err error + want int + }{ + { + name: "prefers upstream response status", + req: newReq(context.Background()), + resp: &http.Response{StatusCode: http.StatusTeapot}, + err: errors.New("ignored"), + want: http.StatusTeapot, + }, + { + name: "context canceled maps to bad gateway", + req: func() *http.Request { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + return newReq(ctx) + }(), + err: context.Canceled, + want: http.StatusBadGateway, + }, + { + name: "deadline exceeded maps to gateway timeout", + req: newReq(context.Background()), + err: context.DeadlineExceeded, + want: http.StatusGatewayTimeout, + }, + { + name: "net timeout maps to gateway timeout", + req: newReq(context.Background()), + err: timeoutErr{}, + want: http.StatusGatewayTimeout, + }, + { + name: "default maps to bad gateway", + req: newReq(context.Background()), + err: errors.New("dial failed"), + want: http.StatusBadGateway, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := statusFromUpstreamError(tt.req, tt.resp, tt.err); got != tt.want { + t.Fatalf("statusFromUpstreamError() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestNewUpstreamTransport(t *testing.T) { + t.Parallel() + + got := newUpstreamTransport() + transport, ok := got.(*http.Transport) + if !ok { + t.Fatalf("newUpstreamTransport() type = %T, want *http.Transport", got) + } + + if transport.Proxy != nil { + t.Fatal("newUpstreamTransport() Proxy must be nil") + } +} + +func TestWritePlainHTTPError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + status int + writer *bufio.Writer + wantErr bool + wantStatus int + }{ + { + name: "writes valid response and flushes", + status: http.StatusBadGateway, + writer: bufio.NewWriter(&strings.Builder{}), + wantErr: false, + wantStatus: http.StatusBadGateway, + }, + { + name: "returns write error", + status: http.StatusBadGateway, + writer: bufio.NewWriter(errWriter{}), + wantErr: true, + }, + { + name: "returns response write error when writer already failed", + status: http.StatusBadGateway, + writer: bufio.NewWriter(errWriter{}), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var sb strings.Builder + w := tt.writer + if !tt.wantErr { + w = bufio.NewWriter(&sb) + } + + err := writePlainHTTPError(w, tt.status) + if tt.name == "returns response write error when writer already failed" { + _ = w.Flush() // set sticky writer error so resp.Write fails immediately + err = writePlainHTTPError(w, tt.status) + } + if (err != nil) != tt.wantErr { + t.Fatalf("writePlainHTTPError() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr { + return + } + + resp, readErr := http.ReadResponse(bufio.NewReader(strings.NewReader(sb.String())), &http.Request{Method: http.MethodGet}) + if readErr != nil { + t.Fatalf("ReadResponse() error = %v", readErr) + } + defer resp.Body.Close() + + if resp.StatusCode != tt.wantStatus { + t.Fatalf("status = %d, want %d", resp.StatusCode, tt.wantStatus) + } + + if gotCT := resp.Header.Get("Content-Type"); gotCT != "text/plain; charset=utf-8" { + t.Fatalf("Content-Type = %q, want %q", gotCT, "text/plain; charset=utf-8") + } + + wantBody := http.StatusText(tt.status) + if gotCL := resp.Header.Get("Content-Length"); gotCL != strconv.Itoa(len(wantBody)) { + t.Fatalf("Content-Length = %q, want %q", gotCL, strconv.Itoa(len(wantBody))) + } + + body, bodyErr := io.ReadAll(resp.Body) + if bodyErr != nil { + t.Fatalf("ReadAll(body) error = %v", bodyErr) + } + if string(body) != wantBody { + t.Fatalf("body = %q, want %q", string(body), wantBody) + } + }) + } +} diff --git a/internal/tui/model_update_test.go b/internal/tui/model_update_test.go new file mode 100644 index 0000000..75d52cc --- /dev/null +++ b/internal/tui/model_update_test.go @@ -0,0 +1,361 @@ +package tui + +import ( + "errors" + "net/http" + "testing" + "time" + + tea "github.com/charmbracelet/bubbletea" + "github.com/google/uuid" + "termtap.dev/internal/model" +) + +func TestNewModelDefaults(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event) + m := NewModel(ch, Controls{}) + + if m.channel != ch { + t.Fatal("channel not set") + } + if len(m.events) != 0 || len(m.requests) != 0 { + t.Fatal("events/requests should initialize empty") + } + if m.width != 0 || m.height != 0 { + t.Fatal("width/height should initialize zero") + } + if m.showEvents || m.showStd || m.showSearch || m.restarting { + t.Fatal("toggle flags should initialize false") + } +} + +func TestInitBatchesEventAndTick(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event) + m := NewModel(ch, Controls{}) + + cmd := m.Init() + if cmd == nil { + t.Fatal("Init() returned nil cmd") + } + if _, ok := cmd().(tea.BatchMsg); !ok { + t.Fatalf("Init cmd message type = %T, want tea.BatchMsg", cmd()) + } +} + +func TestWaitForEvent(t *testing.T) { + t.Parallel() + + t.Run("returns EventMsg when channel has value", func(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event, 1) + ch <- model.Event{Type: model.EventTypeWarn, Body: "hello"} + + msg := waitForEvent(ch)() + ev, ok := msg.(EventMsg) + if !ok { + t.Fatalf("msg type = %T, want EventMsg", msg) + } + if ev.value.Body != "hello" { + t.Fatalf("event body = %q, want %q", ev.value.Body, "hello") + } + }) + + t.Run("returns ErrMsg when channel closed", func(t *testing.T) { + t.Parallel() + + ch := make(chan model.Event) + close(ch) + + msg := waitForEvent(ch)() + if _, ok := msg.(ErrMsg); !ok { + t.Fatalf("msg type = %T, want ErrMsg", msg) + } + }) +} + +func TestMessagesCommands(t *testing.T) { + t.Parallel() + + t.Run("restartCmd nil restart returns nil", func(t *testing.T) { + t.Parallel() + if cmd := restartCmd(nil); cmd != nil { + t.Fatal("restartCmd(nil) should return nil") + } + }) + + t.Run("restartCmd wraps restart result", func(t *testing.T) { + t.Parallel() + wantErr := errors.New("boom") + msg := restartCmd(func() error { return wantErr })() + rm, ok := msg.(RestartResultMsg) + if !ok { + t.Fatalf("msg type = %T, want RestartResultMsg", msg) + } + if !errors.Is(rm.err, wantErr) { + t.Fatalf("restart result error = %v, want %v", rm.err, wantErr) + } + }) + + t.Run("tickCmd emits TickMsg", func(t *testing.T) { + t.Parallel() + cmd := tickCmd() + if cmd == nil { + t.Fatal("tickCmd returned nil") + } + + msgCh := make(chan tea.Msg, 1) + go func() { msgCh <- cmd() }() + + select { + case msg := <-msgCh: + if _, ok := msg.(TickMsg); !ok { + t.Fatalf("msg type = %T, want TickMsg", msg) + } + case <-time.After(200 * time.Millisecond): + t.Fatal("timeout waiting for tick message") + } + }) +} + +func TestUpdate(t *testing.T) { + t.Parallel() + + t.Run("window size updates dimensions", func(t *testing.T) { + t.Parallel() + m := NewModel(make(chan model.Event), Controls{}) + next, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40}) + got := next.(Model) + if got.width != 120 || got.height != 40 { + t.Fatalf("dimensions = (%d,%d), want (120,40)", got.width, got.height) + } + }) + + t.Run("tick updates now and reschedules only with pending requests", func(t *testing.T) { + t.Parallel() + now := time.Now() + + m1 := NewModel(make(chan model.Event), Controls{}) + next1, cmd1 := m1.Update(TickMsg{Now: now}) + got1 := next1.(Model) + if !got1.now.Equal(now) { + t.Fatal("tick should update now") + } + if cmd1 != nil { + t.Fatal("tick without pending requests should not reschedule") + } + + m2 := NewModel(make(chan model.Event), Controls{}) + m2.requests = append(m2.requests, model.Request{ID: uuid.New(), Pending: true}) + _, cmd2 := m2.Update(TickMsg{Now: now}) + if cmd2 == nil { + t.Fatal("tick with pending requests should reschedule") + } + }) + + t.Run("key handling toggles and quit", func(t *testing.T) { + t.Parallel() + + m := NewModel(make(chan model.Event), Controls{}) + next, quitCmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("q")}) + _ = next + if quitCmd == nil { + t.Fatal("q should return quit cmd") + } + + nextCtrlC, quitCtrlC := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC}) + _ = nextCtrlC + if quitCtrlC == nil { + t.Fatal("ctrl+c should return quit cmd") + } + + next2, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("e")}) + if !next2.(Model).showEvents { + t.Fatal("e should toggle showEvents") + } + + next3, _ := next2.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("o")}) + if !next3.(Model).showStd { + t.Fatal("o should toggle showStd") + } + + next4, _ := next3.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("/")}) + if !next4.(Model).showSearch { + t.Fatal("/ should enable search") + } + + next5, _ := next4.(Model).Update(tea.KeyMsg{Type: tea.KeyEsc}) + if next5.(Model).showSearch { + t.Fatal("esc should disable search") + } + + next6, cmd6 := next5.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("x")}) + if cmd6 != nil { + t.Fatal("unknown key should not return command") + } + if next6.(Model).showEvents != next5.(Model).showEvents || + next6.(Model).showStd != next5.(Model).showStd || + next6.(Model).showSearch != next5.(Model).showSearch { + t.Fatal("unknown key should not alter toggle state") + } + }) + + t.Run("restart key guarded by state and control fn", func(t *testing.T) { + t.Parallel() + + m := NewModel(make(chan model.Event), Controls{}) + next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlR}) + if cmd != nil { + t.Fatal("ctrl+r with nil restart control should return nil cmd") + } + if next.(Model).restarting { + t.Fatal("restarting should remain false when restart control missing") + } + + m2 := NewModel(make(chan model.Event), Controls{Restart: func() error { return nil }}) + next2, cmd2 := m2.Update(tea.KeyMsg{Type: tea.KeyCtrlR}) + if cmd2 == nil { + t.Fatal("ctrl+r with restart control should return cmd") + } + if !next2.(Model).restarting { + t.Fatal("restarting should be true after ctrl+r") + } + + next3, cmd3 := next2.(Model).Update(tea.KeyMsg{Type: tea.KeyCtrlR}) + if cmd3 != nil { + t.Fatal("ctrl+r while restarting should return nil cmd") + } + if !next3.(Model).restarting { + t.Fatal("restarting should stay true while guarded") + } + }) + + t.Run("ErrMsg pushes warning event", func(t *testing.T) { + t.Parallel() + m := NewModel(make(chan model.Event), Controls{}) + next, _ := m.Update(ErrMsg{err: errors.New("closed")}) + got := next.(Model) + if len(got.events) != 1 { + t.Fatalf("event len = %d, want 1", len(got.events)) + } + if got.events[0].Type != model.EventTypeWarn { + t.Fatalf("event type = %s, want %s", got.events[0].Type, model.EventTypeWarn) + } + }) + + t.Run("RestartResultMsg clears restarting and warns on error", func(t *testing.T) { + t.Parallel() + m := NewModel(make(chan model.Event), Controls{}) + m.restarting = true + + next1, _ := m.Update(RestartResultMsg{err: nil}) + if next1.(Model).restarting { + t.Fatal("restarting should clear on restart result") + } + + next2, _ := m.Update(RestartResultMsg{err: errors.New("fail")}) + got2 := next2.(Model) + if len(got2.events) != 1 || got2.events[0].Type != model.EventTypeWarn { + t.Fatalf("expected warn event on restart error, got %#v", got2.events) + } + }) + + t.Run("EventMsg updates events and requests", func(t *testing.T) { + t.Parallel() + ch := make(chan model.Event, 2) + m := NewModel(ch, Controls{}) + + reqID := uuid.New() + startEv := EventMsg{value: model.Event{Type: model.EventTypeRequestStarted, Request: model.Request{ID: reqID, Method: http.MethodGet, Pending: true}}} + next1, cmd1 := m.Update(startEv) + if cmd1 == nil { + t.Fatal("EventMsg should return wait/tick cmd") + } + got1 := next1.(Model) + if len(got1.events) != 1 || len(got1.requests) != 1 { + t.Fatalf("expected one event and one request, got events=%d requests=%d", len(got1.events), len(got1.requests)) + } + + finishReq := got1.requests[0] + finishReq.Pending = false + finishReq.Status = 200 + finishEv := EventMsg{value: model.Event{Type: model.EventTypeRequestFinished, Request: finishReq}} + next2, cmd2 := got1.Update(finishEv) + got2 := next2.(Model) + if got2.requests[0].Pending { + t.Fatal("request should be updated to non-pending") + } + if got2.requests[0].Status != 200 { + t.Fatalf("request status = %d, want 200", got2.requests[0].Status) + } + + if cmd2 == nil { + t.Fatal("expected waitForEvent cmd after finished request") + } + ch <- model.Event{Type: model.EventTypeWarn, Body: "next"} + msg := cmd2() + if _, ok := msg.(EventMsg); !ok { + t.Fatalf("cmd2 message type = %T, want EventMsg", msg) + } + }) +} + +func TestModelHelpers(t *testing.T) { + t.Parallel() + + t.Run("pushEvent trims to maxEvents", func(t *testing.T) { + t.Parallel() + m := NewModel(make(chan model.Event), Controls{}) + for i := 0; i < maxEvents+5; i++ { + m.pushEvent(model.Event{Body: "x"}) + } + if len(m.events) != maxEvents { + t.Fatalf("events len = %d, want %d", len(m.events), maxEvents) + } + }) + + t.Run("createRequest ignores CONNECT and trims", func(t *testing.T) { + t.Parallel() + m := NewModel(make(chan model.Event), Controls{}) + m.createRequest(model.Request{Method: http.MethodConnect}) + if len(m.requests) != 0 { + t.Fatal("CONNECT request should be ignored") + } + + for i := 0; i < maxRequests+3; i++ { + m.createRequest(model.Request{ID: uuid.New(), Method: http.MethodGet}) + } + if len(m.requests) != maxRequests { + t.Fatalf("requests len = %d, want %d", len(m.requests), maxRequests) + } + }) + + t.Run("updateRequest updates only matching request", func(t *testing.T) { + t.Parallel() + id1 := uuid.New() + id2 := uuid.New() + m := NewModel(make(chan model.Event), Controls{}) + m.requests = []model.Request{{ID: id1, Status: 100}, {ID: id2, Status: 101}} + + m.updateRequest(model.Request{ID: id2, Status: 202}) + if m.requests[0].Status != 100 || m.requests[1].Status != 202 { + t.Fatalf("unexpected statuses after update: %#v", m.requests) + } + }) + + t.Run("hasPendingRequests true and false", func(t *testing.T) { + t.Parallel() + m := NewModel(make(chan model.Event), Controls{}) + if m.hasPendingRequests() { + t.Fatal("empty model should not have pending requests") + } + m.requests = []model.Request{{ID: uuid.New(), Pending: false}, {ID: uuid.New(), Pending: true}} + if !m.hasPendingRequests() { + t.Fatal("expected pending requests to be true") + } + }) +} diff --git a/internal/tui/view_split_panes_style_test.go b/internal/tui/view_split_panes_style_test.go new file mode 100644 index 0000000..86a3934 --- /dev/null +++ b/internal/tui/view_split_panes_style_test.go @@ -0,0 +1,245 @@ +package tui + +import ( + "strings" + "testing" + "time" + + "github.com/google/uuid" + "termtap.dev/internal/model" +) + +func TestFormatDuration(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in time.Duration + want string + }{ + {name: "pending zero", in: 0, want: "PENDING"}, + {name: "microseconds", in: 750 * time.Microsecond, want: "750us"}, + {name: "milliseconds", in: 20 * time.Millisecond, want: "20ms"}, + {name: "seconds", in: 11 * time.Second, want: "11.00s"}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := formatDuration(tt.in); got != tt.want { + t.Fatalf("formatDuration(%v) = %q, want %q", tt.in, got, tt.want) + } + }) + } +} + +func TestTruncate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + s string + max int + want string + }{ + {name: "short unchanged", s: "abc", max: 3, want: "abc"}, + {name: "max small no ellipsis", s: "abcdef", max: 3, want: "abc"}, + {name: "ellipsis", s: "abcdef", max: 5, want: "ab..."}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := truncate(tt.s, tt.max); got != tt.want { + t.Fatalf("truncate(%q,%d) = %q, want %q", tt.s, tt.max, got, tt.want) + } + }) + } +} + +func TestClampRendered(t *testing.T) { + t.Parallel() + + if got := clampRendered("abcdef", 0); got != "" { + t.Fatalf("clampRendered max=0 = %q, want empty", got) + } + if got := clampRendered("abc", 10); got != "abc" { + t.Fatalf("clampRendered no truncation = %q, want %q", got, "abc") + } + if got := clampRendered("abcdef", 4); !strings.Contains(got, "...") { + t.Fatalf("clampRendered truncation should include ellipsis, got %q", got) + } +} + +func TestGetEventColor(t *testing.T) { + t.Parallel() + + theme := newTheme() + tests := []struct { + name string + typ model.EventType + want string + }{ + {name: "session", typ: model.EventTypeSessionStarted, want: theme.EventSession.Render("x")}, + {name: "proxy", typ: model.EventTypeProxyStarted, want: theme.EventProxy.Render("x")}, + {name: "request in flight", typ: model.EventTypeRequestStarted, want: theme.EventRequestInFlight.Render("x")}, + {name: "request success", typ: model.EventTypeRequestFinished, want: theme.EventSuccess.Render("x")}, + {name: "warn", typ: model.EventTypeWarn, want: theme.EventWarn.Render("x")}, + {name: "error", typ: model.EventTypeRequestFailed, want: theme.EventError.Render("x")}, + {name: "fatal", typ: model.EventTypeFatal, want: theme.EventFatal.Render("x")}, + {name: "default", typ: model.EventType("UnknownType"), want: theme.EventDefault.Render("x")}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := getEventColor(theme, tt.typ).Render("x") + if got != tt.want { + t.Fatalf("unexpected style for %s", tt.typ) + } + }) + } +} + +func TestViewAndPaneStructure(t *testing.T) { + t.Parallel() + + m := NewModel(make(chan model.Event), Controls{}) + + t.Run("view returns raw pane when unset size", func(t *testing.T) { + t.Parallel() + got := m.View() + if got != m.renderAppPane() { + t.Fatal("View should return raw pane when width/height are unset") + } + }) + + t.Run("renderAppPane line count matches height", func(t *testing.T) { + t.Parallel() + m2 := m + m2.width = 80 + m2.height = 12 + got := m2.renderAppPane() + if got == "height of request and details did not match" || got == "height of screen does not match terminal height" { + t.Fatalf("unexpected renderAppPane invariant error: %q", got) + } + if lines := strings.Count(got, "\n") + 1; lines != m2.height { + t.Fatalf("line count = %d, want %d", lines, m2.height) + } + }) + + t.Run("renderAppPane supports toggles", func(t *testing.T) { + t.Parallel() + m2 := m + m2.width = 90 + m2.height = 14 + m2.showEvents = true + m2.showStd = true + m2.showSearch = true + got := m2.renderAppPane() + if got == "height of request and details did not match" || got == "height of screen does not match terminal height" { + t.Fatalf("unexpected renderAppPane invariant error with toggles: %q", got) + } + }) + + t.Run("view applies configured terminal height", func(t *testing.T) { + t.Parallel() + m2 := m + m2.width = 70 + m2.height = 10 + + got := m2.View() + if lines := strings.Count(got, "\n") + 1; lines < m2.height { + t.Fatalf("View line count = %d, want at least %d", lines, m2.height) + } + }) +} + +func TestPaneRenderersAndStatusBar(t *testing.T) { + t.Parallel() + + m := NewModel(make(chan model.Event), Controls{}) + m.width = 100 + m.height = 12 + m.requests = []model.Request{ + {ID: uuid.New(), Method: "GET", Host: "a", URL: "/a", Status: 200, Duration: 5 * time.Millisecond}, + {ID: uuid.New(), Method: "POST", Host: "b", URL: "/b", Status: 500, Duration: 10 * time.Millisecond, Failed: true}, + } + m.events = []model.Event{ + {Type: model.EventTypeWarn, Body: "warn"}, + {Type: model.EventTypeProcessStdout, Body: "out"}, + {Type: model.EventTypeProcessStderr, Body: "err"}, + } + + status := m.renderStatusBar(100) + if !strings.Contains(status, "2 reqs") || !strings.Contains(status, "1 err") { + t.Fatalf("status bar missing expected counters: %q", status) + } + + search := m.renderSearchPane(20, 3) + if len(search) != 3 { + t.Fatalf("search pane len = %d, want 3", len(search)) + } + for i, line := range search { + if len(line) != 20 { + t.Fatalf("search pane line %d len = %d, want %d", i, len(line), 20) + } + } + + requests := m.renderRequestPane(50, 4) + if len(requests) != 4 { + t.Fatalf("request pane len = %d, want 4", len(requests)) + } + + details := m.renderDetailsPane(30, 4) + if len(details) != 4 { + t.Fatalf("details pane len = %d, want 4", len(details)) + } + + events := m.renderEventsPane(60, 4) + if len(events) != 4 { + t.Fatalf("events pane len = %d, want 4", len(events)) + } + if strings.Contains(strings.Join(events, "\n"), "out") || strings.Contains(strings.Join(events, "\n"), "err") { + t.Fatal("events pane should filter stdout/stderr events") + } + + std := m.renderStdPane(60, 4) + if len(std) != 4 { + t.Fatalf("std pane len = %d, want 4", len(std)) + } + joined := strings.Join(std, "\n") + if !strings.Contains(joined, "out") || !strings.Contains(joined, "err") { + t.Fatal("std pane should include stdout/stderr logs") + } +} + +func TestRenderEventsPane_ErrorAndPIDBranches(t *testing.T) { + t.Parallel() + + m := NewModel(make(chan model.Event), Controls{}) + m.events = []model.Event{ + {Type: model.EventTypeWarn, Body: "old"}, + {Type: model.EventTypeRequestFailed, Body: "failed body", PID: 123, Time: time.Now()}, + {Type: model.EventTypeFatal, Body: "fatal body", Time: time.Now()}, + } + + lines := m.renderEventsPane(60, 3) + if len(lines) != 3 { + t.Fatalf("events pane len = %d, want 3", len(lines)) + } + + joined := strings.Join(lines, "\n") + if !strings.Contains(joined, "123") { + t.Fatalf("expected PID to appear in events pane, got: %q", joined) + } + if !strings.Contains(joined, "failed body") { + t.Fatalf("expected failed body to appear in events pane, got: %q", joined) + } + if !strings.Contains(joined, "fatal body") { + t.Fatalf("expected fatal body to appear in events pane, got: %q", joined) + } +}