Merge branch 'feature/testing'

This commit is contained in:
Hayden Hargreaves 2026-04-23 19:47:58 -07:00
commit 51d526c2fe
28 changed files with 4976 additions and 29 deletions

45
TEST_COVERAGE_SUMMARY.md Normal file
View File

@ -0,0 +1,45 @@
# termtap Test Coverage Summary
Generated from:
```bash
go test -coverprofile=/tmp/termtap.cover ./...
go tool cover -func=/tmp/termtap.cover
```
## Package coverage snapshot
| Package | Coverage |
|---|---:|
| `cmd/tap` | 100.0% |
| `internal/app` | 98.1% |
| `internal/cli` | 93.4% |
| `internal/process` | 95.8% |
| `internal/proxy` | 90.0% |
| `internal/tui` | 96.2% |
| `examples/echo` | 0.0% (example app; intentionally not covered) |
| `internal/model` | no test files (pure data structs) |
Total statements in module: **77.9%**.
## Notable lower-coverage targets (production code)
- `internal/proxy/handlers.go:handleConnect` — 75.9%
- `internal/proxy/certs.go:writeFileAtomically` — 76.2%
- `internal/proxy/certs.go:load` — 90.5%
- `internal/proxy/certs.go:create` — 80.8%
- `internal/proxy/certs.go:IsTrustedBySystem` — 76.9%
- `internal/cli/run.go:runCert` — 87.5%
## Interpretation
- Core runtime paths (`internal/app`, `internal/process`, `internal/tui`) are high confidence.
- Proxy package has broad behavior coverage including HTTP and HTTPS MITM integration flow, and now clears 90% package coverage.
- CLI command routing and fatal/stdout/stderr seams are covered, including `Run` success/error branches.
- TUI pane rendering coverage now includes error/PID branch behavior.
## Next optional improvements
1. Add deeper CONNECT tunnel write/flush/read failure permutations inside `handleConnect` loop.
2. Add additional deterministic edge cases for `IsTrustedBySystem` non-unknown-authority verify failures.
3. Optionally add non-unix signal file coverage in CI matrix (currently unix-focused).

49
cmd/tap/main_test.go Normal file
View File

@ -0,0 +1,49 @@
package main
import (
"bytes"
"io"
"os"
"strings"
"testing"
"time"
)
func TestMain_SmokeInvokesCLIRun(t *testing.T) {
origArgs := os.Args
t.Cleanup(func() { os.Args = origArgs })
os.Args = []string{"tap", "invalid"}
origStderr := os.Stderr
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("stderr pipe error: %v", err)
}
t.Cleanup(func() {
os.Stderr = origStderr
_ = r.Close()
})
os.Stderr = w
outCh := make(chan string, 1)
go func() {
var buf bytes.Buffer
_, _ = io.Copy(&buf, r)
outCh <- buf.String()
}()
main()
_ = w.Close()
os.Stderr = origStderr
var got string
select {
case got = <-outCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for stderr capture")
}
if !strings.Contains(got, "usage:") {
t.Fatalf("stderr missing usage output, got: %q", got)
}
}

View File

@ -0,0 +1,109 @@
package app
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"termtap.dev/internal/model"
)
// NOTE: Run with -race; this validates cross-component concurrency.
func TestSessionIntegration_LifecycleAndRequestEvents(t *testing.T) {
addr := freeTCPAddr(t)
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
if err != nil {
t.Fatalf("StartSession() error = %v", err)
}
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
t.Cleanup(upstream.Close)
startupEvents := collectUntilTypes(t, s.Events, []model.EventType{
model.EventTypeProxyStarting,
model.EventTypeProcessStarting,
model.EventTypeProcessStarted,
}, 3*time.Second)
proxyURL, err := url.Parse(s.proxy.Url)
if err != nil {
t.Fatalf("url.Parse(proxy) error = %v", err)
}
client := &http.Client{
Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)},
Timeout: 3 * time.Second,
}
resp, err := client.Get(upstream.URL + "/session")
if err != nil {
t.Fatalf("proxy request error = %v", err)
}
_ = resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
}
requestEvents := collectUntilTypes(t, s.Events, []model.EventType{
model.EventTypeRequestStarted,
model.EventTypeRequestFinished,
}, 3*time.Second)
s.Stop()
select {
case <-s.proc.Done:
case <-time.After(4 * time.Second):
t.Fatal("timeout waiting for process stop")
}
shutdownEvents := collectUntilTypes(t, s.Events, []model.EventType{
model.EventTypeProxyStopped,
model.EventTypeProcessSignaled,
model.EventTypeProcessExited,
}, 4*time.Second)
if !isBefore(startupEvents, model.EventTypeProcessStarting, model.EventTypeProcessStarted) {
t.Fatalf("expected %s before %s in startup events: %#v", model.EventTypeProcessStarting, model.EventTypeProcessStarted, startupEvents)
}
if !isBefore(requestEvents, model.EventTypeRequestStarted, model.EventTypeRequestFinished) {
t.Fatalf("expected %s before %s in request events: %#v", model.EventTypeRequestStarted, model.EventTypeRequestFinished, requestEvents)
}
if !isBefore(shutdownEvents, model.EventTypeProcessSignaled, model.EventTypeProcessExited) {
t.Fatalf("expected %s before %s in shutdown events: %#v", model.EventTypeProcessSignaled, model.EventTypeProcessExited, shutdownEvents)
}
}
func TestSessionIntegration_RestartProcessEmitsLifecycleEvents(t *testing.T) {
addr := freeTCPAddr(t)
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
if err != nil {
t.Fatalf("StartSession() error = %v", err)
}
t.Cleanup(func() { s.Stop() })
_ = collectUntilTypes(t, s.Events, []model.EventType{
model.EventTypeProcessStarted,
model.EventTypeProxyStarting,
}, 3*time.Second)
if err := s.RestartProcess(); err != nil {
t.Fatalf("RestartProcess() error = %v", err)
}
events := collectUntilTypes(t, s.Events, []model.EventType{
model.EventTypeProcessRestarting,
model.EventTypeProcessSignaled,
model.EventTypeProcessExited,
model.EventTypeProcessStarting,
model.EventTypeProcessStarted,
}, 4*time.Second)
if !isBefore(events, model.EventTypeProcessRestarting, model.EventTypeProcessStarted) {
t.Fatalf("expected restarting before process started, got %#v", events)
}
}

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"os/exec" "os/exec"
"sync"
"syscall" "syscall"
"time" "time"
@ -11,6 +12,10 @@ import (
"termtap.dev/internal/process" "termtap.dev/internal/process"
) )
var killEscalationDelay = 1500 * time.Millisecond
var scheduleKillEscalation = time.AfterFunc
var killEscalationMu sync.RWMutex
func StartProcess(cmd model.Command, addr string, ch chan<- model.Event) (*model.Process, error) { func StartProcess(cmd model.Command, addr string, ch chan<- model.Event) (*model.Process, error) {
ch <- model.Event{ ch <- model.Event{
Time: time.Now().Local(), Time: time.Now().Local(),
@ -44,12 +49,16 @@ func StopProcess(proc *model.Process, ch chan<- model.Event, sig syscall.Signal)
_ = process.SignalProcess(proc.Exec, sig) _ = process.SignalProcess(proc.Exec, sig)
go func() { killEscalationMu.RLock()
time.Sleep(1500 * time.Millisecond) delay := killEscalationDelay
scheduler := scheduleKillEscalation
killEscalationMu.RUnlock()
scheduler(delay, func() {
if process.ProcessAlive(proc.Exec) { if process.ProcessAlive(proc.Exec) {
_ = process.SignalProcess(proc.Exec, syscall.SIGKILL) _ = process.SignalProcess(proc.Exec, syscall.SIGKILL)
} }
}() })
} }
func waitForProcessExit(proc *model.Process, ch chan<- model.Event) { func waitForProcessExit(proc *model.Process, ch chan<- model.Event) {

View File

@ -0,0 +1,223 @@
package app
import (
"os/exec"
"syscall"
"testing"
"time"
"termtap.dev/internal/model"
)
func TestStartProcess(t *testing.T) {
t.Parallel()
t.Run("starts process and marks running", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 32)
proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 0.2"}}, "127.0.0.1:8080", ch)
if err != nil {
t.Fatalf("StartProcess() error = %v", err)
}
t.Cleanup(func() {
StopProcess(proc, ch, syscall.SIGTERM)
select {
case <-proc.Done:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting process done in cleanup")
}
})
if proc == nil || proc.Exec == nil {
t.Fatal("StartProcess() returned nil process/exec")
}
events := drainEvents(t, ch, 2, time.Second)
if !hasType(events, model.EventTypeProcessStarting) {
t.Fatalf("missing %s event", model.EventTypeProcessStarting)
}
if !hasType(events, model.EventTypeProcessStarted) {
t.Fatalf("missing %s event", model.EventTypeProcessStarted)
}
})
t.Run("returns error when exec start fails", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
proc, err := StartProcess(model.Command{Name: "definitely-not-a-real-command"}, "127.0.0.1:8080", ch)
if err == nil {
if proc != nil && proc.Exec != nil && proc.Exec.Process != nil {
_ = proc.Exec.Process.Kill()
}
t.Fatal("StartProcess() error = nil, want non-nil")
}
events := drainEvents(t, ch, 1, time.Second)
if !hasType(events, model.EventTypeProcessStarting) {
t.Fatalf("missing %s event", model.EventTypeProcessStarting)
}
})
}
func TestStopProcess(t *testing.T) {
t.Parallel()
t.Run("nil guards", func(t *testing.T) {
t.Parallel()
StopProcess(nil, make(chan model.Event, 1), syscall.SIGTERM)
StopProcess(&model.Process{}, make(chan model.Event, 1), syscall.SIGTERM)
StopProcess(&model.Process{Exec: &exec.Cmd{}}, make(chan model.Event, 1), syscall.SIGTERM)
})
t.Run("emits signaled event", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 32)
proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, "127.0.0.1:8080", ch)
if err != nil {
t.Fatalf("StartProcess() error = %v", err)
}
StopProcess(proc, ch, syscall.SIGTERM)
if _, ok := waitForEventType(t, ch, model.EventTypeProcessSignaled, 2*time.Second); !ok {
t.Fatalf("did not receive %s event", model.EventTypeProcessSignaled)
}
select {
case <-proc.Done:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for process to exit after signal")
}
})
t.Run("schedules deterministic kill escalation hook", func(t *testing.T) {
t.Parallel()
origDelay := killEscalationDelay
origScheduler := scheduleKillEscalation
defer func() {
killEscalationMu.Lock()
killEscalationDelay = origDelay
scheduleKillEscalation = origScheduler
killEscalationMu.Unlock()
}()
ch := make(chan model.Event, 32)
proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, "127.0.0.1:8080", ch)
if err != nil {
t.Fatalf("StartProcess() error = %v", err)
}
killEscalationMu.Lock()
killEscalationDelay = 25 * time.Millisecond
scheduled := make(chan time.Duration, 1)
scheduleKillEscalation = func(d time.Duration, fn func()) *time.Timer {
scheduled <- d
go fn()
return nil
}
killEscalationMu.Unlock()
StopProcess(proc, ch, syscall.SIGTERM)
select {
case d := <-scheduled:
if d != killEscalationDelay {
t.Fatalf("scheduled delay = %v, want %v", d, killEscalationDelay)
}
case <-time.After(time.Second):
t.Fatal("kill escalation was not scheduled")
}
select {
case <-proc.Done:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for process to exit after deterministic escalation")
}
})
}
func TestWaitForProcessExit(t *testing.T) {
t.Parallel()
t.Run("nil guards are no-op", func(t *testing.T) {
t.Parallel()
waitForProcessExit(nil, make(chan model.Event, 1))
waitForProcessExit(&model.Process{}, make(chan model.Event, 1))
})
t.Run("normal exit emits process exited and closes done", func(t *testing.T) {
t.Parallel()
cmd := exec.Command("sh", "-c", "exit 0")
if err := cmd.Start(); err != nil {
t.Fatalf("cmd.Start() error = %v", err)
}
proc := &model.Process{Exec: cmd, Running: true, Done: make(chan struct{})}
ch := make(chan model.Event, 8)
waitForProcessExit(proc, ch)
if _, ok := waitForEventType(t, ch, model.EventTypeProcessExited, time.Second); !ok {
t.Fatalf("did not receive %s event", model.EventTypeProcessExited)
}
select {
case <-proc.Done:
case <-time.After(time.Second):
t.Fatal("Done channel was not closed")
}
})
t.Run("exit error path carries exit code", func(t *testing.T) {
t.Parallel()
cmd := exec.Command("sh", "-c", "exit 7")
if err := cmd.Start(); err != nil {
t.Fatalf("cmd.Start() error = %v", err)
}
proc := &model.Process{Exec: cmd, Running: true, Done: make(chan struct{})}
ch := make(chan model.Event, 8)
waitForProcessExit(proc, ch)
events := drainEvents(t, ch, 2, time.Second)
found := false
for _, ev := range events {
if ev.Type == model.EventTypeProcessExited && ev.ExitCode == 7 {
found = true
break
}
}
if !found {
t.Fatalf("expected ProcessExited with exit code 7, got %#v", events)
}
select {
case <-proc.Done:
case <-time.After(time.Second):
t.Fatal("Done channel was not closed")
}
})
t.Run("unexpected wait failure emits fatal", func(t *testing.T) {
t.Parallel()
proc := &model.Process{Exec: &exec.Cmd{}, Done: make(chan struct{})}
ch := make(chan model.Event, 8)
waitForProcessExit(proc, ch)
events := drainEvents(t, ch, 1, time.Second)
if events[0].Type != model.EventTypeFatal {
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeFatal)
}
select {
case <-proc.Done:
case <-time.After(time.Second):
t.Fatal("Done channel was not closed")
}
})
}

183
internal/app/proxy_test.go Normal file
View File

@ -0,0 +1,183 @@
package app
import (
"context"
"errors"
"net"
"net/http"
"testing"
"time"
"termtap.dev/internal/model"
)
type staticErrListener struct {
err error
}
func (l *staticErrListener) Accept() (net.Conn, error) { return nil, l.err }
func (l *staticErrListener) Close() error { return nil }
func (l *staticErrListener) Addr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080}
}
func TestStartProxy_NilGuards(t *testing.T) {
t.Parallel()
StartProxy(nil, make(chan model.Event, 1))
StartProxy(&model.ProxyServer{}, make(chan model.Event, 1))
StartProxy(&model.ProxyServer{Server: &http.Server{}}, make(chan model.Event, 1))
}
func TestStartProxy_EmitsStartingAndWarnWhenUntrustedCA(t *testing.T) {
t.Parallel()
listenErr := errors.New("accept failed")
var ln net.Listener = &staticErrListener{err: listenErr}
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{
Listener: &ln,
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
CAReady: true,
CATrusted: false,
CACreated: true,
CACertPath: "/tmp/test-ca.pem",
}
StartProxy(ps, ch)
events := drainEvents(t, ch, 3, time.Second)
if !hasType(events, model.EventTypeProxyStarting) {
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
}
if !hasType(events, model.EventTypeWarn) {
t.Fatalf("missing %s event", model.EventTypeWarn)
}
if !containsBody(events, "generated HTTPS interception CA") {
t.Fatalf("expected generated-CA warning body, got events: %#v", events)
}
if !hasType(events, model.EventTypeFatal) {
t.Fatalf("missing %s event", model.EventTypeFatal)
}
}
func TestStartProxy_WarnBodyForExistingUntrustedCA(t *testing.T) {
t.Parallel()
listenErr := errors.New("accept failed")
var ln net.Listener = &staticErrListener{err: listenErr}
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{
Listener: &ln,
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
CAReady: true,
CATrusted: false,
CACreated: false,
CACertPath: "/tmp/test-ca.pem",
}
StartProxy(ps, ch)
events := drainEvents(t, ch, 3, time.Second)
if !containsBody(events, "HTTPS interception CA available at") {
t.Fatalf("expected existing-CA warning body, got events: %#v", events)
}
}
func TestStartProxy_NoCAWarningWhenNotReady(t *testing.T) {
t.Parallel()
listenErr := errors.New("accept failed")
var ln net.Listener = &staticErrListener{err: listenErr}
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{
Listener: &ln,
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
CAReady: false,
CATrusted: false,
}
StartProxy(ps, ch)
events := drainEvents(t, ch, 2, time.Second)
if !hasType(events, model.EventTypeProxyStarting) {
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
}
if hasType(events, model.EventTypeWarn) {
t.Fatalf("unexpected %s event when CA is not ready: %#v", model.EventTypeWarn, events)
}
if !hasType(events, model.EventTypeFatal) {
t.Fatalf("missing %s event", model.EventTypeFatal)
}
}
func TestStartProxy_NoCAWarningWhenTrusted(t *testing.T) {
t.Parallel()
listenErr := errors.New("accept failed")
var ln net.Listener = &staticErrListener{err: listenErr}
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{
Listener: &ln,
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
CAReady: true,
CATrusted: true,
}
StartProxy(ps, ch)
events := drainEvents(t, ch, 2, time.Second)
if !hasType(events, model.EventTypeProxyStarting) {
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
}
if hasType(events, model.EventTypeWarn) {
t.Fatalf("unexpected %s event when CA is already trusted: %#v", model.EventTypeWarn, events)
}
if !hasType(events, model.EventTypeFatal) {
t.Fatalf("missing %s event", model.EventTypeFatal)
}
}
func TestStartProxy_SwallowsErrServerClosed(t *testing.T) {
t.Parallel()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen error = %v", err)
}
t.Cleanup(func() { _ = ln.Close() })
ch := make(chan model.Event, 8)
server := &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})}
ps := &model.ProxyServer{Listener: &ln, Server: server}
done := make(chan struct{})
go func() {
StartProxy(ps, ch)
close(done)
}()
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := server.Shutdown(shutdownCtx); err != nil {
t.Fatalf("shutdown error = %v", err)
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout waiting for StartProxy to return")
}
events := drainEvents(t, ch, 1, time.Second)
if !hasType(events, model.EventTypeProxyStarting) {
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
}
if hasType(events, model.EventTypeFatal) {
t.Fatalf("unexpected %s event", model.EventTypeFatal)
}
}

