Merge branch 'feature/testing'

This commit is contained in:
Hayden Hargreaves 2026-04-23 19:47:58 -07:00
commit 51d526c2fe
28 changed files with 4976 additions and 29 deletions

45
TEST_COVERAGE_SUMMARY.md Normal file
View File

@ -0,0 +1,45 @@
# termtap Test Coverage Summary
Generated from:
```bash
go test -coverprofile=/tmp/termtap.cover ./...
go tool cover -func=/tmp/termtap.cover
```
## Package coverage snapshot
| Package | Coverage |
|---|---:|
| `cmd/tap` | 100.0% |
| `internal/app` | 98.1% |
| `internal/cli` | 93.4% |
| `internal/process` | 95.8% |
| `internal/proxy` | 90.0% |
| `internal/tui` | 96.2% |
| `examples/echo` | 0.0% (example app; intentionally not covered) |
| `internal/model` | no test files (pure data structs) |
Total statements in module: **77.9%**.
## Notable lower-coverage targets (production code)
- `internal/proxy/handlers.go:handleConnect` — 75.9%
- `internal/proxy/certs.go:writeFileAtomically` — 76.2%
- `internal/proxy/certs.go:load` — 90.5%
- `internal/proxy/certs.go:create` — 80.8%
- `internal/proxy/certs.go:IsTrustedBySystem` — 76.9%
- `internal/cli/run.go:runCert` — 87.5%
## Interpretation
- Core runtime paths (`internal/app`, `internal/process`, `internal/tui`) are high confidence.
- Proxy package has broad behavior coverage including HTTP and HTTPS MITM integration flow, and now clears 90% package coverage.
- CLI command routing and fatal/stdout/stderr seams are covered, including `Run` success/error branches.
- TUI pane rendering coverage now includes error/PID branch behavior.
## Next optional improvements
1. Add deeper CONNECT tunnel write/flush/read failure permutations inside `handleConnect` loop.
2. Add additional deterministic edge cases for `IsTrustedBySystem` non-unknown-authority verify failures.
3. Optionally add non-unix signal file coverage in CI matrix (currently unix-focused).

49
cmd/tap/main_test.go Normal file
View File

@ -0,0 +1,49 @@
package main
import (
"bytes"
"io"
"os"
"strings"
"testing"
"time"
)
func TestMain_SmokeInvokesCLIRun(t *testing.T) {
origArgs := os.Args
t.Cleanup(func() { os.Args = origArgs })
os.Args = []string{"tap", "invalid"}
origStderr := os.Stderr
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("stderr pipe error: %v", err)
}
t.Cleanup(func() {
os.Stderr = origStderr
_ = r.Close()
})
os.Stderr = w
outCh := make(chan string, 1)
go func() {
var buf bytes.Buffer
_, _ = io.Copy(&buf, r)
outCh <- buf.String()
}()
main()
_ = w.Close()
os.Stderr = origStderr
var got string
select {
case got = <-outCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for stderr capture")
}
if !strings.Contains(got, "usage:") {
t.Fatalf("stderr missing usage output, got: %q", got)
}
}

View File

@ -0,0 +1,109 @@
package app
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"termtap.dev/internal/model"
)
// NOTE: Run with -race; this validates cross-component concurrency.
func TestSessionIntegration_LifecycleAndRequestEvents(t *testing.T) {
addr := freeTCPAddr(t)
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
if err != nil {
t.Fatalf("StartSession() error = %v", err)
}
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
t.Cleanup(upstream.Close)
startupEvents := collectUntilTypes(t, s.Events, []model.EventType{
model.EventTypeProxyStarting,
model.EventTypeProcessStarting,
model.EventTypeProcessStarted,
}, 3*time.Second)
proxyURL, err := url.Parse(s.proxy.Url)
if err != nil {
t.Fatalf("url.Parse(proxy) error = %v", err)
}
client := &http.Client{
Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)},
Timeout: 3 * time.Second,
}
resp, err := client.Get(upstream.URL + "/session")
if err != nil {
t.Fatalf("proxy request error = %v", err)
}
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
}
requestEvents := collectUntilTypes(t, s.Events, []model.EventType{
model.EventTypeRequestStarted,
model.EventTypeRequestFinished,
}, 3*time.Second)
s.Stop()
select {
case <-s.proc.Done:
case <-time.After(4 * time.Second):
t.Fatal("timeout waiting for process stop")
}
shutdownEvents := collectUntilTypes(t, s.Events, []model.EventType{
model.EventTypeProxyStopped,
model.EventTypeProcessSignaled,
model.EventTypeProcessExited,
}, 4*time.Second)
if !isBefore(startupEvents, model.EventTypeProcessStarting, model.EventTypeProcessStarted) {
t.Fatalf("expected %s before %s in startup events: %#v", model.EventTypeProcessStarting, model.EventTypeProcessStarted, startupEvents)
}
if !isBefore(requestEvents, model.EventTypeRequestStarted, model.EventTypeRequestFinished) {
t.Fatalf("expected %s before %s in request events: %#v", model.EventTypeRequestStarted, model.EventTypeRequestFinished, requestEvents)
}
if !isBefore(shutdownEvents, model.EventTypeProcessSignaled, model.EventTypeProcessExited) {
t.Fatalf("expected %s before %s in shutdown events: %#v", model.EventTypeProcessSignaled, model.EventTypeProcessExited, shutdownEvents)
}
}
func TestSessionIntegration_RestartProcessEmitsLifecycleEvents(t *testing.T) {
addr := freeTCPAddr(t)
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
if err != nil {
t.Fatalf("StartSession() error = %v", err)
}
t.Cleanup(func() { s.Stop() })
_ = collectUntilTypes(t, s.Events, []model.EventType{
model.EventTypeProcessStarted,
model.EventTypeProxyStarting,
}, 3*time.Second)
if err := s.RestartProcess(); err != nil {
t.Fatalf("RestartProcess() error = %v", err)
}
events := collectUntilTypes(t, s.Events, []model.EventType{
model.EventTypeProcessRestarting,
model.EventTypeProcessSignaled,
model.EventTypeProcessExited,
model.EventTypeProcessStarting,
model.EventTypeProcessStarted,
}, 4*time.Second)
if !isBefore(events, model.EventTypeProcessRestarting, model.EventTypeProcessStarted) {
t.Fatalf("expected restarting before process started, got %#v", events)
}
}

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"os/exec"
"sync"
"syscall"
"time"
@ -11,6 +12,10 @@ import (
"termtap.dev/internal/process"
)
var killEscalationDelay = 1500 * time.Millisecond
var scheduleKillEscalation = time.AfterFunc
var killEscalationMu sync.RWMutex
func StartProcess(cmd model.Command, addr string, ch chan<- model.Event) (*model.Process, error) {
ch <- model.Event{
Time: time.Now().Local(),
@ -44,12 +49,16 @@ func StopProcess(proc *model.Process, ch chan<- model.Event, sig syscall.Signal)
_ = process.SignalProcess(proc.Exec, sig)
go func() {
time.Sleep(1500 * time.Millisecond)
killEscalationMu.RLock()
delay := killEscalationDelay
scheduler := scheduleKillEscalation
killEscalationMu.RUnlock()
scheduler(delay, func() {
if process.ProcessAlive(proc.Exec) {
_ = process.SignalProcess(proc.Exec, syscall.SIGKILL)
}
}()
})
}
func waitForProcessExit(proc *model.Process, ch chan<- model.Event) {

View File

@ -0,0 +1,223 @@
package app
import (
"os/exec"
"syscall"
"testing"
"time"
"termtap.dev/internal/model"
)
func TestStartProcess(t *testing.T) {
t.Parallel()
t.Run("starts process and marks running", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 32)
proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 0.2"}}, "127.0.0.1:8080", ch)
if err != nil {
t.Fatalf("StartProcess() error = %v", err)
}
t.Cleanup(func() {
StopProcess(proc, ch, syscall.SIGTERM)
select {
case <-proc.Done:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting process done in cleanup")
}
})
if proc == nil || proc.Exec == nil {
t.Fatal("StartProcess() returned nil process/exec")
}
events := drainEvents(t, ch, 2, time.Second)
if !hasType(events, model.EventTypeProcessStarting) {
t.Fatalf("missing %s event", model.EventTypeProcessStarting)
}
if !hasType(events, model.EventTypeProcessStarted) {
t.Fatalf("missing %s event", model.EventTypeProcessStarted)
}
})
t.Run("returns error when exec start fails", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
proc, err := StartProcess(model.Command{Name: "definitely-not-a-real-command"}, "127.0.0.1:8080", ch)
if err == nil {
if proc != nil && proc.Exec != nil && proc.Exec.Process != nil {
_ = proc.Exec.Process.Kill()
}
t.Fatal("StartProcess() error = nil, want non-nil")
}
events := drainEvents(t, ch, 1, time.Second)
if !hasType(events, model.EventTypeProcessStarting) {
t.Fatalf("missing %s event", model.EventTypeProcessStarting)
}
})
}
func TestStopProcess(t *testing.T) {
t.Parallel()
t.Run("nil guards", func(t *testing.T) {
t.Parallel()
StopProcess(nil, make(chan model.Event, 1), syscall.SIGTERM)
StopProcess(&model.Process{}, make(chan model.Event, 1), syscall.SIGTERM)
StopProcess(&model.Process{Exec: &exec.Cmd{}}, make(chan model.Event, 1), syscall.SIGTERM)
})
t.Run("emits signaled event", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 32)
proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, "127.0.0.1:8080", ch)
if err != nil {
t.Fatalf("StartProcess() error = %v", err)
}
StopProcess(proc, ch, syscall.SIGTERM)
if _, ok := waitForEventType(t, ch, model.EventTypeProcessSignaled, 2*time.Second); !ok {
t.Fatalf("did not receive %s event", model.EventTypeProcessSignaled)
}
select {
case <-proc.Done:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for process to exit after signal")
}
})
t.Run("schedules deterministic kill escalation hook", func(t *testing.T) {
t.Parallel()
origDelay := killEscalationDelay
origScheduler := scheduleKillEscalation
defer func() {
killEscalationMu.Lock()
killEscalationDelay = origDelay
scheduleKillEscalation = origScheduler
killEscalationMu.Unlock()
}()
ch := make(chan model.Event, 32)
proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, "127.0.0.1:8080", ch)
if err != nil {
t.Fatalf("StartProcess() error = %v", err)
}
killEscalationMu.Lock()
killEscalationDelay = 25 * time.Millisecond
scheduled := make(chan time.Duration, 1)
scheduleKillEscalation = func(d time.Duration, fn func()) *time.Timer {
scheduled <- d
go fn()
return nil
}
killEscalationMu.Unlock()
StopProcess(proc, ch, syscall.SIGTERM)
select {
case d := <-scheduled:
if d != killEscalationDelay {
t.Fatalf("scheduled delay = %v, want %v", d, killEscalationDelay)
}
case <-time.After(time.Second):
t.Fatal("kill escalation was not scheduled")
}
select {
case <-proc.Done:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for process to exit after deterministic escalation")
}
})
}
func TestWaitForProcessExit(t *testing.T) {
t.Parallel()
t.Run("nil guards are no-op", func(t *testing.T) {
t.Parallel()
waitForProcessExit(nil, make(chan model.Event, 1))
waitForProcessExit(&model.Process{}, make(chan model.Event, 1))
})
t.Run("normal exit emits process exited and closes done", func(t *testing.T) {
t.Parallel()
cmd := exec.Command("sh", "-c", "exit 0")
if err := cmd.Start(); err != nil {
t.Fatalf("cmd.Start() error = %v", err)
}
proc := &model.Process{Exec: cmd, Running: true, Done: make(chan struct{})}
ch := make(chan model.Event, 8)
waitForProcessExit(proc, ch)
if _, ok := waitForEventType(t, ch, model.EventTypeProcessExited, time.Second); !ok {
t.Fatalf("did not receive %s event", model.EventTypeProcessExited)
}
select {
case <-proc.Done:
case <-time.After(time.Second):
t.Fatal("Done channel was not closed")
}
})
t.Run("exit error path carries exit code", func(t *testing.T) {
t.Parallel()
cmd := exec.Command("sh", "-c", "exit 7")
if err := cmd.Start(); err != nil {
t.Fatalf("cmd.Start() error = %v", err)
}
proc := &model.Process{Exec: cmd, Running: true, Done: make(chan struct{})}
ch := make(chan model.Event, 8)
waitForProcessExit(proc, ch)
events := drainEvents(t, ch, 2, time.Second)
found := false
for _, ev := range events {
if ev.Type == model.EventTypeProcessExited && ev.ExitCode == 7 {
found = true
break
}
}
if !found {
t.Fatalf("expected ProcessExited with exit code 7, got %#v", events)
}
select {
case <-proc.Done:
case <-time.After(time.Second):
t.Fatal("Done channel was not closed")
}
})
t.Run("unexpected wait failure emits fatal", func(t *testing.T) {
t.Parallel()
proc := &model.Process{Exec: &exec.Cmd{}, Done: make(chan struct{})}
ch := make(chan model.Event, 8)
waitForProcessExit(proc, ch)
events := drainEvents(t, ch, 1, time.Second)
if events[0].Type != model.EventTypeFatal {
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeFatal)
}
select {
case <-proc.Done:
case <-time.After(time.Second):
t.Fatal("Done channel was not closed")
}
})
}

183
internal/app/proxy_test.go Normal file
View File

@ -0,0 +1,183 @@
package app
import (
"context"
"errors"
"net"
"net/http"
"testing"
"time"
"termtap.dev/internal/model"
)
type staticErrListener struct {
err error
}
func (l *staticErrListener) Accept() (net.Conn, error) { return nil, l.err }
func (l *staticErrListener) Close() error { return nil }
func (l *staticErrListener) Addr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080}
}
func TestStartProxy_NilGuards(t *testing.T) {
t.Parallel()
StartProxy(nil, make(chan model.Event, 1))
StartProxy(&model.ProxyServer{}, make(chan model.Event, 1))
StartProxy(&model.ProxyServer{Server: &http.Server{}}, make(chan model.Event, 1))
}
func TestStartProxy_EmitsStartingAndWarnWhenUntrustedCA(t *testing.T) {
t.Parallel()
listenErr := errors.New("accept failed")
var ln net.Listener = &staticErrListener{err: listenErr}
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{
Listener: &ln,
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
CAReady: true,
CATrusted: false,
CACreated: true,
CACertPath: "/tmp/test-ca.pem",
}
StartProxy(ps, ch)
events := drainEvents(t, ch, 3, time.Second)
if !hasType(events, model.EventTypeProxyStarting) {
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
}
if !hasType(events, model.EventTypeWarn) {
t.Fatalf("missing %s event", model.EventTypeWarn)
}
if !containsBody(events, "generated HTTPS interception CA") {
t.Fatalf("expected generated-CA warning body, got events: %#v", events)
}
if !hasType(events, model.EventTypeFatal) {
t.Fatalf("missing %s event", model.EventTypeFatal)
}
}
func TestStartProxy_WarnBodyForExistingUntrustedCA(t *testing.T) {
t.Parallel()
listenErr := errors.New("accept failed")
var ln net.Listener = &staticErrListener{err: listenErr}
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{
Listener: &ln,
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
CAReady: true,
CATrusted: false,
CACreated: false,
CACertPath: "/tmp/test-ca.pem",
}
StartProxy(ps, ch)
events := drainEvents(t, ch, 3, time.Second)
if !containsBody(events, "HTTPS interception CA available at") {
t.Fatalf("expected existing-CA warning body, got events: %#v", events)
}
}
func TestStartProxy_NoCAWarningWhenNotReady(t *testing.T) {
t.Parallel()
listenErr := errors.New("accept failed")
var ln net.Listener = &staticErrListener{err: listenErr}
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{
Listener: &ln,
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
CAReady: false,
CATrusted: false,
}
StartProxy(ps, ch)
events := drainEvents(t, ch, 2, time.Second)
if !hasType(events, model.EventTypeProxyStarting) {
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
}
if hasType(events, model.EventTypeWarn) {
t.Fatalf("unexpected %s event when CA is not ready: %#v", model.EventTypeWarn, events)
}
if !hasType(events, model.EventTypeFatal) {
t.Fatalf("missing %s event", model.EventTypeFatal)
}
}
func TestStartProxy_NoCAWarningWhenTrusted(t *testing.T) {
t.Parallel()
listenErr := errors.New("accept failed")
var ln net.Listener = &staticErrListener{err: listenErr}
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{
Listener: &ln,
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
CAReady: true,
CATrusted: true,
}
StartProxy(ps, ch)
events := drainEvents(t, ch, 2, time.Second)
if !hasType(events, model.EventTypeProxyStarting) {
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
}
if hasType(events, model.EventTypeWarn) {
t.Fatalf("unexpected %s event when CA is already trusted: %#v", model.EventTypeWarn, events)
}
if !hasType(events, model.EventTypeFatal) {
t.Fatalf("missing %s event", model.EventTypeFatal)
}
}
func TestStartProxy_SwallowsErrServerClosed(t *testing.T) {
t.Parallel()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen error = %v", err)
}
t.Cleanup(func() { _ = ln.Close() })
ch := make(chan model.Event, 8)
server := &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})}
ps := &model.ProxyServer{Listener: &ln, Server: server}
done := make(chan struct{})
go func() {
StartProxy(ps, ch)
close(done)
}()
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
t.Fatalf("shutdown error = %v", err)
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout waiting for StartProxy to return")
}
events := drainEvents(t, ch, 1, time.Second)
if !hasType(events, model.EventTypeProxyStarting) {
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
}
if hasType(events, model.EventTypeFatal) {
t.Fatalf("unexpected %s event", model.EventTypeFatal)
}
}

