Merge branch 'feature/testing'
This commit is contained in:
commit
51d526c2fe
45
TEST_COVERAGE_SUMMARY.md
Normal file
45
TEST_COVERAGE_SUMMARY.md
Normal file
@ -0,0 +1,45 @@
|
||||
# termtap Test Coverage Summary
|
||||
|
||||
Generated from:
|
||||
|
||||
```bash
|
||||
go test -coverprofile=/tmp/termtap.cover ./...
|
||||
go tool cover -func=/tmp/termtap.cover
|
||||
```
|
||||
|
||||
## Package coverage snapshot
|
||||
|
||||
| Package | Coverage |
|
||||
|---|---:|
|
||||
| `cmd/tap` | 100.0% |
|
||||
| `internal/app` | 98.1% |
|
||||
| `internal/cli` | 93.4% |
|
||||
| `internal/process` | 95.8% |
|
||||
| `internal/proxy` | 90.0% |
|
||||
| `internal/tui` | 96.2% |
|
||||
| `examples/echo` | 0.0% (example app; intentionally not covered) |
|
||||
| `internal/model` | no test files (pure data structs) |
|
||||
|
||||
Total statements in module: **77.9%**.
|
||||
|
||||
## Notable lower-coverage targets (production code)
|
||||
|
||||
- `internal/proxy/handlers.go:handleConnect` — 75.9%
|
||||
- `internal/proxy/certs.go:writeFileAtomically` — 76.2%
|
||||
- `internal/proxy/certs.go:load` — 90.5%
|
||||
- `internal/proxy/certs.go:create` — 80.8%
|
||||
- `internal/proxy/certs.go:IsTrustedBySystem` — 76.9%
|
||||
- `internal/cli/run.go:runCert` — 87.5%
|
||||
|
||||
## Interpretation
|
||||
|
||||
- Core runtime paths (`internal/app`, `internal/process`, `internal/tui`) are high confidence.
|
||||
- Proxy package has broad behavior coverage including HTTP and HTTPS MITM integration flow, and now clears 90% package coverage.
|
||||
- CLI command routing and fatal/stdout/stderr seams are covered, including `Run` success/error branches.
|
||||
- TUI pane rendering coverage now includes error/PID branch behavior.
|
||||
|
||||
## Next optional improvements
|
||||
|
||||
1. Add deeper CONNECT tunnel write/flush/read failure permutations inside `handleConnect` loop.
|
||||
2. Add additional deterministic edge cases for `IsTrustedBySystem` non-unknown-authority verify failures.
|
||||
3. Optionally add non-unix signal file coverage in CI matrix (currently unix-focused).
|
||||
49
cmd/tap/main_test.go
Normal file
49
cmd/tap/main_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMain_SmokeInvokesCLIRun(t *testing.T) {
|
||||
origArgs := os.Args
|
||||
t.Cleanup(func() { os.Args = origArgs })
|
||||
os.Args = []string{"tap", "invalid"}
|
||||
|
||||
origStderr := os.Stderr
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("stderr pipe error: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
os.Stderr = origStderr
|
||||
_ = r.Close()
|
||||
})
|
||||
os.Stderr = w
|
||||
|
||||
outCh := make(chan string, 1)
|
||||
go func() {
|
||||
var buf bytes.Buffer
|
||||
_, _ = io.Copy(&buf, r)
|
||||
outCh <- buf.String()
|
||||
}()
|
||||
|
||||
main()
|
||||
|
||||
_ = w.Close()
|
||||
os.Stderr = origStderr
|
||||
var got string
|
||||
select {
|
||||
case got = <-outCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for stderr capture")
|
||||
}
|
||||
|
||||
if !strings.Contains(got, "usage:") {
|
||||
t.Fatalf("stderr missing usage output, got: %q", got)
|
||||
}
|
||||
}
|
||||
109
internal/app/integration_session_test.go
Normal file
109
internal/app/integration_session_test.go
Normal file
@ -0,0 +1,109 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
// NOTE: Run with -race; this validates cross-component concurrency.
|
||||
|
||||
func TestSessionIntegration_LifecycleAndRequestEvents(t *testing.T) {
|
||||
addr := freeTCPAddr(t)
|
||||
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
|
||||
if err != nil {
|
||||
t.Fatalf("StartSession() error = %v", err)
|
||||
}
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
startupEvents := collectUntilTypes(t, s.Events, []model.EventType{
|
||||
model.EventTypeProxyStarting,
|
||||
model.EventTypeProcessStarting,
|
||||
model.EventTypeProcessStarted,
|
||||
}, 3*time.Second)
|
||||
|
||||
proxyURL, err := url.Parse(s.proxy.Url)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse(proxy) error = %v", err)
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)},
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
resp, err := client.Get(upstream.URL + "/session")
|
||||
if err != nil {
|
||||
t.Fatalf("proxy request error = %v", err)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
requestEvents := collectUntilTypes(t, s.Events, []model.EventType{
|
||||
model.EventTypeRequestStarted,
|
||||
model.EventTypeRequestFinished,
|
||||
}, 3*time.Second)
|
||||
|
||||
s.Stop()
|
||||
select {
|
||||
case <-s.proc.Done:
|
||||
case <-time.After(4 * time.Second):
|
||||
t.Fatal("timeout waiting for process stop")
|
||||
}
|
||||
|
||||
shutdownEvents := collectUntilTypes(t, s.Events, []model.EventType{
|
||||
model.EventTypeProxyStopped,
|
||||
model.EventTypeProcessSignaled,
|
||||
model.EventTypeProcessExited,
|
||||
}, 4*time.Second)
|
||||
|
||||
if !isBefore(startupEvents, model.EventTypeProcessStarting, model.EventTypeProcessStarted) {
|
||||
t.Fatalf("expected %s before %s in startup events: %#v", model.EventTypeProcessStarting, model.EventTypeProcessStarted, startupEvents)
|
||||
}
|
||||
if !isBefore(requestEvents, model.EventTypeRequestStarted, model.EventTypeRequestFinished) {
|
||||
t.Fatalf("expected %s before %s in request events: %#v", model.EventTypeRequestStarted, model.EventTypeRequestFinished, requestEvents)
|
||||
}
|
||||
if !isBefore(shutdownEvents, model.EventTypeProcessSignaled, model.EventTypeProcessExited) {
|
||||
t.Fatalf("expected %s before %s in shutdown events: %#v", model.EventTypeProcessSignaled, model.EventTypeProcessExited, shutdownEvents)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionIntegration_RestartProcessEmitsLifecycleEvents(t *testing.T) {
|
||||
addr := freeTCPAddr(t)
|
||||
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
|
||||
if err != nil {
|
||||
t.Fatalf("StartSession() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { s.Stop() })
|
||||
|
||||
_ = collectUntilTypes(t, s.Events, []model.EventType{
|
||||
model.EventTypeProcessStarted,
|
||||
model.EventTypeProxyStarting,
|
||||
}, 3*time.Second)
|
||||
|
||||
if err := s.RestartProcess(); err != nil {
|
||||
t.Fatalf("RestartProcess() error = %v", err)
|
||||
}
|
||||
|
||||
events := collectUntilTypes(t, s.Events, []model.EventType{
|
||||
model.EventTypeProcessRestarting,
|
||||
model.EventTypeProcessSignaled,
|
||||
model.EventTypeProcessExited,
|
||||
model.EventTypeProcessStarting,
|
||||
model.EventTypeProcessStarted,
|
||||
}, 4*time.Second)
|
||||
|
||||
if !isBefore(events, model.EventTypeProcessRestarting, model.EventTypeProcessStarted) {
|
||||
t.Fatalf("expected restarting before process started, got %#v", events)
|
||||
}
|
||||
}
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@ -11,6 +12,10 @@ import (
|
||||
"termtap.dev/internal/process"
|
||||
)
|
||||
|
||||
var killEscalationDelay = 1500 * time.Millisecond
|
||||
var scheduleKillEscalation = time.AfterFunc
|
||||
var killEscalationMu sync.RWMutex
|
||||
|
||||
func StartProcess(cmd model.Command, addr string, ch chan<- model.Event) (*model.Process, error) {
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
@ -44,12 +49,16 @@ func StopProcess(proc *model.Process, ch chan<- model.Event, sig syscall.Signal)
|
||||
|
||||
_ = process.SignalProcess(proc.Exec, sig)
|
||||
|
||||
go func() {
|
||||
time.Sleep(1500 * time.Millisecond)
|
||||
killEscalationMu.RLock()
|
||||
delay := killEscalationDelay
|
||||
scheduler := scheduleKillEscalation
|
||||
killEscalationMu.RUnlock()
|
||||
|
||||
scheduler(delay, func() {
|
||||
if process.ProcessAlive(proc.Exec) {
|
||||
_ = process.SignalProcess(proc.Exec, syscall.SIGKILL)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func waitForProcessExit(proc *model.Process, ch chan<- model.Event) {
|
||||
|
||||
223
internal/app/process_test.go
Normal file
223
internal/app/process_test.go
Normal file
@ -0,0 +1,223 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
func TestStartProcess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("starts process and marks running", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 32)
|
||||
proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 0.2"}}, "127.0.0.1:8080", ch)
|
||||
if err != nil {
|
||||
t.Fatalf("StartProcess() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
StopProcess(proc, ch, syscall.SIGTERM)
|
||||
select {
|
||||
case <-proc.Done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting process done in cleanup")
|
||||
}
|
||||
})
|
||||
|
||||
if proc == nil || proc.Exec == nil {
|
||||
t.Fatal("StartProcess() returned nil process/exec")
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 2, time.Second)
|
||||
if !hasType(events, model.EventTypeProcessStarting) {
|
||||
t.Fatalf("missing %s event", model.EventTypeProcessStarting)
|
||||
}
|
||||
if !hasType(events, model.EventTypeProcessStarted) {
|
||||
t.Fatalf("missing %s event", model.EventTypeProcessStarted)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns error when exec start fails", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
proc, err := StartProcess(model.Command{Name: "definitely-not-a-real-command"}, "127.0.0.1:8080", ch)
|
||||
if err == nil {
|
||||
if proc != nil && proc.Exec != nil && proc.Exec.Process != nil {
|
||||
_ = proc.Exec.Process.Kill()
|
||||
}
|
||||
t.Fatal("StartProcess() error = nil, want non-nil")
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 1, time.Second)
|
||||
if !hasType(events, model.EventTypeProcessStarting) {
|
||||
t.Fatalf("missing %s event", model.EventTypeProcessStarting)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStopProcess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("nil guards", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
StopProcess(nil, make(chan model.Event, 1), syscall.SIGTERM)
|
||||
StopProcess(&model.Process{}, make(chan model.Event, 1), syscall.SIGTERM)
|
||||
StopProcess(&model.Process{Exec: &exec.Cmd{}}, make(chan model.Event, 1), syscall.SIGTERM)
|
||||
})
|
||||
|
||||
t.Run("emits signaled event", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 32)
|
||||
proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, "127.0.0.1:8080", ch)
|
||||
if err != nil {
|
||||
t.Fatalf("StartProcess() error = %v", err)
|
||||
}
|
||||
|
||||
StopProcess(proc, ch, syscall.SIGTERM)
|
||||
if _, ok := waitForEventType(t, ch, model.EventTypeProcessSignaled, 2*time.Second); !ok {
|
||||
t.Fatalf("did not receive %s event", model.EventTypeProcessSignaled)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-proc.Done:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for process to exit after signal")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("schedules deterministic kill escalation hook", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
origDelay := killEscalationDelay
|
||||
origScheduler := scheduleKillEscalation
|
||||
defer func() {
|
||||
killEscalationMu.Lock()
|
||||
killEscalationDelay = origDelay
|
||||
scheduleKillEscalation = origScheduler
|
||||
killEscalationMu.Unlock()
|
||||
}()
|
||||
|
||||
ch := make(chan model.Event, 32)
|
||||
proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, "127.0.0.1:8080", ch)
|
||||
if err != nil {
|
||||
t.Fatalf("StartProcess() error = %v", err)
|
||||
}
|
||||
|
||||
killEscalationMu.Lock()
|
||||
killEscalationDelay = 25 * time.Millisecond
|
||||
scheduled := make(chan time.Duration, 1)
|
||||
scheduleKillEscalation = func(d time.Duration, fn func()) *time.Timer {
|
||||
scheduled <- d
|
||||
go fn()
|
||||
return nil
|
||||
}
|
||||
killEscalationMu.Unlock()
|
||||
|
||||
StopProcess(proc, ch, syscall.SIGTERM)
|
||||
|
||||
select {
|
||||
case d := <-scheduled:
|
||||
if d != killEscalationDelay {
|
||||
t.Fatalf("scheduled delay = %v, want %v", d, killEscalationDelay)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("kill escalation was not scheduled")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-proc.Done:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for process to exit after deterministic escalation")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForProcessExit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("nil guards are no-op", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
waitForProcessExit(nil, make(chan model.Event, 1))
|
||||
waitForProcessExit(&model.Process{}, make(chan model.Event, 1))
|
||||
})
|
||||
|
||||
t.Run("normal exit emits process exited and closes done", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cmd := exec.Command("sh", "-c", "exit 0")
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("cmd.Start() error = %v", err)
|
||||
}
|
||||
|
||||
proc := &model.Process{Exec: cmd, Running: true, Done: make(chan struct{})}
|
||||
ch := make(chan model.Event, 8)
|
||||
|
||||
waitForProcessExit(proc, ch)
|
||||
|
||||
if _, ok := waitForEventType(t, ch, model.EventTypeProcessExited, time.Second); !ok {
|
||||
t.Fatalf("did not receive %s event", model.EventTypeProcessExited)
|
||||
}
|
||||
select {
|
||||
case <-proc.Done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Done channel was not closed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("exit error path carries exit code", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cmd := exec.Command("sh", "-c", "exit 7")
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("cmd.Start() error = %v", err)
|
||||
}
|
||||
|
||||
proc := &model.Process{Exec: cmd, Running: true, Done: make(chan struct{})}
|
||||
ch := make(chan model.Event, 8)
|
||||
|
||||
waitForProcessExit(proc, ch)
|
||||
|
||||
events := drainEvents(t, ch, 2, time.Second)
|
||||
found := false
|
||||
for _, ev := range events {
|
||||
if ev.Type == model.EventTypeProcessExited && ev.ExitCode == 7 {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected ProcessExited with exit code 7, got %#v", events)
|
||||
}
|
||||
select {
|
||||
case <-proc.Done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Done channel was not closed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unexpected wait failure emits fatal", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
proc := &model.Process{Exec: &exec.Cmd{}, Done: make(chan struct{})}
|
||||
ch := make(chan model.Event, 8)
|
||||
|
||||
waitForProcessExit(proc, ch)
|
||||
|
||||
events := drainEvents(t, ch, 1, time.Second)
|
||||
if events[0].Type != model.EventTypeFatal {
|
||||
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeFatal)
|
||||
}
|
||||
select {
|
||||
case <-proc.Done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Done channel was not closed")
|
||||
}
|
||||
})
|
||||
}
|
||||
183
internal/app/proxy_test.go
Normal file
183
internal/app/proxy_test.go
Normal file
@ -0,0 +1,183 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
type staticErrListener struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (l *staticErrListener) Accept() (net.Conn, error) { return nil, l.err }
|
||||
func (l *staticErrListener) Close() error { return nil }
|
||||
func (l *staticErrListener) Addr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080}
|
||||
}
|
||||
|
||||
func TestStartProxy_NilGuards(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
StartProxy(nil, make(chan model.Event, 1))
|
||||
StartProxy(&model.ProxyServer{}, make(chan model.Event, 1))
|
||||
StartProxy(&model.ProxyServer{Server: &http.Server{}}, make(chan model.Event, 1))
|
||||
}
|
||||
|
||||
func TestStartProxy_EmitsStartingAndWarnWhenUntrustedCA(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
listenErr := errors.New("accept failed")
|
||||
var ln net.Listener = &staticErrListener{err: listenErr}
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
ps := &model.ProxyServer{
|
||||
Listener: &ln,
|
||||
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
|
||||
CAReady: true,
|
||||
CATrusted: false,
|
||||
CACreated: true,
|
||||
CACertPath: "/tmp/test-ca.pem",
|
||||
}
|
||||
|
||||
StartProxy(ps, ch)
|
||||
|
||||
events := drainEvents(t, ch, 3, time.Second)
|
||||
if !hasType(events, model.EventTypeProxyStarting) {
|
||||
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
|
||||
}
|
||||
if !hasType(events, model.EventTypeWarn) {
|
||||
t.Fatalf("missing %s event", model.EventTypeWarn)
|
||||
}
|
||||
if !containsBody(events, "generated HTTPS interception CA") {
|
||||
t.Fatalf("expected generated-CA warning body, got events: %#v", events)
|
||||
}
|
||||
if !hasType(events, model.EventTypeFatal) {
|
||||
t.Fatalf("missing %s event", model.EventTypeFatal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartProxy_WarnBodyForExistingUntrustedCA(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
listenErr := errors.New("accept failed")
|
||||
var ln net.Listener = &staticErrListener{err: listenErr}
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
ps := &model.ProxyServer{
|
||||
Listener: &ln,
|
||||
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
|
||||
CAReady: true,
|
||||
CATrusted: false,
|
||||
CACreated: false,
|
||||
CACertPath: "/tmp/test-ca.pem",
|
||||
}
|
||||
|
||||
StartProxy(ps, ch)
|
||||
|
||||
events := drainEvents(t, ch, 3, time.Second)
|
||||
if !containsBody(events, "HTTPS interception CA available at") {
|
||||
t.Fatalf("expected existing-CA warning body, got events: %#v", events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartProxy_NoCAWarningWhenNotReady(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
listenErr := errors.New("accept failed")
|
||||
var ln net.Listener = &staticErrListener{err: listenErr}
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
ps := &model.ProxyServer{
|
||||
Listener: &ln,
|
||||
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
|
||||
CAReady: false,
|
||||
CATrusted: false,
|
||||
}
|
||||
|
||||
StartProxy(ps, ch)
|
||||
|
||||
events := drainEvents(t, ch, 2, time.Second)
|
||||
if !hasType(events, model.EventTypeProxyStarting) {
|
||||
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
|
||||
}
|
||||
if hasType(events, model.EventTypeWarn) {
|
||||
t.Fatalf("unexpected %s event when CA is not ready: %#v", model.EventTypeWarn, events)
|
||||
}
|
||||
if !hasType(events, model.EventTypeFatal) {
|
||||
t.Fatalf("missing %s event", model.EventTypeFatal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartProxy_NoCAWarningWhenTrusted(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
listenErr := errors.New("accept failed")
|
||||
var ln net.Listener = &staticErrListener{err: listenErr}
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
ps := &model.ProxyServer{
|
||||
Listener: &ln,
|
||||
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
|
||||
CAReady: true,
|
||||
CATrusted: true,
|
||||
}
|
||||
|
||||
StartProxy(ps, ch)
|
||||
|
||||
events := drainEvents(t, ch, 2, time.Second)
|
||||
if !hasType(events, model.EventTypeProxyStarting) {
|
||||
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
|
||||
}
|
||||
if hasType(events, model.EventTypeWarn) {
|
||||
t.Fatalf("unexpected %s event when CA is already trusted: %#v", model.EventTypeWarn, events)
|
||||
}
|
||||
if !hasType(events, model.EventTypeFatal) {
|
||||
t.Fatalf("missing %s event", model.EventTypeFatal)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartProxy_SwallowsErrServerClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = ln.Close() })
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
server := &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})}
|
||||
ps := &model.ProxyServer{Listener: &ln, Server: server}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
StartProxy(ps, ch)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
t.Fatalf("shutdown error = %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for StartProxy to return")
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 1, time.Second)
|
||||
if !hasType(events, model.EventTypeProxyStarting) {
|
||||
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
|
||||
}
|
||||
if hasType(events, model.EventTypeFatal) {
|
||||
t.Fatalf("unexpected %s event", model.EventTypeFatal)
|
||||
}
|
||||
}
|
||||
247
internal/app/session_test.go
Normal file
247
internal/app/session_test.go
Normal file
@ -0,0 +1,247 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
func TestStartSession(t *testing.T) {
|
||||
t.Run("happy path creates proxy and process", func(t *testing.T) {
|
||||
addr := freeTCPAddr(t)
|
||||
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
|
||||
if err != nil {
|
||||
t.Fatalf("StartSession() error = %v", err)
|
||||
}
|
||||
if s == nil {
|
||||
t.Fatal("StartSession() returned nil session")
|
||||
}
|
||||
if s.proxy == nil {
|
||||
t.Fatal("session.proxy is nil")
|
||||
}
|
||||
if s.proc == nil {
|
||||
t.Fatal("session.proc is nil")
|
||||
}
|
||||
|
||||
s.Stop()
|
||||
if s.proc != nil && s.proc.Done != nil {
|
||||
select {
|
||||
case <-s.proc.Done:
|
||||
case <-time.After(4 * time.Second):
|
||||
t.Fatal("timeout waiting for process stop")
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("error when proxy creation fails", func(t *testing.T) {
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
t.Setenv("HOME", "")
|
||||
|
||||
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "true"}}, "127.0.0.1:0")
|
||||
if err == nil {
|
||||
if s != nil {
|
||||
s.Stop()
|
||||
}
|
||||
t.Fatal("StartSession() error = nil, want non-nil")
|
||||
}
|
||||
if s != nil {
|
||||
t.Fatalf("session = %#v, want nil", s)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("destroys proxy when process startup fails", func(t *testing.T) {
|
||||
addr := freeTCPAddr(t)
|
||||
|
||||
s, err := StartSession(model.Command{Name: "definitely-not-a-real-command"}, addr)
|
||||
if err == nil {
|
||||
if s != nil {
|
||||
s.Stop()
|
||||
}
|
||||
t.Fatal("StartSession() error = nil, want non-nil")
|
||||
}
|
||||
|
||||
deadline := time.After(3 * time.Second)
|
||||
ticker := time.NewTicker(25 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
ln, listenErr := net.Listen("tcp", addr)
|
||||
if listenErr == nil {
|
||||
_ = ln.Close()
|
||||
break
|
||||
}
|
||||
select {
|
||||
case <-deadline:
|
||||
t.Fatalf("address %s did not become reusable, got err: %v", addr, listenErr)
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stop during restart stops new process and returns ErrSessionStopped", func(t *testing.T) {
|
||||
s := &Session{
|
||||
ch: make(chan model.Event, 64),
|
||||
cmd: model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}},
|
||||
addr: "127.0.0.1:0",
|
||||
proc: &model.Process{Done: make(chan struct{})},
|
||||
}
|
||||
close(s.proc.Done)
|
||||
|
||||
s.restartMu.Lock()
|
||||
errCh := make(chan error, 1)
|
||||
stopDone := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
errCh <- s.RestartProcess()
|
||||
}()
|
||||
go func() {
|
||||
s.Stop()
|
||||
close(stopDone)
|
||||
}()
|
||||
|
||||
s.restartMu.Unlock()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if !errors.Is(err, ErrSessionStopped) {
|
||||
t.Fatalf("error = %v, want %v", err, ErrSessionStopped)
|
||||
}
|
||||
case <-time.After(6 * time.Second):
|
||||
t.Fatal("timeout waiting for RestartProcess result")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stopDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for Stop completion")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSessionStop_Idempotent(t *testing.T) {
|
||||
addr := freeTCPAddr(t)
|
||||
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
|
||||
if err != nil {
|
||||
t.Fatalf("StartSession() error = %v", err)
|
||||
}
|
||||
|
||||
s.Stop()
|
||||
s.Stop()
|
||||
|
||||
if s.proc != nil && s.proc.Done != nil {
|
||||
select {
|
||||
case <-s.proc.Done:
|
||||
case <-time.After(4 * time.Second):
|
||||
t.Fatal("timeout waiting for process to stop")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRestartProcess(t *testing.T) {
|
||||
t.Run("nil session returns error", func(t *testing.T) {
|
||||
var s *Session
|
||||
err := s.RestartProcess()
|
||||
if err == nil {
|
||||
t.Fatal("RestartProcess() error = nil, want non-nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stopped session returns ErrSessionStopped", func(t *testing.T) {
|
||||
s := &Session{stopped: true}
|
||||
err := s.RestartProcess()
|
||||
if !errors.Is(err, ErrSessionStopped) {
|
||||
t.Fatalf("error = %v, want %v", err, ErrSessionStopped)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("concurrent restart returns ErrRestartInProgress", func(t *testing.T) {
|
||||
s := &Session{restarting: true}
|
||||
err := s.RestartProcess()
|
||||
if !errors.Is(err, ErrRestartInProgress) {
|
||||
t.Fatalf("error = %v, want %v", err, ErrRestartInProgress)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("timeout waiting for stop returns error", func(t *testing.T) {
|
||||
s := &Session{
|
||||
ch: make(chan model.Event, 8),
|
||||
cmd: model.Command{Name: "sh", Args: []string{"-c", "true"}},
|
||||
addr: "127.0.0.1:0",
|
||||
proc: &model.Process{Done: make(chan struct{})},
|
||||
}
|
||||
|
||||
err := s.RestartProcess()
|
||||
if err == nil {
|
||||
t.Fatal("RestartProcess() error = nil, want non-nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("successful restart swaps process", func(t *testing.T) {
|
||||
addr := freeTCPAddr(t)
|
||||
oldProc := &model.Process{Done: make(chan struct{})}
|
||||
close(oldProc.Done)
|
||||
|
||||
s := &Session{
|
||||
ch: make(chan model.Event, 32),
|
||||
cmd: model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}},
|
||||
addr: addr,
|
||||
proc: oldProc,
|
||||
}
|
||||
|
||||
err := s.RestartProcess()
|
||||
if err != nil {
|
||||
t.Fatalf("RestartProcess() error = %v", err)
|
||||
}
|
||||
if _, ok := waitForEventType(t, s.ch, model.EventTypeProcessRestarting, time.Second); !ok {
|
||||
t.Fatalf("did not receive %s event", model.EventTypeProcessRestarting)
|
||||
}
|
||||
if s.proc == nil {
|
||||
t.Fatal("session process is nil after restart")
|
||||
}
|
||||
if s.proc == oldProc {
|
||||
t.Fatal("session process pointer was not swapped")
|
||||
}
|
||||
|
||||
StopProcess(s.proc, s.ch, syscall.SIGTERM)
|
||||
select {
|
||||
case <-s.proc.Done:
|
||||
case <-time.After(4 * time.Second):
|
||||
t.Fatal("timeout waiting for restarted process to stop")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWaitForProcessStop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("nil process and nil done are true", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if !waitForProcessStop(nil, time.Millisecond) {
|
||||
t.Fatal("waitForProcessStop(nil) = false, want true")
|
||||
}
|
||||
if !waitForProcessStop(&model.Process{}, time.Millisecond) {
|
||||
t.Fatal("waitForProcessStop(process with nil done) = false, want true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("done channel closes before timeout", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
proc := &model.Process{Done: make(chan struct{})}
|
||||
close(proc.Done)
|
||||
|
||||
if !waitForProcessStop(proc, time.Second) {
|
||||
t.Fatal("waitForProcessStop() = false, want true")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("timeout returns false", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
proc := &model.Process{Done: make(chan struct{})}
|
||||
if waitForProcessStop(proc, 20*time.Millisecond) {
|
||||
t.Fatal("waitForProcessStop() = true, want false")
|
||||
}
|
||||
})
|
||||
}
|
||||
111
internal/app/test_helpers_test.go
Normal file
111
internal/app/test_helpers_test.go
Normal file
@ -0,0 +1,111 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event {
|
||||
t.Helper()
|
||||
|
||||
out := make([]model.Event, 0, n)
|
||||
deadline := time.After(timeout)
|
||||
for len(out) < n {
|
||||
select {
|
||||
case ev := <-ch:
|
||||
out = append(out, ev)
|
||||
case <-deadline:
|
||||
t.Fatalf("timeout waiting for %d events, got %d", n, len(out))
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func hasType(events []model.Event, typ model.EventType) bool {
|
||||
for _, ev := range events {
|
||||
if ev.Type == typ {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsBody(events []model.Event, part string) bool {
|
||||
for _, ev := range events {
|
||||
if strings.Contains(ev.Body, part) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func waitForEventType(t *testing.T, ch <-chan model.Event, typ model.EventType, timeout time.Duration) (model.Event, bool) {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.After(timeout)
|
||||
for {
|
||||
select {
|
||||
case ev := <-ch:
|
||||
if ev.Type == typ {
|
||||
return ev, true
|
||||
}
|
||||
case <-deadline:
|
||||
return model.Event{}, false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectUntilTypes(t *testing.T, ch <-chan model.Event, required []model.EventType, timeout time.Duration) []model.Event {
|
||||
t.Helper()
|
||||
|
||||
need := make(map[model.EventType]bool, len(required))
|
||||
for _, typ := range required {
|
||||
need[typ] = true
|
||||
}
|
||||
|
||||
events := make([]model.Event, 0, len(required)+8)
|
||||
deadline := time.After(timeout)
|
||||
for len(need) > 0 {
|
||||
select {
|
||||
case ev := <-ch:
|
||||
events = append(events, ev)
|
||||
delete(need, ev.Type)
|
||||
case <-deadline:
|
||||
t.Fatalf("timeout waiting for required events: remaining=%v, collected=%#v", need, events)
|
||||
}
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
func isBefore(events []model.Event, first, second model.EventType) bool {
|
||||
firstIdx := -1
|
||||
secondIdx := -1
|
||||
for i, ev := range events {
|
||||
if ev.Type == first && firstIdx == -1 {
|
||||
firstIdx = i
|
||||
}
|
||||
if ev.Type == second && secondIdx == -1 {
|
||||
secondIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
return firstIdx >= 0 && secondIdx >= 0 && firstIdx < secondIdx
|
||||
}
|
||||
|
||||
func freeTCPAddr(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen error = %v", err)
|
||||
}
|
||||
addr := ln.Addr().String()
|
||||
_ = ln.Close()
|
||||
return addr
|
||||
}
|
||||
@ -2,6 +2,7 @@ package cli
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
@ -16,6 +17,23 @@ import (
|
||||
// This should be configurable at some point, just in case they build on 8080
|
||||
const proxy_addr = "127.0.0.1:8080"
|
||||
|
||||
var fatalExit = log.Fatalln
|
||||
var stdoutWriter io.Writer = stdioRef{isErr: false}
|
||||
var stderrWriter io.Writer = stdioRef{isErr: true}
|
||||
var startSessionFn = app.StartSession
|
||||
var runTUIFn = tui.Run
|
||||
|
||||
type stdioRef struct {
|
||||
isErr bool
|
||||
}
|
||||
|
||||
func (w stdioRef) Write(p []byte) (int, error) {
|
||||
if w.isErr {
|
||||
return os.Stderr.Write(p)
|
||||
}
|
||||
return os.Stdout.Write(p)
|
||||
}
|
||||
|
||||
func Run(args []string) {
|
||||
if len(args) >= 2 && args[1] == "cert" {
|
||||
runCert()
|
||||
@ -28,9 +46,10 @@ func Run(args []string) {
|
||||
return
|
||||
}
|
||||
|
||||
session, err := app.StartSession(cmd, proxy_addr)
|
||||
session, err := startSessionFn(cmd, proxy_addr)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
fatalExit(err)
|
||||
return
|
||||
}
|
||||
defer session.Stop()
|
||||
|
||||
@ -38,8 +57,9 @@ func Run(args []string) {
|
||||
Restart: session.RestartProcess,
|
||||
}
|
||||
|
||||
if err := tui.Run(session.Events, controls); err != nil {
|
||||
log.Fatalln(err)
|
||||
if err := runTUIFn(session.Events, controls); err != nil {
|
||||
fatalExit(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@ -67,49 +87,50 @@ usage:
|
||||
tap run -- <command> [args...]
|
||||
`
|
||||
|
||||
fmt.Fprintln(os.Stderr, helpText)
|
||||
fmt.Fprintln(stderrWriter, helpText)
|
||||
}
|
||||
|
||||
func runCert() {
|
||||
ca, err := proxy.EnsureCertificateAuthority()
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
fatalExit(err)
|
||||
return
|
||||
}
|
||||
|
||||
certPath := ca.CertPath()
|
||||
quotedCertPath := strconv.Quote(certPath)
|
||||
fmt.Printf("Certificate path: %s\n", certPath)
|
||||
fmt.Fprintf(stdoutWriter, "Certificate path: %s\n", certPath)
|
||||
if ca.WasCreated() {
|
||||
fmt.Println("Created a new local HTTPS interception CA.")
|
||||
fmt.Fprintln(stdoutWriter, "Created a new local HTTPS interception CA.")
|
||||
} else {
|
||||
fmt.Println("Using existing local HTTPS interception CA.")
|
||||
fmt.Fprintln(stdoutWriter, "Using existing local HTTPS interception CA.")
|
||||
}
|
||||
|
||||
trusted, err := ca.IsTrustedBySystem()
|
||||
if err != nil {
|
||||
fmt.Printf("System trust check failed: %v\n", err)
|
||||
fmt.Fprintf(stdoutWriter, "System trust check failed: %v\n", err)
|
||||
} else if trusted {
|
||||
fmt.Println("System trust store: trusted")
|
||||
fmt.Fprintln(stdoutWriter, "System trust store: trusted")
|
||||
} else {
|
||||
fmt.Println("System trust store: not trusted")
|
||||
fmt.Fprintln(stdoutWriter, "System trust store: not trusted")
|
||||
}
|
||||
|
||||
if runtime.GOOS != "linux" {
|
||||
fmt.Println("Install this certificate into your OS or client trust store to inspect HTTPS traffic.")
|
||||
fmt.Fprintln(stdoutWriter, "Install this certificate into your OS or client trust store to inspect HTTPS traffic.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("Trust instructions (Linux):")
|
||||
fmt.Println("Debian/Ubuntu:")
|
||||
fmt.Printf(" sudo cp %s /usr/local/share/ca-certificates/termtap.crt\n", quotedCertPath)
|
||||
fmt.Println(" sudo update-ca-certificates")
|
||||
fmt.Println("Fedora/RHEL/CentOS:")
|
||||
fmt.Printf(" sudo cp %s /etc/pki/ca-trust/source/anchors/termtap.crt\n", quotedCertPath)
|
||||
fmt.Println(" sudo update-ca-trust")
|
||||
fmt.Println("Arch:")
|
||||
fmt.Printf(" sudo trust anchor %s\n", quotedCertPath)
|
||||
fmt.Println()
|
||||
fmt.Println("Quick curl test:")
|
||||
fmt.Printf(" curl --proxy http://%s --cacert %s https://example.com\n", proxy_addr, quotedCertPath)
|
||||
fmt.Fprintln(stdoutWriter)
|
||||
fmt.Fprintln(stdoutWriter, "Trust instructions (Linux):")
|
||||
fmt.Fprintln(stdoutWriter, "Debian/Ubuntu:")
|
||||
fmt.Fprintf(stdoutWriter, " sudo cp %s /usr/local/share/ca-certificates/termtap.crt\n", quotedCertPath)
|
||||
fmt.Fprintln(stdoutWriter, " sudo update-ca-certificates")
|
||||
fmt.Fprintln(stdoutWriter, "Fedora/RHEL/CentOS:")
|
||||
fmt.Fprintf(stdoutWriter, " sudo cp %s /etc/pki/ca-trust/source/anchors/termtap.crt\n", quotedCertPath)
|
||||
fmt.Fprintln(stdoutWriter, " sudo update-ca-trust")
|
||||
fmt.Fprintln(stdoutWriter, "Arch:")
|
||||
fmt.Fprintf(stdoutWriter, " sudo trust anchor %s\n", quotedCertPath)
|
||||
fmt.Fprintln(stdoutWriter)
|
||||
fmt.Fprintln(stdoutWriter, "Quick curl test:")
|
||||
fmt.Fprintf(stdoutWriter, " curl --proxy http://%s --cacert %s https://example.com\n", proxy_addr, quotedCertPath)
|
||||
}
|
||||
|
||||
356
internal/cli/run_test.go
Normal file
356
internal/cli/run_test.go
Normal file
@ -0,0 +1,356 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"termtap.dev/internal/app"
|
||||
"termtap.dev/internal/model"
|
||||
"termtap.dev/internal/tui"
|
||||
)
|
||||
|
||||
func TestParseCommand(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
ok bool
|
||||
nameWant string
|
||||
argsWant []string
|
||||
}{
|
||||
{name: "too few args", args: []string{"tap"}, ok: false},
|
||||
{name: "missing run token", args: []string{"tap", "oops", "--", "echo"}, ok: false},
|
||||
{name: "missing separator", args: []string{"tap", "run", "echo"}, ok: false},
|
||||
{name: "single command", args: []string{"tap", "run", "--", "echo"}, ok: true, nameWant: "echo", argsWant: []string{}},
|
||||
{name: "command with args", args: []string{"tap", "run", "--", "curl", "-s", "https://example.com"}, ok: true, nameWant: "curl", argsWant: []string{"-s", "https://example.com"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cmd, ok := parseCommand(tt.args)
|
||||
if ok != tt.ok {
|
||||
t.Fatalf("ok = %v, want %v", ok, tt.ok)
|
||||
}
|
||||
if !tt.ok {
|
||||
return
|
||||
}
|
||||
if cmd.Name != tt.nameWant {
|
||||
t.Fatalf("cmd.Name = %q, want %q", cmd.Name, tt.nameWant)
|
||||
}
|
||||
if strings.Join(cmd.Args, "|") != strings.Join(tt.argsWant, "|") {
|
||||
t.Fatalf("cmd.Args = %#v, want %#v", cmd.Args, tt.argsWant)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisplayHelpWritesToStderr(t *testing.T) {
|
||||
_, stderr := captureOutput(t, func() {
|
||||
displayHelp()
|
||||
})
|
||||
|
||||
if !strings.Contains(stderr, "tap cert") || !strings.Contains(stderr, "tap run --") {
|
||||
t.Fatalf("stderr missing usage text: %q", stderr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_InvalidCommandShowsHelp(t *testing.T) {
|
||||
_, stderr := captureOutput(t, func() {
|
||||
Run([]string{"tap", "wat"})
|
||||
})
|
||||
|
||||
if !strings.Contains(stderr, "usage:") {
|
||||
t.Fatalf("stderr missing usage output: %q", stderr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_RoutesCertCommand(t *testing.T) {
|
||||
configRoot := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
||||
|
||||
stdout, _ := captureOutput(t, func() {
|
||||
Run([]string{"tap", "cert"})
|
||||
})
|
||||
|
||||
if !strings.Contains(stdout, "Certificate path:") {
|
||||
t.Fatalf("stdout missing certificate path output: %q", stdout)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_StartSessionFailureCallsFatalExit(t *testing.T) {
|
||||
restore := installRunSeams(t)
|
||||
defer restore()
|
||||
|
||||
startSessionFn = func(model.Command, string) (*app.Session, error) {
|
||||
return nil, errors.New("boom")
|
||||
}
|
||||
called := installFatalSpy(t)
|
||||
|
||||
Run([]string{"tap", "run", "--", "definitely-not-a-real-command"})
|
||||
if !*called {
|
||||
t.Fatal("expected fatalExit to be called on StartSession failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_TUIFailureCallsFatalExit(t *testing.T) {
|
||||
restore := installRunSeams(t)
|
||||
defer restore()
|
||||
|
||||
startSessionFn = func(model.Command, string) (*app.Session, error) {
|
||||
return &app.Session{Events: make(chan model.Event)}, nil
|
||||
}
|
||||
runTUIFn = func(<-chan model.Event, tui.Controls) error {
|
||||
return errors.New("tui failed")
|
||||
}
|
||||
called := installFatalSpy(t)
|
||||
|
||||
Run([]string{"tap", "run", "--", "echo"})
|
||||
if !*called {
|
||||
t.Fatal("expected fatalExit to be called on tui failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRun_SuccessPathDoesNotCallFatal(t *testing.T) {
|
||||
restore := installRunSeams(t)
|
||||
defer restore()
|
||||
|
||||
startSessionFn = func(model.Command, string) (*app.Session, error) {
|
||||
return &app.Session{Events: make(chan model.Event)}, nil
|
||||
}
|
||||
runTUIFn = func(<-chan model.Event, tui.Controls) error {
|
||||
return nil
|
||||
}
|
||||
called := installFatalSpy(t)
|
||||
|
||||
Run([]string{"tap", "run", "--", "echo"})
|
||||
if *called {
|
||||
t.Fatal("fatalExit should not be called on success path")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCert_EnsureCAFailureCallsFatalExit(t *testing.T) {
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
t.Setenv("HOME", "")
|
||||
called := installFatalSpy(t)
|
||||
|
||||
runCert()
|
||||
if !*called {
|
||||
t.Fatal("expected fatalExit to be called when EnsureCertificateAuthority fails")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCertOutputContract(t *testing.T) {
|
||||
configRoot := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
||||
|
||||
stdout, _ := captureOutput(t, func() {
|
||||
runCert()
|
||||
})
|
||||
|
||||
if !strings.Contains(stdout, "Certificate path:") {
|
||||
t.Fatalf("stdout missing certificate path line: %q", stdout)
|
||||
}
|
||||
if !strings.Contains(stdout, "local HTTPS interception CA") {
|
||||
t.Fatalf("stdout missing CA create/existing line: %q", stdout)
|
||||
}
|
||||
if !strings.Contains(stdout, "System trust store:") && !strings.Contains(stdout, "System trust check failed:") {
|
||||
t.Fatalf("stdout missing trust check line: %q", stdout)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
if !strings.Contains(stdout, "Trust instructions (Linux):") {
|
||||
t.Fatalf("stdout missing linux trust instructions: %q", stdout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCert_CreatedThenExistingMessage(t *testing.T) {
|
||||
configRoot := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
||||
|
||||
firstOut, _ := captureOutput(t, func() {
|
||||
runCert()
|
||||
})
|
||||
if !strings.Contains(firstOut, "Created a new local HTTPS interception CA.") {
|
||||
t.Fatalf("first run should indicate created CA, got: %q", firstOut)
|
||||
}
|
||||
|
||||
secondOut, _ := captureOutput(t, func() {
|
||||
runCert()
|
||||
})
|
||||
if !strings.Contains(secondOut, "Using existing local HTTPS interception CA.") {
|
||||
t.Fatalf("second run should indicate existing CA, got: %q", secondOut)
|
||||
}
|
||||
}
|
||||
|
||||
func captureOutput(t *testing.T, fn func()) (stdout string, stderr string) {
|
||||
t.Helper()
|
||||
|
||||
origStdoutWriter := stdoutWriter
|
||||
origStderrWriter := stderrWriter
|
||||
t.Cleanup(func() {
|
||||
stdoutWriter = origStdoutWriter
|
||||
stderrWriter = origStderrWriter
|
||||
})
|
||||
|
||||
origStdout := os.Stdout
|
||||
origStderr := os.Stderr
|
||||
|
||||
outR, outW, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("stdout pipe error: %v", err)
|
||||
}
|
||||
errR, errW, err := os.Pipe()
|
||||
if err != nil {
|
||||
_ = outR.Close()
|
||||
_ = outW.Close()
|
||||
t.Fatalf("stderr pipe error: %v", err)
|
||||
}
|
||||
|
||||
os.Stdout = outW
|
||||
os.Stderr = errW
|
||||
stdoutWriter = outW
|
||||
stderrWriter = errW
|
||||
|
||||
outCh := make(chan string, 1)
|
||||
errCh := make(chan string, 1)
|
||||
go func() {
|
||||
var buf bytes.Buffer
|
||||
_, _ = io.Copy(&buf, outR)
|
||||
outCh <- buf.String()
|
||||
}()
|
||||
go func() {
|
||||
var buf bytes.Buffer
|
||||
_, _ = io.Copy(&buf, errR)
|
||||
errCh <- buf.String()
|
||||
}()
|
||||
|
||||
fn()
|
||||
|
||||
_ = outW.Close()
|
||||
_ = errW.Close()
|
||||
stdoutWriter = origStdoutWriter
|
||||
stderrWriter = origStderrWriter
|
||||
os.Stdout = origStdout
|
||||
os.Stderr = origStderr
|
||||
|
||||
select {
|
||||
case stdout = <-outCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for stdout capture")
|
||||
}
|
||||
select {
|
||||
case stderr = <-errCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for stderr capture")
|
||||
}
|
||||
_ = outR.Close()
|
||||
_ = errR.Close()
|
||||
|
||||
return stdout, stderr
|
||||
}
|
||||
|
||||
func TestDisplayHelp_UsesInjectedStderrWriter(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
orig := stderrWriter
|
||||
t.Cleanup(func() { stderrWriter = orig })
|
||||
stderrWriter = &buf
|
||||
|
||||
displayHelp()
|
||||
|
||||
if got := buf.String(); !strings.Contains(got, "usage:") {
|
||||
t.Fatalf("help output missing usage, got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCert_UsesInjectedStdoutWriter(t *testing.T) {
|
||||
configRoot := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
||||
|
||||
var buf bytes.Buffer
|
||||
orig := stdoutWriter
|
||||
t.Cleanup(func() { stdoutWriter = orig })
|
||||
stdoutWriter = &buf
|
||||
|
||||
runCert()
|
||||
|
||||
if got := buf.String(); !strings.Contains(got, "Certificate path:") {
|
||||
t.Fatalf("cert output missing path line, got: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func installRunSeams(t *testing.T) func() {
|
||||
t.Helper()
|
||||
|
||||
origStartSession := startSessionFn
|
||||
origRunTUI := runTUIFn
|
||||
return func() {
|
||||
startSessionFn = origStartSession
|
||||
runTUIFn = origRunTUI
|
||||
}
|
||||
}
|
||||
|
||||
func installFatalSpy(t *testing.T) *bool {
|
||||
t.Helper()
|
||||
|
||||
origFatal := fatalExit
|
||||
called := false
|
||||
fatalExit = func(v ...any) {
|
||||
called = true
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
fatalExit = origFatal
|
||||
})
|
||||
|
||||
return &called
|
||||
}
|
||||
|
||||
func TestStdioRefWrite(t *testing.T) {
|
||||
t.Run("writes to stdout", func(t *testing.T) {
|
||||
assertStdioRefWrite(t, false, "hello")
|
||||
})
|
||||
|
||||
t.Run("writes to stderr", func(t *testing.T) {
|
||||
assertStdioRefWrite(t, true, "boom")
|
||||
})
|
||||
}
|
||||
|
||||
func assertStdioRefWrite(t *testing.T, isErr bool, payload string) {
|
||||
t.Helper()
|
||||
|
||||
r, w, err := os.Pipe()
|
||||
if err != nil {
|
||||
t.Fatalf("pipe error: %v", err)
|
||||
}
|
||||
defer func() { _ = r.Close() }()
|
||||
|
||||
if isErr {
|
||||
orig := os.Stderr
|
||||
os.Stderr = w
|
||||
defer func() { os.Stderr = orig }()
|
||||
} else {
|
||||
orig := os.Stdout
|
||||
os.Stdout = w
|
||||
defer func() { os.Stdout = orig }()
|
||||
}
|
||||
|
||||
if _, err := (stdioRef{isErr: isErr}).Write([]byte(payload)); err != nil {
|
||||
_ = w.Close()
|
||||
t.Fatalf("stdioRef write error: %v", err)
|
||||
}
|
||||
_ = w.Close()
|
||||
|
||||
data, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll(pipe) error: %v", err)
|
||||
}
|
||||
if got := string(data); got != payload {
|
||||
t.Fatalf("pipe write = %q, want %q", got, payload)
|
||||
}
|
||||
}
|
||||
219
internal/process/runner_test.go
Normal file
219
internal/process/runner_test.go
Normal file
@ -0,0 +1,219 @@
|
||||
package process
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
func TestCommandString(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cmd model.Command
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty args",
|
||||
cmd: model.Command{Name: "go", Args: []string{}},
|
||||
want: "go ",
|
||||
},
|
||||
{
|
||||
name: "multiple args",
|
||||
cmd: model.Command{Name: "curl", Args: []string{"-s", "https://example.com"}},
|
||||
want: "curl -s https://example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := CommandString(tt.cmd); got != tt.want {
|
||||
t.Fatalf("CommandString() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInjectEnv(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cmd := exec.Command("sh", "-c", "true")
|
||||
injectEnv(cmd, "127.0.0.1:8080")
|
||||
|
||||
mustContain := []string{
|
||||
"HTTP_PROXY=http://127.0.0.1:8080",
|
||||
"http_proxy=http://127.0.0.1:8080",
|
||||
"HTTPS_PROXY=http://127.0.0.1:8080",
|
||||
"https_proxy=http://127.0.0.1:8080",
|
||||
"NO_PROXY=",
|
||||
"no_proxy=",
|
||||
}
|
||||
|
||||
for _, kv := range mustContain {
|
||||
if !containsEnvEntry(cmd.Env, kv) {
|
||||
t.Fatalf("injectEnv() missing env entry %q", kv)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadPipe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("stdout lines emit stdout events", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 4)
|
||||
input := strings.NewReader("line1\nline2\n")
|
||||
|
||||
readPipe(input, model.EventTypeProcessStdout, ch)
|
||||
|
||||
events := drainEvents(t, ch, 2, time.Second)
|
||||
if events[0].Type != model.EventTypeProcessStdout || events[0].Body != "line1" {
|
||||
t.Fatalf("event[0] = %#v, want stdout line1", events[0])
|
||||
}
|
||||
if events[1].Type != model.EventTypeProcessStdout || events[1].Body != "line2" {
|
||||
t.Fatalf("event[1] = %#v, want stdout line2", events[1])
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("stderr lines emit stderr events", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 4)
|
||||
input := strings.NewReader("err1\nerr2\n")
|
||||
|
||||
readPipe(input, model.EventTypeProcessStderr, ch)
|
||||
|
||||
events := drainEvents(t, ch, 2, time.Second)
|
||||
if events[0].Type != model.EventTypeProcessStderr || events[0].Body != "err1" {
|
||||
t.Fatalf("event[0] = %#v, want stderr err1", events[0])
|
||||
}
|
||||
if events[1].Type != model.EventTypeProcessStderr || events[1].Body != "err2" {
|
||||
t.Fatalf("event[1] = %#v, want stderr err2", events[1])
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdateStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("nil process is no-op", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
UpdateStatus(nil, true, make(chan model.Event, 1))
|
||||
})
|
||||
|
||||
t.Run("state unchanged is no-op", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 1)
|
||||
proc := &model.Process{Running: true}
|
||||
UpdateStatus(proc, true, ch)
|
||||
|
||||
select {
|
||||
case ev := <-ch:
|
||||
t.Fatalf("unexpected event for unchanged state: %#v", ev)
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("emits started and stopped events with pid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 2)
|
||||
proc := &model.Process{
|
||||
Exec: &exec.Cmd{Process: &os.Process{Pid: 4321}},
|
||||
}
|
||||
|
||||
UpdateStatus(proc, true, ch)
|
||||
events := drainEvents(t, ch, 1, time.Second)
|
||||
started := events[0]
|
||||
if started.Type != model.EventTypeProcessStarted {
|
||||
t.Fatalf("started type = %s, want %s", started.Type, model.EventTypeProcessStarted)
|
||||
}
|
||||
if started.PID != 4321 {
|
||||
t.Fatalf("started PID = %d, want %d", started.PID, 4321)
|
||||
}
|
||||
if !proc.Running {
|
||||
t.Fatal("proc.Running = false, want true")
|
||||
}
|
||||
|
||||
UpdateStatus(proc, false, ch)
|
||||
events = drainEvents(t, ch, 1, time.Second)
|
||||
stopped := events[0]
|
||||
if stopped.Type != model.EventTypeProcessExited {
|
||||
t.Fatalf("stopped type = %s, want %s", stopped.Type, model.EventTypeProcessExited)
|
||||
}
|
||||
if stopped.PID != 4321 {
|
||||
t.Fatalf("stopped PID = %d, want %d", stopped.PID, 4321)
|
||||
}
|
||||
if proc.Running {
|
||||
t.Fatal("proc.Running = true, want false")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewProcess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
cmd := model.Command{Name: "sh", Args: []string{"-c", "printf test"}}
|
||||
|
||||
proc := NewProcess(cmd, "127.0.0.1:8080", ch)
|
||||
if proc == nil {
|
||||
t.Fatal("NewProcess() returned nil")
|
||||
}
|
||||
if !reflect.DeepEqual(proc.Command, cmd) {
|
||||
t.Fatalf("process command = %#v, want %#v", proc.Command, cmd)
|
||||
}
|
||||
if proc.Exec == nil {
|
||||
t.Fatal("process Exec is nil")
|
||||
}
|
||||
if proc.Running {
|
||||
t.Fatal("new process should not be running")
|
||||
}
|
||||
if proc.Done == nil {
|
||||
t.Fatal("Done channel is nil")
|
||||
}
|
||||
|
||||
if got, want := proc.Exec.Args[0], "sh"; got != want {
|
||||
t.Fatalf("Exec.Args[0] = %q, want %q", got, want)
|
||||
}
|
||||
if !containsEnvEntry(proc.Exec.Env, "HTTP_PROXY=http://127.0.0.1:8080") {
|
||||
t.Fatal("process env missing injected HTTP_PROXY")
|
||||
}
|
||||
}
|
||||
|
||||
func containsEnvEntry(env []string, want string) bool {
|
||||
for _, entry := range env {
|
||||
if entry == want {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event {
|
||||
t.Helper()
|
||||
|
||||
events := make([]model.Event, 0, n)
|
||||
deadline := time.After(timeout)
|
||||
for len(events) < n {
|
||||
select {
|
||||
case ev := <-ch:
|
||||
events = append(events, ev)
|
||||
case <-deadline:
|
||||
t.Fatalf("timeout waiting for %d events; got %d", n, len(events))
|
||||
}
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
147
internal/process/signal_unix_test.go
Normal file
147
internal/process/signal_unix_test.go
Normal file
@ -0,0 +1,147 @@
|
||||
//go:build unix
|
||||
|
||||
package process
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type customSignal struct{}
|
||||
|
||||
func (customSignal) String() string { return "custom" }
|
||||
func (customSignal) Signal() {}
|
||||
|
||||
// NOTE: Run these tests with -race in CI for signal/process safety.
|
||||
|
||||
func TestConfigureProcessForSignals(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cmd := exec.Command("sh", "-c", "sleep 0.1")
|
||||
configureProcessForSignals(cmd)
|
||||
|
||||
if cmd.SysProcAttr == nil {
|
||||
t.Fatal("SysProcAttr is nil")
|
||||
}
|
||||
if !cmd.SysProcAttr.Setpgid {
|
||||
t.Fatal("Setpgid = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalProcess_NilSafe(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if err := SignalProcess(nil, syscall.SIGTERM); err != nil {
|
||||
t.Fatalf("SignalProcess(nil) error = %v, want nil", err)
|
||||
}
|
||||
|
||||
cmd := &exec.Cmd{}
|
||||
if err := SignalProcess(cmd, syscall.SIGTERM); err != nil {
|
||||
t.Fatalf("SignalProcess(cmd without process) error = %v, want nil", err)
|
||||
}
|
||||
|
||||
cmd.Process = &os.Process{Pid: 0}
|
||||
if err := SignalProcess(cmd, syscall.SIGTERM); err != nil {
|
||||
t.Fatalf("SignalProcess(pid<=0) error = %v, want nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalProcess_ESRCHIsTreatedAsSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cmd := &exec.Cmd{Process: &os.Process{Pid: 999999}}
|
||||
if err := SignalProcess(cmd, syscall.SIGTERM); err != nil {
|
||||
t.Fatalf("SignalProcess() error = %v, want nil when process group not found", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalProcess_UsesFallbackOnKillError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cmd := exec.Command("sh", "-c", "sleep 5")
|
||||
configureProcessForSignals(cmd)
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("start command error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = cmd.Process.Kill()
|
||||
_, _ = cmd.Process.Wait()
|
||||
})
|
||||
|
||||
// Invalid signal causes syscall.Kill to fail with EINVAL, then fallback to cmd.Process.Signal.
|
||||
err := SignalProcess(cmd, syscall.Signal(9999))
|
||||
if err == nil {
|
||||
t.Fatal("SignalProcess() error = nil, want non-nil for invalid signal")
|
||||
}
|
||||
if !(errors.Is(err, syscall.EINVAL) || errors.Is(err, os.ErrProcessDone)) {
|
||||
// OS/process timing can vary; ensure we at least failed predictably.
|
||||
t.Fatalf("SignalProcess() unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSignalProcess_NonSyscallSignalUsesProcessSignal(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cmd := exec.Command("sh", "-c", "sleep 1")
|
||||
configureProcessForSignals(cmd)
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("start command error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = cmd.Process.Kill()
|
||||
_, _ = cmd.Process.Wait()
|
||||
})
|
||||
|
||||
err := SignalProcess(cmd, customSignal{})
|
||||
if err == nil {
|
||||
t.Fatal("SignalProcess(custom signal) error = nil, want non-nil")
|
||||
}
|
||||
if errors.Is(err, os.ErrProcessDone) {
|
||||
return
|
||||
}
|
||||
if msg := err.Error(); msg == "" {
|
||||
t.Fatalf("unexpected empty error for custom signal: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessAlive(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if ProcessAlive(nil) {
|
||||
t.Fatal("ProcessAlive(nil) = true, want false")
|
||||
}
|
||||
if ProcessAlive(&exec.Cmd{}) {
|
||||
t.Fatal("ProcessAlive(cmd without process) = true, want false")
|
||||
}
|
||||
|
||||
cmd := exec.Command("sh", "-c", "sleep 0.2")
|
||||
configureProcessForSignals(cmd)
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("start command error = %v", err)
|
||||
}
|
||||
|
||||
if !ProcessAlive(cmd) {
|
||||
_ = cmd.Process.Kill()
|
||||
_, _ = cmd.Process.Wait()
|
||||
t.Fatal("ProcessAlive(running) = false, want true")
|
||||
}
|
||||
|
||||
_ = cmd.Process.Kill()
|
||||
_, _ = cmd.Process.Wait()
|
||||
|
||||
deadline := time.After(time.Second)
|
||||
for {
|
||||
if !ProcessAlive(cmd) {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-deadline:
|
||||
t.Fatal("ProcessAlive(exited) stayed true")
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
242
internal/proxy/certs_error_test.go
Normal file
242
internal/proxy/certs_error_test.go
Normal file
@ -0,0 +1,242 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/pem"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadOrCreateCertificateAuthority_RecreatesWhenKeyMissing(t *testing.T) {
|
||||
configRoot := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
||||
|
||||
baseDir := filepath.Join(configRoot, caDirName)
|
||||
if err := os.MkdirAll(baseDir, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", err)
|
||||
}
|
||||
|
||||
certPath := filepath.Join(baseDir, caCertName)
|
||||
if err := os.WriteFile(certPath, []byte("stale-cert"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile(cert) error = %v", err)
|
||||
}
|
||||
|
||||
ca, err := loadOrCreateCertificateAuthority()
|
||||
if err != nil {
|
||||
t.Fatalf("loadOrCreateCertificateAuthority() error = %v", err)
|
||||
}
|
||||
if !ca.WasCreated() {
|
||||
t.Fatal("WasCreated = false, want true when key is missing")
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(baseDir, caKeyName)); err != nil {
|
||||
t.Fatalf("expected key file to be created, stat error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadOrCreateCertificateAuthority_LoadErrorOnCorruptFiles(t *testing.T) {
|
||||
configRoot := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
||||
|
||||
baseDir := filepath.Join(configRoot, caDirName)
|
||||
if err := os.MkdirAll(baseDir, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll() error = %v", err)
|
||||
}
|
||||
|
||||
certPath := filepath.Join(baseDir, caCertName)
|
||||
keyPath := filepath.Join(baseDir, caKeyName)
|
||||
if err := os.WriteFile(certPath, []byte("not-a-pem"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile(cert) error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(keyPath, []byte("not-a-pem"), 0o600); err != nil {
|
||||
t.Fatalf("WriteFile(key) error = %v", err)
|
||||
}
|
||||
|
||||
_, err := loadOrCreateCertificateAuthority()
|
||||
if err == nil {
|
||||
t.Fatal("loadOrCreateCertificateAuthority() error = nil, want non-nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateAuthorityLoad_ErrorPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
certBytes []byte
|
||||
keyBytes []byte
|
||||
wantPart string
|
||||
}{
|
||||
{
|
||||
name: "invalid cert pem",
|
||||
certBytes: []byte("bad-cert"),
|
||||
keyBytes: []byte("bad-key"),
|
||||
wantPart: "decode ca cert pem",
|
||||
},
|
||||
{
|
||||
name: "parse cert fails",
|
||||
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("bogus")}),
|
||||
keyBytes: []byte("bad-key"),
|
||||
wantPart: "parse ca cert",
|
||||
},
|
||||
{
|
||||
name: "invalid key pem",
|
||||
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}),
|
||||
keyBytes: []byte("bad-key"),
|
||||
wantPart: "decode ca key pem",
|
||||
},
|
||||
{
|
||||
name: "parse key fails",
|
||||
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}),
|
||||
keyBytes: pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: []byte("bogus")}),
|
||||
wantPart: "parse ca key",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
ca := &CertificateAuthority{
|
||||
certPath: filepath.Join(dir, caCertName),
|
||||
keyPath: filepath.Join(dir, caKeyName),
|
||||
leafCert: make(map[string]*tls.Certificate),
|
||||
}
|
||||
|
||||
if err := os.WriteFile(ca.certPath, tt.certBytes, 0o600); err != nil {
|
||||
t.Fatalf("write cert file error = %v", err)
|
||||
}
|
||||
if err := os.WriteFile(ca.keyPath, tt.keyBytes, 0o600); err != nil {
|
||||
t.Fatalf("write key file error = %v", err)
|
||||
}
|
||||
|
||||
err := ca.load()
|
||||
if err == nil {
|
||||
t.Fatal("load() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantPart) {
|
||||
t.Fatalf("load() error = %q, want contains %q", err.Error(), tt.wantPart)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateAuthorityCreate_ErrorWhenWritePathInvalid(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ca := &CertificateAuthority{
|
||||
certPath: filepath.Join("/nope", "missing", "ca-cert.pem"),
|
||||
keyPath: filepath.Join("/nope", "missing", "ca-key.pem"),
|
||||
leafCert: make(map[string]*tls.Certificate),
|
||||
}
|
||||
|
||||
err := ca.create()
|
||||
if err == nil {
|
||||
t.Fatal("create() error = nil, want non-nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateAuthorityCreate_WriteErrorPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("write ca cert wraps error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
badCertPath := filepath.Join(dir, "cert-as-dir")
|
||||
if err := os.MkdirAll(badCertPath, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll(cert dir) error = %v", err)
|
||||
}
|
||||
|
||||
ca := &CertificateAuthority{
|
||||
certPath: badCertPath,
|
||||
keyPath: filepath.Join(dir, "ca-key.pem"),
|
||||
leafCert: make(map[string]*tls.Certificate),
|
||||
}
|
||||
|
||||
err := ca.create()
|
||||
if err == nil {
|
||||
t.Fatal("create() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "write ca cert") {
|
||||
t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca cert")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("write ca key wraps error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
badKeyPath := filepath.Join(dir, "key-as-dir")
|
||||
if err := os.MkdirAll(badKeyPath, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll(key dir) error = %v", err)
|
||||
}
|
||||
|
||||
ca := &CertificateAuthority{
|
||||
certPath: filepath.Join(dir, "ca-cert.pem"),
|
||||
keyPath: badKeyPath,
|
||||
leafCert: make(map[string]*tls.Certificate),
|
||||
}
|
||||
|
||||
err := ca.create()
|
||||
if err == nil {
|
||||
t.Fatal("create() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "write ca key") {
|
||||
t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca key")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCertificateAuthorityLoad_ReadErrorPaths(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("read cert failure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
ca := &CertificateAuthority{
|
||||
certPath: filepath.Join(dir, "missing-cert.pem"),
|
||||
keyPath: filepath.Join(dir, "missing-key.pem"),
|
||||
leafCert: make(map[string]*tls.Certificate),
|
||||
}
|
||||
|
||||
err := ca.load()
|
||||
if err == nil {
|
||||
t.Fatal("load() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "read ca cert") {
|
||||
t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca cert")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("read key failure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
goodCA := newTestCA(t)
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: goodCA.cert.Raw})
|
||||
|
||||
certPath := filepath.Join(dir, "cert.pem")
|
||||
if err := os.WriteFile(certPath, certPEM, 0o600); err != nil {
|
||||
t.Fatalf("WriteFile(cert) error = %v", err)
|
||||
}
|
||||
|
||||
ca := &CertificateAuthority{
|
||||
certPath: certPath,
|
||||
keyPath: filepath.Join(dir, "missing-key.pem"),
|
||||
leafCert: make(map[string]*tls.Certificate),
|
||||
}
|
||||
|
||||
err := ca.load()
|
||||
if err == nil {
|
||||
t.Fatal("load() error = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "read ca key") {
|
||||
t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca key")
|
||||
}
|
||||
})
|
||||
}
|
||||
45
internal/proxy/certs_lifecycle_test.go
Normal file
45
internal/proxy/certs_lifecycle_test.go
Normal file
@ -0,0 +1,45 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadOrCreateCertificateAuthority_CreateThenLoad(t *testing.T) {
|
||||
configRoot := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
||||
|
||||
ca1, err := loadOrCreateCertificateAuthority()
|
||||
if err != nil {
|
||||
t.Fatalf("first loadOrCreateCertificateAuthority() error = %v", err)
|
||||
}
|
||||
if ca1 == nil {
|
||||
t.Fatal("first CA is nil")
|
||||
}
|
||||
if !ca1.WasCreated() {
|
||||
t.Fatal("first CA should report WasCreated=true")
|
||||
}
|
||||
if ca1.CertPath() == "" {
|
||||
t.Fatal("first CA CertPath is empty")
|
||||
}
|
||||
if _, err := os.Stat(ca1.CertPath()); err != nil {
|
||||
t.Fatalf("first CA cert file missing: %v", err)
|
||||
}
|
||||
|
||||
ca2, err := loadOrCreateCertificateAuthority()
|
||||
if err != nil {
|
||||
t.Fatalf("second loadOrCreateCertificateAuthority() error = %v", err)
|
||||
}
|
||||
if ca2 == nil {
|
||||
t.Fatal("second CA is nil")
|
||||
}
|
||||
if ca2.WasCreated() {
|
||||
t.Fatal("second CA should report WasCreated=false (loaded existing)")
|
||||
}
|
||||
if ca2.CertPath() != ca1.CertPath() {
|
||||
t.Fatalf("cert path mismatch: first=%q second=%q", ca1.CertPath(), ca2.CertPath())
|
||||
}
|
||||
if ca2.cert == nil || ca2.key == nil {
|
||||
t.Fatal("loaded CA should include cert and key")
|
||||
}
|
||||
}
|
||||
266
internal/proxy/certs_test.go
Normal file
266
internal/proxy/certs_test.go
Normal file
@ -0,0 +1,266 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"math/big"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeCertHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "host and port", in: "example.com:443", want: "example.com"},
|
||||
{name: "plain host", in: "example.com", want: "example.com"},
|
||||
{name: "whitespace trims", in: " example.com:8443 ", want: "example.com"},
|
||||
{name: "empty", in: " ", want: ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := normalizeCertHost(tt.in); got != tt.want {
|
||||
t.Fatalf("normalizeCertHost(%q) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRandSerialNumber(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serial, err := randSerialNumber()
|
||||
if err != nil {
|
||||
t.Fatalf("randSerialNumber() error = %v", err)
|
||||
}
|
||||
if serial == nil {
|
||||
t.Fatal("serial is nil")
|
||||
}
|
||||
if serial.Sign() < 0 {
|
||||
t.Fatalf("serial must be non-negative, got %v", serial)
|
||||
}
|
||||
|
||||
limit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
if serial.Cmp(limit) >= 0 {
|
||||
t.Fatalf("serial must be < 2^128, got %v", serial)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteFileAtomically(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "cert.pem")
|
||||
|
||||
if err := writeFileAtomically(path, []byte("first"), 0o600); err != nil {
|
||||
t.Fatalf("first writeFileAtomically() error = %v", err)
|
||||
}
|
||||
if err := writeFileAtomically(path, []byte("second"), 0o600); err != nil {
|
||||
t.Fatalf("second writeFileAtomically() error = %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile() error = %v", err)
|
||||
}
|
||||
if got, want := string(data), "second"; got != want {
|
||||
t.Fatalf("file contents = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat() error = %v", err)
|
||||
}
|
||||
if got := info.Mode().Perm(); got != 0o600 {
|
||||
t.Fatalf("file permissions = %#o, want %#o", got, 0o600)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateAuthority_Basics(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var nilCA *CertificateAuthority
|
||||
if got := nilCA.CertPath(); got != "" {
|
||||
t.Fatalf("nil CertPath() = %q, want empty", got)
|
||||
}
|
||||
if got := nilCA.WasCreated(); got {
|
||||
t.Fatalf("nil WasCreated() = %v, want false", got)
|
||||
}
|
||||
|
||||
ca := newTestCA(t)
|
||||
ca.certPath = "/tmp/test-ca.pem"
|
||||
ca.wasCreated = true
|
||||
if got, want := ca.CertPath(), "/tmp/test-ca.pem"; got != want {
|
||||
t.Fatalf("CertPath() = %q, want %q", got, want)
|
||||
}
|
||||
if !ca.WasCreated() {
|
||||
t.Fatal("WasCreated() = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCertificateForHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ca := newTestCA(t)
|
||||
|
||||
t.Run("empty host returns error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cert, err := ca.CertificateForHost(" ")
|
||||
if err == nil {
|
||||
t.Fatal("CertificateForHost() error = nil, want non-nil")
|
||||
}
|
||||
if cert != nil {
|
||||
t.Fatalf("cert = %#v, want nil", cert)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("cache hit returns same pointer", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c1, err := ca.CertificateForHost("example.com:443")
|
||||
if err != nil {
|
||||
t.Fatalf("first CertificateForHost() error = %v", err)
|
||||
}
|
||||
c2, err := ca.CertificateForHost("example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("second CertificateForHost() error = %v", err)
|
||||
}
|
||||
|
||||
if c1 != c2 {
|
||||
t.Fatal("expected same certificate pointer from cache")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ip and dns SAN selection", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipCert, err := ca.CertificateForHost("127.0.0.1:443")
|
||||
if err != nil {
|
||||
t.Fatalf("ip CertificateForHost() error = %v", err)
|
||||
}
|
||||
if ipCert.Leaf == nil {
|
||||
t.Fatal("ip cert leaf is nil")
|
||||
}
|
||||
if len(ipCert.Leaf.IPAddresses) == 0 {
|
||||
t.Fatal("ip cert should contain IP SAN")
|
||||
}
|
||||
if len(ipCert.Leaf.DNSNames) != 0 {
|
||||
t.Fatalf("ip cert DNSNames = %v, want empty", ipCert.Leaf.DNSNames)
|
||||
}
|
||||
|
||||
dnsCert, err := ca.CertificateForHost("service.local")
|
||||
if err != nil {
|
||||
t.Fatalf("dns CertificateForHost() error = %v", err)
|
||||
}
|
||||
if dnsCert.Leaf == nil {
|
||||
t.Fatal("dns cert leaf is nil")
|
||||
}
|
||||
if len(dnsCert.Leaf.DNSNames) == 0 {
|
||||
t.Fatal("dns cert should contain DNS SAN")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("evicts oldest entry over maxLeafCerts", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ca2 := newTestCA(t)
|
||||
for i := 0; i < maxLeafCerts+1; i++ {
|
||||
host := filepath.Base(filepath.Join("h", big.NewInt(int64(i)).String()+".example"))
|
||||
if _, err := ca2.CertificateForHost(host); err != nil {
|
||||
t.Fatalf("CertificateForHost(%q) error = %v", host, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ca2.leafOrder) != maxLeafCerts {
|
||||
t.Fatalf("leafOrder len = %d, want %d", len(ca2.leafOrder), maxLeafCerts)
|
||||
}
|
||||
if _, ok := ca2.leafCert["0.example"]; ok {
|
||||
t.Fatal("expected oldest cert to be evicted")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsTrustedBySystem(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var nilCA *CertificateAuthority
|
||||
_, err := nilCA.IsTrustedBySystem()
|
||||
if err == nil {
|
||||
t.Fatal("nil IsTrustedBySystem() error = nil, want non-nil")
|
||||
}
|
||||
|
||||
ca := &CertificateAuthority{}
|
||||
_, err = ca.IsTrustedBySystem()
|
||||
if err == nil {
|
||||
t.Fatal("missing-cert IsTrustedBySystem() error = nil, want non-nil")
|
||||
}
|
||||
|
||||
t.Run("untrusted generated CA returns false without error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ca := newTestCA(t)
|
||||
|
||||
trusted, err := ca.IsTrustedBySystem()
|
||||
if err != nil {
|
||||
t.Fatalf("IsTrustedBySystem() error = %v, want nil for unknown authority", err)
|
||||
}
|
||||
if trusted {
|
||||
t.Fatal("trusted = true, want false for generated test CA")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestEnsureCertificateAuthority(t *testing.T) {
|
||||
configRoot := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
||||
|
||||
ca, err := EnsureCertificateAuthority()
|
||||
if err != nil {
|
||||
t.Fatalf("EnsureCertificateAuthority() error = %v", err)
|
||||
}
|
||||
if ca == nil {
|
||||
t.Fatal("EnsureCertificateAuthority() returned nil CA")
|
||||
}
|
||||
if ca.CertPath() == "" {
|
||||
t.Fatal("EnsureCertificateAuthority() returned empty cert path")
|
||||
}
|
||||
if _, statErr := os.Stat(ca.CertPath()); statErr != nil {
|
||||
t.Fatalf("expected cert on disk, stat error = %v", statErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteFileAtomically_ErrorPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := writeFileAtomically(filepath.Join("/nope", "bad", "path.pem"), []byte("x"), 0o600)
|
||||
if err == nil {
|
||||
t.Fatal("writeFileAtomically() error = nil, want non-nil")
|
||||
}
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return
|
||||
}
|
||||
// Accept platform-dependent fs errors as long as function fails.
|
||||
}
|
||||
|
||||
func TestWriteFileAtomically_RenameErrorWhenTargetIsDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dir := t.TempDir()
|
||||
targetDir := filepath.Join(dir, "target-as-dir")
|
||||
if err := os.MkdirAll(targetDir, 0o700); err != nil {
|
||||
t.Fatalf("MkdirAll(targetDir) error = %v", err)
|
||||
}
|
||||
|
||||
err := writeFileAtomically(targetDir, []byte("x"), 0o600)
|
||||
if err == nil {
|
||||
t.Fatal("writeFileAtomically() error = nil, want non-nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add deterministic tests for loadOrCreateCertificateAuthority trust-store interactions.
|
||||
29
internal/proxy/certs_trust_test.go
Normal file
29
internal/proxy/certs_trust_test.go
Normal file
@ -0,0 +1,29 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsTrustedBySystem_TrustedViaCertEnv(t *testing.T) {
|
||||
ca := newTestCA(t)
|
||||
|
||||
rootFile := filepath.Join(t.TempDir(), "roots.pem")
|
||||
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.cert.Raw})
|
||||
if err := os.WriteFile(rootFile, pemBytes, 0o600); err != nil {
|
||||
t.Fatalf("WriteFile(root cert) error = %v", err)
|
||||
}
|
||||
|
||||
t.Setenv("SSL_CERT_FILE", rootFile)
|
||||
t.Setenv("SSL_CERT_DIR", t.TempDir())
|
||||
|
||||
trusted, err := ca.IsTrustedBySystem()
|
||||
if err != nil {
|
||||
t.Fatalf("IsTrustedBySystem() error = %v", err)
|
||||
}
|
||||
if !trusted {
|
||||
t.Fatal("trusted = false, want true when CA is in configured cert file")
|
||||
}
|
||||
}
|
||||
330
internal/proxy/handlers_test.go
Normal file
330
internal/proxy/handlers_test.go
Normal file
@ -0,0 +1,330 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
type failingResponseWriter struct {
|
||||
header http.Header
|
||||
code int
|
||||
}
|
||||
|
||||
func (w *failingResponseWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *failingResponseWriter) WriteHeader(statusCode int) {
|
||||
w.code = statusCode
|
||||
}
|
||||
|
||||
func (w *failingResponseWriter) Write(_ []byte) (int, error) {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
type hijackFailWriter struct {
|
||||
header http.Header
|
||||
code int
|
||||
}
|
||||
|
||||
func (w *hijackFailWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *hijackFailWriter) WriteHeader(statusCode int) {
|
||||
w.code = statusCode
|
||||
}
|
||||
|
||||
func (w *hijackFailWriter) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *hijackFailWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return nil, nil, fmt.Errorf("hijack failed")
|
||||
}
|
||||
|
||||
type dummyConn struct{}
|
||||
|
||||
func (dummyConn) Read(_ []byte) (int, error) { return 0, io.EOF }
|
||||
func (dummyConn) Write(p []byte) (int, error) { return len(p), nil }
|
||||
func (dummyConn) Close() error { return nil }
|
||||
func (dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
||||
func (dummyConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
||||
func (dummyConn) SetDeadline(_ time.Time) error { return nil }
|
||||
func (dummyConn) SetReadDeadline(_ time.Time) error { return nil }
|
||||
func (dummyConn) SetWriteDeadline(_ time.Time) error { return nil }
|
||||
|
||||
type writeConnectFailWriter struct {
|
||||
header http.Header
|
||||
code int
|
||||
}
|
||||
|
||||
func (w *writeConnectFailWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *writeConnectFailWriter) WriteHeader(statusCode int) {
|
||||
w.code = statusCode
|
||||
}
|
||||
|
||||
func (w *writeConnectFailWriter) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *writeConnectFailWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
rw := bufio.NewReadWriter(
|
||||
bufio.NewReader(strings.NewReader("")),
|
||||
bufio.NewWriter(errWriter{}),
|
||||
)
|
||||
return dummyConn{}, rw, nil
|
||||
}
|
||||
|
||||
func TestProxyHandler_NonConnectRejectsNonAbsoluteURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
||||
h := proxyHandler(ch, nil, ps)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/not-proxy-form", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 1, time.Second)
|
||||
if !hasEventType(events, model.EventTypeWarn) {
|
||||
t.Fatalf("expected %s event, got %#v", model.EventTypeWarn, events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHandler_NonConnectSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Upstream", "yes")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
_, _ = w.Write([]byte("pong"))
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
||||
h := proxyHandler(ch, nil, ps)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, upstream.URL+"/ping", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusAccepted {
|
||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusAccepted)
|
||||
}
|
||||
if got, want := w.Body.String(), "pong"; got != want {
|
||||
t.Fatalf("body = %q, want %q", got, want)
|
||||
}
|
||||
if got := w.Header().Get("X-Upstream"); got != "yes" {
|
||||
t.Fatalf("X-Upstream header = %q, want yes", got)
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 2, time.Second)
|
||||
if !hasEventType(events, model.EventTypeRequestStarted) {
|
||||
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
|
||||
}
|
||||
if !hasEventType(events, model.EventTypeRequestFinished) {
|
||||
t.Fatalf("expected %s event", model.EventTypeRequestFinished)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHandler_NonConnectUpstreamError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
||||
h := proxyHandler(ch, nil, ps)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1:1/fail", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusBadGateway {
|
||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 2, time.Second)
|
||||
if !hasEventType(events, model.EventTypeRequestStarted) {
|
||||
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
|
||||
}
|
||||
if !hasEventType(events, model.EventTypeRequestFailed) {
|
||||
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHandler_NonConnectWriteFailureEmitsFailedEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.Copy(w, bytes.NewBufferString("response-body"))
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
||||
h := proxyHandler(ch, nil, ps)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, upstream.URL+"/copy-fail", nil)
|
||||
w := &failingResponseWriter{}
|
||||
|
||||
h.ServeHTTP(w, req)
|
||||
|
||||
events := drainEvents(t, ch, 2, time.Second)
|
||||
if !hasEventType(events, model.EventTypeRequestStarted) {
|
||||
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
|
||||
}
|
||||
if !hasEventType(events, model.EventTypeRequestFailed) {
|
||||
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleConnect_EarlyFailures(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transport := &mockTransport{fn: func(req *http.Request) (*http.Response, error) {
|
||||
return nil, context.Canceled
|
||||
}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
writer http.ResponseWriter
|
||||
req *http.Request
|
||||
ca *CertificateAuthority
|
||||
wantCode int
|
||||
wantBody string
|
||||
wantStat int
|
||||
}{
|
||||
{
|
||||
name: "nil CA returns 502 and failed event",
|
||||
writer: httptest.NewRecorder(),
|
||||
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
|
||||
ca: nil,
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantBody: "certificate authority is unavailable",
|
||||
wantStat: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "nil CA with host without port still fails predictably",
|
||||
writer: httptest.NewRecorder(),
|
||||
req: &http.Request{Method: http.MethodConnect, Host: "example.com"},
|
||||
ca: nil,
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantBody: "certificate authority is unavailable",
|
||||
wantStat: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "cert mint failure returns 502 and failed event",
|
||||
writer: httptest.NewRecorder(),
|
||||
req: &http.Request{Method: http.MethodConnect, Host: ""},
|
||||
ca: &CertificateAuthority{},
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantBody: "failed to mint interception certificate",
|
||||
wantStat: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "non-hijacker writer returns 500 and failed event",
|
||||
writer: httptest.NewRecorder(),
|
||||
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
|
||||
ca: newTestCA(t),
|
||||
wantCode: http.StatusInternalServerError,
|
||||
wantBody: "hijack is unavailable",
|
||||
wantStat: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "hijack failure returns 500 and failed event",
|
||||
writer: &hijackFailWriter{},
|
||||
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
|
||||
ca: newTestCA(t),
|
||||
wantCode: http.StatusInternalServerError,
|
||||
wantBody: "CONNECT hijack failed",
|
||||
wantStat: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "write connect established failure emits failed event",
|
||||
writer: &writeConnectFailWriter{},
|
||||
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
|
||||
ca: newTestCA(t),
|
||||
wantCode: 0,
|
||||
wantBody: "CONNECT setup failed",
|
||||
wantStat: http.StatusBadGateway,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
||||
|
||||
handleConnect(tt.writer, tt.req, ch, transport, tt.ca, ps)
|
||||
|
||||
events := drainEvents(t, ch, 2, time.Second)
|
||||
if events[0].Type != model.EventTypeRequestStarted {
|
||||
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
|
||||
}
|
||||
if events[1].Type != model.EventTypeRequestFailed {
|
||||
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
|
||||
}
|
||||
if events[1].Request.Method != http.MethodConnect {
|
||||
t.Fatalf("failed event request method = %q, want CONNECT", events[1].Request.Method)
|
||||
}
|
||||
if events[1].Request.Status != tt.wantStat {
|
||||
t.Fatalf("failed event request status = %d, want %d", events[1].Request.Status, tt.wantStat)
|
||||
}
|
||||
if !events[1].Request.Failed || events[1].Request.Pending {
|
||||
t.Fatalf("failed event request flags = pending:%v failed:%v, want pending:false failed:true", events[1].Request.Pending, events[1].Request.Failed)
|
||||
}
|
||||
if !strings.Contains(events[1].Body, tt.wantBody) {
|
||||
t.Fatalf("failed event body %q does not contain %q", events[1].Body, tt.wantBody)
|
||||
}
|
||||
|
||||
if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok {
|
||||
if recorder.Code != tt.wantCode {
|
||||
t.Fatalf("status = %d, want %d", recorder.Code, tt.wantCode)
|
||||
}
|
||||
}
|
||||
if w, ok := tt.writer.(*hijackFailWriter); ok {
|
||||
if w.code != tt.wantCode {
|
||||
t.Fatalf("status = %d, want %d", w.code, tt.wantCode)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add full TLS MITM loop integration test.
|
||||
157
internal/proxy/headers_test.go
Normal file
157
internal/proxy/headers_test.go
Normal file
@ -0,0 +1,157 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStripHopByHopHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in http.Header
|
||||
want http.Header
|
||||
}{
|
||||
{
|
||||
name: "removes static hop-by-hop headers",
|
||||
in: http.Header{
|
||||
"Connection": {"keep-alive"},
|
||||
"Keep-Alive": {"timeout=5"},
|
||||
"Proxy-Connection": {"keep-alive"},
|
||||
"Transfer-Encoding": {"chunked"},
|
||||
"Upgrade": {"websocket"},
|
||||
"X-Custom": {"ok"},
|
||||
},
|
||||
want: http.Header{
|
||||
"X-Custom": {"ok"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "removes headers listed in connection with spaces and commas",
|
||||
in: http.Header{
|
||||
"Connection": {" keep-alive, X-Foo ,X-Bar"},
|
||||
"Keep-Alive": {"timeout=5"},
|
||||
"X-Foo": {"foo"},
|
||||
"X-Bar": {"bar"},
|
||||
"Content-Type": {"application/json"},
|
||||
"Content-Length": {"12"},
|
||||
},
|
||||
want: http.Header{
|
||||
"Content-Type": {"application/json"},
|
||||
"Content-Length": {"12"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "keeps unrelated headers when no connection header",
|
||||
in: http.Header{
|
||||
"Accept": {"*/*"},
|
||||
"Authorization": {"Bearer abc"},
|
||||
},
|
||||
want: http.Header{
|
||||
"Accept": {"*/*"},
|
||||
"Authorization": {"Bearer abc"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
stripHopByHopHeaders(tt.in)
|
||||
|
||||
if !reflect.DeepEqual(tt.in, tt.want) {
|
||||
t.Fatalf("stripHopByHopHeaders() = %#v, want %#v", tt.in, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripHopByHopHeaders_NilHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
stripHopByHopHeaders(nil)
|
||||
}
|
||||
|
||||
func TestRedactHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input http.Header
|
||||
want http.Header
|
||||
}{
|
||||
{
|
||||
name: "redacts sensitive canonical headers",
|
||||
input: http.Header{
|
||||
"Authorization": {"Bearer token123"},
|
||||
"Cookie": {"session=abc"},
|
||||
"Proxy-Authorization": {"Basic abc"},
|
||||
"Set-Cookie": {"a=1"},
|
||||
"X-Api-Key": {"secret"},
|
||||
},
|
||||
want: http.Header{
|
||||
"Authorization": {"[REDACTED]"},
|
||||
"Cookie": {"[REDACTED]"},
|
||||
"Proxy-Authorization": {"[REDACTED]"},
|
||||
"Set-Cookie": {"[REDACTED]"},
|
||||
"X-Api-Key": {"[REDACTED]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "leaves non-sensitive headers untouched",
|
||||
input: http.Header{
|
||||
"Content-Type": {"application/json"},
|
||||
"X-Trace-ID": {"trace-1"},
|
||||
},
|
||||
want: http.Header{
|
||||
"Content-Type": {"application/json"},
|
||||
"X-Trace-ID": {"trace-1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := redactHeaders(tt.input)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Fatalf("redactHeaders() = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
|
||||
got.Set("Content-Type", "modified")
|
||||
if reflect.DeepEqual(tt.input, got) {
|
||||
t.Fatal("redactHeaders() appears to mutate input or return aliased map")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
src := http.Header{
|
||||
"X-Multi": {"a", "b", "c"},
|
||||
"Content-Type": {"application/json"},
|
||||
}
|
||||
dest := http.Header{
|
||||
"Existing": {"keep"},
|
||||
}
|
||||
|
||||
copyHeaders(src, dest)
|
||||
|
||||
want := http.Header{
|
||||
"Existing": {"keep"},
|
||||
"X-Multi": {"a", "b", "c"},
|
||||
"Content-Type": {"application/json"},
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(dest, want) {
|
||||
t.Fatalf("copyHeaders() dest = %#v, want %#v", dest, want)
|
||||
}
|
||||
}
|
||||
129
internal/proxy/integration_http_test.go
Normal file
129
internal/proxy/integration_http_test.go
Normal file
@ -0,0 +1,129 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
func TestHTTPProxyE2E_WithClientProxyConfig(t *testing.T) {
|
||||
configRoot := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("pong"))
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
ch := make(chan model.Event, 64)
|
||||
ps, err := NewProxyServer("127.0.0.1:0", ch)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxyServer() error = %v", err)
|
||||
}
|
||||
|
||||
serveDone := make(chan error, 1)
|
||||
go func() {
|
||||
serveDone <- ps.Server.Serve(*ps.Listener)
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
Destroy(ps, ch)
|
||||
select {
|
||||
case <-serveDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for proxy server to stop")
|
||||
}
|
||||
})
|
||||
|
||||
proxyURL, err := url.Parse(ps.Url)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse(proxy) error = %v", err)
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)},
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
resp, err := client.Get(upstream.URL + "/ping")
|
||||
if err != nil {
|
||||
t.Fatalf("client.Get() error = %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll(response) error = %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
if got, want := string(body), "pong"; got != want {
|
||||
t.Fatalf("body = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 2, 2*time.Second)
|
||||
if !hasEventType(events, model.EventTypeRequestStarted) {
|
||||
t.Fatalf("missing %s event", model.EventTypeRequestStarted)
|
||||
}
|
||||
if !hasEventType(events, model.EventTypeRequestFinished) {
|
||||
t.Fatalf("missing %s event", model.EventTypeRequestFinished)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPProxyE2E_UpstreamFailureEmitsRequestFailed(t *testing.T) {
|
||||
configRoot := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
||||
|
||||
ch := make(chan model.Event, 64)
|
||||
ps, err := NewProxyServer("127.0.0.1:0", ch)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxyServer() error = %v", err)
|
||||
}
|
||||
|
||||
serveDone := make(chan error, 1)
|
||||
go func() {
|
||||
serveDone <- ps.Server.Serve(*ps.Listener)
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
Destroy(ps, ch)
|
||||
select {
|
||||
case <-serveDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for proxy server to stop")
|
||||
}
|
||||
})
|
||||
|
||||
proxyURL, err := url.Parse(ps.Url)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse(proxy) error = %v", err)
|
||||
}
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)},
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
resp, reqErr := client.Get("http://127.0.0.1:1/unreachable")
|
||||
if reqErr != nil {
|
||||
t.Fatalf("client.Get() error = %v; proxy should reply with mapped HTTP status", reqErr)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadGateway {
|
||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 2, 3*time.Second)
|
||||
if !hasEventType(events, model.EventTypeRequestStarted) {
|
||||
t.Fatalf("missing %s event", model.EventTypeRequestStarted)
|
||||
}
|
||||
if !hasEventType(events, model.EventTypeRequestFailed) {
|
||||
t.Fatalf("missing %s event", model.EventTypeRequestFailed)
|
||||
}
|
||||
}
|
||||
301
internal/proxy/integration_https_mitm_test.go
Normal file
301
internal/proxy/integration_https_mitm_test.go
Normal file
@ -0,0 +1,301 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
type hijackPipeWriter struct {
|
||||
conn net.Conn
|
||||
|
||||
header http.Header
|
||||
code int
|
||||
}
|
||||
|
||||
func (w *hijackPipeWriter) Header() http.Header {
|
||||
if w.header == nil {
|
||||
w.header = make(http.Header)
|
||||
}
|
||||
return w.header
|
||||
}
|
||||
|
||||
func (w *hijackPipeWriter) WriteHeader(statusCode int) {
|
||||
w.code = statusCode
|
||||
}
|
||||
|
||||
func (w *hijackPipeWriter) Write(p []byte) (int, error) {
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *hijackPipeWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return w.conn, nil, nil
|
||||
}
|
||||
|
||||
func startMITMConnect(t *testing.T, transport http.RoundTripper) (net.Conn, *tls.Conn, chan model.Event, chan struct{}) {
|
||||
t.Helper()
|
||||
|
||||
ca := newTestCA(t)
|
||||
clientConn, serverConn := net.Pipe()
|
||||
|
||||
writer := &hijackPipeWriter{conn: serverConn}
|
||||
req, err := http.NewRequest(http.MethodConnect, "http://example.com:443", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest(CONNECT) error = %v", err)
|
||||
}
|
||||
req.Host = "example.com:443"
|
||||
|
||||
ch := make(chan model.Event, 16)
|
||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
||||
|
||||
handleDone := make(chan struct{})
|
||||
go func() {
|
||||
handleConnect(writer, req, ch, transport, ca, ps)
|
||||
close(handleDone)
|
||||
}()
|
||||
|
||||
reader := bufio.NewReader(clientConn)
|
||||
connectResp, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
|
||||
if err != nil {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
t.Fatalf("ReadResponse(CONNECT established) error = %v", err)
|
||||
}
|
||||
if connectResp.StatusCode != http.StatusOK {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
t.Fatalf("CONNECT established status = %d, want %d", connectResp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
pool := x509.NewCertPool()
|
||||
pool.AddCert(ca.cert)
|
||||
tlsClient := tls.Client(clientConn, &tls.Config{
|
||||
ServerName: "example.com",
|
||||
RootCAs: pool,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
|
||||
if err := tlsClient.Handshake(); err != nil {
|
||||
_ = tlsClient.Close()
|
||||
_ = serverConn.Close()
|
||||
t.Fatalf("tls handshake error = %v", err)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
_ = tlsClient.Close()
|
||||
_ = serverConn.Close()
|
||||
})
|
||||
|
||||
return clientConn, tlsClient, ch, handleDone
|
||||
}
|
||||
|
||||
func TestHTTPSE2E_MITMHandleConnectFlow(t *testing.T) {
|
||||
transport := &mockTransport{fn: func(r *http.Request) (*http.Response, error) {
|
||||
if r.URL.Scheme != "https" {
|
||||
t.Fatalf("inner request scheme = %q, want https", r.URL.Scheme)
|
||||
}
|
||||
if r.URL.Host == "" {
|
||||
t.Fatal("inner request host should be populated")
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusCreated,
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: http.Header{"Content-Type": {"text/plain"}},
|
||||
Body: io.NopCloser(strings.NewReader("mitm-ok")),
|
||||
}, nil
|
||||
}}
|
||||
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
|
||||
|
||||
// 3) Send decrypted inner request over tunnel and read MITM response
|
||||
innerReq, err := http.NewRequest(http.MethodGet, "https://example.com/inside", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest(inner) error = %v", err)
|
||||
}
|
||||
innerReq.Host = "example.com"
|
||||
innerReq.Close = true
|
||||
if err := innerReq.Write(tlsClient); err != nil {
|
||||
t.Fatalf("innerReq.Write() error = %v", err)
|
||||
}
|
||||
|
||||
innerResp, err := http.ReadResponse(bufio.NewReader(tlsClient), innerReq)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadResponse(inner) error = %v", err)
|
||||
}
|
||||
defer innerResp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(innerResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll(inner body) error = %v", err)
|
||||
}
|
||||
if innerResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("inner status = %d, want %d", innerResp.StatusCode, http.StatusCreated)
|
||||
}
|
||||
if got, want := string(body), "mitm-ok"; got != want {
|
||||
t.Fatalf("inner body = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
_ = tlsClient.Close()
|
||||
_ = clientConn.Close()
|
||||
|
||||
select {
|
||||
case <-handleDone:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for handleConnect to return")
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 4, 2*time.Second)
|
||||
if events[0].Type != model.EventTypeRequestStarted {
|
||||
t.Fatalf("event[0] = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
|
||||
}
|
||||
if events[1].Type != model.EventTypeRequestStarted {
|
||||
t.Fatalf("event[1] = %s, want %s", events[1].Type, model.EventTypeRequestStarted)
|
||||
}
|
||||
if events[2].Type != model.EventTypeRequestFinished {
|
||||
t.Fatalf("event[2] = %s, want %s", events[2].Type, model.EventTypeRequestFinished)
|
||||
}
|
||||
if events[3].Type != model.EventTypeRequestFinished {
|
||||
t.Fatalf("event[3] = %s, want %s", events[3].Type, model.EventTypeRequestFinished)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPSE2E_MITMUpstreamErrorReturnsHTTPErrorInsideTunnel(t *testing.T) {
|
||||
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("upstream exploded")
|
||||
}}
|
||||
|
||||
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
|
||||
|
||||
innerReq, err := http.NewRequest(http.MethodGet, "https://example.com/fail", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest(inner) error = %v", err)
|
||||
}
|
||||
innerReq.Host = "example.com"
|
||||
if err := innerReq.Write(tlsClient); err != nil {
|
||||
t.Fatalf("innerReq.Write() error = %v", err)
|
||||
}
|
||||
|
||||
innerResp, err := http.ReadResponse(bufio.NewReader(tlsClient), innerReq)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadResponse(inner) error = %v", err)
|
||||
}
|
||||
defer innerResp.Body.Close()
|
||||
|
||||
if innerResp.StatusCode != http.StatusBadGateway {
|
||||
t.Fatalf("inner status = %d, want %d", innerResp.StatusCode, http.StatusBadGateway)
|
||||
}
|
||||
|
||||
_ = tlsClient.Close()
|
||||
_ = clientConn.Close()
|
||||
|
||||
select {
|
||||
case <-handleDone:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for handleConnect to return")
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 4, 2*time.Second)
|
||||
if events[0].Type != model.EventTypeRequestStarted ||
|
||||
events[1].Type != model.EventTypeRequestStarted ||
|
||||
events[2].Type != model.EventTypeRequestFailed ||
|
||||
events[3].Type != model.EventTypeRequestFailed {
|
||||
t.Fatalf("unexpected event sequence: %#v", events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPSE2E_MITMHandshakeFailureEmitsConnectFailed(t *testing.T) {
|
||||
ca := newTestCA(t)
|
||||
|
||||
clientConn, serverConn := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
_ = serverConn.Close()
|
||||
})
|
||||
|
||||
writer := &hijackPipeWriter{conn: serverConn}
|
||||
req, err := http.NewRequest(http.MethodConnect, "http://example.com:443", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest(CONNECT) error = %v", err)
|
||||
}
|
||||
req.Host = "example.com:443"
|
||||
|
||||
ch := make(chan model.Event, 16)
|
||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
||||
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("should not be called")
|
||||
}}
|
||||
|
||||
handleDone := make(chan struct{})
|
||||
go func() {
|
||||
handleConnect(writer, req, ch, transport, ca, ps)
|
||||
close(handleDone)
|
||||
}()
|
||||
|
||||
reader := bufio.NewReader(clientConn)
|
||||
connectResp, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
|
||||
if err != nil {
|
||||
t.Fatalf("ReadResponse(CONNECT established) error = %v", err)
|
||||
}
|
||||
if connectResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("CONNECT established status = %d, want %d", connectResp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
// Write plaintext (not TLS handshake) to force handshake failure.
|
||||
if _, err := clientConn.Write([]byte("not tls handshake")); err != nil {
|
||||
t.Fatalf("client write error = %v", err)
|
||||
}
|
||||
_ = clientConn.Close()
|
||||
|
||||
select {
|
||||
case <-handleDone:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for handleConnect to return on handshake failure")
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 2, 2*time.Second)
|
||||
if events[0].Type != model.EventTypeRequestStarted || events[1].Type != model.EventTypeRequestFailed {
|
||||
t.Fatalf("unexpected event sequence: %#v", events)
|
||||
}
|
||||
if !strings.Contains(events[1].Body, "TLS handshake with client failed") {
|
||||
t.Fatalf("failed event body = %q, want handshake failure details", events[1].Body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPSE2E_MITMDecryptedReadFailureEmitsConnectFailed(t *testing.T) {
|
||||
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
|
||||
return nil, errors.New("should not be called")
|
||||
}}
|
||||
|
||||
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
|
||||
|
||||
// Send malformed decrypted payload so ReadRequest fails (non-EOF branch).
|
||||
if _, err := tlsClient.Write([]byte("BAD REQUEST\r\n\r\n")); err != nil {
|
||||
t.Fatalf("tlsClient.Write() error = %v", err)
|
||||
}
|
||||
_ = tlsClient.Close()
|
||||
_ = clientConn.Close()
|
||||
|
||||
select {
|
||||
case <-handleDone:
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("timeout waiting for handleConnect to return on decrypted read failure")
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 2, 2*time.Second)
|
||||
if events[0].Type != model.EventTypeRequestStarted || events[1].Type != model.EventTypeRequestFailed {
|
||||
t.Fatalf("unexpected event sequence: %#v", events)
|
||||
}
|
||||
if !strings.Contains(events[1].Body, "failed to read decrypted HTTPS request") {
|
||||
t.Fatalf("failed event body = %q, want read failure details", events[1].Body)
|
||||
}
|
||||
}
|
||||
142
internal/proxy/preview_test.go
Normal file
142
internal/proxy/preview_test.go
Normal file
@ -0,0 +1,142 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewBodyPreview(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
contentType string
|
||||
wantEnabled bool
|
||||
}{
|
||||
{name: "text content enabled", contentType: "text/plain", wantEnabled: true},
|
||||
{name: "json content enabled", contentType: "application/json", wantEnabled: true},
|
||||
{name: "binary content disabled", contentType: "application/octet-stream", wantEnabled: false},
|
||||
{name: "empty content type disabled", contentType: "", wantEnabled: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
p := newBodyPreview(tt.contentType)
|
||||
if p.enabled != tt.wantEnabled {
|
||||
t.Fatalf("newBodyPreview(%q).enabled = %v, want %v", tt.contentType, p.enabled, tt.wantEnabled)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBodyPreviewWriteAndPreview(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("nil receiver is safe", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var p *bodyPreview
|
||||
p.Write([]byte("abc"))
|
||||
})
|
||||
|
||||
t.Run("disabled preview ignores data", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := &bodyPreview{enabled: false}
|
||||
p.Write([]byte("abc"))
|
||||
if got := string(p.Preview()); got != "" {
|
||||
t.Fatalf("Preview() = %q, want empty", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty write does nothing", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := &bodyPreview{enabled: true}
|
||||
p.Write(nil)
|
||||
if got := string(p.Preview()); got != "" {
|
||||
t.Fatalf("Preview() = %q, want empty", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("escapes newlines", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := &bodyPreview{enabled: true}
|
||||
p.Write([]byte("a\nb"))
|
||||
if got, want := string(p.Preview()), `a\nb`; got != want {
|
||||
t.Fatalf("Preview() = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("truncates at max preview bytes and appends ellipsis", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := &bodyPreview{enabled: true}
|
||||
p.Write([]byte(strings.Repeat("a", maxPreviewBytes)))
|
||||
p.Write([]byte("b"))
|
||||
|
||||
got := string(p.Preview())
|
||||
if !strings.HasSuffix(got, "...") {
|
||||
t.Fatalf("Preview() must end with ellipsis when truncated: %q", got[len(got)-10:])
|
||||
}
|
||||
if len(got) != maxPreviewBytes+3 {
|
||||
t.Fatalf("len(Preview()) = %d, want %d", len(got), maxPreviewBytes+3)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWrapBufferedConn(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, server := net.Pipe()
|
||||
t.Cleanup(func() {
|
||||
_ = client.Close()
|
||||
_ = server.Close()
|
||||
})
|
||||
|
||||
t.Run("returns original conn when readWriter nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := wrapBufferedConn(client, nil)
|
||||
if got != client {
|
||||
t.Fatal("wrapBufferedConn should return original conn when readWriter is nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("read uses buffered readWriter", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("xyz")), bufio.NewWriter(io.Discard))
|
||||
got := wrapBufferedConn(client, rw)
|
||||
|
||||
buf := make([]byte, 3)
|
||||
n, err := got.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("Read() error = %v", err)
|
||||
}
|
||||
if n != 3 || string(buf) != "xyz" {
|
||||
t.Fatalf("Read() = (%d, %q), want (3, %q)", n, string(buf), "xyz")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPreviewReadCloserRead(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
preview := newBodyPreview("text/plain")
|
||||
rc := &previewReadCloser{
|
||||
ReadCloser: io.NopCloser(strings.NewReader("hello\nworld")),
|
||||
preview: preview,
|
||||
}
|
||||
|
||||
data, err := io.ReadAll(rc)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll() error = %v", err)
|
||||
}
|
||||
if got, want := string(data), "hello\nworld"; got != want {
|
||||
t.Fatalf("read content = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := string(preview.Preview()), `hello\nworld`; got != want {
|
||||
t.Fatalf("preview = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
275
internal/proxy/requests_test.go
Normal file
275
internal/proxy/requests_test.go
Normal file
@ -0,0 +1,275 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
type mockTransport struct {
|
||||
fn func(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return m.fn(req)
|
||||
}
|
||||
|
||||
func TestRoundTripCapturedRequest_Success(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
|
||||
reqURL, err := url.Parse("http://example.com/path?q=1")
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error = %v", err)
|
||||
}
|
||||
req := &http.Request{
|
||||
Method: http.MethodPost,
|
||||
URL: reqURL,
|
||||
Host: "example.com",
|
||||
Header: http.Header{
|
||||
"Connection": {"X-Hop"},
|
||||
"X-Hop": {"drop"},
|
||||
"Authorization": {"Bearer token"},
|
||||
"Content-Type": {"text/plain"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader("req\nbody")),
|
||||
}
|
||||
|
||||
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
|
||||
if got := outReq.Header.Get("Connection"); got != "" {
|
||||
t.Fatalf("Connection header should be stripped, got %q", got)
|
||||
}
|
||||
if got := outReq.Header.Get("X-Hop"); got != "" {
|
||||
t.Fatalf("header listed in Connection should be stripped, got %q", got)
|
||||
}
|
||||
|
||||
data, readErr := io.ReadAll(outReq.Body)
|
||||
if readErr != nil {
|
||||
t.Fatalf("ReadAll(outReq.Body) error = %v", readErr)
|
||||
}
|
||||
if got, want := string(data), "req\nbody"; got != want {
|
||||
t.Fatalf("request body = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusCreated,
|
||||
Header: http.Header{
|
||||
"Set-Cookie": {"session=top-secret"},
|
||||
"Connection": {"close"},
|
||||
"Content-Type": {"text/plain"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader("resp\nbody")),
|
||||
}, nil
|
||||
}}
|
||||
|
||||
resp, captured, responsePreview, err := roundTripCapturedRequest(req, transport, ch, "", false)
|
||||
if err != nil {
|
||||
t.Fatalf("roundTripCapturedRequest() error = %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("roundTripCapturedRequest() returned nil response")
|
||||
}
|
||||
if responsePreview == nil {
|
||||
t.Fatal("roundTripCapturedRequest() returned nil response preview")
|
||||
}
|
||||
|
||||
if _, readErr := io.ReadAll(resp.Body); readErr != nil {
|
||||
t.Fatalf("ReadAll(resp.Body) error = %v", readErr)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if got, want := string(captured.RequestData), `req\nbody`; got != want {
|
||||
t.Fatalf("captured.RequestData = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := string(responsePreview.Preview()), `resp\nbody`; got != want {
|
||||
t.Fatalf("responsePreview = %q, want %q", got, want)
|
||||
}
|
||||
if got := captured.RequestHeaders.Get("Authorization"); got != "[REDACTED]" {
|
||||
t.Fatalf("Authorization header = %q, want [REDACTED]", got)
|
||||
}
|
||||
if got := captured.ResponseHeaders.Get("Set-Cookie"); got != "[REDACTED]" {
|
||||
t.Fatalf("Set-Cookie header = %q, want [REDACTED]", got)
|
||||
}
|
||||
if got := captured.ResponseHeaders.Get("Connection"); got != "" {
|
||||
t.Fatalf("Connection should be stripped from response headers, got %q", got)
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 1, time.Second)
|
||||
if events[0].Type != model.EventTypeRequestStarted {
|
||||
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundTripCapturedRequest_ErrorPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
|
||||
reqURL, err := url.Parse("http://example.com/fail")
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error = %v", err)
|
||||
}
|
||||
req := &http.Request{
|
||||
Method: http.MethodPost,
|
||||
URL: reqURL,
|
||||
Host: "example.com",
|
||||
Header: http.Header{"Content-Type": {"text/plain"}},
|
||||
Body: io.NopCloser(strings.NewReader("boom\nbody")),
|
||||
}
|
||||
|
||||
wantErr := errors.New("upstream failed")
|
||||
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
|
||||
_, _ = io.ReadAll(outReq.Body)
|
||||
return nil, wantErr
|
||||
}}
|
||||
|
||||
resp, captured, responsePreview, gotErr := roundTripCapturedRequest(req, transport, ch, "", false)
|
||||
if !errors.Is(gotErr, wantErr) {
|
||||
t.Fatalf("error = %v, want %v", gotErr, wantErr)
|
||||
}
|
||||
if resp != nil {
|
||||
t.Fatalf("response = %#v, want nil", resp)
|
||||
}
|
||||
if responsePreview != nil {
|
||||
t.Fatalf("responsePreview = %#v, want nil", responsePreview)
|
||||
}
|
||||
if got, want := string(captured.RequestData), `boom\nbody`; got != want {
|
||||
t.Fatalf("captured.RequestData = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
events := drainEvents(t, ch, 1, time.Second)
|
||||
if events[0].Type != model.EventTypeRequestStarted {
|
||||
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundTripCapturedRequest_InterceptedTLSDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
|
||||
u, err := url.Parse("/secure?p=1")
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error = %v", err)
|
||||
}
|
||||
req := &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: u,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
const defaultHost = "api.example.com:443"
|
||||
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
|
||||
if got := outReq.URL.Scheme; got != "https" {
|
||||
t.Fatalf("URL.Scheme = %q, want https", got)
|
||||
}
|
||||
if got := outReq.URL.Host; got != defaultHost {
|
||||
t.Fatalf("URL.Host = %q, want %q", got, defaultHost)
|
||||
}
|
||||
if got := outReq.Host; got != defaultHost {
|
||||
t.Fatalf("Host = %q, want %q", got, defaultHost)
|
||||
}
|
||||
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusNoContent,
|
||||
Header: http.Header{"Content-Type": {"text/plain"}},
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
}, nil
|
||||
}}
|
||||
|
||||
_, _, _, gotErr := roundTripCapturedRequest(req, transport, ch, defaultHost, true)
|
||||
if gotErr != nil {
|
||||
t.Fatalf("roundTripCapturedRequest() error = %v", gotErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConnectRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Now().Add(-time.Second)
|
||||
req := &http.Request{Method: http.MethodConnect, Host: "example.com:443"}
|
||||
|
||||
got := newConnectRequest(req, now)
|
||||
|
||||
if got.Method != http.MethodConnect {
|
||||
t.Fatalf("Method = %q, want CONNECT", got.Method)
|
||||
}
|
||||
if got.Host != req.Host || got.URL != req.Host || got.RawURL != req.Host {
|
||||
t.Fatalf("connect request host/url/raw mismatch: %#v", got)
|
||||
}
|
||||
if !got.Pending || got.Failed {
|
||||
t.Fatalf("Pending/Failed = (%v,%v), want (true,false)", got.Pending, got.Failed)
|
||||
}
|
||||
if got.Status != -1 {
|
||||
t.Fatalf("Status = %d, want -1", got.Status)
|
||||
}
|
||||
if got.StartTime != now {
|
||||
t.Fatalf("StartTime = %v, want %v", got.StartTime, now)
|
||||
}
|
||||
if got.ID == uuid.Nil {
|
||||
t.Fatal("ID must be non-zero UUID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartFinishFailRequestEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 4)
|
||||
req := model.Request{
|
||||
ID: uuid.New(),
|
||||
Method: http.MethodGet,
|
||||
RawURL: "http://example.com/a",
|
||||
StartTime: time.Now().Add(-3 * time.Millisecond),
|
||||
Pending: true,
|
||||
}
|
||||
|
||||
startRequest(ch, req)
|
||||
events := drainEvents(t, ch, 1, time.Second)
|
||||
startEv := events[0]
|
||||
if startEv.Type != model.EventTypeRequestStarted {
|
||||
t.Fatalf("start event type = %s, want %s", startEv.Type, model.EventTypeRequestStarted)
|
||||
}
|
||||
if startEv.Request.Pending != true {
|
||||
t.Fatalf("start request pending = %v, want true", startEv.Request.Pending)
|
||||
}
|
||||
|
||||
finishRequest(ch, req, http.StatusOK)
|
||||
events = drainEvents(t, ch, 1, time.Second)
|
||||
finishEv := events[0]
|
||||
if finishEv.Type != model.EventTypeRequestFinished {
|
||||
t.Fatalf("finish event type = %s, want %s", finishEv.Type, model.EventTypeRequestFinished)
|
||||
}
|
||||
if finishEv.Request.Pending {
|
||||
t.Fatal("finished request should not be pending")
|
||||
}
|
||||
if finishEv.Request.Failed {
|
||||
t.Fatal("finished request should not be failed")
|
||||
}
|
||||
if finishEv.Request.Status != http.StatusOK {
|
||||
t.Fatalf("finished status = %d, want %d", finishEv.Request.Status, http.StatusOK)
|
||||
}
|
||||
|
||||
failRequest(ch, req, http.StatusBadGateway, "upstream error")
|
||||
events = drainEvents(t, ch, 1, time.Second)
|
||||
failEv := events[0]
|
||||
if failEv.Type != model.EventTypeRequestFailed {
|
||||
t.Fatalf("fail event type = %s, want %s", failEv.Type, model.EventTypeRequestFailed)
|
||||
}
|
||||
if failEv.Request.Pending {
|
||||
t.Fatal("failed request should not be pending")
|
||||
}
|
||||
if !failEv.Request.Failed {
|
||||
t.Fatal("failed request should be marked failed")
|
||||
}
|
||||
if failEv.Request.Status != http.StatusBadGateway {
|
||||
t.Fatalf("failed status = %d, want %d", failEv.Request.Status, http.StatusBadGateway)
|
||||
}
|
||||
}
|
||||
138
internal/proxy/secure_utils_test.go
Normal file
138
internal/proxy/secure_utils_test.go
Normal file
@ -0,0 +1,138 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type captureConn struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (c *captureConn) Read(_ []byte) (int, error) { return 0, io.EOF }
|
||||
func (c *captureConn) Close() error { return nil }
|
||||
func (c *captureConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
||||
func (c *captureConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
||||
func (c *captureConn) SetDeadline(_ time.Time) error { return nil }
|
||||
func (c *captureConn) SetReadDeadline(_ time.Time) error { return nil }
|
||||
func (c *captureConn) SetWriteDeadline(_ time.Time) error { return nil }
|
||||
|
||||
type failWriteConn struct{}
|
||||
|
||||
func (failWriteConn) Read(_ []byte) (int, error) { return 0, io.EOF }
|
||||
func (failWriteConn) Write(_ []byte) (int, error) { return 0, io.ErrClosedPipe }
|
||||
func (failWriteConn) Close() error { return nil }
|
||||
func (failWriteConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
||||
func (failWriteConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
||||
func (failWriteConn) SetDeadline(_ time.Time) error { return nil }
|
||||
func (failWriteConn) SetReadDeadline(_ time.Time) error { return nil }
|
||||
func (failWriteConn) SetWriteDeadline(_ time.Time) error { return nil }
|
||||
|
||||
type trackingBody struct {
|
||||
data *bytes.Reader
|
||||
readN int
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (b *trackingBody) Read(p []byte) (int, error) {
|
||||
n, err := b.data.Read(p)
|
||||
b.readN += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (b *trackingBody) Close() error {
|
||||
b.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestWriteConnectEstablished(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("writes directly to raw conn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := &captureConn{}
|
||||
if err := writeConnectEstablished(conn, nil); err != nil {
|
||||
t.Fatalf("writeConnectEstablished() error = %v", err)
|
||||
}
|
||||
|
||||
if got, want := conn.String(), "HTTP/1.1 200 Connection Established\r\n\r\n"; got != want {
|
||||
t.Fatalf("raw write = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("writes and flushes with buffered readWriter", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := &captureConn{}
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bufio.NewWriter(conn))
|
||||
if err := writeConnectEstablished(conn, rw); err != nil {
|
||||
t.Fatalf("writeConnectEstablished() error = %v", err)
|
||||
}
|
||||
|
||||
if got, want := conn.String(), "HTTP/1.1 200 Connection Established\r\n\r\n"; got != want {
|
||||
t.Fatalf("buffered write = %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns flush error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bufio.NewWriter(errWriter{}))
|
||||
err := writeConnectEstablished(&captureConn{}, rw)
|
||||
if err == nil {
|
||||
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns buffered write error when writer already failed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
bw := bufio.NewWriter(errWriter{})
|
||||
_ = bw.Flush() // set sticky error to force WriteString error path
|
||||
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bw)
|
||||
|
||||
err := writeConnectEstablished(&captureConn{}, rw)
|
||||
if err == nil {
|
||||
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns raw conn write error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
err := writeConnectEstablished(failWriteConn{}, nil)
|
||||
if err == nil {
|
||||
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDiscardAndCloseBody(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("nil body is safe", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
discardAndCloseBody(nil)
|
||||
})
|
||||
|
||||
t.Run("closes body and discards at most limit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := bytes.Repeat([]byte("x"), maxDiscardBodyBytes+128)
|
||||
body := &trackingBody{data: bytes.NewReader(payload)}
|
||||
|
||||
discardAndCloseBody(body)
|
||||
|
||||
if !body.closed {
|
||||
t.Fatal("body was not closed")
|
||||
}
|
||||
if body.readN != maxDiscardBodyBytes {
|
||||
t.Fatalf("bytes read = %d, want %d", body.readN, maxDiscardBodyBytes)
|
||||
}
|
||||
})
|
||||
}
|
||||
235
internal/proxy/server_test.go
Normal file
235
internal/proxy/server_test.go
Normal file
@ -0,0 +1,235 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
// NOTE: Run these tests with -race; they cover concurrent state.
|
||||
|
||||
type closeTrackingConn struct {
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (c *closeTrackingConn) Read(_ []byte) (int, error) { return 0, net.ErrClosed }
|
||||
func (c *closeTrackingConn) Write(b []byte) (int, error) { return len(b), nil }
|
||||
func (c *closeTrackingConn) Close() error { c.mu.Lock(); c.closed = true; c.mu.Unlock(); return nil }
|
||||
func (c *closeTrackingConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
||||
func (c *closeTrackingConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
||||
func (c *closeTrackingConn) SetDeadline(_ time.Time) error { return nil }
|
||||
func (c *closeTrackingConn) SetReadDeadline(_ time.Time) error { return nil }
|
||||
func (c *closeTrackingConn) SetWriteDeadline(_ time.Time) error { return nil }
|
||||
|
||||
func (c *closeTrackingConn) isClosed() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.closed
|
||||
}
|
||||
|
||||
func TestNewProxyServer_Success(t *testing.T) {
|
||||
configRoot := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
||||
|
||||
ch := make(chan model.Event, 16)
|
||||
ps, err := NewProxyServer("127.0.0.1:0", ch)
|
||||
if err != nil {
|
||||
t.Fatalf("NewProxyServer() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { Destroy(ps, ch) })
|
||||
|
||||
if ps.Listener == nil || *ps.Listener == nil {
|
||||
t.Fatal("Listener is nil")
|
||||
}
|
||||
if ps.Server == nil {
|
||||
t.Fatal("Server is nil")
|
||||
}
|
||||
if ps.Server.Handler == nil {
|
||||
t.Fatal("Server.Handler is nil")
|
||||
}
|
||||
if !strings.HasPrefix(ps.Url, "http://") {
|
||||
t.Fatalf("Url = %q, want http:// prefix", ps.Url)
|
||||
}
|
||||
if !ps.CAReady {
|
||||
t.Fatal("CAReady = false, want true")
|
||||
}
|
||||
if ps.CACertPath == "" {
|
||||
t.Fatal("CACertPath should not be empty")
|
||||
}
|
||||
if got, want := filepath.Dir(ps.CACertPath), filepath.Join(configRoot, caDirName); got != want {
|
||||
t.Fatalf("CACertPath dir = %q, want %q", got, want)
|
||||
}
|
||||
if _, statErr := os.Stat(ps.CACertPath); statErr != nil {
|
||||
t.Fatalf("expected CA cert on disk, stat error = %v", statErr)
|
||||
}
|
||||
if ps.Conns == nil {
|
||||
t.Fatal("Conns map is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProxyServer_CACreatedFlagAcrossRuns(t *testing.T) {
|
||||
configRoot := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
||||
|
||||
ch := make(chan model.Event, 16)
|
||||
ps1, err := NewProxyServer("127.0.0.1:0", ch)
|
||||
if err != nil {
|
||||
t.Fatalf("first NewProxyServer() error = %v", err)
|
||||
}
|
||||
Destroy(ps1, ch)
|
||||
|
||||
ps2, err := NewProxyServer("127.0.0.1:0", ch)
|
||||
if err != nil {
|
||||
t.Fatalf("second NewProxyServer() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { Destroy(ps2, ch) })
|
||||
|
||||
if !ps1.CACreated {
|
||||
t.Fatal("first NewProxyServer should report CACreated=true")
|
||||
}
|
||||
if ps2.CACreated {
|
||||
t.Fatal("second NewProxyServer should report CACreated=false when CA already exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProxyServer_ErrorWhenCASetupFails(t *testing.T) {
|
||||
t.Setenv("XDG_CONFIG_HOME", "")
|
||||
t.Setenv("HOME", "")
|
||||
|
||||
ch := make(chan model.Event, 8)
|
||||
ps, err := NewProxyServer("127.0.0.1:0", ch)
|
||||
if err == nil {
|
||||
Destroy(ps, ch)
|
||||
t.Fatal("NewProxyServer() error = nil, want non-nil")
|
||||
}
|
||||
if ps != nil {
|
||||
t.Fatalf("proxy server = %#v, want nil", ps)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewProxyServer_ErrorWhenListenFails(t *testing.T) {
|
||||
configRoot := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
||||
|
||||
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("net.Listen() error = %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = occupied.Close() })
|
||||
|
||||
addr := occupied.Addr().String()
|
||||
ch := make(chan model.Event, 8)
|
||||
ps, gotErr := NewProxyServer(addr, ch)
|
||||
if gotErr == nil {
|
||||
Destroy(ps, ch)
|
||||
t.Fatal("NewProxyServer() error = nil, want non-nil")
|
||||
}
|
||||
if ps != nil {
|
||||
t.Fatalf("proxy server = %#v, want nil", ps)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDestroy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("nil-safe", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
Destroy(nil, make(chan model.Event, 1))
|
||||
})
|
||||
|
||||
t.Run("emits ProxyStopped when server exists", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 2)
|
||||
ps := &model.ProxyServer{Server: &http.Server{}, Conns: make(map[net.Conn]struct{})}
|
||||
|
||||
Destroy(ps, ch)
|
||||
|
||||
select {
|
||||
case ev := <-ch:
|
||||
if ev.Type != model.EventTypeProxyStopped {
|
||||
t.Fatalf("event type = %s, want %s", ev.Type, model.EventTypeProxyStopped)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout waiting for ProxyStopped event")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestConnectionTrackingHelpers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("track and untrack are nil-safe", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
trackConnection(nil, nil)
|
||||
untrackConnection(nil, nil)
|
||||
closeTrackedConnections(nil)
|
||||
})
|
||||
|
||||
t.Run("track/untrack mutate connection set", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
||||
c := &closeTrackingConn{}
|
||||
|
||||
trackConnection(ps, c)
|
||||
if _, ok := ps.Conns[c]; !ok {
|
||||
t.Fatal("connection not tracked")
|
||||
}
|
||||
|
||||
untrackConnection(ps, c)
|
||||
if _, ok := ps.Conns[c]; ok {
|
||||
t.Fatal("connection still tracked after untrack")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("closeTrackedConnections closes all tracked conns", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
||||
c1 := &closeTrackingConn{}
|
||||
c2 := &closeTrackingConn{}
|
||||
ps.Conns[c1] = struct{}{}
|
||||
ps.Conns[c2] = struct{}{}
|
||||
|
||||
closeTrackedConnections(ps)
|
||||
|
||||
if !c1.isClosed() || !c2.isClosed() {
|
||||
t.Fatalf("expected all tracked conns closed, got c1=%v c2=%v", c1.isClosed(), c2.isClosed())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDestroy_ShutdownsServerContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Basic smoke test: ensure a real server can be shut down via Destroy.
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen error = %v", err)
|
||||
}
|
||||
|
||||
srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("ok")) })}
|
||||
go func() {
|
||||
_ = srv.Serve(ln)
|
||||
}()
|
||||
|
||||
ch := make(chan model.Event, 2)
|
||||
ps := &model.ProxyServer{Server: srv, Conns: make(map[net.Conn]struct{})}
|
||||
Destroy(ps, ch)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := srv.Shutdown(ctx); err != nil && err != http.ErrServerClosed {
|
||||
t.Fatalf("server should be closed after Destroy, got shutdown error %v", err)
|
||||
}
|
||||
}
|
||||
77
internal/proxy/test_helpers_test.go
Normal file
77
internal/proxy/test_helpers_test.go
Normal file
@ -0,0 +1,77 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"math/big"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event {
|
||||
t.Helper()
|
||||
|
||||
events := make([]model.Event, 0, n)
|
||||
deadline := time.After(timeout)
|
||||
for len(events) < n {
|
||||
select {
|
||||
case ev := <-ch:
|
||||
events = append(events, ev)
|
||||
case <-deadline:
|
||||
t.Fatalf("timeout waiting for %d events, got %d", n, len(events))
|
||||
}
|
||||
}
|
||||
|
||||
return events
|
||||
}
|
||||
|
||||
func hasEventType(events []model.Event, typ model.EventType) bool {
|
||||
for _, ev := range events {
|
||||
if ev.Type == typ {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func newTestCA(t *testing.T) *CertificateAuthority {
|
||||
t.Helper()
|
||||
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateKey() error = %v", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(1),
|
||||
Subject: pkix.Name{CommonName: "test-ca"},
|
||||
NotBefore: now.Add(-time.Minute),
|
||||
NotAfter: now.Add(time.Hour),
|
||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
|
||||
IsCA: true,
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateCertificate() error = %v", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseCertificate() error = %v", err)
|
||||
}
|
||||
|
||||
return &CertificateAuthority{
|
||||
cert: cert,
|
||||
key: key,
|
||||
leafCert: make(map[string]*tls.Certificate),
|
||||
}
|
||||
}
|
||||
256
internal/proxy/utils_test.go
Normal file
256
internal/proxy/utils_test.go
Normal file
@ -0,0 +1,256 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type timeoutErr struct{}
|
||||
|
||||
func (timeoutErr) Error() string { return "timeout" }
|
||||
func (timeoutErr) Timeout() bool { return true }
|
||||
func (timeoutErr) Temporary() bool { return true }
|
||||
|
||||
type errWriter struct{}
|
||||
|
||||
func (errWriter) Write(_ []byte) (int, error) {
|
||||
return 0, errors.New("write failed")
|
||||
}
|
||||
|
||||
func TestCanDisplayContent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
contentType string
|
||||
want bool
|
||||
}{
|
||||
{name: "empty content type", contentType: "", want: false},
|
||||
{name: "text type", contentType: "text/plain", want: true},
|
||||
{name: "json type", contentType: "application/json", want: true},
|
||||
{name: "xml suffix", contentType: "application/problem+xml", want: true},
|
||||
{name: "unknown binary", contentType: "application/octet-stream", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := canDisplayContent(tt.contentType); got != tt.want {
|
||||
t.Fatalf("canDisplayContent(%q) = %v, want %v", tt.contentType, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
headers http.Header
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty returns none",
|
||||
headers: http.Header{},
|
||||
want: "<none>",
|
||||
},
|
||||
{
|
||||
name: "sorts keys stably",
|
||||
headers: http.Header{
|
||||
"B-Key": {"b1"},
|
||||
"A-Key": {"a1", "a2"},
|
||||
},
|
||||
want: `A-Key="a1,a2", B-Key="b1"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := formatHeaders(tt.headers); got != tt.want {
|
||||
t.Fatalf("formatHeaders() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetEndOfUUID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
id := uuid.MustParse("123e4567-e89b-12d3-a456-426614174000")
|
||||
if got, want := getEndOfUUID(id), "426614174000"; got != want {
|
||||
t.Fatalf("getEndOfUUID() = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStatusFromUpstreamError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
newReq := func(ctx context.Context) *http.Request {
|
||||
reqURL, err := url.Parse("http://example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("url parse failed: %v", err)
|
||||
}
|
||||
return (&http.Request{Method: http.MethodGet, URL: reqURL}).WithContext(ctx)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
req *http.Request
|
||||
resp *http.Response
|
||||
err error
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "prefers upstream response status",
|
||||
req: newReq(context.Background()),
|
||||
resp: &http.Response{StatusCode: http.StatusTeapot},
|
||||
err: errors.New("ignored"),
|
||||
want: http.StatusTeapot,
|
||||
},
|
||||
{
|
||||
name: "context canceled maps to bad gateway",
|
||||
req: func() *http.Request {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
return newReq(ctx)
|
||||
}(),
|
||||
err: context.Canceled,
|
||||
want: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "deadline exceeded maps to gateway timeout",
|
||||
req: newReq(context.Background()),
|
||||
err: context.DeadlineExceeded,
|
||||
want: http.StatusGatewayTimeout,
|
||||
},
|
||||
{
|
||||
name: "net timeout maps to gateway timeout",
|
||||
req: newReq(context.Background()),
|
||||
err: timeoutErr{},
|
||||
want: http.StatusGatewayTimeout,
|
||||
},
|
||||
{
|
||||
name: "default maps to bad gateway",
|
||||
req: newReq(context.Background()),
|
||||
err: errors.New("dial failed"),
|
||||
want: http.StatusBadGateway,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := statusFromUpstreamError(tt.req, tt.resp, tt.err); got != tt.want {
|
||||
t.Fatalf("statusFromUpstreamError() = %d, want %d", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUpstreamTransport(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := newUpstreamTransport()
|
||||
transport, ok := got.(*http.Transport)
|
||||
if !ok {
|
||||
t.Fatalf("newUpstreamTransport() type = %T, want *http.Transport", got)
|
||||
}
|
||||
|
||||
if transport.Proxy != nil {
|
||||
t.Fatal("newUpstreamTransport() Proxy must be nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWritePlainHTTPError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
status int
|
||||
writer *bufio.Writer
|
||||
wantErr bool
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "writes valid response and flushes",
|
||||
status: http.StatusBadGateway,
|
||||
writer: bufio.NewWriter(&strings.Builder{}),
|
||||
wantErr: false,
|
||||
wantStatus: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "returns write error",
|
||||
status: http.StatusBadGateway,
|
||||
writer: bufio.NewWriter(errWriter{}),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "returns response write error when writer already failed",
|
||||
status: http.StatusBadGateway,
|
||||
writer: bufio.NewWriter(errWriter{}),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var sb strings.Builder
|
||||
w := tt.writer
|
||||
if !tt.wantErr {
|
||||
w = bufio.NewWriter(&sb)
|
||||
}
|
||||
|
||||
err := writePlainHTTPError(w, tt.status)
|
||||
if tt.name == "returns response write error when writer already failed" {
|
||||
_ = w.Flush() // set sticky writer error so resp.Write fails immediately
|
||||
err = writePlainHTTPError(w, tt.status)
|
||||
}
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Fatalf("writePlainHTTPError() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
resp, readErr := http.ReadResponse(bufio.NewReader(strings.NewReader(sb.String())), &http.Request{Method: http.MethodGet})
|
||||
if readErr != nil {
|
||||
t.Fatalf("ReadResponse() error = %v", readErr)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Fatalf("status = %d, want %d", resp.StatusCode, tt.wantStatus)
|
||||
}
|
||||
|
||||
if gotCT := resp.Header.Get("Content-Type"); gotCT != "text/plain; charset=utf-8" {
|
||||
t.Fatalf("Content-Type = %q, want %q", gotCT, "text/plain; charset=utf-8")
|
||||
}
|
||||
|
||||
wantBody := http.StatusText(tt.status)
|
||||
if gotCL := resp.Header.Get("Content-Length"); gotCL != strconv.Itoa(len(wantBody)) {
|
||||
t.Fatalf("Content-Length = %q, want %q", gotCL, strconv.Itoa(len(wantBody)))
|
||||
}
|
||||
|
||||
body, bodyErr := io.ReadAll(resp.Body)
|
||||
if bodyErr != nil {
|
||||
t.Fatalf("ReadAll(body) error = %v", bodyErr)
|
||||
}
|
||||
if string(body) != wantBody {
|
||||
t.Fatalf("body = %q, want %q", string(body), wantBody)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
361
internal/tui/model_update_test.go
Normal file
361
internal/tui/model_update_test.go
Normal file
@ -0,0 +1,361 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/google/uuid"
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
func TestNewModelDefaults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event)
|
||||
m := NewModel(ch, Controls{})
|
||||
|
||||
if m.channel != ch {
|
||||
t.Fatal("channel not set")
|
||||
}
|
||||
if len(m.events) != 0 || len(m.requests) != 0 {
|
||||
t.Fatal("events/requests should initialize empty")
|
||||
}
|
||||
if m.width != 0 || m.height != 0 {
|
||||
t.Fatal("width/height should initialize zero")
|
||||
}
|
||||
if m.showEvents || m.showStd || m.showSearch || m.restarting {
|
||||
t.Fatal("toggle flags should initialize false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitBatchesEventAndTick(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event)
|
||||
m := NewModel(ch, Controls{})
|
||||
|
||||
cmd := m.Init()
|
||||
if cmd == nil {
|
||||
t.Fatal("Init() returned nil cmd")
|
||||
}
|
||||
if _, ok := cmd().(tea.BatchMsg); !ok {
|
||||
t.Fatalf("Init cmd message type = %T, want tea.BatchMsg", cmd())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWaitForEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("returns EventMsg when channel has value", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event, 1)
|
||||
ch <- model.Event{Type: model.EventTypeWarn, Body: "hello"}
|
||||
|
||||
msg := waitForEvent(ch)()
|
||||
ev, ok := msg.(EventMsg)
|
||||
if !ok {
|
||||
t.Fatalf("msg type = %T, want EventMsg", msg)
|
||||
}
|
||||
if ev.value.Body != "hello" {
|
||||
t.Fatalf("event body = %q, want %q", ev.value.Body, "hello")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns ErrMsg when channel closed", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ch := make(chan model.Event)
|
||||
close(ch)
|
||||
|
||||
msg := waitForEvent(ch)()
|
||||
if _, ok := msg.(ErrMsg); !ok {
|
||||
t.Fatalf("msg type = %T, want ErrMsg", msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMessagesCommands(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("restartCmd nil restart returns nil", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if cmd := restartCmd(nil); cmd != nil {
|
||||
t.Fatal("restartCmd(nil) should return nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("restartCmd wraps restart result", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
wantErr := errors.New("boom")
|
||||
msg := restartCmd(func() error { return wantErr })()
|
||||
rm, ok := msg.(RestartResultMsg)
|
||||
if !ok {
|
||||
t.Fatalf("msg type = %T, want RestartResultMsg", msg)
|
||||
}
|
||||
if !errors.Is(rm.err, wantErr) {
|
||||
t.Fatalf("restart result error = %v, want %v", rm.err, wantErr)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tickCmd emits TickMsg", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cmd := tickCmd()
|
||||
if cmd == nil {
|
||||
t.Fatal("tickCmd returned nil")
|
||||
}
|
||||
|
||||
msgCh := make(chan tea.Msg, 1)
|
||||
go func() { msgCh <- cmd() }()
|
||||
|
||||
select {
|
||||
case msg := <-msgCh:
|
||||
if _, ok := msg.(TickMsg); !ok {
|
||||
t.Fatalf("msg type = %T, want TickMsg", msg)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Fatal("timeout waiting for tick message")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("window size updates dimensions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewModel(make(chan model.Event), Controls{})
|
||||
next, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
|
||||
got := next.(Model)
|
||||
if got.width != 120 || got.height != 40 {
|
||||
t.Fatalf("dimensions = (%d,%d), want (120,40)", got.width, got.height)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tick updates now and reschedules only with pending requests", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
now := time.Now()
|
||||
|
||||
m1 := NewModel(make(chan model.Event), Controls{})
|
||||
next1, cmd1 := m1.Update(TickMsg{Now: now})
|
||||
got1 := next1.(Model)
|
||||
if !got1.now.Equal(now) {
|
||||
t.Fatal("tick should update now")
|
||||
}
|
||||
if cmd1 != nil {
|
||||
t.Fatal("tick without pending requests should not reschedule")
|
||||
}
|
||||
|
||||
m2 := NewModel(make(chan model.Event), Controls{})
|
||||
m2.requests = append(m2.requests, model.Request{ID: uuid.New(), Pending: true})
|
||||
_, cmd2 := m2.Update(TickMsg{Now: now})
|
||||
if cmd2 == nil {
|
||||
t.Fatal("tick with pending requests should reschedule")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("key handling toggles and quit", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewModel(make(chan model.Event), Controls{})
|
||||
next, quitCmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("q")})
|
||||
_ = next
|
||||
if quitCmd == nil {
|
||||
t.Fatal("q should return quit cmd")
|
||||
}
|
||||
|
||||
nextCtrlC, quitCtrlC := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
|
||||
_ = nextCtrlC
|
||||
if quitCtrlC == nil {
|
||||
t.Fatal("ctrl+c should return quit cmd")
|
||||
}
|
||||
|
||||
next2, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("e")})
|
||||
if !next2.(Model).showEvents {
|
||||
t.Fatal("e should toggle showEvents")
|
||||
}
|
||||
|
||||
next3, _ := next2.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("o")})
|
||||
if !next3.(Model).showStd {
|
||||
t.Fatal("o should toggle showStd")
|
||||
}
|
||||
|
||||
next4, _ := next3.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("/")})
|
||||
if !next4.(Model).showSearch {
|
||||
t.Fatal("/ should enable search")
|
||||
}
|
||||
|
||||
next5, _ := next4.(Model).Update(tea.KeyMsg{Type: tea.KeyEsc})
|
||||
if next5.(Model).showSearch {
|
||||
t.Fatal("esc should disable search")
|
||||
}
|
||||
|
||||
next6, cmd6 := next5.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("x")})
|
||||
if cmd6 != nil {
|
||||
t.Fatal("unknown key should not return command")
|
||||
}
|
||||
if next6.(Model).showEvents != next5.(Model).showEvents ||
|
||||
next6.(Model).showStd != next5.(Model).showStd ||
|
||||
next6.(Model).showSearch != next5.(Model).showSearch {
|
||||
t.Fatal("unknown key should not alter toggle state")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("restart key guarded by state and control fn", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewModel(make(chan model.Event), Controls{})
|
||||
next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlR})
|
||||
if cmd != nil {
|
||||
t.Fatal("ctrl+r with nil restart control should return nil cmd")
|
||||
}
|
||||
if next.(Model).restarting {
|
||||
t.Fatal("restarting should remain false when restart control missing")
|
||||
}
|
||||
|
||||
m2 := NewModel(make(chan model.Event), Controls{Restart: func() error { return nil }})
|
||||
next2, cmd2 := m2.Update(tea.KeyMsg{Type: tea.KeyCtrlR})
|
||||
if cmd2 == nil {
|
||||
t.Fatal("ctrl+r with restart control should return cmd")
|
||||
}
|
||||
if !next2.(Model).restarting {
|
||||
t.Fatal("restarting should be true after ctrl+r")
|
||||
}
|
||||
|
||||
next3, cmd3 := next2.(Model).Update(tea.KeyMsg{Type: tea.KeyCtrlR})
|
||||
if cmd3 != nil {
|
||||
t.Fatal("ctrl+r while restarting should return nil cmd")
|
||||
}
|
||||
if !next3.(Model).restarting {
|
||||
t.Fatal("restarting should stay true while guarded")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ErrMsg pushes warning event", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewModel(make(chan model.Event), Controls{})
|
||||
next, _ := m.Update(ErrMsg{err: errors.New("closed")})
|
||||
got := next.(Model)
|
||||
if len(got.events) != 1 {
|
||||
t.Fatalf("event len = %d, want 1", len(got.events))
|
||||
}
|
||||
if got.events[0].Type != model.EventTypeWarn {
|
||||
t.Fatalf("event type = %s, want %s", got.events[0].Type, model.EventTypeWarn)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("RestartResultMsg clears restarting and warns on error", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewModel(make(chan model.Event), Controls{})
|
||||
m.restarting = true
|
||||
|
||||
next1, _ := m.Update(RestartResultMsg{err: nil})
|
||||
if next1.(Model).restarting {
|
||||
t.Fatal("restarting should clear on restart result")
|
||||
}
|
||||
|
||||
next2, _ := m.Update(RestartResultMsg{err: errors.New("fail")})
|
||||
got2 := next2.(Model)
|
||||
if len(got2.events) != 1 || got2.events[0].Type != model.EventTypeWarn {
|
||||
t.Fatalf("expected warn event on restart error, got %#v", got2.events)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EventMsg updates events and requests", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ch := make(chan model.Event, 2)
|
||||
m := NewModel(ch, Controls{})
|
||||
|
||||
reqID := uuid.New()
|
||||
startEv := EventMsg{value: model.Event{Type: model.EventTypeRequestStarted, Request: model.Request{ID: reqID, Method: http.MethodGet, Pending: true}}}
|
||||
next1, cmd1 := m.Update(startEv)
|
||||
if cmd1 == nil {
|
||||
t.Fatal("EventMsg should return wait/tick cmd")
|
||||
}
|
||||
got1 := next1.(Model)
|
||||
if len(got1.events) != 1 || len(got1.requests) != 1 {
|
||||
t.Fatalf("expected one event and one request, got events=%d requests=%d", len(got1.events), len(got1.requests))
|
||||
}
|
||||
|
||||
finishReq := got1.requests[0]
|
||||
finishReq.Pending = false
|
||||
finishReq.Status = 200
|
||||
finishEv := EventMsg{value: model.Event{Type: model.EventTypeRequestFinished, Request: finishReq}}
|
||||
next2, cmd2 := got1.Update(finishEv)
|
||||
got2 := next2.(Model)
|
||||
if got2.requests[0].Pending {
|
||||
t.Fatal("request should be updated to non-pending")
|
||||
}
|
||||
if got2.requests[0].Status != 200 {
|
||||
t.Fatalf("request status = %d, want 200", got2.requests[0].Status)
|
||||
}
|
||||
|
||||
if cmd2 == nil {
|
||||
t.Fatal("expected waitForEvent cmd after finished request")
|
||||
}
|
||||
ch <- model.Event{Type: model.EventTypeWarn, Body: "next"}
|
||||
msg := cmd2()
|
||||
if _, ok := msg.(EventMsg); !ok {
|
||||
t.Fatalf("cmd2 message type = %T, want EventMsg", msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelHelpers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("pushEvent trims to maxEvents", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewModel(make(chan model.Event), Controls{})
|
||||
for i := 0; i < maxEvents+5; i++ {
|
||||
m.pushEvent(model.Event{Body: "x"})
|
||||
}
|
||||
if len(m.events) != maxEvents {
|
||||
t.Fatalf("events len = %d, want %d", len(m.events), maxEvents)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("createRequest ignores CONNECT and trims", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewModel(make(chan model.Event), Controls{})
|
||||
m.createRequest(model.Request{Method: http.MethodConnect})
|
||||
if len(m.requests) != 0 {
|
||||
t.Fatal("CONNECT request should be ignored")
|
||||
}
|
||||
|
||||
for i := 0; i < maxRequests+3; i++ {
|
||||
m.createRequest(model.Request{ID: uuid.New(), Method: http.MethodGet})
|
||||
}
|
||||
if len(m.requests) != maxRequests {
|
||||
t.Fatalf("requests len = %d, want %d", len(m.requests), maxRequests)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("updateRequest updates only matching request", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
id1 := uuid.New()
|
||||
id2 := uuid.New()
|
||||
m := NewModel(make(chan model.Event), Controls{})
|
||||
m.requests = []model.Request{{ID: id1, Status: 100}, {ID: id2, Status: 101}}
|
||||
|
||||
m.updateRequest(model.Request{ID: id2, Status: 202})
|
||||
if m.requests[0].Status != 100 || m.requests[1].Status != 202 {
|
||||
t.Fatalf("unexpected statuses after update: %#v", m.requests)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("hasPendingRequests true and false", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
m := NewModel(make(chan model.Event), Controls{})
|
||||
if m.hasPendingRequests() {
|
||||
t.Fatal("empty model should not have pending requests")
|
||||
}
|
||||
m.requests = []model.Request{{ID: uuid.New(), Pending: false}, {ID: uuid.New(), Pending: true}}
|
||||
if !m.hasPendingRequests() {
|
||||
t.Fatal("expected pending requests to be true")
|
||||
}
|
||||
})
|
||||
}
|
||||
245
internal/tui/view_split_panes_style_test.go
Normal file
245
internal/tui/view_split_panes_style_test.go
Normal file
@ -0,0 +1,245 @@
|
||||
package tui
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
func TestFormatDuration(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
in time.Duration
|
||||
want string
|
||||
}{
|
||||
{name: "pending zero", in: 0, want: "PENDING"},
|
||||
{name: "microseconds", in: 750 * time.Microsecond, want: "750us"},
|
||||
{name: "milliseconds", in: 20 * time.Millisecond, want: "20ms"},
|
||||
{name: "seconds", in: 11 * time.Second, want: "11.00s"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := formatDuration(tt.in); got != tt.want {
|
||||
t.Fatalf("formatDuration(%v) = %q, want %q", tt.in, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
s string
|
||||
max int
|
||||
want string
|
||||
}{
|
||||
{name: "short unchanged", s: "abc", max: 3, want: "abc"},
|
||||
{name: "max small no ellipsis", s: "abcdef", max: 3, want: "abc"},
|
||||
{name: "ellipsis", s: "abcdef", max: 5, want: "ab..."},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
if got := truncate(tt.s, tt.max); got != tt.want {
|
||||
t.Fatalf("truncate(%q,%d) = %q, want %q", tt.s, tt.max, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClampRendered(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := clampRendered("abcdef", 0); got != "" {
|
||||
t.Fatalf("clampRendered max=0 = %q, want empty", got)
|
||||
}
|
||||
if got := clampRendered("abc", 10); got != "abc" {
|
||||
t.Fatalf("clampRendered no truncation = %q, want %q", got, "abc")
|
||||
}
|
||||
if got := clampRendered("abcdef", 4); !strings.Contains(got, "...") {
|
||||
t.Fatalf("clampRendered truncation should include ellipsis, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetEventColor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
theme := newTheme()
|
||||
tests := []struct {
|
||||
name string
|
||||
typ model.EventType
|
||||
want string
|
||||
}{
|
||||
{name: "session", typ: model.EventTypeSessionStarted, want: theme.EventSession.Render("x")},
|
||||
{name: "proxy", typ: model.EventTypeProxyStarted, want: theme.EventProxy.Render("x")},
|
||||
{name: "request in flight", typ: model.EventTypeRequestStarted, want: theme.EventRequestInFlight.Render("x")},
|
||||
{name: "request success", typ: model.EventTypeRequestFinished, want: theme.EventSuccess.Render("x")},
|
||||
{name: "warn", typ: model.EventTypeWarn, want: theme.EventWarn.Render("x")},
|
||||
{name: "error", typ: model.EventTypeRequestFailed, want: theme.EventError.Render("x")},
|
||||
{name: "fatal", typ: model.EventTypeFatal, want: theme.EventFatal.Render("x")},
|
||||
{name: "default", typ: model.EventType("UnknownType"), want: theme.EventDefault.Render("x")},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := getEventColor(theme, tt.typ).Render("x")
|
||||
if got != tt.want {
|
||||
t.Fatalf("unexpected style for %s", tt.typ)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestViewAndPaneStructure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewModel(make(chan model.Event), Controls{})
|
||||
|
||||
t.Run("view returns raw pane when unset size", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := m.View()
|
||||
if got != m.renderAppPane() {
|
||||
t.Fatal("View should return raw pane when width/height are unset")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("renderAppPane line count matches height", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
m2 := m
|
||||
m2.width = 80
|
||||
m2.height = 12
|
||||
got := m2.renderAppPane()
|
||||
if got == "height of request and details did not match" || got == "height of screen does not match terminal height" {
|
||||
t.Fatalf("unexpected renderAppPane invariant error: %q", got)
|
||||
}
|
||||
if lines := strings.Count(got, "\n") + 1; lines != m2.height {
|
||||
t.Fatalf("line count = %d, want %d", lines, m2.height)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("renderAppPane supports toggles", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
m2 := m
|
||||
m2.width = 90
|
||||
m2.height = 14
|
||||
m2.showEvents = true
|
||||
m2.showStd = true
|
||||
m2.showSearch = true
|
||||
got := m2.renderAppPane()
|
||||
if got == "height of request and details did not match" || got == "height of screen does not match terminal height" {
|
||||
t.Fatalf("unexpected renderAppPane invariant error with toggles: %q", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("view applies configured terminal height", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
m2 := m
|
||||
m2.width = 70
|
||||
m2.height = 10
|
||||
|
||||
got := m2.View()
|
||||
if lines := strings.Count(got, "\n") + 1; lines < m2.height {
|
||||
t.Fatalf("View line count = %d, want at least %d", lines, m2.height)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPaneRenderersAndStatusBar(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewModel(make(chan model.Event), Controls{})
|
||||
m.width = 100
|
||||
m.height = 12
|
||||
m.requests = []model.Request{
|
||||
{ID: uuid.New(), Method: "GET", Host: "a", URL: "/a", Status: 200, Duration: 5 * time.Millisecond},
|
||||
{ID: uuid.New(), Method: "POST", Host: "b", URL: "/b", Status: 500, Duration: 10 * time.Millisecond, Failed: true},
|
||||
}
|
||||
m.events = []model.Event{
|
||||
{Type: model.EventTypeWarn, Body: "warn"},
|
||||
{Type: model.EventTypeProcessStdout, Body: "out"},
|
||||
{Type: model.EventTypeProcessStderr, Body: "err"},
|
||||
}
|
||||
|
||||
status := m.renderStatusBar(100)
|
||||
if !strings.Contains(status, "2 reqs") || !strings.Contains(status, "1 err") {
|
||||
t.Fatalf("status bar missing expected counters: %q", status)
|
||||
}
|
||||
|
||||
search := m.renderSearchPane(20, 3)
|
||||
if len(search) != 3 {
|
||||
t.Fatalf("search pane len = %d, want 3", len(search))
|
||||
}
|
||||
for i, line := range search {
|
||||
if len(line) != 20 {
|
||||
t.Fatalf("search pane line %d len = %d, want %d", i, len(line), 20)
|
||||
}
|
||||
}
|
||||
|
||||
requests := m.renderRequestPane(50, 4)
|
||||
if len(requests) != 4 {
|
||||
t.Fatalf("request pane len = %d, want 4", len(requests))
|
||||
}
|
||||
|
||||
details := m.renderDetailsPane(30, 4)
|
||||
if len(details) != 4 {
|
||||
t.Fatalf("details pane len = %d, want 4", len(details))
|
||||
}
|
||||
|
||||
events := m.renderEventsPane(60, 4)
|
||||
if len(events) != 4 {
|
||||
t.Fatalf("events pane len = %d, want 4", len(events))
|
||||
}
|
||||
if strings.Contains(strings.Join(events, "\n"), "out") || strings.Contains(strings.Join(events, "\n"), "err") {
|
||||
t.Fatal("events pane should filter stdout/stderr events")
|
||||
}
|
||||
|
||||
std := m.renderStdPane(60, 4)
|
||||
if len(std) != 4 {
|
||||
t.Fatalf("std pane len = %d, want 4", len(std))
|
||||
}
|
||||
joined := strings.Join(std, "\n")
|
||||
if !strings.Contains(joined, "out") || !strings.Contains(joined, "err") {
|
||||
t.Fatal("std pane should include stdout/stderr logs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderEventsPane_ErrorAndPIDBranches(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
m := NewModel(make(chan model.Event), Controls{})
|
||||
m.events = []model.Event{
|
||||
{Type: model.EventTypeWarn, Body: "old"},
|
||||
{Type: model.EventTypeRequestFailed, Body: "failed body", PID: 123, Time: time.Now()},
|
||||
{Type: model.EventTypeFatal, Body: "fatal body", Time: time.Now()},
|
||||
}
|
||||
|
||||
lines := m.renderEventsPane(60, 3)
|
||||
if len(lines) != 3 {
|
||||
t.Fatalf("events pane len = %d, want 3", len(lines))
|
||||
}
|
||||
|
||||
joined := strings.Join(lines, "\n")
|
||||
if !strings.Contains(joined, "123") {
|
||||
t.Fatalf("expected PID to appear in events pane, got: %q", joined)
|
||||
}
|
||||
if !strings.Contains(joined, "failed body") {
|
||||
t.Fatalf("expected failed body to appear in events pane, got: %q", joined)
|
||||
}
|
||||
if !strings.Contains(joined, "fatal body") {
|
||||
t.Fatalf("expected fatal body to appear in events pane, got: %q", joined)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user