View File

@ -0,0 +1,247 @@
package app
import (
"errors"
"net"
"syscall"
"testing"
"time"
"termtap.dev/internal/model"
)
func TestStartSession(t *testing.T) {
t.Run("happy path creates proxy and process", func(t *testing.T) {
addr := freeTCPAddr(t)
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
if err != nil {
t.Fatalf("StartSession() error = %v", err)
}
if s == nil {
t.Fatal("StartSession() returned nil session")
}
if s.proxy == nil {
t.Fatal("session.proxy is nil")
}
if s.proc == nil {
t.Fatal("session.proc is nil")
}
s.Stop()
if s.proc != nil && s.proc.Done != nil {
select {
case <-s.proc.Done:
case <-time.After(4 * time.Second):
t.Fatal("timeout waiting for process stop")
}
}
})
t.Run("error when proxy creation fails", func(t *testing.T) {
t.Setenv("XDG_CONFIG_HOME", "")
t.Setenv("HOME", "")
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "true"}}, "127.0.0.1:0")
if err == nil {
if s != nil {
s.Stop()
}
t.Fatal("StartSession() error = nil, want non-nil")
}
if s != nil {
t.Fatalf("session = %#v, want nil", s)
}
})
t.Run("destroys proxy when process startup fails", func(t *testing.T) {
addr := freeTCPAddr(t)
s, err := StartSession(model.Command{Name: "definitely-not-a-real-command"}, addr)
if err == nil {
if s != nil {
s.Stop()
}
t.Fatal("StartSession() error = nil, want non-nil")
}
deadline := time.After(3 * time.Second)
ticker := time.NewTicker(25 * time.Millisecond)
defer ticker.Stop()
for {
ln, listenErr := net.Listen("tcp", addr)
if listenErr == nil {
_ = ln.Close()
break
}
select {
case <-deadline:
t.Fatalf("address %s did not become reusable, got err: %v", addr, listenErr)
case <-ticker.C:
}
}
})
t.Run("stop during restart stops new process and returns ErrSessionStopped", func(t *testing.T) {
s := &Session{
ch: make(chan model.Event, 64),
cmd: model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}},
addr: "127.0.0.1:0",
proc: &model.Process{Done: make(chan struct{})},
}
close(s.proc.Done)
s.restartMu.Lock()
errCh := make(chan error, 1)
stopDone := make(chan struct{})
go func() {
errCh <- s.RestartProcess()
}()
go func() {
s.Stop()
close(stopDone)
}()
s.restartMu.Unlock()
select {
case err := <-errCh:
if !errors.Is(err, ErrSessionStopped) {
t.Fatalf("error = %v, want %v", err, ErrSessionStopped)
}
case <-time.After(6 * time.Second):
t.Fatal("timeout waiting for RestartProcess result")
}
select {
case <-stopDone:
case <-time.After(time.Second):
t.Fatal("timeout waiting for Stop completion")
}
})
}
func TestSessionStop_Idempotent(t *testing.T) {
addr := freeTCPAddr(t)
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
if err != nil {
t.Fatalf("StartSession() error = %v", err)
}
s.Stop()
s.Stop()
if s.proc != nil && s.proc.Done != nil {
select {
case <-s.proc.Done:
case <-time.After(4 * time.Second):
t.Fatal("timeout waiting for process to stop")
}
}
}
func TestRestartProcess(t *testing.T) {
t.Run("nil session returns error", func(t *testing.T) {
var s *Session
err := s.RestartProcess()
if err == nil {
t.Fatal("RestartProcess() error = nil, want non-nil")
}
})
t.Run("stopped session returns ErrSessionStopped", func(t *testing.T) {
s := &Session{stopped: true}
err := s.RestartProcess()
if !errors.Is(err, ErrSessionStopped) {
t.Fatalf("error = %v, want %v", err, ErrSessionStopped)
}
})
t.Run("concurrent restart returns ErrRestartInProgress", func(t *testing.T) {
s := &Session{restarting: true}
err := s.RestartProcess()
if !errors.Is(err, ErrRestartInProgress) {
t.Fatalf("error = %v, want %v", err, ErrRestartInProgress)
}
})
t.Run("timeout waiting for stop returns error", func(t *testing.T) {
s := &Session{
ch: make(chan model.Event, 8),
cmd: model.Command{Name: "sh", Args: []string{"-c", "true"}},
addr: "127.0.0.1:0",
proc: &model.Process{Done: make(chan struct{})},
}
err := s.RestartProcess()
if err == nil {
t.Fatal("RestartProcess() error = nil, want non-nil")
}
})
t.Run("successful restart swaps process", func(t *testing.T) {
addr := freeTCPAddr(t)
oldProc := &model.Process{Done: make(chan struct{})}
close(oldProc.Done)
s := &Session{
ch: make(chan model.Event, 32),
cmd: model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}},
addr: addr,
proc: oldProc,
}
err := s.RestartProcess()
if err != nil {
t.Fatalf("RestartProcess() error = %v", err)
}
if _, ok := waitForEventType(t, s.ch, model.EventTypeProcessRestarting, time.Second); !ok {
t.Fatalf("did not receive %s event", model.EventTypeProcessRestarting)
}
if s.proc == nil {
t.Fatal("session process is nil after restart")
}
if s.proc == oldProc {
t.Fatal("session process pointer was not swapped")
}
StopProcess(s.proc, s.ch, syscall.SIGTERM)
select {
case <-s.proc.Done:
case <-time.After(4 * time.Second):
t.Fatal("timeout waiting for restarted process to stop")
}
})
}
func TestWaitForProcessStop(t *testing.T) {
t.Parallel()
t.Run("nil process and nil done are true", func(t *testing.T) {
t.Parallel()
if !waitForProcessStop(nil, time.Millisecond) {
t.Fatal("waitForProcessStop(nil) = false, want true")
}
if !waitForProcessStop(&model.Process{}, time.Millisecond) {
t.Fatal("waitForProcessStop(process with nil done) = false, want true")
}
})
t.Run("done channel closes before timeout", func(t *testing.T) {
t.Parallel()
proc := &model.Process{Done: make(chan struct{})}
close(proc.Done)
if !waitForProcessStop(proc, time.Second) {
t.Fatal("waitForProcessStop() = false, want true")
}
})
t.Run("timeout returns false", func(t *testing.T) {
t.Parallel()
proc := &model.Process{Done: make(chan struct{})}
if waitForProcessStop(proc, 20*time.Millisecond) {
t.Fatal("waitForProcessStop() = true, want false")
}
})
}

View File

@ -0,0 +1,111 @@
package app
import (
"net"
"strings"
"testing"
"time"
"termtap.dev/internal/model"
)
func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event {
t.Helper()
out := make([]model.Event, 0, n)
deadline := time.After(timeout)
for len(out) < n {
select {
case ev := <-ch:
out = append(out, ev)
case <-deadline:
t.Fatalf("timeout waiting for %d events, got %d", n, len(out))
}
}
return out
}
func hasType(events []model.Event, typ model.EventType) bool {
for _, ev := range events {
if ev.Type == typ {
return true
}
}
return false
}
func containsBody(events []model.Event, part string) bool {
for _, ev := range events {
if strings.Contains(ev.Body, part) {
return true
}
}
return false
}
func waitForEventType(t *testing.T, ch <-chan model.Event, typ model.EventType, timeout time.Duration) (model.Event, bool) {
t.Helper()
deadline := time.After(timeout)
for {
select {
case ev := <-ch:
if ev.Type == typ {
return ev, true
}
case <-deadline:
return model.Event{}, false
}
}
}
func collectUntilTypes(t *testing.T, ch <-chan model.Event, required []model.EventType, timeout time.Duration) []model.Event {
t.Helper()
need := make(map[model.EventType]bool, len(required))
for _, typ := range required {
need[typ] = true
}
events := make([]model.Event, 0, len(required)+8)
deadline := time.After(timeout)
for len(need) > 0 {
select {
case ev := <-ch:
events = append(events, ev)
delete(need, ev.Type)
case <-deadline:
t.Fatalf("timeout waiting for required events: remaining=%v, collected=%#v", need, events)
}
}
return events
}
func isBefore(events []model.Event, first, second model.EventType) bool {
firstIdx := -1
secondIdx := -1
for i, ev := range events {
if ev.Type == first && firstIdx == -1 {
firstIdx = i
}
if ev.Type == second && secondIdx == -1 {
secondIdx = i
}
}
return firstIdx >= 0 && secondIdx >= 0 && firstIdx < secondIdx
}
func freeTCPAddr(t *testing.T) string {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen error = %v", err)
}
addr := ln.Addr().String()
_ = ln.Close()
return addr
}

View File

@ -2,6 +2,7 @@ package cli
import ( import (
"fmt" "fmt"
"io"
"log" "log"
"os" "os"
"runtime" "runtime"
@ -16,6 +17,23 @@ import (
// This should be configurable at some point, just in case they build on 8080 // This should be configurable at some point, just in case they build on 8080
const proxy_addr = "127.0.0.1:8080" const proxy_addr = "127.0.0.1:8080"
var fatalExit = log.Fatalln
var stdoutWriter io.Writer = stdioRef{isErr: false}
var stderrWriter io.Writer = stdioRef{isErr: true}
var startSessionFn = app.StartSession
var runTUIFn = tui.Run
type stdioRef struct {
isErr bool
}
func (w stdioRef) Write(p []byte) (int, error) {
if w.isErr {
return os.Stderr.Write(p)
}
return os.Stdout.Write(p)
}
func Run(args []string) { func Run(args []string) {
if len(args) >= 2 && args[1] == "cert" { if len(args) >= 2 && args[1] == "cert" {
runCert() runCert()
@ -28,9 +46,10 @@ func Run(args []string) {
return return
} }
session, err := app.StartSession(cmd, proxy_addr) session, err := startSessionFn(cmd, proxy_addr)
if err != nil { if err != nil {
log.Fatalln(err) fatalExit(err)
return
} }
defer session.Stop() defer session.Stop()
@ -38,8 +57,9 @@ func Run(args []string) {
Restart: session.RestartProcess, Restart: session.RestartProcess,
} }
if err := tui.Run(session.Events, controls); err != nil { if err := runTUIFn(session.Events, controls); err != nil {
log.Fatalln(err) fatalExit(err)
return
} }
} }
@ -67,49 +87,50 @@ usage:
tap run -- <command> [args...] tap run -- <command> [args...]
` `
fmt.Fprintln(os.Stderr, helpText) fmt.Fprintln(stderrWriter, helpText)
} }
func runCert() { func runCert() {
ca, err := proxy.EnsureCertificateAuthority() ca, err := proxy.EnsureCertificateAuthority()
if err != nil { if err != nil {
log.Fatalln(err) fatalExit(err)
return
} }
certPath := ca.CertPath() certPath := ca.CertPath()
quotedCertPath := strconv.Quote(certPath) quotedCertPath := strconv.Quote(certPath)
fmt.Printf("Certificate path: %s\n", certPath) fmt.Fprintf(stdoutWriter, "Certificate path: %s\n", certPath)
if ca.WasCreated() { if ca.WasCreated() {
fmt.Println("Created a new local HTTPS interception CA.") fmt.Fprintln(stdoutWriter, "Created a new local HTTPS interception CA.")
} else { } else {
fmt.Println("Using existing local HTTPS interception CA.") fmt.Fprintln(stdoutWriter, "Using existing local HTTPS interception CA.")
} }
trusted, err := ca.IsTrustedBySystem() trusted, err := ca.IsTrustedBySystem()
if err != nil { if err != nil {
fmt.Printf("System trust check failed: %v\n", err) fmt.Fprintf(stdoutWriter, "System trust check failed: %v\n", err)
} else if trusted { } else if trusted {
fmt.Println("System trust store: trusted") fmt.Fprintln(stdoutWriter, "System trust store: trusted")
} else { } else {
fmt.Println("System trust store: not trusted") fmt.Fprintln(stdoutWriter, "System trust store: not trusted")
} }
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {
fmt.Println("Install this certificate into your OS or client trust store to inspect HTTPS traffic.") fmt.Fprintln(stdoutWriter, "Install this certificate into your OS or client trust store to inspect HTTPS traffic.")
return return
} }
fmt.Println() fmt.Fprintln(stdoutWriter)
fmt.Println("Trust instructions (Linux):") fmt.Fprintln(stdoutWriter, "Trust instructions (Linux):")
fmt.Println("Debian/Ubuntu:") fmt.Fprintln(stdoutWriter, "Debian/Ubuntu:")
fmt.Printf(" sudo cp %s /usr/local/share/ca-certificates/termtap.crt\n", quotedCertPath) fmt.Fprintf(stdoutWriter, " sudo cp %s /usr/local/share/ca-certificates/termtap.crt\n", quotedCertPath)
fmt.Println(" sudo update-ca-certificates") fmt.Fprintln(stdoutWriter, " sudo update-ca-certificates")
fmt.Println("Fedora/RHEL/CentOS:") fmt.Fprintln(stdoutWriter, "Fedora/RHEL/CentOS:")
fmt.Printf(" sudo cp %s /etc/pki/ca-trust/source/anchors/termtap.crt\n", quotedCertPath) fmt.Fprintf(stdoutWriter, " sudo cp %s /etc/pki/ca-trust/source/anchors/termtap.crt\n", quotedCertPath)
fmt.Println(" sudo update-ca-trust") fmt.Fprintln(stdoutWriter, " sudo update-ca-trust")
fmt.Println("Arch:") fmt.Fprintln(stdoutWriter, "Arch:")
fmt.Printf(" sudo trust anchor %s\n", quotedCertPath) fmt.Fprintf(stdoutWriter, " sudo trust anchor %s\n", quotedCertPath)
fmt.Println() fmt.Fprintln(stdoutWriter)
fmt.Println("Quick curl test:") fmt.Fprintln(stdoutWriter, "Quick curl test:")
fmt.Printf(" curl --proxy http://%s --cacert %s https://example.com\n", proxy_addr, quotedCertPath) fmt.Fprintf(stdoutWriter, " curl --proxy http://%s --cacert %s https://example.com\n", proxy_addr, quotedCertPath)
} }

356
internal/cli/run_test.go Normal file
View File