View File

@ -0,0 +1,247 @@
package app
import (
"errors"
"net"
"syscall"
"testing"
"time"
"termtap.dev/internal/model"
)
func TestStartSession(t *testing.T) {
t.Run("happy path creates proxy and process", func(t *testing.T) {
addr := freeTCPAddr(t)
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
if err != nil {
t.Fatalf("StartSession() error = %v", err)
}
if s == nil {
t.Fatal("StartSession() returned nil session")
}
if s.proxy == nil {
t.Fatal("session.proxy is nil")
}
if s.proc == nil {
t.Fatal("session.proc is nil")
}
s.Stop()
if s.proc != nil && s.proc.Done != nil {
select {
case <-s.proc.Done:
case <-time.After(4 * time.Second):
t.Fatal("timeout waiting for process stop")
}
}
})
t.Run("error when proxy creation fails", func(t *testing.T) {
t.Setenv("XDG_CONFIG_HOME", "")
t.Setenv("HOME", "")
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "true"}}, "127.0.0.1:0")
if err == nil {
if s != nil {
s.Stop()
}
t.Fatal("StartSession() error = nil, want non-nil")
}
if s != nil {
t.Fatalf("session = %#v, want nil", s)
}
})
t.Run("destroys proxy when process startup fails", func(t *testing.T) {
addr := freeTCPAddr(t)
s, err := StartSession(model.Command{Name: "definitely-not-a-real-command"}, addr)
if err == nil {
if s != nil {
s.Stop()
}
t.Fatal("StartSession() error = nil, want non-nil")
}
deadline := time.After(3 * time.Second)
ticker := time.NewTicker(25 * time.Millisecond)
defer ticker.Stop()
for {
ln, listenErr := net.Listen("tcp", addr)
if listenErr == nil {
_ = ln.Close()
break
}
select {
case <-deadline:
t.Fatalf("address %s did not become reusable, got err: %v", addr, listenErr)
case <-ticker.C:
}
}
})
t.Run("stop during restart stops new process and returns ErrSessionStopped", func(t *testing.T) {
s := &Session{
ch: make(chan model.Event, 64),
cmd: model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}},
addr: "127.0.0.1:0",
proc: &model.Process{Done: make(chan struct{})},
}
close(s.proc.Done)
s.restartMu.Lock()
errCh := make(chan error, 1)
stopDone := make(chan struct{})
go func() {
errCh <- s.RestartProcess()
}()
go func() {
s.Stop()
close(stopDone)
}()
s.restartMu.Unlock()
select {
case err := <-errCh:
if !errors.Is(err, ErrSessionStopped) {
t.Fatalf("error = %v, want %v", err, ErrSessionStopped)
}
case <-time.After(6 * time.Second):
t.Fatal("timeout waiting for RestartProcess result")
}
select {
case <-stopDone:
case <-time.After(time.Second):
t.Fatal("timeout waiting for Stop completion")
}
})
}
func TestSessionStop_Idempotent(t *testing.T) {
addr := freeTCPAddr(t)
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
if err != nil {
t.Fatalf("StartSession() error = %v", err)
}
s.Stop()
s.Stop()
if s.proc != nil && s.proc.Done != nil {
select {
case <-s.proc.Done:
case <-time.After(4 * time.Second):
t.Fatal("timeout waiting for process to stop")
}
}
}
func TestRestartProcess(t *testing.T) {
t.Run("nil session returns error", func(t *testing.T) {
var s *Session
err := s.RestartProcess()
if err == nil {
t.Fatal("RestartProcess() error = nil, want non-nil")
}
})
t.Run("stopped session returns ErrSessionStopped", func(t *testing.T) {
s := &Session{stopped: true}
err := s.RestartProcess()
if !errors.Is(err, ErrSessionStopped) {
t.Fatalf("error = %v, want %v", err, ErrSessionStopped)
}
})
t.Run("concurrent restart returns ErrRestartInProgress", func(t *testing.T) {
s := &Session{restarting: true}
err := s.RestartProcess()
if !errors.Is(err, ErrRestartInProgress) {
t.Fatalf("error = %v, want %v", err, ErrRestartInProgress)
}
})
t.Run("timeout waiting for stop returns error", func(t *testing.T) {
s := &Session{
ch: make(chan model.Event, 8),
cmd: model.Command{Name: "sh", Args: []string{"-c", "true"}},
addr: "127.0.0.1:0",
proc: &model.Process{Done: make(chan struct{})},
}
err := s.RestartProcess()
if err == nil {
t.Fatal("RestartProcess() error = nil, want non-nil")
}
})
t.Run("successful restart swaps process", func(t *testing.T) {
addr := freeTCPAddr(t)
oldProc := &model.Process{Done: make(chan struct{})}
close(oldProc.Done)
s := &Session{
ch: make(chan model.Event, 32),
cmd: model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}},
addr: addr,
proc: oldProc,
}
err := s.RestartProcess()
if err != nil {
t.Fatalf("RestartProcess() error = %v", err)
}
if _, ok := waitForEventType(t, s.ch, model.EventTypeProcessRestarting, time.Second); !ok {
t.Fatalf("did not receive %s event", model.EventTypeProcessRestarting)
}
if s.proc == nil {
t.Fatal("session process is nil after restart")
}
if s.proc == oldProc {
t.Fatal("session process pointer was not swapped")
}
StopProcess(s.proc, s.ch, syscall.SIGTERM)
select {
case <-s.proc.Done:
case <-time.After(4 * time.Second):
t.Fatal("timeout waiting for restarted process to stop")
}
})
}
func TestWaitForProcessStop(t *testing.T) {
t.Parallel()
t.Run("nil process and nil done are true", func(t *testing.T) {
t.Parallel()
if !waitForProcessStop(nil, time.Millisecond) {
t.Fatal("waitForProcessStop(nil) = false, want true")
}
if !waitForProcessStop(&model.Process{}, time.Millisecond) {
t.Fatal("waitForProcessStop(process with nil done) = false, want true")
}
})
t.Run("done channel closes before timeout", func(t *testing.T) {
t.Parallel()
proc := &model.Process{Done: make(chan struct{})}
close(proc.Done)
if !waitForProcessStop(proc, time.Second) {
t.Fatal("waitForProcessStop() = false, want true")
}
})
t.Run("timeout returns false", func(t *testing.T) {
t.Parallel()
proc := &model.Process{Done: make(chan struct{})}
if waitForProcessStop(proc, 20*time.Millisecond) {
t.Fatal("waitForProcessStop() = true, want false")
}
})
}

View File

@ -0,0 +1,111 @@
package app
import (
"net"
"strings"
"testing"
"time"
"termtap.dev/internal/model"
)
func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event {
t.Helper()
out := make([]model.Event, 0, n)
deadline := time.After(timeout)
for len(out) < n {
select {
case ev := <-ch:
out = append(out, ev)
case <-deadline:
t.Fatalf("timeout waiting for %d events, got %d", n, len(out))
}
}
return out
}
func hasType(events []model.Event, typ model.EventType) bool {
for _, ev := range events {
if ev.Type == typ {
return true
}
}
return false
}
func containsBody(events []model.Event, part string) bool {
for _, ev := range events {
if strings.Contains(ev.Body, part) {
return true
}
}
return false
}
func waitForEventType(t *testing.T, ch <-chan model.Event, typ model.EventType, timeout time.Duration) (model.Event, bool) {
t.Helper()
deadline := time.After(timeout)
for {
select {
case ev := <-ch:
if ev.Type == typ {
return ev, true
}
case <-deadline:
return model.Event{}, false
}
}
}
func collectUntilTypes(t *testing.T, ch <-chan model.Event, required []model.EventType, timeout time.Duration) []model.Event {
t.Helper()
need := make(map[model.EventType]bool, len(required))
for _, typ := range required {
need[typ] = true
}
events := make([]model.Event, 0, len(required)+8)
deadline := time.After(timeout)
for len(need) > 0 {
select {
case ev := <-ch:
events = append(events, ev)
delete(need, ev.Type)
case <-deadline:
t.Fatalf("timeout waiting for required events: remaining=%v, collected=%#v", need, events)
}
}
return events
}
func isBefore(events []model.Event, first, second model.EventType) bool {
firstIdx := -1
secondIdx := -1
for i, ev := range events {
if ev.Type == first && firstIdx == -1 {
firstIdx = i
}
if ev.Type == second && secondIdx == -1 {
secondIdx = i
}
}
return firstIdx >= 0 && secondIdx >= 0 && firstIdx < secondIdx
}
func freeTCPAddr(t *testing.T) string {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen error = %v", err)
}
addr := ln.Addr().String()
_ = ln.Close()
return addr
}

View File