@ -0,0 +1,356 @@
package cli
import (
"bytes"
"errors"
"io"
"os"
"runtime"
"strings"
"testing"
"time"
"termtap.dev/internal/app"
"termtap.dev/internal/model"
"termtap.dev/internal/tui"
)
func TestParseCommand(t *testing.T) {
tests := []struct {
name string
args []string
ok bool
nameWant string
argsWant []string
}{
{name: "too few args", args: []string{"tap"}, ok: false},
{name: "missing run token", args: []string{"tap", "oops", "--", "echo"}, ok: false},
{name: "missing separator", args: []string{"tap", "run", "echo"}, ok: false},
{name: "single command", args: []string{"tap", "run", "--", "echo"}, ok: true, nameWant: "echo", argsWant: []string{}},
{name: "command with args", args: []string{"tap", "run", "--", "curl", "-s", "https://example.com"}, ok: true, nameWant: "curl", argsWant: []string{"-s", "https://example.com"}},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
cmd, ok := parseCommand(tt.args)
if ok != tt.ok {
t.Fatalf("ok = %v, want %v", ok, tt.ok)
}
if !tt.ok {
return
}
if cmd.Name != tt.nameWant {
t.Fatalf("cmd.Name = %q, want %q", cmd.Name, tt.nameWant)
}
if strings.Join(cmd.Args, "|") != strings.Join(tt.argsWant, "|") {
t.Fatalf("cmd.Args = %#v, want %#v", cmd.Args, tt.argsWant)
}
})
}
}
func TestDisplayHelpWritesToStderr(t *testing.T) {
_, stderr := captureOutput(t, func() {
displayHelp()
})
if !strings.Contains(stderr, "tap cert") || !strings.Contains(stderr, "tap run --") {
t.Fatalf("stderr missing usage text: %q", stderr)
}
}
func TestRun_InvalidCommandShowsHelp(t *testing.T) {
_, stderr := captureOutput(t, func() {
Run([]string{"tap", "wat"})
})
if !strings.Contains(stderr, "usage:") {
t.Fatalf("stderr missing usage output: %q", stderr)
}
}
func TestRun_RoutesCertCommand(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
stdout, _ := captureOutput(t, func() {
Run([]string{"tap", "cert"})
})
if !strings.Contains(stdout, "Certificate path:") {
t.Fatalf("stdout missing certificate path output: %q", stdout)
}
}
func TestRun_StartSessionFailureCallsFatalExit(t *testing.T) {
restore := installRunSeams(t)
defer restore()
startSessionFn = func(model.Command, string) (*app.Session, error) {
return nil, errors.New("boom")
}
called := installFatalSpy(t)
Run([]string{"tap", "run", "--", "definitely-not-a-real-command"})
if !*called {
t.Fatal("expected fatalExit to be called on StartSession failure")
}
}
func TestRun_TUIFailureCallsFatalExit(t *testing.T) {
restore := installRunSeams(t)
defer restore()
startSessionFn = func(model.Command, string) (*app.Session, error) {
return &app.Session{Events: make(chan model.Event)}, nil
}
runTUIFn = func(<-chan model.Event, tui.Controls) error {
return errors.New("tui failed")
}
called := installFatalSpy(t)
Run([]string{"tap", "run", "--", "echo"})
if !*called {
t.Fatal("expected fatalExit to be called on tui failure")
}
}
func TestRun_SuccessPathDoesNotCallFatal(t *testing.T) {
restore := installRunSeams(t)
defer restore()
startSessionFn = func(model.Command, string) (*app.Session, error) {
return &app.Session{Events: make(chan model.Event)}, nil
}
runTUIFn = func(<-chan model.Event, tui.Controls) error {
return nil
}
called := installFatalSpy(t)
Run([]string{"tap", "run", "--", "echo"})
if *called {
t.Fatal("fatalExit should not be called on success path")
}
}
func TestRunCert_EnsureCAFailureCallsFatalExit(t *testing.T) {
t.Setenv("XDG_CONFIG_HOME", "")
t.Setenv("HOME", "")
called := installFatalSpy(t)
runCert()
if !*called {
t.Fatal("expected fatalExit to be called when EnsureCertificateAuthority fails")
}
}
func TestRunCertOutputContract(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
stdout, _ := captureOutput(t, func() {
runCert()
})
if !strings.Contains(stdout, "Certificate path:") {
t.Fatalf("stdout missing certificate path line: %q", stdout)
}
if !strings.Contains(stdout, "local HTTPS interception CA") {
t.Fatalf("stdout missing CA create/existing line: %q", stdout)
}
if !strings.Contains(stdout, "System trust store:") && !strings.Contains(stdout, "System trust check failed:") {
t.Fatalf("stdout missing trust check line: %q", stdout)
}
if runtime.GOOS == "linux" {
if !strings.Contains(stdout, "Trust instructions (Linux):") {
t.Fatalf("stdout missing linux trust instructions: %q", stdout)
}
}
}
func TestRunCert_CreatedThenExistingMessage(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
firstOut, _ := captureOutput(t, func() {
runCert()
})
if !strings.Contains(firstOut, "Created a new local HTTPS interception CA.") {
t.Fatalf("first run should indicate created CA, got: %q", firstOut)
}
secondOut, _ := captureOutput(t, func() {
runCert()
})
if !strings.Contains(secondOut, "Using existing local HTTPS interception CA.") {
t.Fatalf("second run should indicate existing CA, got: %q", secondOut)
}
}
func captureOutput(t *testing.T, fn func()) (stdout string, stderr string) {
t.Helper()
origStdoutWriter := stdoutWriter
origStderrWriter := stderrWriter
t.Cleanup(func() {
stdoutWriter = origStdoutWriter
stderrWriter = origStderrWriter
})
origStdout := os.Stdout
origStderr := os.Stderr
outR, outW, err := os.Pipe()
if err != nil {
t.Fatalf("stdout pipe error: %v", err)
}
errR, errW, err := os.Pipe()
if err != nil {
_ = outR.Close()
_ = outW.Close()
t.Fatalf("stderr pipe error: %v", err)
}
os.Stdout = outW
os.Stderr = errW
stdoutWriter = outW
stderrWriter = errW
outCh := make(chan string, 1)
errCh := make(chan string, 1)
go func() {
var buf bytes.Buffer
_, _ = io.Copy(&buf, outR)
outCh <- buf.String()
}()
go func() {
var buf bytes.Buffer
_, _ = io.Copy(&buf, errR)
errCh <- buf.String()
}()
fn()
_ = outW.Close()
_ = errW.Close()
stdoutWriter = origStdoutWriter
stderrWriter = origStderrWriter
os.Stdout = origStdout
os.Stderr = origStderr
select {
case stdout = <-outCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for stdout capture")
}
select {
case stderr = <-errCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for stderr capture")
}
_ = outR.Close()
_ = errR.Close()
return stdout, stderr
}
func TestDisplayHelp_UsesInjectedStderrWriter(t *testing.T) {
var buf bytes.Buffer
orig := stderrWriter
t.Cleanup(func() { stderrWriter = orig })
stderrWriter = &buf
displayHelp()
if got := buf.String(); !strings.Contains(got, "usage:") {
t.Fatalf("help output missing usage, got: %q", got)
}
}
func TestRunCert_UsesInjectedStdoutWriter(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
var buf bytes.Buffer
orig := stdoutWriter
t.Cleanup(func() { stdoutWriter = orig })
stdoutWriter = &buf
runCert()
if got := buf.String(); !strings.Contains(got, "Certificate path:") {
t.Fatalf("cert output missing path line, got: %q", got)
}
}
func installRunSeams(t *testing.T) func() {
t.Helper()
origStartSession := startSessionFn
origRunTUI := runTUIFn
return func() {
startSessionFn = origStartSession
runTUIFn = origRunTUI
}
}
func installFatalSpy(t *testing.T) *bool {
t.Helper()
origFatal := fatalExit
called := false
fatalExit = func(v ...any) {
called = true
}
t.Cleanup(func() {
fatalExit = origFatal
})
return &called
}
func TestStdioRefWrite(t *testing.T) {
t.Run("writes to stdout", func(t *testing.T) {
assertStdioRefWrite(t, false, "hello")
})
t.Run("writes to stderr", func(t *testing.T) {
assertStdioRefWrite(t, true, "boom")
})
}
func assertStdioRefWrite(t *testing.T, isErr bool, payload string) {
t.Helper()
r, w, err := os.Pipe()
if err != nil {
t.Fatalf("pipe error: %v", err)
}
defer func() { _ = r.Close() }()
if isErr {
orig := os.Stderr
os.Stderr = w
defer func() { os.Stderr = orig }()
} else {
orig := os.Stdout
os.Stdout = w
defer func() { os.Stdout = orig }()
}
if _, err := (stdioRef{isErr: isErr}).Write([]byte(payload)); err != nil {
_ = w.Close()
t.Fatalf("stdioRef write error: %v", err)
}
_ = w.Close()
data, err := io.ReadAll(r)
if err != nil {
t.Fatalf("ReadAll(pipe) error: %v", err)
}
if got := string(data); got != payload {
t.Fatalf("pipe write = %q, want %q", got, payload)
}
}

View File

@ -0,0 +1,219 @@
package process
import (
"os"
"os/exec"
"reflect"
"strings"
"testing"
"time"
"termtap.dev/internal/model"
)
func TestCommandString(t *testing.T) {
t.Parallel()
tests := []struct {
name string
cmd model.Command
want string
}{
{
name: "empty args",
cmd: model.Command{Name: "go", Args: []string{}},
want: "go ",
},
{
name: "multiple args",
cmd: model.Command{Name: "curl", Args: []string{"-s", "https://example.com"}},
want: "curl -s https://example.com",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := CommandString(tt.cmd); got != tt.want {
t.Fatalf("CommandString() = %q, want %q", got, tt.want)
}
})
}
}
func TestInjectEnv(t *testing.T) {
t.Parallel()
cmd := exec.Command("sh", "-c", "true")
injectEnv(cmd, "127.0.0.1:8080")
mustContain := []string{
"HTTP_PROXY=http://127.0.0.1:8080",
"http_proxy=http://127.0.0.1:8080",
"HTTPS_PROXY=http://127.0.0.1:8080",
"https_proxy=http://127.0.0.1:8080",
"NO_PROXY=",
"no_proxy=",
}
for _, kv := range mustContain {
if !containsEnvEntry(cmd.Env, kv) {
t.Fatalf("injectEnv() missing env entry %q", kv)
}
}
}
func TestReadPipe(t *testing.T) {
t.Parallel()
t.Run("stdout lines emit stdout events", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 4)
input := strings.NewReader("line1\nline2\n")
readPipe(input, model.EventTypeProcessStdout, ch)
events := drainEvents(t, ch, 2, time.Second)
if events[0].Type != model.EventTypeProcessStdout || events[0].Body != "line1" {
t.Fatalf("event[0] = %#v, want stdout line1", events[0])
}
if events[1].Type != model.EventTypeProcessStdout || events[1].Body != "line2" {
t.Fatalf("event[1] = %#v, want stdout line2", events[1])
}
})
t.Run("stderr lines emit stderr events", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 4)
input := strings.NewReader("err1\nerr2\n")
readPipe(input, model.EventTypeProcessStderr, ch)
events := drainEvents(t, ch, 2, time.Second)
if events[0].Type != model.EventTypeProcessStderr || events[0].Body != "err1" {
t.Fatalf("event[0] = %#v, want stderr err1", events[0])
}
if events[1].Type != model.EventTypeProcessStderr || events[1].Body != "err2" {
t.Fatalf("event[1] = %#v, want stderr err2", events[1])
}
})
}
func TestUpdateStatus(t *testing.T) {
t.Parallel()
t.Run("nil process is no-op", func(t *testing.T) {
t.Parallel()
UpdateStatus(nil, true, make(chan model.Event, 1))
})
t.Run("state unchanged is no-op", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 1)
proc := &model.Process{Running: true}
UpdateStatus(proc, true, ch)
select {
case ev := <-ch:
t.Fatalf("unexpected event for unchanged state: %#v", ev)
default:
}
})
t.Run("emits started and stopped events with pid", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 2)
proc := &model.Process{
Exec: &exec.Cmd{Process: &os.Process{Pid: 4321}},
}
UpdateStatus(proc, true, ch)
events := drainEvents(t, ch, 1, time.Second)
started := events[0]
if started.Type != model.EventTypeProcessStarted {
t.Fatalf("started type = %s, want %s", started.Type, model.EventTypeProcessStarted)
}
if started.PID != 4321 {
t.Fatalf("started PID = %d, want %d", started.PID, 4321)
}
if !proc.Running {
t.Fatal("proc.Running = false, want true")
}
UpdateStatus(proc, false, ch)
events = drainEvents(t, ch, 1, time.Second)
stopped := events[0]
if stopped.Type != model.EventTypeProcessExited {
t.Fatalf("stopped type = %s, want %s", stopped.Type, model.EventTypeProcessExited)
}
if stopped.PID != 4321 {
t.Fatalf("stopped PID = %d, want %d", stopped.PID, 4321)
}
if proc.Running {
t.Fatal("proc.Running = true, want false")
}
})
}
func TestNewProcess(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
cmd := model.Command{Name: "sh", Args: []string{"-c", "printf test"}}
proc := NewProcess(cmd, "127.0.0.1:8080", ch)
if proc == nil {
t.Fatal("NewProcess() returned nil")
}
if !reflect.DeepEqual(proc.Command, cmd) {
t.Fatalf("process command = %#v, want %#v", proc.Command, cmd)
}
if proc.Exec == nil {
t.Fatal("process Exec is nil")
}
if proc.Running {
t.Fatal("new process should not be running")
}
if proc.Done == nil {
t.Fatal("Done channel is nil")
}
if got, want := proc.Exec.Args[0], "sh"; got != want {
t.Fatalf("Exec.Args[0] = %q, want %q", got, want)
}
if !containsEnvEntry(proc.Exec.Env, "HTTP_PROXY=http://127.0.0.1:8080") {
t.Fatal("process env missing injected HTTP_PROXY")
}
}
func containsEnvEntry(env []string, want string) bool {
for _, entry := range env {
if entry == want {
return true
}
}
return false
}
func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event {
t.Helper()
events := make([]model.Event, 0, n)
deadline := time.After(timeout)
for len(events) < n {
select {
case ev := <-ch:
events = append(events, ev)
case <-deadline:
t.Fatalf("timeout waiting for %d events; got %d", n, len(events))
}
}
return events
}

View File

@ -0,0 +1,147 @@
//go:build unix
package process
import (
"errors"
"os"
"os/exec"
"syscall"
"testing"
"time"
)
type customSignal struct{}
func (customSignal) String() string { return "custom" }
func (customSignal) Signal() {}
// NOTE: Run these tests with -race in CI for signal/process safety.
func TestConfigureProcessForSignals(t *testing.T) {
t.Parallel()
cmd := exec.Command("sh", "-c", "sleep 0.1")
configureProcessForSignals(cmd)
if cmd.SysProcAttr == nil {
t.Fatal("SysProcAttr is nil")
}
if !cmd.SysProcAttr.Setpgid {
t.Fatal("Setpgid = false, want true")
}
}
func TestSignalProcess_NilSafe(t *testing.T) {
t.Parallel()
if err := SignalProcess(nil, syscall.SIGTERM); err != nil {
t.Fatalf("SignalProcess(nil) error = %v, want nil", err)
}
cmd := &exec.Cmd{}
if err := SignalProcess(cmd, syscall.SIGTERM); err != nil {
t.Fatalf("SignalProcess(cmd without process) error = %v, want nil", err)
}
cmd.Process = &os.Process{Pid: 0}
if err := SignalProcess(cmd, syscall.SIGTERM); err != nil {
t.Fatalf("SignalProcess(pid<=0) error = %v, want nil", err)
}
}
func TestSignalProcess_ESRCHIsTreatedAsSuccess(t *testing.T) {
t.Parallel()
cmd := &exec.Cmd{Process: &os.Process{Pid: 999999}}
if err := SignalProcess(cmd, syscall.SIGTERM); err != nil {
t.Fatalf("SignalProcess() error = %v, want nil when process group not found", err)
}
}
func TestSignalProcess_UsesFallbackOnKillError(t *testing.T) {
t.Parallel()
cmd := exec.Command("sh", "-c", "sleep 5")
configureProcessForSignals(cmd)
if err := cmd.Start(); err != nil {
t.Fatalf("start command error = %v", err)
}
t.Cleanup(func() {
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
})
// Invalid signal causes syscall.Kill to fail with EINVAL, then fallback to cmd.Process.Signal.
err := SignalProcess(cmd, syscall.Signal(9999))
if err == nil {
t.Fatal("SignalProcess() error = nil, want non-nil for invalid signal")
}
if !(errors.Is(err, syscall.EINVAL) || errors.Is(err, os.ErrProcessDone)) {
// OS/process timing can vary; ensure we at least failed predictably.
t.Fatalf("SignalProcess() unexpected error: %v", err)
}
}
func TestSignalProcess_NonSyscallSignalUsesProcessSignal(t *testing.T) {
t.Parallel()
cmd := exec.Command("sh", "-c", "sleep 1")
configureProcessForSignals(cmd)
if err := cmd.Start(); err != nil {
t.Fatalf("start command error = %v", err)
}
t.Cleanup(func() {
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
})
err := SignalProcess(cmd, customSignal{})
if err == nil {
t.Fatal("SignalProcess(custom signal) error = nil, want non-nil")
}
if errors.Is(err, os.ErrProcessDone) {
return
}
if msg := err.Error(); msg == "" {
t.Fatalf("unexpected empty error for custom signal: %v", err)
}
}
func TestProcessAlive(t *testing.T) {
t.Parallel()
if ProcessAlive(nil) {
t.Fatal("ProcessAlive(nil) = true, want false")
}
if ProcessAlive(&exec.Cmd{}) {
t.Fatal("ProcessAlive(cmd without process) = true, want false")
}
cmd := exec.Command("sh", "-c", "sleep 0.2")
configureProcessForSignals(cmd)
if err := cmd.Start(); err != nil {
t.Fatalf("start command error = %v", err)
}
if !ProcessAlive(cmd) {
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
t.Fatal("ProcessAlive(running) = false, want true")
}
_ = cmd.Process.Kill()
_, _ = cmd.Process.Wait()
deadline := time.After(time.Second)
for {
if !ProcessAlive(cmd) {
return
}
select {
case <-deadline:
t.Fatal("ProcessAlive(exited) stayed true")
default:
}
}
}

View File

@ -0,0 +1,242 @@
package proxy
import (
"crypto/tls"
"encoding/pem"
"os"
"path/filepath"
"strings"
"testing"
)
func TestLoadOrCreateCertificateAuthority_RecreatesWhenKeyMissing(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
baseDir := filepath.Join(configRoot, caDirName)
if err := os.MkdirAll(baseDir, 0o700); err != nil {
t.Fatalf("MkdirAll() error = %v", err)
}
certPath := filepath.Join(baseDir, caCertName)
if err := os.WriteFile(certPath, []byte("stale-cert"), 0o600); err != nil {
t.Fatalf("WriteFile(cert) error = %v", err)
}
ca, err := loadOrCreateCertificateAuthority()
if err != nil {
t.Fatalf("loadOrCreateCertificateAuthority() error = %v", err)
}
if !ca.WasCreated() {
t.Fatal("WasCreated = false, want true when key is missing")
}
if _, err := os.Stat(filepath.Join(baseDir, caKeyName)); err != nil {
t.Fatalf("expected key file to be created, stat error = %v", err)
}
}
func TestLoadOrCreateCertificateAuthority_LoadErrorOnCorruptFiles(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
baseDir := filepath.Join(configRoot, caDirName)
if err := os.MkdirAll(baseDir, 0o700); err != nil {
t.Fatalf("MkdirAll() error = %v", err)
}
certPath := filepath.Join(baseDir, caCertName)
keyPath := filepath.Join(baseDir, caKeyName)
if err := os.WriteFile(certPath, []byte("not-a-pem"), 0o600); err != nil {
t.Fatalf("WriteFile(cert) error = %v", err)
}
if err := os.WriteFile(keyPath, []byte("not-a-pem"), 0o600); err != nil {
t.Fatalf("WriteFile(key) error = %v", err)
}
_, err := loadOrCreateCertificateAuthority()
if err == nil {
t.Fatal("loadOrCreateCertificateAuthority() error = nil, want non-nil")
}
}
func TestCertificateAuthorityLoad_ErrorPaths(t *testing.T) {
t.Parallel()
tests := []struct {
name string
certBytes []byte
keyBytes []byte
wantPart string
}{
{
name: "invalid cert pem",
certBytes: []byte("bad-cert"),
keyBytes: []byte("bad-key"),
wantPart: "decode ca cert pem",
},
{
name: "parse cert fails",
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("bogus")}),
keyBytes: []byte("bad-key"),
wantPart: "parse ca cert",
},
{
name: "invalid key pem",
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}),
keyBytes: []byte("bad-key"),
wantPart: "decode ca key pem",
},
{
name: "parse key fails",
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}),
keyBytes: pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: []byte("bogus")}),
wantPart: "parse ca key",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
ca := &CertificateAuthority{
certPath: filepath.Join(dir, caCertName),
keyPath: filepath.Join(dir, caKeyName),
leafCert: make(map[string]*tls.Certificate),
}
if err := os.WriteFile(ca.certPath, tt.certBytes, 0o600); err != nil {
t.Fatalf("write cert file error = %v", err)
}
if err := os.WriteFile(ca.keyPath, tt.keyBytes, 0o600); err != nil {
t.Fatalf("write key file error = %v", err)
}
err := ca.load()
if err == nil {
t.Fatal("load() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), tt.wantPart) {
t.Fatalf("load() error = %q, want contains %q", err.Error(), tt.wantPart)
}
})
}
}
func TestCertificateAuthorityCreate_ErrorWhenWritePathInvalid(t *testing.T) {
t.Parallel()
ca := &CertificateAuthority{
certPath: filepath.Join("/nope", "missing", "ca-cert.pem"),
keyPath: filepath.Join("/nope", "missing", "ca-key.pem"),
leafCert: make(map[string]*tls.Certificate),
}
err := ca.create()
if err == nil {
t.Fatal("create() error = nil, want non-nil")
}
}
func TestCertificateAuthorityCreate_WriteErrorPaths(t *testing.T) {
t.Parallel()
t.Run("write ca cert wraps error", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
badCertPath := filepath.Join(dir, "cert-as-dir")
if err := os.MkdirAll(badCertPath, 0o700); err != nil {
t.Fatalf("MkdirAll(cert dir) error = %v", err)
}
ca := &CertificateAuthority{
certPath: badCertPath,
keyPath: filepath.Join(dir, "ca-key.pem"),
leafCert: make(map[string]*tls.Certificate),
}
err := ca.create()
if err == nil {
t.Fatal("create() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "write ca cert") {
t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca cert")
}
})
t.Run("write ca key wraps error", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
badKeyPath := filepath.Join(dir, "key-as-dir")
if err := os.MkdirAll(badKeyPath, 0o700); err != nil {
t.Fatalf("MkdirAll(key dir) error = %v", err)
}
ca := &CertificateAuthority{
certPath: filepath.Join(dir, "ca-cert.pem"),
keyPath: badKeyPath,
leafCert: make(map[string]*tls.Certificate),
}
err := ca.create()
if err == nil {
t.Fatal("create() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "write ca key") {
t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca key")
}
})
}
func TestCertificateAuthorityLoad_ReadErrorPaths(t *testing.T) {
t.Parallel()
t.Run("read cert failure", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
ca := &CertificateAuthority{
certPath: filepath.Join(dir, "missing-cert.pem"),
keyPath: filepath.Join(dir, "missing-key.pem"),
leafCert: make(map[string]*tls.Certificate),
}
err := ca.load()
if err == nil {
t.Fatal("load() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "read ca cert") {
t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca cert")
}
})
t.Run("read key failure", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
goodCA := newTestCA(t)
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: goodCA.cert.Raw})
certPath := filepath.Join(dir, "cert.pem")
if err := os.WriteFile(certPath, certPEM, 0o600); err != nil {
t.Fatalf("WriteFile(cert) error = %v", err)
}
ca := &CertificateAuthority{
certPath: certPath,
keyPath: filepath.Join(dir, "missing-key.pem"),
leafCert: make(map[string]*tls.Certificate),
}
err := ca.load()
if err == nil {
t.Fatal("load() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "read ca key") {
t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca key")
}
})
}

View File

@ -0,0 +1,45 @@
package proxy
import (
"os"
"testing"
)
func TestLoadOrCreateCertificateAuthority_CreateThenLoad(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
ca1, err := loadOrCreateCertificateAuthority()
if err != nil {
t.Fatalf("first loadOrCreateCertificateAuthority() error = %v", err)
}
if ca1 == nil {
t.Fatal("first CA is nil")
}
if !ca1.WasCreated() {
t.Fatal("first CA should report WasCreated=true")
}
if ca1.CertPath() == "" {
t.Fatal("first CA CertPath is empty")
}
if _, err := os.Stat(ca1.CertPath()); err != nil {
t.Fatalf("first CA cert file missing: %v", err)
}
ca2, err := loadOrCreateCertificateAuthority()
if err != nil {
t.Fatalf("second loadOrCreateCertificateAuthority() error = %v", err)
}
if ca2 == nil {
t.Fatal("second CA is nil")
}
if ca2.WasCreated() {
t.Fatal("second CA should report WasCreated=false (loaded existing)")
}
if ca2.CertPath() != ca1.CertPath() {
t.Fatalf("cert path mismatch: first=%q second=%q", ca1.CertPath(), ca2.CertPath())
}
if ca2.cert == nil || ca2.key == nil {
t.Fatal("loaded CA should include cert and key")
}
}

View File

@ -0,0 +1,266 @@
package proxy
import (
"errors"
"math/big"
"os"
"path/filepath"
"testing"
)
func TestNormalizeCertHost(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in string
want string
}{
{name: "host and port", in: "example.com:443", want: "example.com"},
{name: "plain host", in: "example.com", want: "example.com"},
{name: "whitespace trims", in: " example.com:8443 ", want: "example.com"},
{name: "empty", in: " ", want: ""},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := normalizeCertHost(tt.in); got != tt.want {
t.Fatalf("normalizeCertHost(%q) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestRandSerialNumber(t *testing.T) {
t.Parallel()
serial, err := randSerialNumber()
if err != nil {
t.Fatalf("randSerialNumber() error = %v", err)
}
if serial == nil {
t.Fatal("serial is nil")
}
if serial.Sign() < 0 {
t.Fatalf("serial must be non-negative, got %v", serial)
}
limit := new(big.Int).Lsh(big.NewInt(1), 128)
if serial.Cmp(limit) >= 0 {
t.Fatalf("serial must be < 2^128, got %v", serial)
}
}
func TestWriteFileAtomically(t *testing.T) {
t.Parallel()
dir := t.TempDir()
path := filepath.Join(dir, "cert.pem")
if err := writeFileAtomically(path, []byte("first"), 0o600); err != nil {
t.Fatalf("first writeFileAtomically() error = %v", err)
}
if err := writeFileAtomically(path, []byte("second"), 0o600); err != nil {
t.Fatalf("second writeFileAtomically() error = %v", err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("ReadFile() error = %v", err)
}
if got, want := string(data), "second"; got != want {
t.Fatalf("file contents = %q, want %q", got, want)
}
info, err := os.Stat(path)
if err != nil {
t.Fatalf("Stat() error = %v", err)
}
if got := info.Mode().Perm(); got != 0o600 {
t.Fatalf("file permissions = %#o, want %#o", got, 0o600)
}
}
func TestCertificateAuthority_Basics(t *testing.T) {
t.Parallel()
var nilCA *CertificateAuthority
if got := nilCA.CertPath(); got != "" {
t.Fatalf("nil CertPath() = %q, want empty", got)
}
if got := nilCA.WasCreated(); got {
t.Fatalf("nil WasCreated() = %v, want false", got)
}
ca := newTestCA(t)
ca.certPath = "/tmp/test-ca.pem"
ca.wasCreated = true
if got, want := ca.CertPath(), "/tmp/test-ca.pem"; got != want {
t.Fatalf("CertPath() = %q, want %q", got, want)
}
if !ca.WasCreated() {
t.Fatal("WasCreated() = false, want true")
}
}
func TestCertificateForHost(t *testing.T) {
t.Parallel()
ca := newTestCA(t)
t.Run("empty host returns error", func(t *testing.T) {
t.Parallel()
cert, err := ca.CertificateForHost(" ")
if err == nil {
t.Fatal("CertificateForHost() error = nil, want non-nil")
}
if cert != nil {
t.Fatalf("cert = %#v, want nil", cert)
}
})
t.Run("cache hit returns same pointer", func(t *testing.T) {
t.Parallel()
c1, err := ca.CertificateForHost("example.com:443")
if err != nil {
t.Fatalf("first CertificateForHost() error = %v", err)
}
c2, err := ca.CertificateForHost("example.com")
if err != nil {
t.Fatalf("second CertificateForHost() error = %v", err)
}
if c1 != c2 {
t.Fatal("expected same certificate pointer from cache")
}
})
t.Run("ip and dns SAN selection", func(t *testing.T) {
t.Parallel()
ipCert, err := ca.CertificateForHost("127.0.0.1:443")
if err != nil {
t.Fatalf("ip CertificateForHost() error = %v", err)
}
if ipCert.Leaf == nil {
t.Fatal("ip cert leaf is nil")
}
if len(ipCert.Leaf.IPAddresses) == 0 {
t.Fatal("ip cert should contain IP SAN")
}
if len(ipCert.Leaf.DNSNames) != 0 {
t.Fatalf("ip cert DNSNames = %v, want empty", ipCert.Leaf.DNSNames)
}
dnsCert, err := ca.CertificateForHost("service.local")
if err != nil {
t.Fatalf("dns CertificateForHost() error = %v", err)
}
if dnsCert.Leaf == nil {
t.Fatal("dns cert leaf is nil")
}
if len(dnsCert.Leaf.DNSNames) == 0 {
t.Fatal("dns cert should contain DNS SAN")
}
})
t.Run("evicts oldest entry over maxLeafCerts", func(t *testing.T) {
t.Parallel()
ca2 := newTestCA(t)
for i := 0; i < maxLeafCerts+1; i++ {
host := filepath.Base(filepath.Join("h", big.NewInt(int64(i)).String()+".example"))
if _, err := ca2.CertificateForHost(host); err != nil {
t.Fatalf("CertificateForHost(%q) error = %v", host, err)
}
}
if len(ca2.leafOrder) != maxLeafCerts {
t.Fatalf("leafOrder len = %d, want %d", len(ca2.leafOrder), maxLeafCerts)
}
if _, ok := ca2.leafCert["0.example"]; ok {
t.Fatal("expected oldest cert to be evicted")
}
})
}
func TestIsTrustedBySystem(t *testing.T) {
t.Parallel()
var nilCA *CertificateAuthority
_, err := nilCA.IsTrustedBySystem()
if err == nil {
t.Fatal("nil IsTrustedBySystem() error = nil, want non-nil")
}
ca := &CertificateAuthority{}
_, err = ca.IsTrustedBySystem()
if err == nil {
t.Fatal("missing-cert IsTrustedBySystem() error = nil, want non-nil")
}
t.Run("untrusted generated CA returns false without error", func(t *testing.T) {
t.Parallel()
ca := newTestCA(t)
trusted, err := ca.IsTrustedBySystem()
if err != nil {
t.Fatalf("IsTrustedBySystem() error = %v, want nil for unknown authority", err)
}
if trusted {
t.Fatal("trusted = true, want false for generated test CA")
}
})
}
func TestEnsureCertificateAuthority(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
ca, err := EnsureCertificateAuthority()
if err != nil {
t.Fatalf("EnsureCertificateAuthority() error = %v", err)
}
if ca == nil {
t.Fatal("EnsureCertificateAuthority() returned nil CA")
}
if ca.CertPath() == "" {
t.Fatal("EnsureCertificateAuthority() returned empty cert path")
}
if _, statErr := os.Stat(ca.CertPath()); statErr != nil {
t.Fatalf("expected cert on disk, stat error = %v", statErr)
}
}
func TestWriteFileAtomically_ErrorPath(t *testing.T) {
t.Parallel()
err := writeFileAtomically(filepath.Join("/nope", "bad", "path.pem"), []byte("x"), 0o600)
if err == nil {
t.Fatal("writeFileAtomically() error = nil, want non-nil")
}
if errors.Is(err, os.ErrNotExist) {
return
}
// Accept platform-dependent fs errors as long as function fails.
}
func TestWriteFileAtomically_RenameErrorWhenTargetIsDirectory(t *testing.T) {
t.Parallel()
dir := t.TempDir()
targetDir := filepath.Join(dir, "target-as-dir")
if err := os.MkdirAll(targetDir, 0o700); err != nil {
t.Fatalf("MkdirAll(targetDir) error = %v", err)
}
err := writeFileAtomically(targetDir, []byte("x"), 0o600)
if err == nil {
t.Fatal("writeFileAtomically() error = nil, want non-nil")
}
}
// TODO: Add deterministic tests for loadOrCreateCertificateAuthority trust-store interactions.

View File

@ -0,0 +1,29 @@
package proxy
import (
"encoding/pem"
"os"
"path/filepath"
"testing"
)
func TestIsTrustedBySystem_TrustedViaCertEnv(t *testing.T) {
ca := newTestCA(t)
rootFile := filepath.Join(t.TempDir(), "roots.pem")
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.cert.Raw})
if err := os.WriteFile(rootFile, pemBytes, 0o600); err != nil {
t.Fatalf("WriteFile(root cert) error = %v", err)
}
t.Setenv("SSL_CERT_FILE", rootFile)
t.Setenv("SSL_CERT_DIR", t.TempDir())
trusted, err := ca.IsTrustedBySystem()
if err != nil {
t.Fatalf("IsTrustedBySystem() error = %v", err)
}
if !trusted {
t.Fatal("trusted = false, want true when CA is in configured cert file")
}
}

View File

@ -0,0 +1,330 @@
package proxy
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"termtap.dev/internal/model"
)
type failingResponseWriter struct {
header http.Header
code int
}
func (w *failingResponseWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *failingResponseWriter) WriteHeader(statusCode int) {
w.code = statusCode
}
func (w *failingResponseWriter) Write(_ []byte) (int, error) {
return 0, io.ErrClosedPipe
}
type hijackFailWriter struct {
header http.Header
code int
}
func (w *hijackFailWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *hijackFailWriter) WriteHeader(statusCode int) {
w.code = statusCode
}
func (w *hijackFailWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (w *hijackFailWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, fmt.Errorf("hijack failed")
}
type dummyConn struct{}
func (dummyConn) Read(_ []byte) (int, error) { return 0, io.EOF }
func (dummyConn) Write(p []byte) (int, error) { return len(p), nil }
func (dummyConn) Close() error { return nil }
func (dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (dummyConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (dummyConn) SetDeadline(_ time.Time) error { return nil }
func (dummyConn) SetReadDeadline(_ time.Time) error { return nil }
func (dummyConn) SetWriteDeadline(_ time.Time) error { return nil }
type writeConnectFailWriter struct {
header http.Header
code int
}
func (w *writeConnectFailWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *writeConnectFailWriter) WriteHeader(statusCode int) {
w.code = statusCode
}
func (w *writeConnectFailWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (w *writeConnectFailWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
rw := bufio.NewReadWriter(
bufio.NewReader(strings.NewReader("")),
bufio.NewWriter(errWriter{}),
)
return dummyConn{}, rw, nil
}
func TestProxyHandler_NonConnectRejectsNonAbsoluteURL(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
h := proxyHandler(ch, nil, ps)
req := httptest.NewRequest(http.MethodGet, "/not-proxy-form", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", w.Code, http.StatusBadRequest)
}
events := drainEvents(t, ch, 1, time.Second)
if !hasEventType(events, model.EventTypeWarn) {
t.Fatalf("expected %s event, got %#v", model.EventTypeWarn, events)
}
}
func TestProxyHandler_NonConnectSuccess(t *testing.T) {
t.Parallel()
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Upstream", "yes")
w.WriteHeader(http.StatusAccepted)
_, _ = w.Write([]byte("pong"))
}))
t.Cleanup(upstream.Close)
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
h := proxyHandler(ch, nil, ps)
req := httptest.NewRequest(http.MethodGet, upstream.URL+"/ping", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
if w.Code != http.StatusAccepted {
t.Fatalf("status = %d, want %d", w.Code, http.StatusAccepted)
}
if got, want := w.Body.String(), "pong"; got != want {
t.Fatalf("body = %q, want %q", got, want)
}
if got := w.Header().Get("X-Upstream"); got != "yes" {
t.Fatalf("X-Upstream header = %q, want yes", got)
}
events := drainEvents(t, ch, 2, time.Second)
if !hasEventType(events, model.EventTypeRequestStarted) {
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
}
if !hasEventType(events, model.EventTypeRequestFinished) {
t.Fatalf("expected %s event", model.EventTypeRequestFinished)
}
}
func TestProxyHandler_NonConnectUpstreamError(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
h := proxyHandler(ch, nil, ps)
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1:1/fail", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
if w.Code != http.StatusBadGateway {
t.Fatalf("status = %d, want %d", w.Code, http.StatusBadGateway)
}
events := drainEvents(t, ch, 2, time.Second)
if !hasEventType(events, model.EventTypeRequestStarted) {
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
}
if !hasEventType(events, model.EventTypeRequestFailed) {
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
}
}
func TestProxyHandler_NonConnectWriteFailureEmitsFailedEvent(t *testing.T) {
t.Parallel()
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = io.Copy(w, bytes.NewBufferString("response-body"))
}))
t.Cleanup(upstream.Close)
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
h := proxyHandler(ch, nil, ps)
req := httptest.NewRequest(http.MethodGet, upstream.URL+"/copy-fail", nil)
w := &failingResponseWriter{}
h.ServeHTTP(w, req)
events := drainEvents(t, ch, 2, time.Second)
if !hasEventType(events, model.EventTypeRequestStarted) {
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
}
if !hasEventType(events, model.EventTypeRequestFailed) {
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
}
}
func TestHandleConnect_EarlyFailures(t *testing.T) {
t.Parallel()
transport := &mockTransport{fn: func(req *http.Request) (*http.Response, error) {
return nil, context.Canceled
}}
tests := []struct {
name string
writer http.ResponseWriter
req *http.Request
ca *CertificateAuthority
wantCode int
wantBody string
wantStat int
}{
{
name: "nil CA returns 502 and failed event",
writer: httptest.NewRecorder(),
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
ca: nil,
wantCode: http.StatusBadGateway,
wantBody: "certificate authority is unavailable",
wantStat: http.StatusBadGateway,
},
{
name: "nil CA with host without port still fails predictably",
writer: httptest.NewRecorder(),
req: &http.Request{Method: http.MethodConnect, Host: "example.com"},
ca: nil,
wantCode: http.StatusBadGateway,
wantBody: "certificate authority is unavailable",
wantStat: http.StatusBadGateway,
},
{
name: "cert mint failure returns 502 and failed event",
writer: httptest.NewRecorder(),
req: &http.Request{Method: http.MethodConnect, Host: ""},
ca: &CertificateAuthority{},
wantCode: http.StatusBadGateway,
wantBody: "failed to mint interception certificate",
wantStat: http.StatusBadGateway,
},
{
name: "non-hijacker writer returns 500 and failed event",
writer: httptest.NewRecorder(),
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
ca: newTestCA(t),
wantCode: http.StatusInternalServerError,
wantBody: "hijack is unavailable",
wantStat: http.StatusInternalServerError,
},
{
name: "hijack failure returns 500 and failed event",
writer: &hijackFailWriter{},
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
ca: newTestCA(t),
wantCode: http.StatusInternalServerError,
wantBody: "CONNECT hijack failed",
wantStat: http.StatusInternalServerError,
},
{
name: "write connect established failure emits failed event",
writer: &writeConnectFailWriter{},
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
ca: newTestCA(t),
wantCode: 0,
wantBody: "CONNECT setup failed",
wantStat: http.StatusBadGateway,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
handleConnect(tt.writer, tt.req, ch, transport, tt.ca, ps)
events := drainEvents(t, ch, 2, time.Second)
if events[0].Type != model.EventTypeRequestStarted {
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
}
if events[1].Type != model.EventTypeRequestFailed {
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
}
if events[1].Request.Method != http.MethodConnect {
t.Fatalf("failed event request method = %q, want CONNECT", events[1].Request.Method)
}
if events[1].Request.Status != tt.wantStat {
t.Fatalf("failed event request status = %d, want %d", events[1].Request.Status, tt.wantStat)
}
if !events[1].Request.Failed || events[1].Request.Pending {
t.Fatalf("failed event request flags = pending:%v failed:%v, want pending:false failed:true", events[1].Request.Pending, events[1].Request.Failed)
}
if !strings.Contains(events[1].Body, tt.wantBody) {
t.Fatalf("failed event body %q does not contain %q", events[1].Body, tt.wantBody)
}
if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok {
if recorder.Code != tt.wantCode {
t.Fatalf("status = %d, want %d", recorder.Code, tt.wantCode)
}
}
if w, ok := tt.writer.(*hijackFailWriter); ok {
if w.code != tt.wantCode {
t.Fatalf("status = %d, want %d", w.code, tt.wantCode)
}
}
})
}
}
// TODO: Add full TLS MITM loop integration test.