@ -2,6 +2,7 @@ package cli
import (
"fmt"
"io"
"log"
"os"
"runtime"
@ -16,6 +17,23 @@ import (
// This should be configurable at some point, just in case they build on 8080
const proxy_addr = "127.0.0.1:8080"
var fatalExit = log.Fatalln
var stdoutWriter io.Writer = stdioRef{isErr: false}
var stderrWriter io.Writer = stdioRef{isErr: true}
var startSessionFn = app.StartSession
var runTUIFn = tui.Run
type stdioRef struct {
isErr bool
}
func (w stdioRef) Write(p []byte) (int, error) {
if w.isErr {
return os.Stderr.Write(p)
}
return os.Stdout.Write(p)
}
func Run(args []string) {
if len(args) >= 2 && args[1] == "cert" {
runCert()
@ -28,9 +46,10 @@ func Run(args []string) {
return
}
session, err := app.StartSession(cmd, proxy_addr)
session, err := startSessionFn(cmd, proxy_addr)
if err != nil {
log.Fatalln(err)
fatalExit(err)
return
}
defer session.Stop()
@ -38,8 +57,9 @@ func Run(args []string) {
Restart: session.RestartProcess,
}
if err := tui.Run(session.Events, controls); err != nil {
log.Fatalln(err)
if err := runTUIFn(session.Events, controls); err != nil {
fatalExit(err)
return
}
}
@ -67,49 +87,50 @@ usage:
tap run -- <command> [args...]
`
fmt.Fprintln(os.Stderr, helpText)
fmt.Fprintln(stderrWriter, helpText)
}
func runCert() {
ca, err := proxy.EnsureCertificateAuthority()
if err != nil {
log.Fatalln(err)
fatalExit(err)
return
}
certPath := ca.CertPath()
quotedCertPath := strconv.Quote(certPath)
fmt.Printf("Certificate path: %s\n", certPath)
fmt.Fprintf(stdoutWriter, "Certificate path: %s\n", certPath)
if ca.WasCreated() {
fmt.Println("Created a new local HTTPS interception CA.")
fmt.Fprintln(stdoutWriter, "Created a new local HTTPS interception CA.")
} else {
fmt.Println("Using existing local HTTPS interception CA.")
fmt.Fprintln(stdoutWriter, "Using existing local HTTPS interception CA.")
}
trusted, err := ca.IsTrustedBySystem()
if err != nil {
fmt.Printf("System trust check failed: %v\n", err)
fmt.Fprintf(stdoutWriter, "System trust check failed: %v\n", err)
} else if trusted {
fmt.Println("System trust store: trusted")
fmt.Fprintln(stdoutWriter, "System trust store: trusted")
} else {
fmt.Println("System trust store: not trusted")
fmt.Fprintln(stdoutWriter, "System trust store: not trusted")
}
if runtime.GOOS != "linux" {
fmt.Println("Install this certificate into your OS or client trust store to inspect HTTPS traffic.")
fmt.Fprintln(stdoutWriter, "Install this certificate into your OS or client trust store to inspect HTTPS traffic.")
return
}
fmt.Println()
fmt.Println("Trust instructions (Linux):")
fmt.Println("Debian/Ubuntu:")
fmt.Printf(" sudo cp %s /usr/local/share/ca-certificates/termtap.crt\n", quotedCertPath)
fmt.Println(" sudo update-ca-certificates")
fmt.Println("Fedora/RHEL/CentOS:")
fmt.Printf(" sudo cp %s /etc/pki/ca-trust/source/anchors/termtap.crt\n", quotedCertPath)
fmt.Println(" sudo update-ca-trust")
fmt.Println("Arch:")
fmt.Printf(" sudo trust anchor %s\n", quotedCertPath)
fmt.Println()
fmt.Println("Quick curl test:")
fmt.Printf(" curl --proxy http://%s --cacert %s https://example.com\n", proxy_addr, quotedCertPath)
fmt.Fprintln(stdoutWriter)
fmt.Fprintln(stdoutWriter, "Trust instructions (Linux):")
fmt.Fprintln(stdoutWriter, "Debian/Ubuntu:")
fmt.Fprintf(stdoutWriter, " sudo cp %s /usr/local/share/ca-certificates/termtap.crt\n", quotedCertPath)
fmt.Fprintln(stdoutWriter, " sudo update-ca-certificates")
fmt.Fprintln(stdoutWriter, "Fedora/RHEL/CentOS:")
fmt.Fprintf(stdoutWriter, " sudo cp %s /etc/pki/ca-trust/source/anchors/termtap.crt\n", quotedCertPath)
fmt.Fprintln(stdoutWriter, " sudo update-ca-trust")
fmt.Fprintln(stdoutWriter, "Arch:")
fmt.Fprintf(stdoutWriter, " sudo trust anchor %s\n", quotedCertPath)
fmt.Fprintln(stdoutWriter)
fmt.Fprintln(stdoutWriter, "Quick curl test:")
fmt.Fprintf(stdoutWriter, " curl --proxy http://%s --cacert %s https://example.com\n", proxy_addr, quotedCertPath)
}

356
internal/cli/run_test.go Normal file
View File

@ -0,0 +1,356 @@
package cli
import (
"bytes"
"errors"
"io"
"os"
"runtime"
"strings"
"testing"
"time"
"termtap.dev/internal/app"
"termtap.dev/internal/model"
"termtap.dev/internal/tui"
)
func TestParseCommand(t *testing.T) {
tests := []struct {
name string
args []string
ok bool
nameWant string
argsWant []string
}{
{name: "too few args", args: []string{"tap"}, ok: false},
{name: "missing run token", args: []string{"tap", "oops", "--", "echo"}, ok: false},
{name: "missing separator", args: []string{"tap", "run", "echo"}, ok: false},
{name: "single command", args: []string{"tap", "run", "--", "echo"}, ok: true, nameWant: "echo", argsWant: []string{}},
{name: "command with args", args: []string{"tap", "run", "--", "curl", "-s", "https://example.com"}, ok: true, nameWant: "curl", argsWant: []string{"-s", "https://example.com"}},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
cmd, ok := parseCommand(tt.args)
if ok != tt.ok {
t.Fatalf("ok = %v, want %v", ok, tt.ok)
}
if !tt.ok {
return
}
if cmd.Name != tt.nameWant {
t.Fatalf("cmd.Name = %q, want %q", cmd.Name, tt.nameWant)
}
if strings.Join(cmd.Args, "|") != strings.Join(tt.argsWant, "|") {
t.Fatalf("cmd.Args = %#v, want %#v", cmd.Args, tt.argsWant)
}
})
}
}
func TestDisplayHelpWritesToStderr(t *testing.T) {
_, stderr := captureOutput(t, func() {
displayHelp()
})
if !strings.Contains(stderr, "tap cert") || !strings.Contains(stderr, "tap run --") {
t.Fatalf("stderr missing usage text: %q", stderr)
}
}
func TestRun_InvalidCommandShowsHelp(t *testing.T) {
_, stderr := captureOutput(t, func() {
Run([]string{"tap", "wat"})
})
if !strings.Contains(stderr, "usage:") {
t.Fatalf("stderr missing usage output: %q", stderr)
}
}
func TestRun_RoutesCertCommand(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
stdout, _ := captureOutput(t, func() {
Run([]string{"tap", "cert"})
})
if !strings.Contains(stdout, "Certificate path:") {
t.Fatalf("stdout missing certificate path output: %q", stdout)
}
}
func TestRun_StartSessionFailureCallsFatalExit(t *testing.T) {
restore := installRunSeams(t)
defer restore()
startSessionFn = func(model.Command, string) (*app.Session, error) {
return nil, errors.New("boom")
}
called := installFatalSpy(t)
Run([]string{"tap", "run", "--", "definitely-not-a-real-command"})
if !*called {
t.Fatal("expected fatalExit to be called on StartSession failure")
}
}
func TestRun_TUIFailureCallsFatalExit(t *testing.T) {
restore := installRunSeams(t)
defer restore()
startSessionFn = func(model.Command, string) (*app.Session, error) {
return &app.Session{Events: make(chan model.Event)}, nil
}
runTUIFn = func(<-chan model.Event, tui.Controls) error {
return errors.New("tui failed")
}
called := installFatalSpy(t)
Run([]string{"tap", "run", "--", "echo"})
if !*called {
t.Fatal("expected fatalExit to be called on tui failure")
}
}
func TestRun_SuccessPathDoesNotCallFatal(t *testing.T) {
restore := installRunSeams(t)
defer restore()
startSessionFn = func(model.Command, string) (*app.Session, error) {
return &app.Session{Events: make(chan model.Event)}, nil
}
runTUIFn = func(<-chan model.Event, tui.Controls) error {
return nil
}
called := installFatalSpy(t)
Run([]string{"tap", "run", "--", "echo"})
if *called {
t.Fatal("fatalExit should not be called on success path")
}
}
func TestRunCert_EnsureCAFailureCallsFatalExit(t *testing.T) {
t.Setenv("XDG_CONFIG_HOME", "")
t.Setenv("HOME", "")
called := installFatalSpy(t)
runCert()
if !*called {
t.Fatal("expected fatalExit to be called when EnsureCertificateAuthority fails")
}
}
func TestRunCertOutputContract(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
stdout, _ := captureOutput(t, func() {
runCert()
})
if !strings.Contains(stdout, "Certificate path:") {
t.Fatalf("stdout missing certificate path line: %q", stdout)
}
if !strings.Contains(stdout, "local HTTPS interception CA") {
t.Fatalf("stdout missing CA create/existing line: %q", stdout)
}
if !strings.Contains(stdout, "System trust store:") && !strings.Contains(stdout, "System trust check failed:") {
t.Fatalf("stdout missing trust check line: %q", stdout)
}
if runtime.GOOS == "linux" {
if !strings.Contains(stdout, "Trust instructions (Linux):") {
t.Fatalf("stdout missing linux trust instructions: %q", stdout)
}
}
}
func TestRunCert_CreatedThenExistingMessage(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
firstOut, _ := captureOutput(t, func() {
runCert()
})
if !strings.Contains(firstOut, "Created a new local HTTPS interception CA.") {
t.Fatalf("first run should indicate created CA, got: %q", firstOut)
}
secondOut, _ := captureOutput(t, func() {
runCert()
})
if !strings.Contains(secondOut, "Using existing local HTTPS interception CA.") {
t.Fatalf("second run should indicate existing CA, got: %q", secondOut)
}
}
func captureOutput(t *testing.T, fn func()) (stdout string, stderr string) {
t.Helper()
origStdoutWriter := stdoutWriter
origStderrWriter := stderrWriter
t.Cleanup(func() {
stdoutWriter = origStdoutWriter
stderrWriter = origStderrWriter
})
origStdout := os.Stdout
origStderr := os.Stderr
outR, outW, err := os.Pipe()
if err != nil {
t.Fatalf("stdout pipe error: %v", err)
}
errR, errW, err := os.Pipe()
if err != nil {
_ = outR.Close()
_ = outW.Close()
t.Fatalf("stderr pipe error: %v", err)
}
os.Stdout = outW
os.Stderr = errW
stdoutWriter = outW
stderrWriter = errW
outCh := make(chan string, 1)
errCh := make(chan string, 1)
go func() {
var buf bytes.Buffer
_, _ = io.Copy(&buf, outR)
outCh <- buf.String()
}()
go func() {
var buf bytes.Buffer
_, _ = io.Copy(&buf, errR)
errCh <- buf.String()
}()
fn()
_ = outW.Close()
_ = errW.Close()
stdoutWriter = origStdoutWriter
stderrWriter = origStderrWriter
os.Stdout = origStdout
os.Stderr = origStderr
select {
case stdout = <-outCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for stdout capture")
}
select {
case stderr = <-errCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for stderr capture")
}
_ = outR.Close()
_ = errR.Close()
return stdout, stderr
}
func TestDisplayHelp_UsesInjectedStderrWriter(t *testing.T) {
var buf bytes.Buffer
orig := stderrWriter
t.Cleanup(func() { stderrWriter = orig })
stderrWriter = &buf
displayHelp()
if got := buf.String(); !strings.Contains(got, "usage:") {
t.Fatalf("help output missing usage, got: %q", got)
}
}
func TestRunCert_UsesInjectedStdoutWriter(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
var buf bytes.Buffer
orig := stdoutWriter
t.Cleanup(func() { stdoutWriter = orig })
stdoutWriter = &buf
runCert()
if got := buf.String(); !strings.Contains(got, "Certificate path:") {
t.Fatalf("cert output missing path line, got: %q", got)
}
}
func installRunSeams(t *testing.T) func() {
t.Helper()
origStartSession := startSessionFn
origRunTUI := runTUIFn
return func() {
startSessionFn = origStartSession
runTUIFn = origRunTUI
}
}
func installFatalSpy(t *testing.T) *bool {
t.Helper()
origFatal := fatalExit
called := false
fatalExit = func(v ...any) {
called = true
}
t.Cleanup(func() {
fatalExit = origFatal
})
return &called
}
func TestStdioRefWrite(t *testing.T) {
t.Run("writes to stdout", func(t *testing.T) {
assertStdioRefWrite(t, false, "hello")
})
t.Run("writes to stderr", func(t *testing.T) {
assertStdioRefWrite(t, true, "boom")
})
}
func assertStdioRefWrite(t *testing.T, isErr bool, payload string) {
t.Helper()
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("pipe error: %v", err)
}
defer func() { _ = r.Close() }()
if isErr {
orig := os.Stderr
os.Stderr = w
defer func() { os.Stderr = orig }()
} else {
orig := os.Stdout
os.Stdout = w
defer func() { os.Stdout = orig }()
}
if _, err := (stdioRef{isErr: isErr}).Write([]byte(payload)); err != nil {
_ = w.Close()
t.Fatalf("stdioRef write error: %v", err)
}
_ = w.Close()
data, err := io.ReadAll(r)
if err != nil {
t.Fatalf("ReadAll(pipe) error: %v", err)
}
if got := string(data); got != payload {
t.Fatalf("pipe write = %q, want %q", got, payload)
}
}

View File

@ -0,0 +1,219 @@
package process
import (
"os"
"os/exec"
"reflect"
"strings"
"testing"
"time"
"termtap.dev/internal/model"
)
func TestCommandString(t *testing.T) {
t.Parallel()
tests := []struct {
name string
cmd model.Command
want string
}{
{
name: "empty args",
cmd: model.Command{Name: "go", Args: []string{}},
want: "go ",
},
{
name: "multiple args",
cmd: model.Command{Name: "curl", Args: []string{"-s", "https://example.com"}},
want: "curl -s https://example.com",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := CommandString(tt.cmd); got != tt.want {
t.Fatalf("CommandString() = %q, want %q", got, tt.want)
}
})
}
}
func TestInjectEnv(t *testing.T) {
t.Parallel()
cmd := exec.Command("sh", "-c", "true")
injectEnv(cmd, "127.0.0.1:8080")
mustContain := []string{
"HTTP_PROXY=http://127.0.0.1:8080",
"http_proxy=http://127.0.0.1:8080",
"HTTPS_PROXY=http://127.0.0.1:8080",
"https_proxy=http://127.0.0.1:8080",
"NO_PROXY=",
"no_proxy=",
}
for _, kv := range mustContain {
if !containsEnvEntry(cmd.Env, kv) {
t.Fatalf("injectEnv() missing env entry %q", kv)
}
}
}
func TestReadPipe(t *testing.T) {
t.Parallel()
t.Run("stdout lines emit stdout events", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 4)
input := strings.NewReader("line1\nline2\n")
readPipe(input, model.EventTypeProcessStdout, ch)
events := drainEvents(t, ch, 2, time.Second)
if events[0].Type != model.EventTypeProcessStdout || events[0].Body != "line1" {
t.Fatalf("event[0] = %#v, want stdout line1", events[0])
}
if events[1].Type != model.EventTypeProcessStdout || events[1].Body != "line2" {
t.Fatalf("event[1] = %#v, want stdout line2", events[1])
}
})
t.Run("stderr lines emit stderr events", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 4)
input := strings.NewReader("err1\nerr2\n")
readPipe(input, model.EventTypeProcessStderr, ch)
events := drainEvents(t, ch, 2, time.Second)
if events[0].Type != model.EventTypeProcessStderr || events[0].Body != "err1" {
t.Fatalf("event[0] = %#v, want stderr err1", events[0])
}
if events[1].Type != model.EventTypeProcessStderr || events[1].Body != "err2" {
t.Fatalf("event[1] = %#v, want stderr err2", events[1])
}
})
}
func TestUpdateStatus(t *testing.T) {
t.Parallel()
t.Run("nil process is no-op", func(t *testing.T) {
t.Parallel()
UpdateStatus(nil, true, make(chan model.Event, 1))
})
t.Run("state unchanged is no-op", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 1)
proc := &model.Process{Running: true}
UpdateStatus(proc, true, ch)
select {
case ev := <-ch:
t.Fatalf("unexpected event for unchanged state: %#v", ev)
default:
}
})
t.Run("emits started and stopped events with pid", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 2)
proc := &model.Process{
Exec: &exec.Cmd{Process: &os.Process{Pid: 4321}},
}
UpdateStatus(proc, true, ch)
events := drainEvents(t, ch, 1, time.Second)
started := events[0]
if started.Type != model.EventTypeProcessStarted {
t.Fatalf("started type = %s, want %s", started.Type, model.EventTypeProcessStarted)
}
if started.PID != 4321 {
t.Fatalf("started PID = %d, want %d", started.PID, 4321)
}
if !proc.Running {
t.Fatal("proc.Running = false, want true")
}
UpdateStatus(proc, false, ch)
events = drainEvents(t, ch, 1, time.Second)
stopped := events[0]
if stopped.Type != model.EventTypeProcessExited {
t.Fatalf("stopped type = %s, want %s", stopped.Type, model.EventTypeProcessExited)
}
if stopped.PID != 4321 {
t.Fatalf("stopped PID = %d, want %d", stopped.PID, 4321)
}
if proc.Running {
t.Fatal("proc.Running = true, want false")
}
})
}
func TestNewProcess(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
cmd := model.Command{Name: "sh", Args: []string{"-c", "printf test"}}
proc := NewProcess(cmd, "127.0.0.1:8080", ch)
if proc == nil {
t.Fatal("NewProcess() returned nil")
}
if !reflect.DeepEqual(proc.Command, cmd) {
t.Fatalf("process command = %#v, want %#v", proc.Command, cmd)
}
if proc.Exec == nil {
t.Fatal("process Exec is nil")
}
if proc.Running {
t.Fatal("new process should not be running")
}
if proc.Done == nil {
t.Fatal("Done channel is nil")
}
if got, want := proc.Exec.Args[0], "sh"; got != want {
t.Fatalf("Exec.Args[0] = %q, want %q", got, want)
}
if !containsEnvEntry(proc.Exec.Env, "HTTP_PROXY=http://127.0.0.1:8080") {
t.Fatal("process env missing injected HTTP_PROXY")
}
}
func containsEnvEntry(env []string, want string) bool {
for _, entry := range env {
if entry == want {
return true
}
}
return false
}
func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event {
t.Helper()
events := make([]model.Event, 0, n)
deadline := time.After(timeout)
for len(events) < n {
select {
case ev := <-ch:
events = append(events, ev)
case <-deadline:
t.Fatalf("timeout waiting for %d events; got %d", n, len(events))
}
}
return events
}

View File

@ -0,0 +1,147 @@
//go:build unix
package process
import (
"errors"
"os"
"os/exec"
"syscall"
"testing"
"time"
)
type customSignal struct{}
func (customSignal) String() string { return "custom" }
func (customSignal) Signal() {}
// NOTE: Run these tests with -race in CI for signal/process safety.
func TestConfigureProcessForSignals(t *testing.T) {
t.Parallel()
cmd := exec.Command("sh", "-c", "sleep 0.1")
configureProcessForSignals(cmd)
if cmd.SysProcAttr == nil {
t.Fatal("SysProcAttr is nil")
}
if !cmd.SysProcAttr.Setpgid {
t.Fatal("Setpgid = false, want true")
}
}
func TestSignalProcess_NilSafe(t *testing.T) {
t.Parallel()
if err := SignalProcess(nil, syscall.SIGTERM); err != nil {
t.Fatalf("SignalProcess(nil) error = %v, want nil", err)
}
cmd := &exec.Cmd{}
if err := SignalProcess(cmd, syscall.SIGTERM); err != nil {
t.Fatalf("SignalProcess(cmd without process) error = %v, want nil", err)
}
cmd.Process = &os.Process{Pid: 0}
if err := SignalProcess(cmd, syscall.SIGTERM); err != nil {
t.Fatalf("SignalProcess(pid<=0) error = %v, want nil", err)
}
}
func TestSignalProcess_ESRCHIsTreatedAsSuccess(t *testing.T) {
t.Parallel()
cmd := &exec.Cmd{Process: &os.Process{Pid: 999999}}
if err := SignalProcess(cmd, syscall.SIGTERM); err != nil {
t.Fatalf("SignalProcess() error = %v, want nil when process group not found", err)
}
}
func TestSignalProcess_UsesFallbackOnKillError(t *testing.T) {
t.Parallel()
cmd := exec.Command("sh", "-c", "sleep 5")
configureProcessForSignals(cmd)
if err := cmd.Start(); err != nil {
t.Fatalf("start command error = %v", err)
}
t.Cleanup(func() {
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
})
// Invalid signal causes syscall.Kill to fail with EINVAL, then fallback to cmd.Process.Signal.
err := SignalProcess(cmd, syscall.Signal(9999))
if err == nil {
t.Fatal("SignalProcess() error = nil, want non-nil for invalid signal")
}
if !(errors.Is(err, syscall.EINVAL) || errors.Is(err, os.ErrProcessDone)) {
// OS/process timing can vary; ensure we at least failed predictably.
t.Fatalf("SignalProcess() unexpected error: %v", err)
}
}
func TestSignalProcess_NonSyscallSignalUsesProcessSignal(t *testing.T) {
t.Parallel()
cmd := exec.Command("sh", "-c", "sleep 1")
configureProcessForSignals(cmd)
if err := cmd.Start(); err != nil {
t.Fatalf("start command error = %v", err)
}
t.Cleanup(func() {
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
})
err := SignalProcess(cmd, customSignal{})
if err == nil {
t.Fatal("SignalProcess(custom signal) error = nil, want non-nil")
}
if errors.Is(err, os.ErrProcessDone) {
return
}
if msg := err.Error(); msg == "" {
t.Fatalf("unexpected empty error for custom signal: %v", err)
}
}
func TestProcessAlive(t *testing.T) {
t.Parallel()
if ProcessAlive(nil) {
t.Fatal("ProcessAlive(nil) = true, want false")
}
if ProcessAlive(&exec.Cmd{}) {
t.Fatal("ProcessAlive(cmd without process) = true, want false")
}
cmd := exec.Command("sh", "-c", "sleep 0.2")
configureProcessForSignals(cmd)
if err := cmd.Start(); err != nil {
t.Fatalf("start command error = %v", err)
}
if !ProcessAlive(cmd) {
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
t.Fatal("ProcessAlive(running) = false, want true")
}
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
deadline := time.After(time.Second)
for {
if !ProcessAlive(cmd) {
return
}
select {
case <-deadline:
t.Fatal("ProcessAlive(exited) stayed true")
default:
}
}
}

View File

@ -0,0 +1,242 @@
package proxy
import (
"crypto/tls"
"encoding/pem"
"os"
"path/filepath"
"strings"
"testing"
)
func TestLoadOrCreateCertificateAuthority_RecreatesWhenKeyMissing(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
baseDir := filepath.Join(configRoot, caDirName)
if err := os.MkdirAll(baseDir, 0o700); err != nil {
t.Fatalf("MkdirAll() error = %v", err)
}
certPath := filepath.Join(baseDir, caCertName)
if err := os.WriteFile(certPath, []byte("stale-cert"), 0o600); err != nil {
t.Fatalf("WriteFile(cert) error = %v", err)
}
ca, err := loadOrCreateCertificateAuthority()
if err != nil {
t.Fatalf("loadOrCreateCertificateAuthority() error = %v", err)
}
if !ca.WasCreated() {
t.Fatal("WasCreated = false, want true when key is missing")
}
if _, err := os.Stat(filepath.Join(baseDir, caKeyName)); err != nil {
t.Fatalf("expected key file to be created, stat error = %v", err)
}
}
func TestLoadOrCreateCertificateAuthority_LoadErrorOnCorruptFiles(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
baseDir := filepath.Join(configRoot, caDirName)
if err := os.MkdirAll(baseDir, 0o700); err != nil {
t.Fatalf("MkdirAll() error = %v", err)
}
certPath := filepath.Join(baseDir, caCertName)
keyPath := filepath.Join(baseDir, caKeyName)
if err := os.WriteFile(certPath, []byte("not-a-pem"), 0o600); err != nil {
t.Fatalf("WriteFile(cert) error = %v", err)
}
if err := os.WriteFile(keyPath, []byte("not-a-pem"), 0o600); err != nil {
t.Fatalf("WriteFile(key) error = %v", err)
}
_, err := loadOrCreateCertificateAuthority()
if err == nil {
t.Fatal("loadOrCreateCertificateAuthority() error = nil, want non-nil")
}
}
func TestCertificateAuthorityLoad_ErrorPaths(t *testing.T) {
t.Parallel()
tests := []struct {
name string
certBytes []byte
keyBytes []byte
wantPart string
}{
{
name: "invalid cert pem",
certBytes: []byte("bad-cert"),
keyBytes: []byte("bad-key"),
wantPart: "decode ca cert pem",
},
{
name: "parse cert fails",
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("bogus")}),
keyBytes: []byte("bad-key"),
wantPart: "parse ca cert",
},
{
name: "invalid key pem",
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}),
keyBytes: []byte("bad-key"),
wantPart: "decode ca key pem",
},
{
name: "parse key fails",
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}),
keyBytes: pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: []byte("bogus")}),
wantPart: "parse ca key",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
ca := &CertificateAuthority{
certPath: filepath.Join(dir, caCertName),
keyPath: filepath.Join(dir, caKeyName),
leafCert: make(map[string]*tls.Certificate),
}
if err := os.WriteFile(ca.certPath, tt.certBytes, 0o600); err != nil {
t.Fatalf("write cert file error = %v", err)
}
if err := os.WriteFile(ca.keyPath, tt.keyBytes, 0o600); err != nil {
t.Fatalf("write key file error = %v", err)
}
err := ca.load()
if err == nil {
t.Fatal("load() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), tt.wantPart) {
t.Fatalf("load() error = %q, want contains %q", err.Error(), tt.wantPart)
}
})
}
}
func TestCertificateAuthorityCreate_ErrorWhenWritePathInvalid(t *testing.T) {
t.Parallel()
ca := &CertificateAuthority{
certPath: filepath.Join("/nope", "missing", "ca-cert.pem"),
keyPath: filepath.Join("/nope", "missing", "ca-key.pem"),
leafCert: make(map[string]*tls.Certificate),
}
err := ca.create()
if err == nil {
t.Fatal("create() error = nil, want non-nil")
}
}
func TestCertificateAuthorityCreate_WriteErrorPaths(t *testing.T) {
t.Parallel()
t.Run("write ca cert wraps error", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
badCertPath := filepath.Join(dir, "cert-as-dir")
if err := os.MkdirAll(badCertPath, 0o700); err != nil {
t.Fatalf("MkdirAll(cert dir) error = %v", err)
}
ca := &CertificateAuthority{
certPath: badCertPath,
keyPath: filepath.Join(dir, "ca-key.pem"),
leafCert: make(map[string]*tls.Certificate),
}
err := ca.create()
if err == nil {
t.Fatal("create() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "write ca cert") {
t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca cert")
}
})
t.Run("write ca key wraps error", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
badKeyPath := filepath.Join(dir, "key-as-dir")
if err := os.MkdirAll(badKeyPath, 0o700); err != nil {
t.Fatalf("MkdirAll(key dir) error = %v", err)
}
ca := &CertificateAuthority{
certPath: filepath.Join(dir, "ca-cert.pem"),
keyPath: badKeyPath,
leafCert: make(map[string]*tls.Certificate),
}
err := ca.create()
if err == nil {
t.Fatal("create() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "write ca key") {
t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca key")
}
})
}
func TestCertificateAuthorityLoad_ReadErrorPaths(t *testing.T) {
t.Parallel()
t.Run("read cert failure", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
ca := &CertificateAuthority{
certPath: filepath.Join(dir, "missing-cert.pem"),
keyPath: filepath.Join(dir, "missing-key.pem"),
leafCert: make(map[string]*tls.Certificate),
}
err := ca.load()
if err == nil {
t.Fatal("load() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "read ca cert") {
t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca cert")
}
})
t.Run("read key failure", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
goodCA := newTestCA(t)
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: goodCA.cert.Raw})
certPath := filepath.Join(dir, "cert.pem")
if err := os.WriteFile(certPath, certPEM, 0o600); err != nil {
t.Fatalf("WriteFile(cert) error = %v", err)
}
ca := &CertificateAuthority{
certPath: certPath,
keyPath: filepath.Join(dir, "missing-key.pem"),
leafCert: make(map[string]*tls.Certificate),
}
err := ca.load()
if err == nil {
t.Fatal("load() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "read ca key") {
t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca key")
}
})
}

View File

@ -0,0 +1,45 @@
package proxy
import (
"os"
"testing"
)
func TestLoadOrCreateCertificateAuthority_CreateThenLoad(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
ca1, err := loadOrCreateCertificateAuthority()
if err != nil {
t.Fatalf("first loadOrCreateCertificateAuthority() error = %v", err)
}
if ca1 == nil {
t.Fatal("first CA is nil")
}
if !ca1.WasCreated() {
t.Fatal("first CA should report WasCreated=true")
}
if ca1.CertPath() == "" {
t.Fatal("first CA CertPath is empty")
}
if _, err := os.Stat(ca1.CertPath()); err != nil {
t.Fatalf("first CA cert file missing: %v", err)
}
ca2, err := loadOrCreateCertificateAuthority()
if err != nil {
t.Fatalf("second loadOrCreateCertificateAuthority() error = %v", err)
}
if ca2 == nil {
t.Fatal("second CA is nil")
}
if ca2.WasCreated() {
t.Fatal("second CA should report WasCreated=false (loaded existing)")
}
if ca2.CertPath() != ca1.CertPath() {
t.Fatalf("cert path mismatch: first=%q second=%q", ca1.CertPath(), ca2.CertPath())
}
if ca2.cert == nil || ca2.key == nil {
t.Fatal("loaded CA should include cert and key")
}
}

View File

@ -0,0 +1,266 @@
package proxy
import (
"errors"
"math/big"
"os"
"path/filepath"
"testing"
)
func TestNormalizeCertHost(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
want string
}{
{name: "host and port", in: "example.com:443", want: "example.com"},
{name: "plain host", in: "example.com", want: "example.com"},
{name: "whitespace trims", in: " example.com:8443 ", want: "example.com"},
{name: "empty", in: " ", want: ""},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := normalizeCertHost(tt.in); got != tt.want {
t.Fatalf("normalizeCertHost(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestRandSerialNumber(t *testing.T) {
t.Parallel()
serial, err := randSerialNumber()
if err != nil {
t.Fatalf("randSerialNumber() error = %v", err)
}
if serial == nil {
t.Fatal("serial is nil")
}
if serial.Sign() < 0 {
t.Fatalf("serial must be non-negative, got %v", serial)
}
limit := new(big.Int).Lsh(big.NewInt(1), 128)
if serial.Cmp(limit) >= 0 {
t.Fatalf("serial must be < 2^128, got %v", serial)
}
}
func TestWriteFileAtomically(t *testing.T) {
t.Parallel()
dir := t.TempDir()
path := filepath.Join(dir, "cert.pem")
if err := writeFileAtomically(path, []byte("first"), 0o600); err != nil {
t.Fatalf("first writeFileAtomically() error = %v", err)
}
if err := writeFileAtomically(path, []byte("second"), 0o600); err != nil {
t.Fatalf("second writeFileAtomically() error = %v", err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
if got, want := string(data), "second"; got != want {
t.Fatalf("file contents = %q, want %q", got, want)
}
info, err := os.Stat(path)
if err != nil {
t.Fatalf("Stat() error = %v", err)
}
if got := info.Mode().Perm(); got != 0o600 {
t.Fatalf("file permissions = %#o, want %#o", got, 0o600)
}
}
func TestCertificateAuthority_Basics(t *testing.T) {
t.Parallel()
var nilCA *CertificateAuthority
if got := nilCA.CertPath(); got != "" {
t.Fatalf("nil CertPath() = %q, want empty", got)
}
if got := nilCA.WasCreated(); got {
t.Fatalf("nil WasCreated() = %v, want false", got)
}
ca := newTestCA(t)
ca.certPath = "/tmp/test-ca.pem"
ca.wasCreated = true
if got, want := ca.CertPath(), "/tmp/test-ca.pem"; got != want {
t.Fatalf("CertPath() = %q, want %q", got, want)
}
if !ca.WasCreated() {
t.Fatal("WasCreated() = false, want true")
}
}
func TestCertificateForHost(t *testing.T) {
t.Parallel()
ca := newTestCA(t)
t.Run("empty host returns error", func(t *testing.T) {
t.Parallel()
cert, err := ca.CertificateForHost(" ")
if err == nil {
t.Fatal("CertificateForHost() error = nil, want non-nil")
}
if cert != nil {
t.Fatalf("cert = %#v, want nil", cert)
}
})
t.Run("cache hit returns same pointer", func(t *testing.T) {
t.Parallel()
c1, err := ca.CertificateForHost("example.com:443")
if err != nil {
t.Fatalf("first CertificateForHost() error = %v", err)
}
c2, err := ca.CertificateForHost("example.com")
if err != nil {
t.Fatalf("second CertificateForHost() error = %v", err)
}
if c1 != c2 {
t.Fatal("expected same certificate pointer from cache")
}
})
t.Run("ip and dns SAN selection", func(t *testing.T) {
t.Parallel()
ipCert, err := ca.CertificateForHost("127.0.0.1:443")
if err != nil {
t.Fatalf("ip CertificateForHost() error = %v", err)
}
if ipCert.Leaf == nil {
t.Fatal("ip cert leaf is nil")
}
if len(ipCert.Leaf.IPAddresses) == 0 {
t.Fatal("ip cert should contain IP SAN")
}
if len(ipCert.Leaf.DNSNames) != 0 {
t.Fatalf("ip cert DNSNames = %v, want empty", ipCert.Leaf.DNSNames)
}
dnsCert, err := ca.CertificateForHost("service.local")
if err != nil {
t.Fatalf("dns CertificateForHost() error = %v", err)
}
if dnsCert.Leaf == nil {
t.Fatal("dns cert leaf is nil")
}
if len(dnsCert.Leaf.DNSNames) == 0 {
t.Fatal("dns cert should contain DNS SAN")
}
})
t.Run("evicts oldest entry over maxLeafCerts", func(t *testing.T) {
t.Parallel()
ca2 := newTestCA(t)
for i := 0; i < maxLeafCerts+1; i++ {
host := filepath.Base(filepath.Join("h", big.NewInt(int64(i)).String()+".example"))
if _, err := ca2.CertificateForHost(host); err != nil {
t.Fatalf("CertificateForHost(%q) error = %v", host, err)
}
}
if len(ca2.leafOrder) != maxLeafCerts {
t.Fatalf("leafOrder len = %d, want %d", len(ca2.leafOrder), maxLeafCerts)
}
if _, ok := ca2.leafCert["0.example"]; ok {
t.Fatal("expected oldest cert to be evicted")
}
})
}
func TestIsTrustedBySystem(t *testing.T) {
t.Parallel()
var nilCA *CertificateAuthority
_, err := nilCA.IsTrustedBySystem()
if err == nil {
t.Fatal("nil IsTrustedBySystem() error = nil, want non-nil")
}
ca := &CertificateAuthority{}
_, err = ca.IsTrustedBySystem()
if err == nil {
t.Fatal("missing-cert IsTrustedBySystem() error = nil, want non-nil")
}
t.Run("untrusted generated CA returns false without error", func(t *testing.T) {
t.Parallel()
ca := newTestCA(t)
trusted, err := ca.IsTrustedBySystem()
if err != nil {
t.Fatalf("IsTrustedBySystem() error = %v, want nil for unknown authority", err)
}
if trusted {
t.Fatal("trusted = true, want false for generated test CA")
}
})
}
func TestEnsureCertificateAuthority(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
ca, err := EnsureCertificateAuthority()
if err != nil {
t.Fatalf("EnsureCertificateAuthority() error = %v", err)
}
if ca == nil {
t.Fatal("EnsureCertificateAuthority() returned nil CA")
}
if ca.CertPath() == "" {
t.Fatal("EnsureCertificateAuthority() returned empty cert path")
}
if _, statErr := os.Stat(ca.CertPath()); statErr != nil {
t.Fatalf("expected cert on disk, stat error = %v", statErr)
}
}
func TestWriteFileAtomically_ErrorPath(t *testing.T) {
t.Parallel()
err := writeFileAtomically(filepath.Join("/nope", "bad", "path.pem"), []byte("x"), 0o600)
if err == nil {
t.Fatal("writeFileAtomically() error = nil, want non-nil")
}
if errors.Is(err, os.ErrNotExist) {
return
}
// Accept platform-dependent fs errors as long as function fails.
}
func TestWriteFileAtomically_RenameErrorWhenTargetIsDirectory(t *testing.T) {
t.Parallel()
dir := t.TempDir()
targetDir := filepath.Join(dir, "target-as-dir")
if err := os.MkdirAll(targetDir, 0o700); err != nil {
t.Fatalf("MkdirAll(targetDir) error = %v", err)
}
err := writeFileAtomically(targetDir, []byte("x"), 0o600)
if err == nil {
t.Fatal("writeFileAtomically() error = nil, want non-nil")
}
}
// TODO: Add deterministic tests for loadOrCreateCertificateAuthority trust-store interactions.

View File

@ -0,0 +1,29 @@
package proxy
import (
"encoding/pem"
"os"
"path/filepath"
"testing"
)
func TestIsTrustedBySystem_TrustedViaCertEnv(t *testing.T) {
ca := newTestCA(t)
rootFile := filepath.Join(t.TempDir(), "roots.pem")
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.cert.Raw})
if err := os.WriteFile(rootFile, pemBytes, 0o600); err != nil {
t.Fatalf("WriteFile(root cert) error = %v", err)
}
t.Setenv("SSL_CERT_FILE", rootFile)
t.Setenv("SSL_CERT_DIR", t.TempDir())
trusted, err := ca.IsTrustedBySystem()
if err != nil {
t.Fatalf("IsTrustedBySystem() error = %v", err)
}
if !trusted {
t.Fatal("trusted = false, want true when CA is in configured cert file")
}
}

View File

@ -0,0 +1,330 @@
package proxy
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"termtap.dev/internal/model"
)
type failingResponseWriter struct {
header http.Header
code int
}
func (w *failingResponseWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *failingResponseWriter) WriteHeader(statusCode int) {
w.code = statusCode
}
func (w *failingResponseWriter) Write(_ []byte) (int, error) {
return 0, io.ErrClosedPipe
}
type hijackFailWriter struct {
header http.Header
code int
}
func (w *hijackFailWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *hijackFailWriter) WriteHeader(statusCode int) {
w.code = statusCode
}
func (w *hijackFailWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (w *hijackFailWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, fmt.Errorf("hijack failed")
}
type dummyConn struct{}
func (dummyConn) Read(_ []byte) (int, error) { return 0, io.EOF }
func (dummyConn) Write(p []byte) (int, error) { return len(p), nil }
func (dummyConn) Close() error { return nil }
func (dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (dummyConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (dummyConn) SetDeadline(_ time.Time) error { return nil }
func (dummyConn) SetReadDeadline(_ time.Time) error { return nil }
func (dummyConn) SetWriteDeadline(_ time.Time) error { return nil }
type writeConnectFailWriter struct {
header http.Header
code int
}
func (w *writeConnectFailWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *writeConnectFailWriter) WriteHeader(statusCode int) {
w.code = statusCode
}
func (w *writeConnectFailWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (w *writeConnectFailWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
rw := bufio.NewReadWriter(
bufio.NewReader(strings.NewReader("")),
bufio.NewWriter(errWriter{}),
)
return dummyConn{}, rw, nil
}
func TestProxyHandler_NonConnectRejectsNonAbsoluteURL(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
h := proxyHandler(ch, nil, ps)
req := httptest.NewRequest(http.MethodGet, "/not-proxy-form", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", w.Code, http.StatusBadRequest)
}
events := drainEvents(t, ch, 1, time.Second)
if !hasEventType(events, model.EventTypeWarn) {
t.Fatalf("expected %s event, got %#v", model.EventTypeWarn, events)
}
}
func TestProxyHandler_NonConnectSuccess(t *testing.T) {
t.Parallel()
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Upstream", "yes")
w.WriteHeader(http.StatusAccepted)
_, _ = w.Write([]byte("pong"))
}))
t.Cleanup(upstream.Close)
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
h := proxyHandler(ch, nil, ps)
req := httptest.NewRequest(http.MethodGet, upstream.URL+"/ping", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
if w.Code != http.StatusAccepted {
t.Fatalf("status = %d, want %d", w.Code, http.StatusAccepted)
}
if got, want := w.Body.String(), "pong"; got != want {
t.Fatalf("body = %q, want %q", got, want)
}
if got := w.Header().Get("X-Upstream"); got != "yes" {
t.Fatalf("X-Upstream header = %q, want yes", got)
}
events := drainEvents(t, ch, 2, time.Second)
if !hasEventType(events, model.EventTypeRequestStarted) {
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
}
if !hasEventType(events, model.EventTypeRequestFinished) {
t.Fatalf("expected %s event", model.EventTypeRequestFinished)
}
}
func TestProxyHandler_NonConnectUpstreamError(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
h := proxyHandler(ch, nil, ps)
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1:1/fail", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
if w.Code != http.StatusBadGateway {
t.Fatalf("status = %d, want %d", w.Code, http.StatusBadGateway)
}
events := drainEvents(t, ch, 2, time.Second)
if !hasEventType(events, model.EventTypeRequestStarted) {
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
}
if !hasEventType(events, model.EventTypeRequestFailed) {
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
}
}
func TestProxyHandler_NonConnectWriteFailureEmitsFailedEvent(t *testing.T) {
t.Parallel()
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = io.Copy(w, bytes.NewBufferString("response-body"))
}))
t.Cleanup(upstream.Close)
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
h := proxyHandler(ch, nil, ps)
req := httptest.NewRequest(http.MethodGet, upstream.URL+"/copy-fail", nil)
w := &failingResponseWriter{}
h.ServeHTTP(w, req)
events := drainEvents(t, ch, 2, time.Second)
if !hasEventType(events, model.EventTypeRequestStarted) {
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
}
if !hasEventType(events, model.EventTypeRequestFailed) {
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
}
}
func TestHandleConnect_EarlyFailures(t *testing.T) {
t.Parallel()
transport := &mockTransport{fn: func(req *http.Request) (*http.Response, error) {
return nil, context.Canceled
}}
tests := []struct {
name string
writer http.ResponseWriter
req *http.Request
ca *CertificateAuthority
wantCode int
wantBody string
wantStat int
}{
{
name: "nil CA returns 502 and failed event",
writer: httptest.NewRecorder(),
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
ca: nil,
wantCode: http.StatusBadGateway,
wantBody: "certificate authority is unavailable",
wantStat: http.StatusBadGateway,
},
{
name: "nil CA with host without port still fails predictably",
writer: httptest.NewRecorder(),
req: &http.Request{Method: http.MethodConnect, Host: "example.com"},
ca: nil,
wantCode: http.StatusBadGateway,
wantBody: "certificate authority is unavailable",
wantStat: http.StatusBadGateway,
},
{
name: "cert mint failure returns 502 and failed event",
writer: httptest.NewRecorder(),
req: &http.Request{Method: http.MethodConnect, Host: ""},
ca: &CertificateAuthority{},
wantCode: http.StatusBadGateway,
wantBody: "failed to mint interception certificate",
wantStat: http.StatusBadGateway,
},
{
name: "non-hijacker writer returns 500 and failed event",
writer: httptest.NewRecorder(),
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
ca: newTestCA(t),
wantCode: http.StatusInternalServerError,
wantBody: "hijack is unavailable",
wantStat: http.StatusInternalServerError,
},
{
name: "hijack failure returns 500 and failed event",
writer: &hijackFailWriter{},
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
ca: newTestCA(t),
wantCode: http.StatusInternalServerError,
wantBody: "CONNECT hijack failed",
wantStat: http.StatusInternalServerError,
},
{
name: "write connect established failure emits failed event",
writer: &writeConnectFailWriter{},
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
ca: newTestCA(t),
wantCode: 0,
wantBody: "CONNECT setup failed",
wantStat: http.StatusBadGateway,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
handleConnect(tt.writer, tt.req, ch, transport, tt.ca, ps)
events := drainEvents(t, ch, 2, time.Second)
if events[0].Type != model.EventTypeRequestStarted {
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
}
if events[1].Type != model.EventTypeRequestFailed {
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
}
if events[1].Request.Method != http.MethodConnect {
t.Fatalf("failed event request method = %q, want CONNECT", events[1].Request.Method)
}
if events[1].Request.Status != tt.wantStat {
t.Fatalf("failed event request status = %d, want %d", events[1].Request.Status, tt.wantStat)
}
if !events[1].Request.Failed || events[1].Request.Pending {
t.Fatalf("failed event request flags = pending:%v failed:%v, want pending:false failed:true", events[1].Request.Pending, events[1].Request.Failed)
}
if !strings.Contains(events[1].Body, tt.wantBody) {
t.Fatalf("failed event body %q does not contain %q", events[1].Body, tt.wantBody)
}
if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok {
if recorder.Code != tt.wantCode {
t.Fatalf("status = %d, want %d", recorder.Code, tt.wantCode)
}
}
if w, ok := tt.writer.(*hijackFailWriter); ok {
if w.code != tt.wantCode {
t.Fatalf("status = %d, want %d", w.code, tt.wantCode)
}
}
})
}
}
// TODO: Add full TLS MITM loop integration test.

View File

@ -0,0 +1,157 @@
package proxy
import (
"net/http"
"reflect"
"testing"
)
func TestStripHopByHopHeaders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in http.Header
want http.Header
}{
{
name: "removes static hop-by-hop headers",
in: http.Header{
"Connection": {"keep-alive"},
"Keep-Alive": {"timeout=5"},
"Proxy-Connection": {"keep-alive"},
"Transfer-Encoding": {"chunked"},
"Upgrade": {"websocket"},
"X-Custom": {"ok"},
},
want: http.Header{
"X-Custom": {"ok"},
},
},
{
name: "removes headers listed in connection with spaces and commas",
in: http.Header{
"Connection": {" keep-alive, X-Foo ,X-Bar"},
"Keep-Alive": {"timeout=5"},
"X-Foo": {"foo"},
"X-Bar": {"bar"},
"Content-Type": {"application/json"},
"Content-Length": {"12"},
},
want: http.Header{
"Content-Type": {"application/json"},
"Content-Length": {"12"},
},
},
{
name: "keeps unrelated headers when no connection header",
in: http.Header{
"Accept": {"*/*"},
"Authorization": {"Bearer abc"},
},
want: http.Header{
"Accept": {"*/*"},
"Authorization": {"Bearer abc"},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
stripHopByHopHeaders(tt.in)
if !reflect.DeepEqual(tt.in, tt.want) {
t.Fatalf("stripHopByHopHeaders() = %#v, want %#v", tt.in, tt.want)
}
})
}
}
func TestStripHopByHopHeaders_NilHeader(t *testing.T) {
t.Parallel()
stripHopByHopHeaders(nil)
}
func TestRedactHeaders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input http.Header
want http.Header
}{
{
name: "redacts sensitive canonical headers",
input: http.Header{
"Authorization": {"Bearer token123"},
"Cookie": {"session=abc"},
"Proxy-Authorization": {"Basic abc"},
"Set-Cookie": {"a=1"},
"X-Api-Key": {"secret"},
},
want: http.Header{
"Authorization": {"[REDACTED]"},
"Cookie": {"[REDACTED]"},
"Proxy-Authorization": {"[REDACTED]"},
"Set-Cookie": {"[REDACTED]"},
"X-Api-Key": {"[REDACTED]"},
},
},
{
name: "leaves non-sensitive headers untouched",
input: http.Header{
"Content-Type": {"application/json"},
"X-Trace-ID": {"trace-1"},
},
want: http.Header{
"Content-Type": {"application/json"},
"X-Trace-ID": {"trace-1"},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := redactHeaders(tt.input)
if !reflect.DeepEqual(got, tt.want) {
t.Fatalf("redactHeaders() = %#v, want %#v", got, tt.want)
}
got.Set("Content-Type", "modified")
if reflect.DeepEqual(tt.input, got) {
t.Fatal("redactHeaders() appears to mutate input or return aliased map")
}
})
}
}
func TestCopyHeaders(t *testing.T) {
t.Parallel()
src := http.Header{
"X-Multi": {"a", "b", "c"},
"Content-Type": {"application/json"},
}
dest := http.Header{
"Existing": {"keep"},
}
copyHeaders(src, dest)
want := http.Header{
"Existing": {"keep"},
"X-Multi": {"a", "b", "c"},
"Content-Type": {"application/json"},
}
if !reflect.DeepEqual(dest, want) {
t.Fatalf("copyHeaders() dest = %#v, want %#v", dest, want)
}
}

View File

@ -0,0 +1,129 @@
package proxy
import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"termtap.dev/internal/model"
)
func TestHTTPProxyE2E_WithClientProxyConfig(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("pong"))
}))
t.Cleanup(upstream.Close)
ch := make(chan model.Event, 64)
ps, err := NewProxyServer("127.0.0.1:0", ch)
if err != nil {
t.Fatalf("NewProxyServer() error = %v", err)
}
serveDone := make(chan error, 1)
go func() {
serveDone <- ps.Server.Serve(*ps.Listener)
}()
t.Cleanup(func() {
Destroy(ps, ch)
select {
case <-serveDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for proxy server to stop")
}
})
proxyURL, err := url.Parse(ps.Url)
if err != nil {
t.Fatalf("url.Parse(proxy) error = %v", err)
}
client := &http.Client{
Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)},
Timeout: 3 * time.Second,
}
resp, err := client.Get(upstream.URL + "/ping")
if err != nil {
t.Fatalf("client.Get() error = %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ReadAll(response) error = %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
}
if got, want := string(body), "pong"; got != want {
t.Fatalf("body = %q, want %q", got, want)
}
events := drainEvents(t, ch, 2, 2*time.Second)
if !hasEventType(events, model.EventTypeRequestStarted) {
t.Fatalf("missing %s event", model.EventTypeRequestStarted)
}
if !hasEventType(events, model.EventTypeRequestFinished) {
t.Fatalf("missing %s event", model.EventTypeRequestFinished)
}
}
func TestHTTPProxyE2E_UpstreamFailureEmitsRequestFailed(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
ch := make(chan model.Event, 64)
ps, err := NewProxyServer("127.0.0.1:0", ch)
if err != nil {
t.Fatalf("NewProxyServer() error = %v", err)
}
serveDone := make(chan error, 1)
go func() {
serveDone <- ps.Server.Serve(*ps.Listener)
}()
t.Cleanup(func() {
Destroy(ps, ch)
select {
case <-serveDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for proxy server to stop")
}
})
proxyURL, err := url.Parse(ps.Url)
if err != nil {
t.Fatalf("url.Parse(proxy) error = %v", err)
}
client := &http.Client{
Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)},
Timeout: 3 * time.Second,
}
resp, reqErr := client.Get("http://127.0.0.1:1/unreachable")
if reqErr != nil {
t.Fatalf("client.Get() error = %v; proxy should reply with mapped HTTP status", reqErr)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadGateway {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusBadGateway)
}
events := drainEvents(t, ch, 2, 3*time.Second)
if !hasEventType(events, model.EventTypeRequestStarted) {
t.Fatalf("missing %s event", model.EventTypeRequestStarted)
}
if !hasEventType(events, model.EventTypeRequestFailed) {
t.Fatalf("missing %s event", model.EventTypeRequestFailed)
}
}

View File

@ -0,0 +1,301 @@
package proxy
import (
"bufio"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
"termtap.dev/internal/model"
)
type hijackPipeWriter struct {
conn net.Conn
header http.Header
code int
}
func (w *hijackPipeWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *hijackPipeWriter) WriteHeader(statusCode int) {
w.code = statusCode
}
func (w *hijackPipeWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (w *hijackPipeWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.conn, nil, nil
}
func startMITMConnect(t *testing.T, transport http.RoundTripper) (net.Conn, *tls.Conn, chan model.Event, chan struct{}) {
t.Helper()
ca := newTestCA(t)
clientConn, serverConn := net.Pipe()
writer := &hijackPipeWriter{conn: serverConn}
req, err := http.NewRequest(http.MethodConnect, "http://example.com:443", nil)
if err != nil {
t.Fatalf("NewRequest(CONNECT) error = %v", err)
}
req.Host = "example.com:443"
ch := make(chan model.Event, 16)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
handleDone := make(chan struct{})
go func() {
handleConnect(writer, req, ch, transport, ca, ps)
close(handleDone)
}()
reader := bufio.NewReader(clientConn)
connectResp, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
if err != nil {
_ = clientConn.Close()
_ = serverConn.Close()
t.Fatalf("ReadResponse(CONNECT established) error = %v", err)
}
if connectResp.StatusCode != http.StatusOK {
_ = clientConn.Close()
_ = serverConn.Close()
t.Fatalf("CONNECT established status = %d, want %d", connectResp.StatusCode, http.StatusOK)
}
pool := x509.NewCertPool()
pool.AddCert(ca.cert)
tlsClient := tls.Client(clientConn, &tls.Config{
ServerName: "example.com",
RootCAs: pool,
MinVersion: tls.VersionTLS12,
})
if err := tlsClient.Handshake(); err != nil {
_ = tlsClient.Close()
_ = serverConn.Close()
t.Fatalf("tls handshake error = %v", err)
}
t.Cleanup(func() {
_ = tlsClient.Close()
_ = serverConn.Close()
})
return clientConn, tlsClient, ch, handleDone
}
func TestHTTPSE2E_MITMHandleConnectFlow(t *testing.T) {
transport := &mockTransport{fn: func(r *http.Request) (*http.Response, error) {
if r.URL.Scheme != "https" {
t.Fatalf("inner request scheme = %q, want https", r.URL.Scheme)
}
if r.URL.Host == "" {
t.Fatal("inner request host should be populated")
}
return &http.Response{
StatusCode: http.StatusCreated,
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{"Content-Type": {"text/plain"}},
Body: io.NopCloser(strings.NewReader("mitm-ok")),
}, nil
}}
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
// 3) Send decrypted inner request over tunnel and read MITM response
innerReq, err := http.NewRequest(http.MethodGet, "https://example.com/inside", nil)
if err != nil {
t.Fatalf("NewRequest(inner) error = %v", err)
}
innerReq.Host = "example.com"
innerReq.Close = true
if err := innerReq.Write(tlsClient); err != nil {
t.Fatalf("innerReq.Write() error = %v", err)
}
innerResp, err := http.ReadResponse(bufio.NewReader(tlsClient), innerReq)
if err != nil {
t.Fatalf("ReadResponse(inner) error = %v", err)
}
defer innerResp.Body.Close()
body, err := io.ReadAll(innerResp.Body)
if err != nil {
t.Fatalf("ReadAll(inner body) error = %v", err)
}
if innerResp.StatusCode != http.StatusCreated {
t.Fatalf("inner status = %d, want %d", innerResp.StatusCode, http.StatusCreated)
}
if got, want := string(body), "mitm-ok"; got != want {
t.Fatalf("inner body = %q, want %q", got, want)
}
_ = tlsClient.Close()
_ = clientConn.Close()
select {
case <-handleDone:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for handleConnect to return")
}
events := drainEvents(t, ch, 4, 2*time.Second)
if events[0].Type != model.EventTypeRequestStarted {
t.Fatalf("event[0] = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
}
if events[1].Type != model.EventTypeRequestStarted {
t.Fatalf("event[1] = %s, want %s", events[1].Type, model.EventTypeRequestStarted)
}
if events[2].Type != model.EventTypeRequestFinished {
t.Fatalf("event[2] = %s, want %s", events[2].Type, model.EventTypeRequestFinished)
}
if events[3].Type != model.EventTypeRequestFinished {
t.Fatalf("event[3] = %s, want %s", events[3].Type, model.EventTypeRequestFinished)
}
}
func TestHTTPSE2E_MITMUpstreamErrorReturnsHTTPErrorInsideTunnel(t *testing.T) {
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
return nil, errors.New("upstream exploded")
}}
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
innerReq, err := http.NewRequest(http.MethodGet, "https://example.com/fail", nil)
if err != nil {
t.Fatalf("NewRequest(inner) error = %v", err)
}
innerReq.Host = "example.com"
if err := innerReq.Write(tlsClient); err != nil {
t.Fatalf("innerReq.Write() error = %v", err)
}
innerResp, err := http.ReadResponse(bufio.NewReader(tlsClient), innerReq)
if err != nil {
t.Fatalf("ReadResponse(inner) error = %v", err)
}
defer innerResp.Body.Close()
if innerResp.StatusCode != http.StatusBadGateway {
t.Fatalf("inner status = %d, want %d", innerResp.StatusCode, http.StatusBadGateway)
}
_ = tlsClient.Close()
_ = clientConn.Close()
select {
case <-handleDone:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for handleConnect to return")
}
events := drainEvents(t, ch, 4, 2*time.Second)
if events[0].Type != model.EventTypeRequestStarted ||
events[1].Type != model.EventTypeRequestStarted ||
events[2].Type != model.EventTypeRequestFailed ||
events[3].Type != model.EventTypeRequestFailed {
t.Fatalf("unexpected event sequence: %#v", events)
}
}
func TestHTTPSE2E_MITMHandshakeFailureEmitsConnectFailed(t *testing.T) {
ca := newTestCA(t)
clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
})
writer := &hijackPipeWriter{conn: serverConn}
req, err := http.NewRequest(http.MethodConnect, "http://example.com:443", nil)
if err != nil {
t.Fatalf("NewRequest(CONNECT) error = %v", err)
}
req.Host = "example.com:443"
ch := make(chan model.Event, 16)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
return nil, errors.New("should not be called")
}}
handleDone := make(chan struct{})
go func() {
handleConnect(writer, req, ch, transport, ca, ps)
close(handleDone)
}()
reader := bufio.NewReader(clientConn)
connectResp, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
if err != nil {
t.Fatalf("ReadResponse(CONNECT established) error = %v", err)
}
if connectResp.StatusCode != http.StatusOK {
t.Fatalf("CONNECT established status = %d, want %d", connectResp.StatusCode, http.StatusOK)
}
// Write plaintext (not TLS handshake) to force handshake failure.
if _, err := clientConn.Write([]byte("not tls handshake")); err != nil {
t.Fatalf("client write error = %v", err)
}
_ = clientConn.Close()
select {
case <-handleDone:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for handleConnect to return on handshake failure")
}
events := drainEvents(t, ch, 2, 2*time.Second)
if events[0].Type != model.EventTypeRequestStarted || events[1].Type != model.EventTypeRequestFailed {
t.Fatalf("unexpected event sequence: %#v", events)
}
if !strings.Contains(events[1].Body, "TLS handshake with client failed") {
t.Fatalf("failed event body = %q, want handshake failure details", events[1].Body)
}
}
func TestHTTPSE2E_MITMDecryptedReadFailureEmitsConnectFailed(t *testing.T) {
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
return nil, errors.New("should not be called")
}}
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
// Send malformed decrypted payload so ReadRequest fails (non-EOF branch).
if _, err := tlsClient.Write([]byte("BAD REQUEST\r\n\r\n")); err != nil {
t.Fatalf("tlsClient.Write() error = %v", err)
}
_ = tlsClient.Close()
_ = clientConn.Close()
select {
case <-handleDone:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for handleConnect to return on decrypted read failure")
}
events := drainEvents(t, ch, 2, 2*time.Second)
if events[0].Type != model.EventTypeRequestStarted || events[1].Type != model.EventTypeRequestFailed {
t.Fatalf("unexpected event sequence: %#v", events)
}
if !strings.Contains(events[1].Body, "failed to read decrypted HTTPS request") {
t.Fatalf("failed event body = %q, want read failure details", events[1].Body)
}
}

View File

@ -0,0 +1,142 @@
package proxy
import (
"bufio"
"io"
"net"
"strings"
"testing"
)
func TestNewBodyPreview(t *testing.T) {
t.Parallel()
tests := []struct {
name string
contentType string
wantEnabled bool
}{
{name: "text content enabled", contentType: "text/plain", wantEnabled: true},
{name: "json content enabled", contentType: "application/json", wantEnabled: true},
{name: "binary content disabled", contentType: "application/octet-stream", wantEnabled: false},
{name: "empty content type disabled", contentType: "", wantEnabled: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
p := newBodyPreview(tt.contentType)
if p.enabled != tt.wantEnabled {
t.Fatalf("newBodyPreview(%q).enabled = %v, want %v", tt.contentType, p.enabled, tt.wantEnabled)
}
})
}
}
func TestBodyPreviewWriteAndPreview(t *testing.T) {
t.Parallel()
t.Run("nil receiver is safe", func(t *testing.T) {
t.Parallel()
var p *bodyPreview
p.Write([]byte("abc"))
})
t.Run("disabled preview ignores data", func(t *testing.T) {
t.Parallel()
p := &bodyPreview{enabled: false}
p.Write([]byte("abc"))
if got := string(p.Preview()); got != "" {
t.Fatalf("Preview() = %q, want empty", got)
}
})
t.Run("empty write does nothing", func(t *testing.T) {
t.Parallel()
p := &bodyPreview{enabled: true}
p.Write(nil)
if got := string(p.Preview()); got != "" {
t.Fatalf("Preview() = %q, want empty", got)
}
})
t.Run("escapes newlines", func(t *testing.T) {
t.Parallel()
p := &bodyPreview{enabled: true}
p.Write([]byte("a\nb"))
if got, want := string(p.Preview()), `a\nb`; got != want {
t.Fatalf("Preview() = %q, want %q", got, want)
}
})
t.Run("truncates at max preview bytes and appends ellipsis", func(t *testing.T) {
t.Parallel()
p := &bodyPreview{enabled: true}
p.Write([]byte(strings.Repeat("a", maxPreviewBytes)))
p.Write([]byte("b"))
got := string(p.Preview())
if !strings.HasSuffix(got, "...") {
t.Fatalf("Preview() must end with ellipsis when truncated: %q", got[len(got)-10:])
}
if len(got) != maxPreviewBytes+3 {
t.Fatalf("len(Preview()) = %d, want %d", len(got), maxPreviewBytes+3)
}
})
}
func TestWrapBufferedConn(t *testing.T) {
t.Parallel()
client, server := net.Pipe()
t.Cleanup(func() {
_ = client.Close()
_ = server.Close()
})
t.Run("returns original conn when readWriter nil", func(t *testing.T) {
t.Parallel()
got := wrapBufferedConn(client, nil)
if got != client {
t.Fatal("wrapBufferedConn should return original conn when readWriter is nil")
}
})
t.Run("read uses buffered readWriter", func(t *testing.T) {
t.Parallel()
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("xyz")), bufio.NewWriter(io.Discard))
got := wrapBufferedConn(client, rw)
buf := make([]byte, 3)
n, err := got.Read(buf)
if err != nil {
t.Fatalf("Read() error = %v", err)
}
if n != 3 || string(buf) != "xyz" {
t.Fatalf("Read() = (%d, %q), want (3, %q)", n, string(buf), "xyz")
}
})
}
func TestPreviewReadCloserRead(t *testing.T) {
t.Parallel()
preview := newBodyPreview("text/plain")
rc := &previewReadCloser{
ReadCloser: io.NopCloser(strings.NewReader("hello\nworld")),
preview: preview,
}
data, err := io.ReadAll(rc)
if err != nil {
t.Fatalf("ReadAll() error = %v", err)
}
if got, want := string(data), "hello\nworld"; got != want {
t.Fatalf("read content = %q, want %q", got, want)
}
if got, want := string(preview.Preview()), `hello\nworld`; got != want {
t.Fatalf("preview = %q, want %q", got, want)
}
}

View File

@ -0,0 +1,275 @@
package proxy
import (
"errors"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/google/uuid"
"termtap.dev/internal/model"
)
type mockTransport struct {
fn func(*http.Request) (*http.Response, error)
}
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return m.fn(req)
}
func TestRoundTripCapturedRequest_Success(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
reqURL, err := url.Parse("http://example.com/path?q=1")
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
req := &http.Request{
Method: http.MethodPost,
URL: reqURL,
Host: "example.com",
Header: http.Header{
"Connection": {"X-Hop"},
"X-Hop": {"drop"},
"Authorization": {"Bearer token"},
"Content-Type": {"text/plain"},
},
Body: io.NopCloser(strings.NewReader("req\nbody")),
}
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
if got := outReq.Header.Get("Connection"); got != "" {
t.Fatalf("Connection header should be stripped, got %q", got)
}
if got := outReq.Header.Get("X-Hop"); got != "" {
t.Fatalf("header listed in Connection should be stripped, got %q", got)
}
data, readErr := io.ReadAll(outReq.Body)
if readErr != nil {
t.Fatalf("ReadAll(outReq.Body) error = %v", readErr)
}
if got, want := string(data), "req\nbody"; got != want {
t.Fatalf("request body = %q, want %q", got, want)
}
return &http.Response{
StatusCode: http.StatusCreated,
Header: http.Header{
"Set-Cookie": {"session=top-secret"},
"Connection": {"close"},
"Content-Type": {"text/plain"},
},
Body: io.NopCloser(strings.NewReader("resp\nbody")),
}, nil
}}
resp, captured, responsePreview, err := roundTripCapturedRequest(req, transport, ch, "", false)
if err != nil {
t.Fatalf("roundTripCapturedRequest() error = %v", err)
}
if resp == nil {
t.Fatal("roundTripCapturedRequest() returned nil response")
}
if responsePreview == nil {
t.Fatal("roundTripCapturedRequest() returned nil response preview")
}
if _, readErr := io.ReadAll(resp.Body); readErr != nil {
t.Fatalf("ReadAll(resp.Body) error = %v", readErr)
}
_ = resp.Body.Close()
if got, want := string(captured.RequestData), `req\nbody`; got != want {
t.Fatalf("captured.RequestData = %q, want %q", got, want)
}
if got, want := string(responsePreview.Preview()), `resp\nbody`; got != want {
t.Fatalf("responsePreview = %q, want %q", got, want)
}
if got := captured.RequestHeaders.Get("Authorization"); got != "[REDACTED]" {
t.Fatalf("Authorization header = %q, want [REDACTED]", got)
}
if got := captured.ResponseHeaders.Get("Set-Cookie"); got != "[REDACTED]" {
t.Fatalf("Set-Cookie header = %q, want [REDACTED]", got)
}
if got := captured.ResponseHeaders.Get("Connection"); got != "" {
t.Fatalf("Connection should be stripped from response headers, got %q", got)
}
events := drainEvents(t, ch, 1, time.Second)
if events[0].Type != model.EventTypeRequestStarted {
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
}
}
func TestRoundTripCapturedRequest_ErrorPath(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
reqURL, err := url.Parse("http://example.com/fail")
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
req := &http.Request{
Method: http.MethodPost,
URL: reqURL,
Host: "example.com",
Header: http.Header{"Content-Type": {"text/plain"}},
Body: io.NopCloser(strings.NewReader("boom\nbody")),
}
wantErr := errors.New("upstream failed")
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
_, _ = io.ReadAll(outReq.Body)
return nil, wantErr
}}
resp, captured, responsePreview, gotErr := roundTripCapturedRequest(req, transport, ch, "", false)
if !errors.Is(gotErr, wantErr) {
t.Fatalf("error = %v, want %v", gotErr, wantErr)
}
if resp != nil {
t.Fatalf("response = %#v, want nil", resp)
}
if responsePreview != nil {
t.Fatalf("responsePreview = %#v, want nil", responsePreview)
}
if got, want := string(captured.RequestData), `boom\nbody`; got != want {
t.Fatalf("captured.RequestData = %q, want %q", got, want)
}
events := drainEvents(t, ch, 1, time.Second)
if events[0].Type != model.EventTypeRequestStarted {
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
}
}
func TestRoundTripCapturedRequest_InterceptedTLSDefaults(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
u, err := url.Parse("/secure?p=1")
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
req := &http.Request{
Method: http.MethodGet,
URL: u,
Header: http.Header{},
}
const defaultHost = "api.example.com:443"
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
if got := outReq.URL.Scheme; got != "https" {
t.Fatalf("URL.Scheme = %q, want https", got)
}
if got := outReq.URL.Host; got != defaultHost {
t.Fatalf("URL.Host = %q, want %q", got, defaultHost)
}
if got := outReq.Host; got != defaultHost {
t.Fatalf("Host = %q, want %q", got, defaultHost)
}
return &http.Response{
StatusCode: http.StatusNoContent,
Header: http.Header{"Content-Type": {"text/plain"}},
Body: io.NopCloser(strings.NewReader("")),
}, nil
}}
_, _, _, gotErr := roundTripCapturedRequest(req, transport, ch, defaultHost, true)
if gotErr != nil {
t.Fatalf("roundTripCapturedRequest() error = %v", gotErr)
}
}
func TestNewConnectRequest(t *testing.T) {
t.Parallel()
now := time.Now().Add(-time.Second)
req := &http.Request{Method: http.MethodConnect, Host: "example.com:443"}
got := newConnectRequest(req, now)
if got.Method != http.MethodConnect {
t.Fatalf("Method = %q, want CONNECT", got.Method)
}
if got.Host != req.Host || got.URL != req.Host || got.RawURL != req.Host {
t.Fatalf("connect request host/url/raw mismatch: %#v", got)
}
if !got.Pending || got.Failed {
t.Fatalf("Pending/Failed = (%v,%v), want (true,false)", got.Pending, got.Failed)
}
if got.Status != -1 {
t.Fatalf("Status = %d, want -1", got.Status)
}
if got.StartTime != now {
t.Fatalf("StartTime = %v, want %v", got.StartTime, now)
}
if got.ID == uuid.Nil {
t.Fatal("ID must be non-zero UUID")
}
}
func TestStartFinishFailRequestEvents(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 4)
req := model.Request{
ID: uuid.New(),
Method: http.MethodGet,
RawURL: "http://example.com/a",
StartTime: time.Now().Add(-3 * time.Millisecond),
Pending: true,
}
startRequest(ch, req)
events := drainEvents(t, ch, 1, time.Second)
startEv := events[0]
if startEv.Type != model.EventTypeRequestStarted {
t.Fatalf("start event type = %s, want %s", startEv.Type, model.EventTypeRequestStarted)
}
if startEv.Request.Pending != true {
t.Fatalf("start request pending = %v, want true", startEv.Request.Pending)
}
finishRequest(ch, req, http.StatusOK)
events = drainEvents(t, ch, 1, time.Second)
finishEv := events[0]
if finishEv.Type != model.EventTypeRequestFinished {
t.Fatalf("finish event type = %s, want %s", finishEv.Type, model.EventTypeRequestFinished)
}
if finishEv.Request.Pending {
t.Fatal("finished request should not be pending")
}
if finishEv.Request.Failed {
t.Fatal("finished request should not be failed")
}
if finishEv.Request.Status != http.StatusOK {
t.Fatalf("finished status = %d, want %d", finishEv.Request.Status, http.StatusOK)
}
failRequest(ch, req, http.StatusBadGateway, "upstream error")
events = drainEvents(t, ch, 1, time.Second)
failEv := events[0]
if failEv.Type != model.EventTypeRequestFailed {
t.Fatalf("fail event type = %s, want %s", failEv.Type, model.EventTypeRequestFailed)
}
if failEv.Request.Pending {
t.Fatal("failed request should not be pending")
}
if !failEv.Request.Failed {
t.Fatal("failed request should be marked failed")
}
if failEv.Request.Status != http.StatusBadGateway {
t.Fatalf("failed status = %d, want %d", failEv.Request.Status, http.StatusBadGateway)
}
}

View File

@ -0,0 +1,138 @@
package proxy
import (
"bufio"
"bytes"
"io"
"net"
"strings"
"testing"
"time"
)
type captureConn struct {
bytes.Buffer
}
func (c *captureConn) Read(_ []byte) (int, error) { return 0, io.EOF }
func (c *captureConn) Close() error { return nil }
func (c *captureConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (c *captureConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (c *captureConn) SetDeadline(_ time.Time) error { return nil }
func (c *captureConn) SetReadDeadline(_ time.Time) error { return nil }
func (c *captureConn) SetWriteDeadline(_ time.Time) error { return nil }
type failWriteConn struct{}
func (failWriteConn) Read(_ []byte) (int, error) { return 0, io.EOF }
func (failWriteConn) Write(_ []byte) (int, error) { return 0, io.ErrClosedPipe }
func (failWriteConn) Close() error { return nil }
func (failWriteConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (failWriteConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (failWriteConn) SetDeadline(_ time.Time) error { return nil }
func (failWriteConn) SetReadDeadline(_ time.Time) error { return nil }
func (failWriteConn) SetWriteDeadline(_ time.Time) error { return nil }
type trackingBody struct {
data *bytes.Reader
readN int
closed bool
}
func (b *trackingBody) Read(p []byte) (int, error) {
n, err := b.data.Read(p)
b.readN += n
return n, err
}
func (b *trackingBody) Close() error {
b.closed = true
return nil
}
func TestWriteConnectEstablished(t *testing.T) {
t.Parallel()
t.Run("writes directly to raw conn", func(t *testing.T) {
t.Parallel()
conn := &captureConn{}
if err := writeConnectEstablished(conn, nil); err != nil {
t.Fatalf("writeConnectEstablished() error = %v", err)
}
if got, want := conn.String(), "HTTP/1.1 200 Connection Established\r\n\r\n"; got != want {
t.Fatalf("raw write = %q, want %q", got, want)
}
})
t.Run("writes and flushes with buffered readWriter", func(t *testing.T) {
t.Parallel()
conn := &captureConn{}
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bufio.NewWriter(conn))
if err := writeConnectEstablished(conn, rw); err != nil {
t.Fatalf("writeConnectEstablished() error = %v", err)
}
if got, want := conn.String(), "HTTP/1.1 200 Connection Established\r\n\r\n"; got != want {
t.Fatalf("buffered write = %q, want %q", got, want)
}
})
t.Run("returns flush error", func(t *testing.T) {
t.Parallel()
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bufio.NewWriter(errWriter{}))
err := writeConnectEstablished(&captureConn{}, rw)
if err == nil {
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
}
})
t.Run("returns buffered write error when writer already failed", func(t *testing.T) {
t.Parallel()
bw := bufio.NewWriter(errWriter{})
_ = bw.Flush() // set sticky error to force WriteString error path
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bw)
err := writeConnectEstablished(&captureConn{}, rw)
if err == nil {
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
}
})
t.Run("returns raw conn write error", func(t *testing.T) {
t.Parallel()
err := writeConnectEstablished(failWriteConn{}, nil)
if err == nil {
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
}
})
}
func TestDiscardAndCloseBody(t *testing.T) {
t.Parallel()
t.Run("nil body is safe", func(t *testing.T) {
t.Parallel()
discardAndCloseBody(nil)
})
t.Run("closes body and discards at most limit", func(t *testing.T) {
t.Parallel()
payload := bytes.Repeat([]byte("x"), maxDiscardBodyBytes+128)
body := &trackingBody{data: bytes.NewReader(payload)}
discardAndCloseBody(body)
if !body.closed {
t.Fatal("body was not closed")
}
if body.readN != maxDiscardBodyBytes {
t.Fatalf("bytes read = %d, want %d", body.readN, maxDiscardBodyBytes)
}
})
}

View File

@ -0,0 +1,235 @@
package proxy
import (
"context"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"termtap.dev/internal/model"
)
// NOTE: Run these tests with -race; they cover concurrent state.
type closeTrackingConn struct {
closed bool
mu sync.Mutex
}
func (c *closeTrackingConn) Read(_ []byte) (int, error) { return 0, net.ErrClosed }
func (c *closeTrackingConn) Write(b []byte) (int, error) { return len(b), nil }
func (c *closeTrackingConn) Close() error { c.mu.Lock(); c.closed = true; c.mu.Unlock(); return nil }
func (c *closeTrackingConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (c *closeTrackingConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (c *closeTrackingConn) SetDeadline(_ time.Time) error { return nil }
func (c *closeTrackingConn) SetReadDeadline(_ time.Time) error { return nil }
func (c *closeTrackingConn) SetWriteDeadline(_ time.Time) error { return nil }
func (c *closeTrackingConn) isClosed() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.closed
}
func TestNewProxyServer_Success(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
ch := make(chan model.Event, 16)
ps, err := NewProxyServer("127.0.0.1:0", ch)
if err != nil {
t.Fatalf("NewProxyServer() error = %v", err)
}
t.Cleanup(func() { Destroy(ps, ch) })
if ps.Listener == nil || *ps.Listener == nil {
t.Fatal("Listener is nil")
}
if ps.Server == nil {
t.Fatal("Server is nil")
}
if ps.Server.Handler == nil {
t.Fatal("Server.Handler is nil")
}
if !strings.HasPrefix(ps.Url, "http://") {
t.Fatalf("Url = %q, want http:// prefix", ps.Url)
}
if !ps.CAReady {
t.Fatal("CAReady = false, want true")
}
if ps.CACertPath == "" {
t.Fatal("CACertPath should not be empty")
}
if got, want := filepath.Dir(ps.CACertPath), filepath.Join(configRoot, caDirName); got != want {
t.Fatalf("CACertPath dir = %q, want %q", got, want)
}
if _, statErr := os.Stat(ps.CACertPath); statErr != nil {
t.Fatalf("expected CA cert on disk, stat error = %v", statErr)
}
if ps.Conns == nil {
t.Fatal("Conns map is nil")
}
}
func TestNewProxyServer_CACreatedFlagAcrossRuns(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
ch := make(chan model.Event, 16)
ps1, err := NewProxyServer("127.0.0.1:0", ch)
if err != nil {
t.Fatalf("first NewProxyServer() error = %v", err)
}
Destroy(ps1, ch)
ps2, err := NewProxyServer("127.0.0.1:0", ch)
if err != nil {
t.Fatalf("second NewProxyServer() error = %v", err)
}
t.Cleanup(func() { Destroy(ps2, ch) })
if !ps1.CACreated {
t.Fatal("first NewProxyServer should report CACreated=true")
}
if ps2.CACreated {
t.Fatal("second NewProxyServer should report CACreated=false when CA already exists")
}
}
func TestNewProxyServer_ErrorWhenCASetupFails(t *testing.T) {
t.Setenv("XDG_CONFIG_HOME", "")
t.Setenv("HOME", "")
ch := make(chan model.Event, 8)
ps, err := NewProxyServer("127.0.0.1:0", ch)
if err == nil {
Destroy(ps, ch)
t.Fatal("NewProxyServer() error = nil, want non-nil")
}
if ps != nil {
t.Fatalf("proxy server = %#v, want nil", ps)
}
}
func TestNewProxyServer_ErrorWhenListenFails(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
occupied, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen() error = %v", err)
}
t.Cleanup(func() { _ = occupied.Close() })
addr := occupied.Addr().String()
ch := make(chan model.Event, 8)
ps, gotErr := NewProxyServer(addr, ch)
if gotErr == nil {
Destroy(ps, ch)
t.Fatal("NewProxyServer() error = nil, want non-nil")
}
if ps != nil {
t.Fatalf("proxy server = %#v, want nil", ps)
}
}
func TestDestroy(t *testing.T) {
t.Parallel()
t.Run("nil-safe", func(t *testing.T) {
t.Parallel()
Destroy(nil, make(chan model.Event, 1))
})
t.Run("emits ProxyStopped when server exists", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 2)
ps := &model.ProxyServer{Server: &http.Server{}, Conns: make(map[net.Conn]struct{})}
Destroy(ps, ch)
select {
case ev := <-ch:
if ev.Type != model.EventTypeProxyStopped {
t.Fatalf("event type = %s, want %s", ev.Type, model.EventTypeProxyStopped)
}
case <-time.After(time.Second):
t.Fatal("timeout waiting for ProxyStopped event")
}
})
}
func TestConnectionTrackingHelpers(t *testing.T) {
t.Parallel()
t.Run("track and untrack are nil-safe", func(t *testing.T) {
t.Parallel()
trackConnection(nil, nil)
untrackConnection(nil, nil)
closeTrackedConnections(nil)
})
t.Run("track/untrack mutate connection set", func(t *testing.T) {
t.Parallel()
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
c := &closeTrackingConn{}
trackConnection(ps, c)
if _, ok := ps.Conns[c]; !ok {
t.Fatal("connection not tracked")
}
untrackConnection(ps, c)
if _, ok := ps.Conns[c]; ok {
t.Fatal("connection still tracked after untrack")
}
})
t.Run("closeTrackedConnections closes all tracked conns", func(t *testing.T) {
t.Parallel()
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
c1 := &closeTrackingConn{}
c2 := &closeTrackingConn{}
ps.Conns[c1] = struct{}{}
ps.Conns[c2] = struct{}{}
closeTrackedConnections(ps)
if !c1.isClosed() || !c2.isClosed() {
t.Fatalf("expected all tracked conns closed, got c1=%v c2=%v", c1.isClosed(), c2.isClosed())
}
})
}
func TestDestroy_ShutdownsServerContext(t *testing.T) {
t.Parallel()
// Basic smoke test: ensure a real server can be shut down via Destroy.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen error = %v", err)
}
srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("ok")) })}
go func() {
_ = srv.Serve(ln)
}()
ch := make(chan model.Event, 2)
ps := &model.ProxyServer{Server: srv, Conns: make(map[net.Conn]struct{})}
Destroy(ps, ch)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil && err != http.ErrServerClosed {
t.Fatalf("server should be closed after Destroy, got shutdown error %v", err)
}
}

View File

@ -0,0 +1,77 @@
package proxy
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"testing"
"time"
"termtap.dev/internal/model"
)
func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event {
t.Helper()
events := make([]model.Event, 0, n)
deadline := time.After(timeout)
for len(events) < n {
select {
case ev := <-ch:
events = append(events, ev)
case <-deadline:
t.Fatalf("timeout waiting for %d events, got %d", n, len(events))
}
}
return events
}
func hasEventType(events []model.Event, typ model.EventType) bool {
for _, ev := range events {
if ev.Type == typ {
return true
}
}
return false
}
func newTestCA(t *testing.T) *CertificateAuthority {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("GenerateKey() error = %v", err)
}
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "test-ca"},
NotBefore: now.Add(-time.Minute),
NotAfter: now.Add(time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
IsCA: true,
BasicConstraintsValid: true,
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("CreateCertificate() error = %v", err)
}
cert, err := x509.ParseCertificate(der)
if err != nil {
t.Fatalf("ParseCertificate() error = %v", err)
}
return &CertificateAuthority{
cert: cert,
key: key,
leafCert: make(map[string]*tls.Certificate),
}
}

View File

@ -0,0 +1,256 @@
package proxy
import (
"bufio"
"context"
"errors"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"testing"
"github.com/google/uuid"
)
type timeoutErr struct{}
func (timeoutErr) Error() string { return "timeout" }
func (timeoutErr) Timeout() bool { return true }
func (timeoutErr) Temporary() bool { return true }
type errWriter struct{}
func (errWriter) Write(_ []byte) (int, error) {
return 0, errors.New("write failed")
}
func TestCanDisplayContent(t *testing.T) {
t.Parallel()
tests := []struct {
name string
contentType string
want bool
}{
{name: "empty content type", contentType: "", want: false},
{name: "text type", contentType: "text/plain", want: true},
{name: "json type", contentType: "application/json", want: true},
{name: "xml suffix", contentType: "application/problem+xml", want: true},
{name: "unknown binary", contentType: "application/octet-stream", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := canDisplayContent(tt.contentType); got != tt.want {
t.Fatalf("canDisplayContent(%q) = %v, want %v", tt.contentType, got, tt.want)
}
})
}
}
func TestFormatHeaders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
headers http.Header
want string
}{
{
name: "empty returns none",
headers: http.Header{},
want: "<none>",
},
{
name: "sorts keys stably",
headers: http.Header{
"B-Key": {"b1"},
"A-Key": {"a1", "a2"},
},
want: `A-Key="a1,a2", B-Key="b1"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := formatHeaders(tt.headers); got != tt.want {
t.Fatalf("formatHeaders() = %q, want %q", got, tt.want)
}
})
}
}
func TestGetEndOfUUID(t *testing.T) {
t.Parallel()
id := uuid.MustParse("123e4567-e89b-12d3-a456-426614174000")
if got, want := getEndOfUUID(id), "426614174000"; got != want {
t.Fatalf("getEndOfUUID() = %q, want %q", got, want)
}
}
func TestStatusFromUpstreamError(t *testing.T) {
t.Parallel()
newReq := func(ctx context.Context) *http.Request {
reqURL, err := url.Parse("http://example.com")
if err != nil {
t.Fatalf("url parse failed: %v", err)
}
return (&http.Request{Method: http.MethodGet, URL: reqURL}).WithContext(ctx)
}
tests := []struct {
name string
req *http.Request
resp *http.Response
err error
want int
}{
{
name: "prefers upstream response status",
req: newReq(context.Background()),
resp: &http.Response{StatusCode: http.StatusTeapot},
err: errors.New("ignored"),
want: http.StatusTeapot,
},
{
name: "context canceled maps to bad gateway",
req: func() *http.Request {
ctx, cancel := context.WithCancel(context.Background())
cancel()
return newReq(ctx)
}(),
err: context.Canceled,
want: http.StatusBadGateway,
},
{
name: "deadline exceeded maps to gateway timeout",
req: newReq(context.Background()),
err: context.DeadlineExceeded,
want: http.StatusGatewayTimeout,
},
{
name: "net timeout maps to gateway timeout",
req: newReq(context.Background()),
err: timeoutErr{},
want: http.StatusGatewayTimeout,
},
{
name: "default maps to bad gateway",
req: newReq(context.Background()),
err: errors.New("dial failed"),
want: http.StatusBadGateway,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := statusFromUpstreamError(tt.req, tt.resp, tt.err); got != tt.want {
t.Fatalf("statusFromUpstreamError() = %d, want %d", got, tt.want)
}
})
}
}
func TestNewUpstreamTransport(t *testing.T) {
t.Parallel()
got := newUpstreamTransport()
transport, ok := got.(*http.Transport)
if !ok {
t.Fatalf("newUpstreamTransport() type = %T, want *http.Transport", got)
}
if transport.Proxy != nil {
t.Fatal("newUpstreamTransport() Proxy must be nil")
}
}
func TestWritePlainHTTPError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
status int
writer *bufio.Writer
wantErr bool
wantStatus int
}{
{
name: "writes valid response and flushes",
status: http.StatusBadGateway,
writer: bufio.NewWriter(&strings.Builder{}),
wantErr: false,
wantStatus: http.StatusBadGateway,
},
{
name: "returns write error",
status: http.StatusBadGateway,
writer: bufio.NewWriter(errWriter{}),
wantErr: true,
},
{
name: "returns response write error when writer already failed",
status: http.StatusBadGateway,
writer: bufio.NewWriter(errWriter{}),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var sb strings.Builder
w := tt.writer
if !tt.wantErr {
w = bufio.NewWriter(&sb)
}
err := writePlainHTTPError(w, tt.status)
if tt.name == "returns response write error when writer already failed" {
_ = w.Flush() // set sticky writer error so resp.Write fails immediately
err = writePlainHTTPError(w, tt.status)
}
if (err != nil) != tt.wantErr {
t.Fatalf("writePlainHTTPError() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
return
}
resp, readErr := http.ReadResponse(bufio.NewReader(strings.NewReader(sb.String())), &http.Request{Method: http.MethodGet})
if readErr != nil {
t.Fatalf("ReadResponse() error = %v", readErr)
}
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Fatalf("status = %d, want %d", resp.StatusCode, tt.wantStatus)
}
if gotCT := resp.Header.Get("Content-Type"); gotCT != "text/plain; charset=utf-8" {
t.Fatalf("Content-Type = %q, want %q", gotCT, "text/plain; charset=utf-8")
}
wantBody := http.StatusText(tt.status)
if gotCL := resp.Header.Get("Content-Length"); gotCL != strconv.Itoa(len(wantBody)) {
t.Fatalf("Content-Length = %q, want %q", gotCL, strconv.Itoa(len(wantBody)))
}
body, bodyErr := io.ReadAll(resp.Body)
if bodyErr != nil {
t.Fatalf("ReadAll(body) error = %v", bodyErr)
}
if string(body) != wantBody {
t.Fatalf("body = %q, want %q", string(body), wantBody)
}
})
}
}