View File

@ -0,0 +1,157 @@
package proxy
import (
"net/http"
"reflect"
"testing"
)
func TestStripHopByHopHeaders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in http.Header
want http.Header
}{
{
name: "removes static hop-by-hop headers",
in: http.Header{
"Connection": {"keep-alive"},
"Keep-Alive": {"timeout=5"},
"Proxy-Connection": {"keep-alive"},
"Transfer-Encoding": {"chunked"},
"Upgrade": {"websocket"},
"X-Custom": {"ok"},
},
want: http.Header{
"X-Custom": {"ok"},
},
},
{
name: "removes headers listed in connection with spaces and commas",
in: http.Header{
"Connection": {" keep-alive, X-Foo ,X-Bar"},
"Keep-Alive": {"timeout=5"},
"X-Foo": {"foo"},
"X-Bar": {"bar"},
"Content-Type": {"application/json"},
"Content-Length": {"12"},
},
want: http.Header{
"Content-Type": {"application/json"},
"Content-Length": {"12"},
},
},
{
name: "keeps unrelated headers when no connection header",
in: http.Header{
"Accept": {"*/*"},
"Authorization": {"Bearer abc"},
},
want: http.Header{
"Accept": {"*/*"},
"Authorization": {"Bearer abc"},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
stripHopByHopHeaders(tt.in)
if !reflect.DeepEqual(tt.in, tt.want) {
t.Fatalf("stripHopByHopHeaders() = %#v, want %#v", tt.in, tt.want)
}
})
}
}
func TestStripHopByHopHeaders_NilHeader(t *testing.T) {
t.Parallel()
stripHopByHopHeaders(nil)
}
func TestRedactHeaders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input http.Header
want http.Header
}{
{
name: "redacts sensitive canonical headers",
input: http.Header{
"Authorization": {"Bearer token123"},
"Cookie": {"session=abc"},
"Proxy-Authorization": {"Basic abc"},
"Set-Cookie": {"a=1"},
"X-Api-Key": {"secret"},
},
want: http.Header{
"Authorization": {"[REDACTED]"},
"Cookie": {"[REDACTED]"},
"Proxy-Authorization": {"[REDACTED]"},
"Set-Cookie": {"[REDACTED]"},
"X-Api-Key": {"[REDACTED]"},
},
},
{
name: "leaves non-sensitive headers untouched",
input: http.Header{
"Content-Type": {"application/json"},
"X-Trace-ID": {"trace-1"},
},
want: http.Header{
"Content-Type": {"application/json"},
"X-Trace-ID": {"trace-1"},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := redactHeaders(tt.input)
if !reflect.DeepEqual(got, tt.want) {
t.Fatalf("redactHeaders() = %#v, want %#v", got, tt.want)
}
got.Set("Content-Type", "modified")
if reflect.DeepEqual(tt.input, got) {
t.Fatal("redactHeaders() appears to mutate input or return aliased map")
}
})
}
}
func TestCopyHeaders(t *testing.T) {
t.Parallel()
src := http.Header{
"X-Multi": {"a", "b", "c"},
"Content-Type": {"application/json"},
}
dest := http.Header{
"Existing": {"keep"},
}
copyHeaders(src, dest)
want := http.Header{
"Existing": {"keep"},
"X-Multi": {"a", "b", "c"},
"Content-Type": {"application/json"},
}
if !reflect.DeepEqual(dest, want) {
t.Fatalf("copyHeaders() dest = %#v, want %#v", dest, want)
}
}

View File

@ -0,0 +1,129 @@
package proxy
import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"termtap.dev/internal/model"
)
func TestHTTPProxyE2E_WithClientProxyConfig(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("pong"))
}))
t.Cleanup(upstream.Close)
ch := make(chan model.Event, 64)
ps, err := NewProxyServer("127.0.0.1:0", ch)
if err != nil {
t.Fatalf("NewProxyServer() error = %v", err)
}
serveDone := make(chan error, 1)
go func() {
serveDone <- ps.Server.Serve(*ps.Listener)
}()
t.Cleanup(func() {
Destroy(ps, ch)
select {
case <-serveDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for proxy server to stop")
}
})
proxyURL, err := url.Parse(ps.Url)
if err != nil {
t.Fatalf("url.Parse(proxy) error = %v", err)
}
client := &http.Client{
Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)},
Timeout: 3 * time.Second,
}
resp, err := client.Get(upstream.URL + "/ping")
if err != nil {
t.Fatalf("client.Get() error = %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ReadAll(response) error = %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
}
if got, want := string(body), "pong"; got != want {
t.Fatalf("body = %q, want %q", got, want)
}
events := drainEvents(t, ch, 2, 2*time.Second)
if !hasEventType(events, model.EventTypeRequestStarted) {
t.Fatalf("missing %s event", model.EventTypeRequestStarted)
}
if !hasEventType(events, model.EventTypeRequestFinished) {
t.Fatalf("missing %s event", model.EventTypeRequestFinished)
}
}
func TestHTTPProxyE2E_UpstreamFailureEmitsRequestFailed(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
ch := make(chan model.Event, 64)
ps, err := NewProxyServer("127.0.0.1:0", ch)
if err != nil {
t.Fatalf("NewProxyServer() error = %v", err)
}
serveDone := make(chan error, 1)
go func() {
serveDone <- ps.Server.Serve(*ps.Listener)
}()
t.Cleanup(func() {
Destroy(ps, ch)
select {
case <-serveDone:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for proxy server to stop")
}
})
proxyURL, err := url.Parse(ps.Url)
if err != nil {
t.Fatalf("url.Parse(proxy) error = %v", err)
}
client := &http.Client{
Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)},
Timeout: 3 * time.Second,
}
resp, reqErr := client.Get("http://127.0.0.1:1/unreachable")
if reqErr != nil {
t.Fatalf("client.Get() error = %v; proxy should reply with mapped HTTP status", reqErr)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadGateway {
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusBadGateway)
}
events := drainEvents(t, ch, 2, 3*time.Second)
if !hasEventType(events, model.EventTypeRequestStarted) {
t.Fatalf("missing %s event", model.EventTypeRequestStarted)
}
if !hasEventType(events, model.EventTypeRequestFailed) {
t.Fatalf("missing %s event", model.EventTypeRequestFailed)
}
}

View File

@ -0,0 +1,301 @@
package proxy
import (
"bufio"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
"termtap.dev/internal/model"
)
type hijackPipeWriter struct {
conn net.Conn
header http.Header
code int
}
func (w *hijackPipeWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *hijackPipeWriter) WriteHeader(statusCode int) {
w.code = statusCode
}
func (w *hijackPipeWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (w *hijackPipeWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.conn, nil, nil
}
func startMITMConnect(t *testing.T, transport http.RoundTripper) (net.Conn, *tls.Conn, chan model.Event, chan struct{}) {
t.Helper()
ca := newTestCA(t)
clientConn, serverConn := net.Pipe()
writer := &hijackPipeWriter{conn: serverConn}
req, err := http.NewRequest(http.MethodConnect, "http://example.com:443", nil)
if err != nil {
t.Fatalf("NewRequest(CONNECT) error = %v", err)
}
req.Host = "example.com:443"
ch := make(chan model.Event, 16)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
handleDone := make(chan struct{})
go func() {
handleConnect(writer, req, ch, transport, ca, ps)
close(handleDone)
}()
reader := bufio.NewReader(clientConn)
connectResp, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
if err != nil {
_ = clientConn.Close()
_ = serverConn.Close()
t.Fatalf("ReadResponse(CONNECT established) error = %v", err)
}
if connectResp.StatusCode != http.StatusOK {
_ = clientConn.Close()
_ = serverConn.Close()
t.Fatalf("CONNECT established status = %d, want %d", connectResp.StatusCode, http.StatusOK)
}
pool := x509.NewCertPool()
pool.AddCert(ca.cert)
tlsClient := tls.Client(clientConn, &tls.Config{
ServerName: "example.com",
RootCAs: pool,
MinVersion: tls.VersionTLS12,
})
if err := tlsClient.Handshake(); err != nil {
_ = tlsClient.Close()
_ = serverConn.Close()
t.Fatalf("tls handshake error = %v", err)
}
t.Cleanup(func() {
_ = tlsClient.Close()
_ = serverConn.Close()
})
return clientConn, tlsClient, ch, handleDone
}
func TestHTTPSE2E_MITMHandleConnectFlow(t *testing.T) {
transport := &mockTransport{fn: func(r *http.Request) (*http.Response, error) {
if r.URL.Scheme != "https" {
t.Fatalf("inner request scheme = %q, want https", r.URL.Scheme)
}
if r.URL.Host == "" {
t.Fatal("inner request host should be populated")
}
return &http.Response{
StatusCode: http.StatusCreated,
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{"Content-Type": {"text/plain"}},
Body: io.NopCloser(strings.NewReader("mitm-ok")),
}, nil
}}
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
// 3) Send decrypted inner request over tunnel and read MITM response
innerReq, err := http.NewRequest(http.MethodGet, "https://example.com/inside", nil)
if err != nil {
t.Fatalf("NewRequest(inner) error = %v", err)
}
innerReq.Host = "example.com"
innerReq.Close = true
if err := innerReq.Write(tlsClient); err != nil {
t.Fatalf("innerReq.Write() error = %v", err)
}
innerResp, err := http.ReadResponse(bufio.NewReader(tlsClient), innerReq)
if err != nil {
t.Fatalf("ReadResponse(inner) error = %v", err)
}
defer innerResp.Body.Close()
body, err := io.ReadAll(innerResp.Body)
if err != nil {
t.Fatalf("ReadAll(inner body) error = %v", err)
}
if innerResp.StatusCode != http.StatusCreated {
t.Fatalf("inner status = %d, want %d", innerResp.StatusCode, http.StatusCreated)
}
if got, want := string(body), "mitm-ok"; got != want {
t.Fatalf("inner body = %q, want %q", got, want)
}
_ = tlsClient.Close()
_ = clientConn.Close()
select {
case <-handleDone:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for handleConnect to return")
}
events := drainEvents(t, ch, 4, 2*time.Second)
if events[0].Type != model.EventTypeRequestStarted {
t.Fatalf("event[0] = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
}
if events[1].Type != model.EventTypeRequestStarted {
t.Fatalf("event[1] = %s, want %s", events[1].Type, model.EventTypeRequestStarted)
}
if events[2].Type != model.EventTypeRequestFinished {
t.Fatalf("event[2] = %s, want %s", events[2].Type, model.EventTypeRequestFinished)
}
if events[3].Type != model.EventTypeRequestFinished {
t.Fatalf("event[3] = %s, want %s", events[3].Type, model.EventTypeRequestFinished)
}
}
func TestHTTPSE2E_MITMUpstreamErrorReturnsHTTPErrorInsideTunnel(t *testing.T) {
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
return nil, errors.New("upstream exploded")
}}
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
innerReq, err := http.NewRequest(http.MethodGet, "https://example.com/fail", nil)
if err != nil {
t.Fatalf("NewRequest(inner) error = %v", err)
}
innerReq.Host = "example.com"
if err := innerReq.Write(tlsClient); err != nil {
t.Fatalf("innerReq.Write() error = %v", err)
}
innerResp, err := http.ReadResponse(bufio.NewReader(tlsClient), innerReq)
if err != nil {
t.Fatalf("ReadResponse(inner) error = %v", err)
}
defer innerResp.Body.Close()
if innerResp.StatusCode != http.StatusBadGateway {
t.Fatalf("inner status = %d, want %d", innerResp.StatusCode, http.StatusBadGateway)
}
_ = tlsClient.Close()
_ = clientConn.Close()
select {
case <-handleDone:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for handleConnect to return")
}
events := drainEvents(t, ch, 4, 2*time.Second)
if events[0].Type != model.EventTypeRequestStarted ||
events[1].Type != model.EventTypeRequestStarted ||
events[2].Type != model.EventTypeRequestFailed ||
events[3].Type != model.EventTypeRequestFailed {
t.Fatalf("unexpected event sequence: %#v", events)
}
}
func TestHTTPSE2E_MITMHandshakeFailureEmitsConnectFailed(t *testing.T) {
ca := newTestCA(t)
clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
})
writer := &hijackPipeWriter{conn: serverConn}
req, err := http.NewRequest(http.MethodConnect, "http://example.com:443", nil)
if err != nil {
t.Fatalf("NewRequest(CONNECT) error = %v", err)
}
req.Host = "example.com:443"
ch := make(chan model.Event, 16)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
return nil, errors.New("should not be called")
}}
handleDone := make(chan struct{})
go func() {
handleConnect(writer, req, ch, transport, ca, ps)
close(handleDone)
}()
reader := bufio.NewReader(clientConn)
connectResp, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
if err != nil {
t.Fatalf("ReadResponse(CONNECT established) error = %v", err)
}
if connectResp.StatusCode != http.StatusOK {
t.Fatalf("CONNECT established status = %d, want %d", connectResp.StatusCode, http.StatusOK)
}
// Write plaintext (not TLS handshake) to force handshake failure.
if _, err := clientConn.Write([]byte("not tls handshake")); err != nil {
t.Fatalf("client write error = %v", err)
}
_ = clientConn.Close()
select {
case <-handleDone:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for handleConnect to return on handshake failure")
}
events := drainEvents(t, ch, 2, 2*time.Second)
if events[0].Type != model.EventTypeRequestStarted || events[1].Type != model.EventTypeRequestFailed {
t.Fatalf("unexpected event sequence: %#v", events)
}
if !strings.Contains(events[1].Body, "TLS handshake with client failed") {
t.Fatalf("failed event body = %q, want handshake failure details", events[1].Body)
}
}
func TestHTTPSE2E_MITMDecryptedReadFailureEmitsConnectFailed(t *testing.T) {
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
return nil, errors.New("should not be called")
}}
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
// Send malformed decrypted payload so ReadRequest fails (non-EOF branch).
if _, err := tlsClient.Write([]byte("BAD REQUEST\r\n\r\n")); err != nil {
t.Fatalf("tlsClient.Write() error = %v", err)
}
_ = tlsClient.Close()
_ = clientConn.Close()
select {
case <-handleDone:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for handleConnect to return on decrypted read failure")
}
events := drainEvents(t, ch, 2, 2*time.Second)
if events[0].Type != model.EventTypeRequestStarted || events[1].Type != model.EventTypeRequestFailed {
t.Fatalf("unexpected event sequence: %#v", events)
}
if !strings.Contains(events[1].Body, "failed to read decrypted HTTPS request") {
t.Fatalf("failed event body = %q, want read failure details", events[1].Body)
}
}

View File

@ -0,0 +1,142 @@
package proxy
import (
"bufio"
"io"
"net"
"strings"
"testing"
)
func TestNewBodyPreview(t *testing.T) {
t.Parallel()
tests := []struct {
name string
contentType string
wantEnabled bool
}{
{name: "text content enabled", contentType: "text/plain", wantEnabled: true},
{name: "json content enabled", contentType: "application/json", wantEnabled: true},
{name: "binary content disabled", contentType: "application/octet-stream", wantEnabled: false},
{name: "empty content type disabled", contentType: "", wantEnabled: false},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
p := newBodyPreview(tt.contentType)
if p.enabled != tt.wantEnabled {
t.Fatalf("newBodyPreview(%q).enabled = %v, want %v", tt.contentType, p.enabled, tt.wantEnabled)
}
})
}
}
func TestBodyPreviewWriteAndPreview(t *testing.T) {
t.Parallel()
t.Run("nil receiver is safe", func(t *testing.T) {
t.Parallel()
var p *bodyPreview
p.Write([]byte("abc"))
})
t.Run("disabled preview ignores data", func(t *testing.T) {
t.Parallel()
p := &bodyPreview{enabled: false}
p.Write([]byte("abc"))
if got := string(p.Preview()); got != "" {
t.Fatalf("Preview() = %q, want empty", got)
}
})
t.Run("empty write does nothing", func(t *testing.T) {
t.Parallel()
p := &bodyPreview{enabled: true}
p.Write(nil)
if got := string(p.Preview()); got != "" {
t.Fatalf("Preview() = %q, want empty", got)
}
})
t.Run("escapes newlines", func(t *testing.T) {
t.Parallel()
p := &bodyPreview{enabled: true}
p.Write([]byte("a\nb"))
if got, want := string(p.Preview()), `a\nb`; got != want {
t.Fatalf("Preview() = %q, want %q", got, want)
}
})
t.Run("truncates at max preview bytes and appends ellipsis", func(t *testing.T) {
t.Parallel()
p := &bodyPreview{enabled: true}
p.Write([]byte(strings.Repeat("a", maxPreviewBytes)))
p.Write([]byte("b"))
got := string(p.Preview())
if !strings.HasSuffix(got, "...") {
t.Fatalf("Preview() must end with ellipsis when truncated: %q", got[len(got)-10:])
}
if len(got) != maxPreviewBytes+3 {
t.Fatalf("len(Preview()) = %d, want %d", len(got), maxPreviewBytes+3)
}
})
}
func TestWrapBufferedConn(t *testing.T) {
t.Parallel()
client, server := net.Pipe()
t.Cleanup(func() {
_ = client.Close()
_ = server.Close()
})
t.Run("returns original conn when readWriter nil", func(t *testing.T) {
t.Parallel()
got := wrapBufferedConn(client, nil)
if got != client {
t.Fatal("wrapBufferedConn should return original conn when readWriter is nil")
}
})
t.Run("read uses buffered readWriter", func(t *testing.T) {
t.Parallel()
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("xyz")), bufio.NewWriter(io.Discard))
got := wrapBufferedConn(client, rw)
buf := make([]byte, 3)
n, err := got.Read(buf)
if err != nil {
t.Fatalf("Read() error = %v", err)
}
if n != 3 || string(buf) != "xyz" {
t.Fatalf("Read() = (%d, %q), want (3, %q)", n, string(buf), "xyz")
}
})
}
func TestPreviewReadCloserRead(t *testing.T) {
t.Parallel()
preview := newBodyPreview("text/plain")
rc := &previewReadCloser{
ReadCloser: io.NopCloser(strings.NewReader("hello\nworld")),
preview: preview,
}
data, err := io.ReadAll(rc)
if err != nil {
t.Fatalf("ReadAll() error = %v", err)
}
if got, want := string(data), "hello\nworld"; got != want {
t.Fatalf("read content = %q, want %q", got, want)
}
if got, want := string(preview.Preview()), `hello\nworld`; got != want {
t.Fatalf("preview = %q, want %q", got, want)
}
}

View File

@ -0,0 +1,275 @@
package proxy
import (
"errors"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/google/uuid"
"termtap.dev/internal/model"
)
type mockTransport struct {
fn func(*http.Request) (*http.Response, error)
}
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
return m.fn(req)
}
func TestRoundTripCapturedRequest_Success(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
reqURL, err := url.Parse("http://example.com/path?q=1")
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
req := &http.Request{
Method: http.MethodPost,
URL: reqURL,
Host: "example.com",
Header: http.Header{
"Connection": {"X-Hop"},
"X-Hop": {"drop"},
"Authorization": {"Bearer token"},
"Content-Type": {"text/plain"},
},
Body: io.NopCloser(strings.NewReader("req\nbody")),
}
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
if got := outReq.Header.Get("Connection"); got != "" {
t.Fatalf("Connection header should be stripped, got %q", got)
}
if got := outReq.Header.Get("X-Hop"); got != "" {
t.Fatalf("header listed in Connection should be stripped, got %q", got)
}
data, readErr := io.ReadAll(outReq.Body)
if readErr != nil {
t.Fatalf("ReadAll(outReq.Body) error = %v", readErr)
}
if got, want := string(data), "req\nbody"; got != want {
t.Fatalf("request body = %q, want %q", got, want)
}
return &http.Response{
StatusCode: http.StatusCreated,
Header: http.Header{
"Set-Cookie": {"session=top-secret"},
"Connection": {"close"},
"Content-Type": {"text/plain"},
},
Body: io.NopCloser(strings.NewReader("resp\nbody")),
}, nil
}}
resp, captured, responsePreview, err := roundTripCapturedRequest(req, transport, ch, "", false)
if err != nil {
t.Fatalf("roundTripCapturedRequest() error = %v", err)
}
if resp == nil {
t.Fatal("roundTripCapturedRequest() returned nil response")
}
if responsePreview == nil {
t.Fatal("roundTripCapturedRequest() returned nil response preview")
}
if _, readErr := io.ReadAll(resp.Body); readErr != nil {
t.Fatalf("ReadAll(resp.Body) error = %v", readErr)
}
_ = resp.Body.Close()
if got, want := string(captured.RequestData), `req\nbody`; got != want {
t.Fatalf("captured.RequestData = %q, want %q", got, want)
}
if got, want := string(responsePreview.Preview()), `resp\nbody`; got != want {
t.Fatalf("responsePreview = %q, want %q", got, want)
}
if got := captured.RequestHeaders.Get("Authorization"); got != "[REDACTED]" {
t.Fatalf("Authorization header = %q, want [REDACTED]", got)
}
if got := captured.ResponseHeaders.Get("Set-Cookie"); got != "[REDACTED]" {
t.Fatalf("Set-Cookie header = %q, want [REDACTED]", got)
}
if got := captured.ResponseHeaders.Get("Connection"); got != "" {
t.Fatalf("Connection should be stripped from response headers, got %q", got)
}
events := drainEvents(t, ch, 1, time.Second)
if events[0].Type != model.EventTypeRequestStarted {
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
}
}
func TestRoundTripCapturedRequest_ErrorPath(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
reqURL, err := url.Parse("http://example.com/fail")
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
req := &http.Request{
Method: http.MethodPost,
URL: reqURL,
Host: "example.com",
Header: http.Header{"Content-Type": {"text/plain"}},
Body: io.NopCloser(strings.NewReader("boom\nbody")),
}
wantErr := errors.New("upstream failed")
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
_, _ = io.ReadAll(outReq.Body)
return nil, wantErr
}}
resp, captured, responsePreview, gotErr := roundTripCapturedRequest(req, transport, ch, "", false)
if !errors.Is(gotErr, wantErr) {
t.Fatalf("error = %v, want %v", gotErr, wantErr)
}
if resp != nil {
t.Fatalf("response = %#v, want nil", resp)
}
if responsePreview != nil {
t.Fatalf("responsePreview = %#v, want nil", responsePreview)
}
if got, want := string(captured.RequestData), `boom\nbody`; got != want {
t.Fatalf("captured.RequestData = %q, want %q", got, want)
}
events := drainEvents(t, ch, 1, time.Second)
if events[0].Type != model.EventTypeRequestStarted {
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
}
}
func TestRoundTripCapturedRequest_InterceptedTLSDefaults(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 8)
u, err := url.Parse("/secure?p=1")
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
req := &http.Request{
Method: http.MethodGet,
URL: u,
Header: http.Header{},
}
const defaultHost = "api.example.com:443"
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
if got := outReq.URL.Scheme; got != "https" {
t.Fatalf("URL.Scheme = %q, want https", got)
}
if got := outReq.URL.Host; got != defaultHost {
t.Fatalf("URL.Host = %q, want %q", got, defaultHost)
}
if got := outReq.Host; got != defaultHost {
t.Fatalf("Host = %q, want %q", got, defaultHost)
}
return &http.Response{
StatusCode: http.StatusNoContent,
Header: http.Header{"Content-Type": {"text/plain"}},
Body: io.NopCloser(strings.NewReader("")),
}, nil
}}
_, _, _, gotErr := roundTripCapturedRequest(req, transport, ch, defaultHost, true)
if gotErr != nil {
t.Fatalf("roundTripCapturedRequest() error = %v", gotErr)
}
}
func TestNewConnectRequest(t *testing.T) {
t.Parallel()
now := time.Now().Add(-time.Second)
req := &http.Request{Method: http.MethodConnect, Host: "example.com:443"}
got := newConnectRequest(req, now)
if got.Method != http.MethodConnect {
t.Fatalf("Method = %q, want CONNECT", got.Method)
}
if got.Host != req.Host || got.URL != req.Host || got.RawURL != req.Host {
t.Fatalf("connect request host/url/raw mismatch: %#v", got)
}
if !got.Pending || got.Failed {
t.Fatalf("Pending/Failed = (%v,%v), want (true,false)", got.Pending, got.Failed)
}
if got.Status != -1 {
t.Fatalf("Status = %d, want -1", got.Status)
}
if got.StartTime != now {
t.Fatalf("StartTime = %v, want %v", got.StartTime, now)
}
if got.ID == uuid.Nil {
t.Fatal("ID must be non-zero UUID")
}
}
func TestStartFinishFailRequestEvents(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 4)
req := model.Request{
ID: uuid.New(),
Method: http.MethodGet,
RawURL: "http://example.com/a",
StartTime: time.Now().Add(-3 * time.Millisecond),
Pending: true,
}
startRequest(ch, req)
events := drainEvents(t, ch, 1, time.Second)
startEv := events[0]
if startEv.Type != model.EventTypeRequestStarted {
t.Fatalf("start event type = %s, want %s", startEv.Type, model.EventTypeRequestStarted)
}
if startEv.Request.Pending != true {
t.Fatalf("start request pending = %v, want true", startEv.Request.Pending)
}
finishRequest(ch, req, http.StatusOK)
events = drainEvents(t, ch, 1, time.Second)
finishEv := events[0]
if finishEv.Type != model.EventTypeRequestFinished {
t.Fatalf("finish event type = %s, want %s", finishEv.Type, model.EventTypeRequestFinished)
}
if finishEv.Request.Pending {
t.Fatal("finished request should not be pending")
}
if finishEv.Request.Failed {
t.Fatal("finished request should not be failed")
}
if finishEv.Request.Status != http.StatusOK {
t.Fatalf("finished status = %d, want %d", finishEv.Request.Status, http.StatusOK)
}
failRequest(ch, req, http.StatusBadGateway, "upstream error")
events = drainEvents(t, ch, 1, time.Second)
failEv := events[0]
if failEv.Type != model.EventTypeRequestFailed {
t.Fatalf("fail event type = %s, want %s", failEv.Type, model.EventTypeRequestFailed)
}
if failEv.Request.Pending {
t.Fatal("failed request should not be pending")
}
if !failEv.Request.Failed {
t.Fatal("failed request should be marked failed")
}
if failEv.Request.Status != http.StatusBadGateway {
t.Fatalf("failed status = %d, want %d", failEv.Request.Status, http.StatusBadGateway)
}
}

View File