View File

@ -0,0 +1,361 @@
package tui
import (
"errors"
"net/http"
"testing"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/google/uuid"
"termtap.dev/internal/model"
)
func TestNewModelDefaults(t *testing.T) {
t.Parallel()
ch := make(chan model.Event)
m := NewModel(ch, Controls{})
if m.channel != ch {
t.Fatal("channel not set")
}
if len(m.events) != 0 || len(m.requests) != 0 {
t.Fatal("events/requests should initialize empty")
}
if m.width != 0 || m.height != 0 {
t.Fatal("width/height should initialize zero")
}
if m.showEvents || m.showStd || m.showSearch || m.restarting {
t.Fatal("toggle flags should initialize false")
}
}
func TestInitBatchesEventAndTick(t *testing.T) {
t.Parallel()
ch := make(chan model.Event)
m := NewModel(ch, Controls{})
cmd := m.Init()
if cmd == nil {
t.Fatal("Init() returned nil cmd")
}
if _, ok := cmd().(tea.BatchMsg); !ok {
t.Fatalf("Init cmd message type = %T, want tea.BatchMsg", cmd())
}
}
func TestWaitForEvent(t *testing.T) {
t.Parallel()
t.Run("returns EventMsg when channel has value", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 1)
ch <- model.Event{Type: model.EventTypeWarn, Body: "hello"}
msg := waitForEvent(ch)()
ev, ok := msg.(EventMsg)
if !ok {
t.Fatalf("msg type = %T, want EventMsg", msg)
}
if ev.value.Body != "hello" {
t.Fatalf("event body = %q, want %q", ev.value.Body, "hello")
}
})
t.Run("returns ErrMsg when channel closed", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event)
close(ch)
msg := waitForEvent(ch)()
if _, ok := msg.(ErrMsg); !ok {
t.Fatalf("msg type = %T, want ErrMsg", msg)
}
})
}
func TestMessagesCommands(t *testing.T) {
t.Parallel()
t.Run("restartCmd nil restart returns nil", func(t *testing.T) {
t.Parallel()
if cmd := restartCmd(nil); cmd != nil {
t.Fatal("restartCmd(nil) should return nil")
}
})
t.Run("restartCmd wraps restart result", func(t *testing.T) {
t.Parallel()
wantErr := errors.New("boom")
msg := restartCmd(func() error { return wantErr })()
rm, ok := msg.(RestartResultMsg)
if !ok {
t.Fatalf("msg type = %T, want RestartResultMsg", msg)
}
if !errors.Is(rm.err, wantErr) {
t.Fatalf("restart result error = %v, want %v", rm.err, wantErr)
}
})
t.Run("tickCmd emits TickMsg", func(t *testing.T) {
t.Parallel()
cmd := tickCmd()
if cmd == nil {
t.Fatal("tickCmd returned nil")
}
msgCh := make(chan tea.Msg, 1)
go func() { msgCh <- cmd() }()
select {
case msg := <-msgCh:
if _, ok := msg.(TickMsg); !ok {
t.Fatalf("msg type = %T, want TickMsg", msg)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("timeout waiting for tick message")
}
})
}
func TestUpdate(t *testing.T) {
t.Parallel()
t.Run("window size updates dimensions", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
next, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
got := next.(Model)
if got.width != 120 || got.height != 40 {
t.Fatalf("dimensions = (%d,%d), want (120,40)", got.width, got.height)
}
})
t.Run("tick updates now and reschedules only with pending requests", func(t *testing.T) {
t.Parallel()
now := time.Now()
m1 := NewModel(make(chan model.Event), Controls{})
next1, cmd1 := m1.Update(TickMsg{Now: now})
got1 := next1.(Model)
if !got1.now.Equal(now) {
t.Fatal("tick should update now")
}
if cmd1 != nil {
t.Fatal("tick without pending requests should not reschedule")
}
m2 := NewModel(make(chan model.Event), Controls{})
m2.requests = append(m2.requests, model.Request{ID: uuid.New(), Pending: true})
_, cmd2 := m2.Update(TickMsg{Now: now})
if cmd2 == nil {
t.Fatal("tick with pending requests should reschedule")
}
})
t.Run("key handling toggles and quit", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
next, quitCmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("q")})
_ = next
if quitCmd == nil {
t.Fatal("q should return quit cmd")
}
nextCtrlC, quitCtrlC := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
_ = nextCtrlC
if quitCtrlC == nil {
t.Fatal("ctrl+c should return quit cmd")
}
next2, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("e")})
if !next2.(Model).showEvents {
t.Fatal("e should toggle showEvents")
}
next3, _ := next2.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("o")})
if !next3.(Model).showStd {
t.Fatal("o should toggle showStd")
}
next4, _ := next3.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("/")})
if !next4.(Model).showSearch {
t.Fatal("/ should enable search")
}
next5, _ := next4.(Model).Update(tea.KeyMsg{Type: tea.KeyEsc})
if next5.(Model).showSearch {
t.Fatal("esc should disable search")
}
next6, cmd6 := next5.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("x")})
if cmd6 != nil {
t.Fatal("unknown key should not return command")
}
if next6.(Model).showEvents != next5.(Model).showEvents ||
next6.(Model).showStd != next5.(Model).showStd ||
next6.(Model).showSearch != next5.(Model).showSearch {
t.Fatal("unknown key should not alter toggle state")
}
})
t.Run("restart key guarded by state and control fn", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlR})
if cmd != nil {
t.Fatal("ctrl+r with nil restart control should return nil cmd")
}
if next.(Model).restarting {
t.Fatal("restarting should remain false when restart control missing")
}
m2 := NewModel(make(chan model.Event), Controls{Restart: func() error { return nil }})
next2, cmd2 := m2.Update(tea.KeyMsg{Type: tea.KeyCtrlR})
if cmd2 == nil {
t.Fatal("ctrl+r with restart control should return cmd")
}
if !next2.(Model).restarting {
t.Fatal("restarting should be true after ctrl+r")
}
next3, cmd3 := next2.(Model).Update(tea.KeyMsg{Type: tea.KeyCtrlR})
if cmd3 != nil {
t.Fatal("ctrl+r while restarting should return nil cmd")
}
if !next3.(Model).restarting {
t.Fatal("restarting should stay true while guarded")
}
})
t.Run("ErrMsg pushes warning event", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
next, _ := m.Update(ErrMsg{err: errors.New("closed")})
got := next.(Model)
if len(got.events) != 1 {
t.Fatalf("event len = %d, want 1", len(got.events))
}
if got.events[0].Type != model.EventTypeWarn {
t.Fatalf("event type = %s, want %s", got.events[0].Type, model.EventTypeWarn)
}
})
t.Run("RestartResultMsg clears restarting and warns on error", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
m.restarting = true
next1, _ := m.Update(RestartResultMsg{err: nil})
if next1.(Model).restarting {
t.Fatal("restarting should clear on restart result")
}
next2, _ := m.Update(RestartResultMsg{err: errors.New("fail")})
got2 := next2.(Model)
if len(got2.events) != 1 || got2.events[0].Type != model.EventTypeWarn {
t.Fatalf("expected warn event on restart error, got %#v", got2.events)
}
})
t.Run("EventMsg updates events and requests", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 2)
m := NewModel(ch, Controls{})
reqID := uuid.New()
startEv := EventMsg{value: model.Event{Type: model.EventTypeRequestStarted, Request: model.Request{ID: reqID, Method: http.MethodGet, Pending: true}}}
next1, cmd1 := m.Update(startEv)
if cmd1 == nil {
t.Fatal("EventMsg should return wait/tick cmd")
}
got1 := next1.(Model)
if len(got1.events) != 1 || len(got1.requests) != 1 {
t.Fatalf("expected one event and one request, got events=%d requests=%d", len(got1.events), len(got1.requests))
}
finishReq := got1.requests[0]
finishReq.Pending = false
finishReq.Status = 200
finishEv := EventMsg{value: model.Event{Type: model.EventTypeRequestFinished, Request: finishReq}}
next2, cmd2 := got1.Update(finishEv)
got2 := next2.(Model)
if got2.requests[0].Pending {
t.Fatal("request should be updated to non-pending")
}
if got2.requests[0].Status != 200 {
t.Fatalf("request status = %d, want 200", got2.requests[0].Status)
}
if cmd2 == nil {
t.Fatal("expected waitForEvent cmd after finished request")
}
ch <- model.Event{Type: model.EventTypeWarn, Body: "next"}
msg := cmd2()
if _, ok := msg.(EventMsg); !ok {
t.Fatalf("cmd2 message type = %T, want EventMsg", msg)
}
})
}
func TestModelHelpers(t *testing.T) {
t.Parallel()
t.Run("pushEvent trims to maxEvents", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
for i := 0; i < maxEvents+5; i++ {
m.pushEvent(model.Event{Body: "x"})
}
if len(m.events) != maxEvents {
t.Fatalf("events len = %d, want %d", len(m.events), maxEvents)
}
})
t.Run("createRequest ignores CONNECT and trims", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
m.createRequest(model.Request{Method: http.MethodConnect})
if len(m.requests) != 0 {
t.Fatal("CONNECT request should be ignored")
}
for i := 0; i < maxRequests+3; i++ {
m.createRequest(model.Request{ID: uuid.New(), Method: http.MethodGet})
}
if len(m.requests) != maxRequests {
t.Fatalf("requests len = %d, want %d", len(m.requests), maxRequests)
}
})
t.Run("updateRequest updates only matching request", func(t *testing.T) {
t.Parallel()
id1 := uuid.New()
id2 := uuid.New()
m := NewModel(make(chan model.Event), Controls{})
m.requests = []model.Request{{ID: id1, Status: 100}, {ID: id2, Status: 101}}
m.updateRequest(model.Request{ID: id2, Status: 202})
if m.requests[0].Status != 100 || m.requests[1].Status != 202 {
t.Fatalf("unexpected statuses after update: %#v", m.requests)
}
})
t.Run("hasPendingRequests true and false", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
if m.hasPendingRequests() {
t.Fatal("empty model should not have pending requests")
}
m.requests = []model.Request{{ID: uuid.New(), Pending: false}, {ID: uuid.New(), Pending: true}}
if !m.hasPendingRequests() {
t.Fatal("expected pending requests to be true")
}
})
}