@ -0,0 +1,138 @@
package proxy
import (
"bufio"
"bytes"
"io"
"net"
"strings"
"testing"
"time"
)
type captureConn struct {
bytes.Buffer
}
func (c *captureConn) Read(_ []byte) (int, error) { return 0, io.EOF }
func (c *captureConn) Close() error { return nil }
func (c *captureConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (c *captureConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (c *captureConn) SetDeadline(_ time.Time) error { return nil }
func (c *captureConn) SetReadDeadline(_ time.Time) error { return nil }
func (c *captureConn) SetWriteDeadline(_ time.Time) error { return nil }
type failWriteConn struct{}
func (failWriteConn) Read(_ []byte) (int, error) { return 0, io.EOF }
func (failWriteConn) Write(_ []byte) (int, error) { return 0, io.ErrClosedPipe }
func (failWriteConn) Close() error { return nil }
func (failWriteConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (failWriteConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (failWriteConn) SetDeadline(_ time.Time) error { return nil }
func (failWriteConn) SetReadDeadline(_ time.Time) error { return nil }
func (failWriteConn) SetWriteDeadline(_ time.Time) error { return nil }
type trackingBody struct {
data *bytes.Reader
readN int
closed bool
}
func (b *trackingBody) Read(p []byte) (int, error) {
n, err := b.data.Read(p)
b.readN += n
return n, err
}
func (b *trackingBody) Close() error {
b.closed = true
return nil
}
func TestWriteConnectEstablished(t *testing.T) {
t.Parallel()
t.Run("writes directly to raw conn", func(t *testing.T) {
t.Parallel()
conn := &captureConn{}
if err := writeConnectEstablished(conn, nil); err != nil {
t.Fatalf("writeConnectEstablished() error = %v", err)
}
if got, want := conn.String(), "HTTP/1.1 200 Connection Established\r\n\r\n"; got != want {
t.Fatalf("raw write = %q, want %q", got, want)
}
})
t.Run("writes and flushes with buffered readWriter", func(t *testing.T) {
t.Parallel()
conn := &captureConn{}
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bufio.NewWriter(conn))
if err := writeConnectEstablished(conn, rw); err != nil {
t.Fatalf("writeConnectEstablished() error = %v", err)
}
if got, want := conn.String(), "HTTP/1.1 200 Connection Established\r\n\r\n"; got != want {
t.Fatalf("buffered write = %q, want %q", got, want)
}
})
t.Run("returns flush error", func(t *testing.T) {
t.Parallel()
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bufio.NewWriter(errWriter{}))
err := writeConnectEstablished(&captureConn{}, rw)
if err == nil {
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
}
})
t.Run("returns buffered write error when writer already failed", func(t *testing.T) {
t.Parallel()
bw := bufio.NewWriter(errWriter{})
_ = bw.Flush() // set sticky error to force WriteString error path
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bw)
err := writeConnectEstablished(&captureConn{}, rw)
if err == nil {
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
}
})
t.Run("returns raw conn write error", func(t *testing.T) {
t.Parallel()
err := writeConnectEstablished(failWriteConn{}, nil)
if err == nil {
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
}
})
}
func TestDiscardAndCloseBody(t *testing.T) {
t.Parallel()
t.Run("nil body is safe", func(t *testing.T) {
t.Parallel()
discardAndCloseBody(nil)
})
t.Run("closes body and discards at most limit", func(t *testing.T) {
t.Parallel()
payload := bytes.Repeat([]byte("x"), maxDiscardBodyBytes+128)
body := &trackingBody{data: bytes.NewReader(payload)}
discardAndCloseBody(body)
if !body.closed {
t.Fatal("body was not closed")
}
if body.readN != maxDiscardBodyBytes {
t.Fatalf("bytes read = %d, want %d", body.readN, maxDiscardBodyBytes)
}
})
}

View File

@ -0,0 +1,235 @@
package proxy
import (
"context"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"termtap.dev/internal/model"
)
// NOTE: Run these tests with -race; they cover concurrent state.
type closeTrackingConn struct {
closed bool
mu sync.Mutex
}
func (c *closeTrackingConn) Read(_ []byte) (int, error) { return 0, net.ErrClosed }
func (c *closeTrackingConn) Write(b []byte) (int, error) { return len(b), nil }
func (c *closeTrackingConn) Close() error { c.mu.Lock(); c.closed = true; c.mu.Unlock(); return nil }
func (c *closeTrackingConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (c *closeTrackingConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (c *closeTrackingConn) SetDeadline(_ time.Time) error { return nil }
func (c *closeTrackingConn) SetReadDeadline(_ time.Time) error { return nil }
func (c *closeTrackingConn) SetWriteDeadline(_ time.Time) error { return nil }
func (c *closeTrackingConn) isClosed() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.closed
}
func TestNewProxyServer_Success(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
ch := make(chan model.Event, 16)
ps, err := NewProxyServer("127.0.0.1:0", ch)
if err != nil {
t.Fatalf("NewProxyServer() error = %v", err)
}
t.Cleanup(func() { Destroy(ps, ch) })
if ps.Listener == nil || *ps.Listener == nil {
t.Fatal("Listener is nil")
}
if ps.Server == nil {
t.Fatal("Server is nil")
}
if ps.Server.Handler == nil {
t.Fatal("Server.Handler is nil")
}
if !strings.HasPrefix(ps.Url, "http://") {
t.Fatalf("Url = %q, want http:// prefix", ps.Url)
}
if !ps.CAReady {
t.Fatal("CAReady = false, want true")
}
if ps.CACertPath == "" {
t.Fatal("CACertPath should not be empty")
}
if got, want := filepath.Dir(ps.CACertPath), filepath.Join(configRoot, caDirName); got != want {
t.Fatalf("CACertPath dir = %q, want %q", got, want)
}
if _, statErr := os.Stat(ps.CACertPath); statErr != nil {
t.Fatalf("expected CA cert on disk, stat error = %v", statErr)
}
if ps.Conns == nil {
t.Fatal("Conns map is nil")
}
}
func TestNewProxyServer_CACreatedFlagAcrossRuns(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
ch := make(chan model.Event, 16)
ps1, err := NewProxyServer("127.0.0.1:0", ch)
if err != nil {
t.Fatalf("first NewProxyServer() error = %v", err)
}
Destroy(ps1, ch)
ps2, err := NewProxyServer("127.0.0.1:0", ch)
if err != nil {
t.Fatalf("second NewProxyServer() error = %v", err)
}
t.Cleanup(func() { Destroy(ps2, ch) })
if !ps1.CACreated {
t.Fatal("first NewProxyServer should report CACreated=true")
}
if ps2.CACreated {
t.Fatal("second NewProxyServer should report CACreated=false when CA already exists")
}
}
func TestNewProxyServer_ErrorWhenCASetupFails(t *testing.T) {
t.Setenv("XDG_CONFIG_HOME", "")
t.Setenv("HOME", "")
ch := make(chan model.Event, 8)
ps, err := NewProxyServer("127.0.0.1:0", ch)
if err == nil {
Destroy(ps, ch)
t.Fatal("NewProxyServer() error = nil, want non-nil")
}
if ps != nil {
t.Fatalf("proxy server = %#v, want nil", ps)
}
}
func TestNewProxyServer_ErrorWhenListenFails(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
occupied, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen() error = %v", err)
}
t.Cleanup(func() { _ = occupied.Close() })
addr := occupied.Addr().String()
ch := make(chan model.Event, 8)
ps, gotErr := NewProxyServer(addr, ch)
if gotErr == nil {
Destroy(ps, ch)
t.Fatal("NewProxyServer() error = nil, want non-nil")
}
if ps != nil {
t.Fatalf("proxy server = %#v, want nil", ps)
}
}
func TestDestroy(t *testing.T) {
t.Parallel()
t.Run("nil-safe", func(t *testing.T) {
t.Parallel()
Destroy(nil, make(chan model.Event, 1))
})
t.Run("emits ProxyStopped when server exists", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 2)
ps := &model.ProxyServer{Server: &http.Server{}, Conns: make(map[net.Conn]struct{})}
Destroy(ps, ch)
select {
case ev := <-ch:
if ev.Type != model.EventTypeProxyStopped {
t.Fatalf("event type = %s, want %s", ev.Type, model.EventTypeProxyStopped)
}
case <-time.After(time.Second):
t.Fatal("timeout waiting for ProxyStopped event")
}
})
}
func TestConnectionTrackingHelpers(t *testing.T) {
t.Parallel()
t.Run("track and untrack are nil-safe", func(t *testing.T) {
t.Parallel()
trackConnection(nil, nil)
untrackConnection(nil, nil)
closeTrackedConnections(nil)
})
t.Run("track/untrack mutate connection set", func(t *testing.T) {
t.Parallel()
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
c := &closeTrackingConn{}
trackConnection(ps, c)
if _, ok := ps.Conns[c]; !ok {
t.Fatal("connection not tracked")
}
untrackConnection(ps, c)
if _, ok := ps.Conns[c]; ok {
t.Fatal("connection still tracked after untrack")
}
})
t.Run("closeTrackedConnections closes all tracked conns", func(t *testing.T) {
t.Parallel()
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
c1 := &closeTrackingConn{}
c2 := &closeTrackingConn{}
ps.Conns[c1] = struct{}{}
ps.Conns[c2] = struct{}{}
closeTrackedConnections(ps)
if !c1.isClosed() || !c2.isClosed() {
t.Fatalf("expected all tracked conns closed, got c1=%v c2=%v", c1.isClosed(), c2.isClosed())
}
})
}
func TestDestroy_ShutdownsServerContext(t *testing.T) {
t.Parallel()
// Basic smoke test: ensure a real server can be shut down via Destroy.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen error = %v", err)
}
srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("ok")) })}
go func() {
_ = srv.Serve(ln)
}()
ch := make(chan model.Event, 2)
ps := &model.ProxyServer{Server: srv, Conns: make(map[net.Conn]struct{})}
Destroy(ps, ch)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil && err != http.ErrServerClosed {
t.Fatalf("server should be closed after Destroy, got shutdown error %v", err)
}
}

View File

@ -0,0 +1,77 @@
package proxy
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"testing"
"time"
"termtap.dev/internal/model"
)
func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event {
t.Helper()
events := make([]model.Event, 0, n)
deadline := time.After(timeout)
for len(events) < n {
select {
case ev := <-ch:
events = append(events, ev)
case <-deadline:
t.Fatalf("timeout waiting for %d events, got %d", n, len(events))
}
}
return events
}
func hasEventType(events []model.Event, typ model.EventType) bool {
for _, ev := range events {
if ev.Type == typ {
return true
}
}
return false
}
func newTestCA(t *testing.T) *CertificateAuthority {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("GenerateKey() error = %v", err)
}
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "test-ca"},
NotBefore: now.Add(-time.Minute),
NotAfter: now.Add(time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
IsCA: true,
BasicConstraintsValid: true,
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("CreateCertificate() error = %v", err)
}
cert, err := x509.ParseCertificate(der)
if err != nil {
t.Fatalf("ParseCertificate() error = %v", err)
}
return &CertificateAuthority{
cert: cert,
key: key,
leafCert: make(map[string]*tls.Certificate),
}
}

View File

@ -0,0 +1,256 @@
package proxy
import (
"bufio"
"context"
"errors"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"testing"
"github.com/google/uuid"
)
type timeoutErr struct{}
func (timeoutErr) Error() string { return "timeout" }
func (timeoutErr) Timeout() bool { return true }
func (timeoutErr) Temporary() bool { return true }
type errWriter struct{}
func (errWriter) Write(_ []byte) (int, error) {
return 0, errors.New("write failed")
}
func TestCanDisplayContent(t *testing.T) {
t.Parallel()
tests := []struct {
name string
contentType string
want bool
}{
{name: "empty content type", contentType: "", want: false},
{name: "text type", contentType: "text/plain", want: true},
{name: "json type", contentType: "application/json", want: true},
{name: "xml suffix", contentType: "application/problem+xml", want: true},
{name: "unknown binary", contentType: "application/octet-stream", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := canDisplayContent(tt.contentType); got != tt.want {
t.Fatalf("canDisplayContent(%q) = %v, want %v", tt.contentType, got, tt.want)
}
})
}
}
func TestFormatHeaders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
headers http.Header
want string
}{
{
name: "empty returns none",
headers: http.Header{},
want: "<none>",
},
{
name: "sorts keys stably",
headers: http.Header{
"B-Key": {"b1"},
"A-Key": {"a1", "a2"},
},
want: `A-Key="a1,a2", B-Key="b1"`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := formatHeaders(tt.headers); got != tt.want {
t.Fatalf("formatHeaders() = %q, want %q", got, tt.want)
}
})
}
}
func TestGetEndOfUUID(t *testing.T) {
t.Parallel()
id := uuid.MustParse("123e4567-e89b-12d3-a456-426614174000")
if got, want := getEndOfUUID(id), "426614174000"; got != want {
t.Fatalf("getEndOfUUID() = %q, want %q", got, want)
}
}
func TestStatusFromUpstreamError(t *testing.T) {
t.Parallel()
newReq := func(ctx context.Context) *http.Request {
reqURL, err := url.Parse("http://example.com")
if err != nil {
t.Fatalf("url parse failed: %v", err)
}
return (&http.Request{Method: http.MethodGet, URL: reqURL}).WithContext(ctx)
}
tests := []struct {
name string
req *http.Request
resp *http.Response
err error
want int
}{
{
name: "prefers upstream response status",
req: newReq(context.Background()),
resp: &http.Response{StatusCode: http.StatusTeapot},
err: errors.New("ignored"),
want: http.StatusTeapot,
},
{
name: "context canceled maps to bad gateway",
req: func() *http.Request {
ctx, cancel := context.WithCancel(context.Background())
cancel()
return newReq(ctx)
}(),
err: context.Canceled,
want: http.StatusBadGateway,
},
{
name: "deadline exceeded maps to gateway timeout",
req: newReq(context.Background()),
err: context.DeadlineExceeded,
want: http.StatusGatewayTimeout,
},
{
name: "net timeout maps to gateway timeout",
req: newReq(context.Background()),
err: timeoutErr{},
want: http.StatusGatewayTimeout,
},
{
name: "default maps to bad gateway",
req: newReq(context.Background()),
err: errors.New("dial failed"),
want: http.StatusBadGateway,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := statusFromUpstreamError(tt.req, tt.resp, tt.err); got != tt.want {
t.Fatalf("statusFromUpstreamError() = %d, want %d", got, tt.want)
}
})
}
}
func TestNewUpstreamTransport(t *testing.T) {
t.Parallel()
got := newUpstreamTransport()
transport, ok := got.(*http.Transport)
if !ok {
t.Fatalf("newUpstreamTransport() type = %T, want *http.Transport", got)
}
if transport.Proxy != nil {
t.Fatal("newUpstreamTransport() Proxy must be nil")
}
}
func TestWritePlainHTTPError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
status int
writer *bufio.Writer
wantErr bool
wantStatus int
}{
{
name: "writes valid response and flushes",
status: http.StatusBadGateway,
writer: bufio.NewWriter(&strings.Builder{}),
wantErr: false,
wantStatus: http.StatusBadGateway,
},
{
name: "returns write error",
status: http.StatusBadGateway,
writer: bufio.NewWriter(errWriter{}),
wantErr: true,
},
{
name: "returns response write error when writer already failed",
status: http.StatusBadGateway,
writer: bufio.NewWriter(errWriter{}),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var sb strings.Builder
w := tt.writer
if !tt.wantErr {
w = bufio.NewWriter(&sb)
}
err := writePlainHTTPError(w, tt.status)
if tt.name == "returns response write error when writer already failed" {
_ = w.Flush() // set sticky writer error so resp.Write fails immediately
err = writePlainHTTPError(w, tt.status)
}
if (err != nil) != tt.wantErr {
t.Fatalf("writePlainHTTPError() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
return
}
resp, readErr := http.ReadResponse(bufio.NewReader(strings.NewReader(sb.String())), &http.Request{Method: http.MethodGet})
if readErr != nil {
t.Fatalf("ReadResponse() error = %v", readErr)
}
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Fatalf("status = %d, want %d", resp.StatusCode, tt.wantStatus)
}
if gotCT := resp.Header.Get("Content-Type"); gotCT != "text/plain; charset=utf-8" {
t.Fatalf("Content-Type = %q, want %q", gotCT, "text/plain; charset=utf-8")
}
wantBody := http.StatusText(tt.status)
if gotCL := resp.Header.Get("Content-Length"); gotCL != strconv.Itoa(len(wantBody)) {
t.Fatalf("Content-Length = %q, want %q", gotCL, strconv.Itoa(len(wantBody)))
}
body, bodyErr := io.ReadAll(resp.Body)
if bodyErr != nil {
t.Fatalf("ReadAll(body) error = %v", bodyErr)
}
if string(body) != wantBody {
t.Fatalf("body = %q, want %q", string(body), wantBody)
}
})
}
}

View File