View File

@ -0,0 +1,245 @@
package tui
import (
"strings"
"testing"
"time"
"github.com/google/uuid"
"termtap.dev/internal/model"
)
func TestFormatDuration(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in time.Duration
want string
}{
{name: "pending zero", in: 0, want: "PENDING"},
{name: "microseconds", in: 750 * time.Microsecond, want: "750us"},
{name: "milliseconds", in: 20 * time.Millisecond, want: "20ms"},
{name: "seconds", in: 11 * time.Second, want: "11.00s"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := formatDuration(tt.in); got != tt.want {
t.Fatalf("formatDuration(%v) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestTruncate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
s string
max int
want string
}{
{name: "short unchanged", s: "abc", max: 3, want: "abc"},
{name: "max small no ellipsis", s: "abcdef", max: 3, want: "abc"},
{name: "ellipsis", s: "abcdef", max: 5, want: "ab..."},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := truncate(tt.s, tt.max); got != tt.want {
t.Fatalf("truncate(%q,%d) = %q, want %q", tt.s, tt.max, got, tt.want)
}
})
}
}
func TestClampRendered(t *testing.T) {
t.Parallel()
if got := clampRendered("abcdef", 0); got != "" {
t.Fatalf("clampRendered max=0 = %q, want empty", got)
}
if got := clampRendered("abc", 10); got != "abc" {
t.Fatalf("clampRendered no truncation = %q, want %q", got, "abc")
}
if got := clampRendered("abcdef", 4); !strings.Contains(got, "...") {
t.Fatalf("clampRendered truncation should include ellipsis, got %q", got)
}
}
func TestGetEventColor(t *testing.T) {
t.Parallel()
theme := newTheme()
tests := []struct {
name string
typ model.EventType
want string
}{
{name: "session", typ: model.EventTypeSessionStarted, want: theme.EventSession.Render("x")},
{name: "proxy", typ: model.EventTypeProxyStarted, want: theme.EventProxy.Render("x")},
{name: "request in flight", typ: model.EventTypeRequestStarted, want: theme.EventRequestInFlight.Render("x")},
{name: "request success", typ: model.EventTypeRequestFinished, want: theme.EventSuccess.Render("x")},
{name: "warn", typ: model.EventTypeWarn, want: theme.EventWarn.Render("x")},
{name: "error", typ: model.EventTypeRequestFailed, want: theme.EventError.Render("x")},
{name: "fatal", typ: model.EventTypeFatal, want: theme.EventFatal.Render("x")},
{name: "default", typ: model.EventType("UnknownType"), want: theme.EventDefault.Render("x")},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := getEventColor(theme, tt.typ).Render("x")
if got != tt.want {
t.Fatalf("unexpected style for %s", tt.typ)
}
})
}
}
func TestViewAndPaneStructure(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
t.Run("view returns raw pane when unset size", func(t *testing.T) {
t.Parallel()
got := m.View()
if got != m.renderAppPane() {
t.Fatal("View should return raw pane when width/height are unset")
}
})
t.Run("renderAppPane line count matches height", func(t *testing.T) {
t.Parallel()
m2 := m
m2.width = 80
m2.height = 12
got := m2.renderAppPane()
if got == "height of request and details did not match" || got == "height of screen does not match terminal height" {
t.Fatalf("unexpected renderAppPane invariant error: %q", got)
}
if lines := strings.Count(got, "\n") + 1; lines != m2.height {
t.Fatalf("line count = %d, want %d", lines, m2.height)
}
})
t.Run("renderAppPane supports toggles", func(t *testing.T) {
t.Parallel()
m2 := m
m2.width = 90
m2.height = 14
m2.showEvents = true
m2.showStd = true
m2.showSearch = true
got := m2.renderAppPane()
if got == "height of request and details did not match" || got == "height of screen does not match terminal height" {
t.Fatalf("unexpected renderAppPane invariant error with toggles: %q", got)
}
})
t.Run("view applies configured terminal height", func(t *testing.T) {
t.Parallel()
m2 := m
m2.width = 70
m2.height = 10
got := m2.View()
if lines := strings.Count(got, "\n") + 1; lines < m2.height {
t.Fatalf("View line count = %d, want at least %d", lines, m2.height)
}
})
}
func TestPaneRenderersAndStatusBar(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
m.width = 100
m.height = 12
m.requests = []model.Request{
{ID: uuid.New(), Method: "GET", Host: "a", URL: "/a", Status: 200, Duration: 5 * time.Millisecond},
{ID: uuid.New(), Method: "POST", Host: "b", URL: "/b", Status: 500, Duration: 10 * time.Millisecond, Failed: true},
}
m.events = []model.Event{
{Type: model.EventTypeWarn, Body: "warn"},
{Type: model.EventTypeProcessStdout, Body: "out"},
{Type: model.EventTypeProcessStderr, Body: "err"},
}
status := m.renderStatusBar(100)
if !strings.Contains(status, "2 reqs") || !strings.Contains(status, "1 err") {
t.Fatalf("status bar missing expected counters: %q", status)
}
search := m.renderSearchPane(20, 3)
if len(search) != 3 {
t.Fatalf("search pane len = %d, want 3", len(search))
}
for i, line := range search {
if len(line) != 20 {
t.Fatalf("search pane line %d len = %d, want %d", i, len(line), 20)
}
}
requests := m.renderRequestPane(50, 4)
if len(requests) != 4 {
t.Fatalf("request pane len = %d, want 4", len(requests))
}
details := m.renderDetailsPane(30, 4)
if len(details) != 4 {
t.Fatalf("details pane len = %d, want 4", len(details))
}
events := m.renderEventsPane(60, 4)
if len(events) != 4 {
t.Fatalf("events pane len = %d, want 4", len(events))
}
if strings.Contains(strings.Join(events, "\n"), "out") || strings.Contains(strings.Join(events, "\n"), "err") {
t.Fatal("events pane should filter stdout/stderr events")
}
std := m.renderStdPane(60, 4)
if len(std) != 4 {
t.Fatalf("std pane len = %d, want 4", len(std))
}
joined := strings.Join(std, "\n")
if !strings.Contains(joined, "out") || !strings.Contains(joined, "err") {
t.Fatal("std pane should include stdout/stderr logs")
}
}
func TestRenderEventsPane_ErrorAndPIDBranches(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
m.events = []model.Event{
{Type: model.EventTypeWarn, Body: "old"},
{Type: model.EventTypeRequestFailed, Body: "failed body", PID: 123, Time: time.Now()},
{Type: model.EventTypeFatal, Body: "fatal body", Time: time.Now()},
}
lines := m.renderEventsPane(60, 3)
if len(lines) != 3 {
t.Fatalf("events pane len = %d, want 3", len(lines))
}
joined := strings.Join(lines, "\n")
if !strings.Contains(joined, "123") {
t.Fatalf("expected PID to appear in events pane, got: %q", joined)
}
if !strings.Contains(joined, "failed body") {
t.Fatalf("expected failed body to appear in events pane, got: %q", joined)
}
if !strings.Contains(joined, "fatal body") {
t.Fatalf("expected fatal body to appear in events pane, got: %q", joined)
}
}