@ -0,0 +1,361 @@
package tui
import (
"errors"
"net/http"
"testing"
"time"
tea "github.com/charmbracelet/bubbletea"
"github.com/google/uuid"
"termtap.dev/internal/model"
)
func TestNewModelDefaults(t *testing.T) {
t.Parallel()
ch := make(chan model.Event)
m := NewModel(ch, Controls{})
if m.channel != ch {
t.Fatal("channel not set")
}
if len(m.events) != 0 || len(m.requests) != 0 {
t.Fatal("events/requests should initialize empty")
}
if m.width != 0 || m.height != 0 {
t.Fatal("width/height should initialize zero")
}
if m.showEvents || m.showStd || m.showSearch || m.restarting {
t.Fatal("toggle flags should initialize false")
}
}
func TestInitBatchesEventAndTick(t *testing.T) {
t.Parallel()
ch := make(chan model.Event)
m := NewModel(ch, Controls{})
cmd := m.Init()
if cmd == nil {
t.Fatal("Init() returned nil cmd")
}
if _, ok := cmd().(tea.BatchMsg); !ok {
t.Fatalf("Init cmd message type = %T, want tea.BatchMsg", cmd())
}
}
func TestWaitForEvent(t *testing.T) {
t.Parallel()
t.Run("returns EventMsg when channel has value", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 1)
ch <- model.Event{Type: model.EventTypeWarn, Body: "hello"}
msg := waitForEvent(ch)()
ev, ok := msg.(EventMsg)
if !ok {
t.Fatalf("msg type = %T, want EventMsg", msg)
}
if ev.value.Body != "hello" {
t.Fatalf("event body = %q, want %q", ev.value.Body, "hello")
}
})
t.Run("returns ErrMsg when channel closed", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event)
close(ch)
msg := waitForEvent(ch)()
if _, ok := msg.(ErrMsg); !ok {
t.Fatalf("msg type = %T, want ErrMsg", msg)
}
})
}
func TestMessagesCommands(t *testing.T) {
t.Parallel()
t.Run("restartCmd nil restart returns nil", func(t *testing.T) {
t.Parallel()
if cmd := restartCmd(nil); cmd != nil {
t.Fatal("restartCmd(nil) should return nil")
}
})
t.Run("restartCmd wraps restart result", func(t *testing.T) {
t.Parallel()
wantErr := errors.New("boom")
msg := restartCmd(func() error { return wantErr })()
rm, ok := msg.(RestartResultMsg)
if !ok {
t.Fatalf("msg type = %T, want RestartResultMsg", msg)
}
if !errors.Is(rm.err, wantErr) {
t.Fatalf("restart result error = %v, want %v", rm.err, wantErr)
}
})
t.Run("tickCmd emits TickMsg", func(t *testing.T) {
t.Parallel()
cmd := tickCmd()
if cmd == nil {
t.Fatal("tickCmd returned nil")
}
msgCh := make(chan tea.Msg, 1)
go func() { msgCh <- cmd() }()
select {
case msg := <-msgCh:
if _, ok := msg.(TickMsg); !ok {
t.Fatalf("msg type = %T, want TickMsg", msg)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("timeout waiting for tick message")
}
})
}
func TestUpdate(t *testing.T) {
t.Parallel()
t.Run("window size updates dimensions", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
next, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
got := next.(Model)
if got.width != 120 || got.height != 40 {
t.Fatalf("dimensions = (%d,%d), want (120,40)", got.width, got.height)
}
})
t.Run("tick updates now and reschedules only with pending requests", func(t *testing.T) {
t.Parallel()
now := time.Now()
m1 := NewModel(make(chan model.Event), Controls{})
next1, cmd1 := m1.Update(TickMsg{Now: now})
got1 := next1.(Model)
if !got1.now.Equal(now) {
t.Fatal("tick should update now")
}
if cmd1 != nil {
t.Fatal("tick without pending requests should not reschedule")
}
m2 := NewModel(make(chan model.Event), Controls{})
m2.requests = append(m2.requests, model.Request{ID: uuid.New(), Pending: true})
_, cmd2 := m2.Update(TickMsg{Now: now})
if cmd2 == nil {
t.Fatal("tick with pending requests should reschedule")
}
})
t.Run("key handling toggles and quit", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
next, quitCmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("q")})
_ = next
if quitCmd == nil {
t.Fatal("q should return quit cmd")
}
nextCtrlC, quitCtrlC := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
_ = nextCtrlC
if quitCtrlC == nil {
t.Fatal("ctrl+c should return quit cmd")
}
next2, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("e")})
if !next2.(Model).showEvents {
t.Fatal("e should toggle showEvents")
}
next3, _ := next2.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("o")})
if !next3.(Model).showStd {
t.Fatal("o should toggle showStd")
}
next4, _ := next3.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("/")})
if !next4.(Model).showSearch {
t.Fatal("/ should enable search")
}
next5, _ := next4.(Model).Update(tea.KeyMsg{Type: tea.KeyEsc})
if next5.(Model).showSearch {
t.Fatal("esc should disable search")
}
next6, cmd6 := next5.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("x")})
if cmd6 != nil {
t.Fatal("unknown key should not return command")
}
if next6.(Model).showEvents != next5.(Model).showEvents ||
next6.(Model).showStd != next5.(Model).showStd ||
next6.(Model).showSearch != next5.(Model).showSearch {
t.Fatal("unknown key should not alter toggle state")
}
})
t.Run("restart key guarded by state and control fn", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlR})
if cmd != nil {
t.Fatal("ctrl+r with nil restart control should return nil cmd")
}
if next.(Model).restarting {
t.Fatal("restarting should remain false when restart control missing")
}
m2 := NewModel(make(chan model.Event), Controls{Restart: func() error { return nil }})
next2, cmd2 := m2.Update(tea.KeyMsg{Type: tea.KeyCtrlR})
if cmd2 == nil {
t.Fatal("ctrl+r with restart control should return cmd")
}
if !next2.(Model).restarting {
t.Fatal("restarting should be true after ctrl+r")
}
next3, cmd3 := next2.(Model).Update(tea.KeyMsg{Type: tea.KeyCtrlR})
if cmd3 != nil {
t.Fatal("ctrl+r while restarting should return nil cmd")
}
if !next3.(Model).restarting {
t.Fatal("restarting should stay true while guarded")
}
})
t.Run("ErrMsg pushes warning event", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
next, _ := m.Update(ErrMsg{err: errors.New("closed")})
got := next.(Model)
if len(got.events) != 1 {
t.Fatalf("event len = %d, want 1", len(got.events))
}
if got.events[0].Type != model.EventTypeWarn {
t.Fatalf("event type = %s, want %s", got.events[0].Type, model.EventTypeWarn)
}
})
t.Run("RestartResultMsg clears restarting and warns on error", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
m.restarting = true
next1, _ := m.Update(RestartResultMsg{err: nil})
if next1.(Model).restarting {
t.Fatal("restarting should clear on restart result")
}
next2, _ := m.Update(RestartResultMsg{err: errors.New("fail")})
got2 := next2.(Model)
if len(got2.events) != 1 || got2.events[0].Type != model.EventTypeWarn {
t.Fatalf("expected warn event on restart error, got %#v", got2.events)
}
})
t.Run("EventMsg updates events and requests", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 2)
m := NewModel(ch, Controls{})
reqID := uuid.New()
startEv := EventMsg{value: model.Event{Type: model.EventTypeRequestStarted, Request: model.Request{ID: reqID, Method: http.MethodGet, Pending: true}}}
next1, cmd1 := m.Update(startEv)
if cmd1 == nil {
t.Fatal("EventMsg should return wait/tick cmd")
}
got1 := next1.(Model)
if len(got1.events) != 1 || len(got1.requests) != 1 {
t.Fatalf("expected one event and one request, got events=%d requests=%d", len(got1.events), len(got1.requests))
}
finishReq := got1.requests[0]
finishReq.Pending = false
finishReq.Status = 200
finishEv := EventMsg{value: model.Event{Type: model.EventTypeRequestFinished, Request: finishReq}}
next2, cmd2 := got1.Update(finishEv)
got2 := next2.(Model)
if got2.requests[0].Pending {
t.Fatal("request should be updated to non-pending")
}
if got2.requests[0].Status != 200 {
t.Fatalf("request status = %d, want 200", got2.requests[0].Status)
}
if cmd2 == nil {
t.Fatal("expected waitForEvent cmd after finished request")
}
ch <- model.Event{Type: model.EventTypeWarn, Body: "next"}
msg := cmd2()
if _, ok := msg.(EventMsg); !ok {
t.Fatalf("cmd2 message type = %T, want EventMsg", msg)
}
})
}
func TestModelHelpers(t *testing.T) {
t.Parallel()
t.Run("pushEvent trims to maxEvents", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
for i := 0; i < maxEvents+5; i++ {
m.pushEvent(model.Event{Body: "x"})
}
if len(m.events) != maxEvents {
t.Fatalf("events len = %d, want %d", len(m.events), maxEvents)
}
})
t.Run("createRequest ignores CONNECT and trims", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
m.createRequest(model.Request{Method: http.MethodConnect})
if len(m.requests) != 0 {
t.Fatal("CONNECT request should be ignored")
}
for i := 0; i < maxRequests+3; i++ {
m.createRequest(model.Request{ID: uuid.New(), Method: http.MethodGet})
}
if len(m.requests) != maxRequests {
t.Fatalf("requests len = %d, want %d", len(m.requests), maxRequests)
}
})
t.Run("updateRequest updates only matching request", func(t *testing.T) {
t.Parallel()
id1 := uuid.New()
id2 := uuid.New()
m := NewModel(make(chan model.Event), Controls{})
m.requests = []model.Request{{ID: id1, Status: 100}, {ID: id2, Status: 101}}
m.updateRequest(model.Request{ID: id2, Status: 202})
if m.requests[0].Status != 100 || m.requests[1].Status != 202 {
t.Fatalf("unexpected statuses after update: %#v", m.requests)
}
})
t.Run("hasPendingRequests true and false", func(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
if m.hasPendingRequests() {
t.Fatal("empty model should not have pending requests")
}
m.requests = []model.Request{{ID: uuid.New(), Pending: false}, {ID: uuid.New(), Pending: true}}
if !m.hasPendingRequests() {
t.Fatal("expected pending requests to be true")
}
})
}

View File

@ -0,0 +1,245 @@
package tui
import (
"strings"
"testing"
"time"
"github.com/google/uuid"
"termtap.dev/internal/model"
)
func TestFormatDuration(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in time.Duration
want string
}{
{name: "pending zero", in: 0, want: "PENDING"},
{name: "microseconds", in: 750 * time.Microsecond, want: "750us"},
{name: "milliseconds", in: 20 * time.Millisecond, want: "20ms"},
{name: "seconds", in: 11 * time.Second, want: "11.00s"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := formatDuration(tt.in); got != tt.want {
t.Fatalf("formatDuration(%v) = %q, want %q", tt.in, got, tt.want)
}
})
}
}
func TestTruncate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
s string
max int
want string
}{
{name: "short unchanged", s: "abc", max: 3, want: "abc"},
{name: "max small no ellipsis", s: "abcdef", max: 3, want: "abc"},
{name: "ellipsis", s: "abcdef", max: 5, want: "ab..."},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
if got := truncate(tt.s, tt.max); got != tt.want {
t.Fatalf("truncate(%q,%d) = %q, want %q", tt.s, tt.max, got, tt.want)
}
})
}
}
func TestClampRendered(t *testing.T) {
t.Parallel()
if got := clampRendered("abcdef", 0); got != "" {
t.Fatalf("clampRendered max=0 = %q, want empty", got)
}
if got := clampRendered("abc", 10); got != "abc" {
t.Fatalf("clampRendered no truncation = %q, want %q", got, "abc")
}
if got := clampRendered("abcdef", 4); !strings.Contains(got, "...") {
t.Fatalf("clampRendered truncation should include ellipsis, got %q", got)
}
}
func TestGetEventColor(t *testing.T) {
t.Parallel()
theme := newTheme()
tests := []struct {
name string
typ model.EventType
want string
}{
{name: "session", typ: model.EventTypeSessionStarted, want: theme.EventSession.Render("x")},
{name: "proxy", typ: model.EventTypeProxyStarted, want: theme.EventProxy.Render("x")},
{name: "request in flight", typ: model.EventTypeRequestStarted, want: theme.EventRequestInFlight.Render("x")},
{name: "request success", typ: model.EventTypeRequestFinished, want: theme.EventSuccess.Render("x")},
{name: "warn", typ: model.EventTypeWarn, want: theme.EventWarn.Render("x")},
{name: "error", typ: model.EventTypeRequestFailed, want: theme.EventError.Render("x")},
{name: "fatal", typ: model.EventTypeFatal, want: theme.EventFatal.Render("x")},
{name: "default", typ: model.EventType("UnknownType"), want: theme.EventDefault.Render("x")},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := getEventColor(theme, tt.typ).Render("x")
if got != tt.want {
t.Fatalf("unexpected style for %s", tt.typ)
}
})
}
}
func TestViewAndPaneStructure(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
t.Run("view returns raw pane when unset size", func(t *testing.T) {
t.Parallel()
got := m.View()
if got != m.renderAppPane() {
t.Fatal("View should return raw pane when width/height are unset")
}
})
t.Run("renderAppPane line count matches height", func(t *testing.T) {
t.Parallel()
m2 := m
m2.width = 80
m2.height = 12
got := m2.renderAppPane()
if got == "height of request and details did not match" || got == "height of screen does not match terminal height" {
t.Fatalf("unexpected renderAppPane invariant error: %q", got)
}
if lines := strings.Count(got, "\n") + 1; lines != m2.height {
t.Fatalf("line count = %d, want %d", lines, m2.height)
}
})
t.Run("renderAppPane supports toggles", func(t *testing.T) {
t.Parallel()
m2 := m
m2.width = 90
m2.height = 14
m2.showEvents = true
m2.showStd = true
m2.showSearch = true
got := m2.renderAppPane()
if got == "height of request and details did not match" || got == "height of screen does not match terminal height" {
t.Fatalf("unexpected renderAppPane invariant error with toggles: %q", got)
}
})
t.Run("view applies configured terminal height", func(t *testing.T) {
t.Parallel()
m2 := m
m2.width = 70
m2.height = 10
got := m2.View()
if lines := strings.Count(got, "\n") + 1; lines < m2.height {
t.Fatalf("View line count = %d, want at least %d", lines, m2.height)
}
})
}
func TestPaneRenderersAndStatusBar(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
m.width = 100
m.height = 12
m.requests = []model.Request{
{ID: uuid.New(), Method: "GET", Host: "a", URL: "/a", Status: 200, Duration: 5 * time.Millisecond},
{ID: uuid.New(), Method: "POST", Host: "b", URL: "/b", Status: 500, Duration: 10 * time.Millisecond, Failed: true},
}
m.events = []model.Event{
{Type: model.EventTypeWarn, Body: "warn"},
{Type: model.EventTypeProcessStdout, Body: "out"},
{Type: model.EventTypeProcessStderr, Body: "err"},
}
status := m.renderStatusBar(100)
if !strings.Contains(status, "2 reqs") || !strings.Contains(status, "1 err") {
t.Fatalf("status bar missing expected counters: %q", status)
}
search := m.renderSearchPane(20, 3)
if len(search) != 3 {
t.Fatalf("search pane len = %d, want 3", len(search))
}
for i, line := range search {
if len(line) != 20 {
t.Fatalf("search pane line %d len = %d, want %d", i, len(line), 20)
}
}
requests := m.renderRequestPane(50, 4)
if len(requests) != 4 {
t.Fatalf("request pane len = %d, want 4", len(requests))
}
details := m.renderDetailsPane(30, 4)
if len(details) != 4 {
t.Fatalf("details pane len = %d, want 4", len(details))
}
events := m.renderEventsPane(60, 4)
if len(events) != 4 {
t.Fatalf("events pane len = %d, want 4", len(events))
}
if strings.Contains(strings.Join(events, "\n"), "out") || strings.Contains(strings.Join(events, "\n"), "err") {
t.Fatal("events pane should filter stdout/stderr events")
}
std := m.renderStdPane(60, 4)
if len(std) != 4 {
t.Fatalf("std pane len = %d, want 4", len(std))
}
joined := strings.Join(std, "\n")
if !strings.Contains(joined, "out") || !strings.Contains(joined, "err") {
t.Fatal("std pane should include stdout/stderr logs")
}
}
func TestRenderEventsPane_ErrorAndPIDBranches(t *testing.T) {
t.Parallel()
m := NewModel(make(chan model.Event), Controls{})
m.events = []model.Event{
{Type: model.EventTypeWarn, Body: "old"},
{Type: model.EventTypeRequestFailed, Body: "failed body", PID: 123, Time: time.Now()},
{Type: model.EventTypeFatal, Body: "fatal body", Time: time.Now()},
}
lines := m.renderEventsPane(60, 3)
if len(lines) != 3 {
t.Fatalf("events pane len = %d, want 3", len(lines))
}
joined := strings.Join(lines, "\n")
if !strings.Contains(joined, "123") {
t.Fatalf("expected PID to appear in events pane, got: %q", joined)
}
if !strings.Contains(joined, "failed body") {
t.Fatalf("expected failed body to appear in events pane, got: %q", joined)
}
if !strings.Contains(joined, "fatal body") {
t.Fatalf("expected fatal body to appear in events pane, got: %q", joined)
}
}