Compare commits
No commits in common. "51d526c2fea0dac6eb91829625d8cf28a1b6794f" and "f87d4fc0406ba588b2109f57607f7e21dbc0de8c" have entirely different histories.
51d526c2fe
...
f87d4fc040
@ -1,45 +0,0 @@
|
|||||||
# termtap Test Coverage Summary
|
|
||||||
|
|
||||||
Generated from:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
go test -coverprofile=/tmp/termtap.cover ./...
|
|
||||||
go tool cover -func=/tmp/termtap.cover
|
|
||||||
```
|
|
||||||
|
|
||||||
## Package coverage snapshot
|
|
||||||
|
|
||||||
| Package | Coverage |
|
|
||||||
|---|---:|
|
|
||||||
| `cmd/tap` | 100.0% |
|
|
||||||
| `internal/app` | 98.1% |
|
|
||||||
| `internal/cli` | 93.4% |
|
|
||||||
| `internal/process` | 95.8% |
|
|
||||||
| `internal/proxy` | 90.0% |
|
|
||||||
| `internal/tui` | 96.2% |
|
|
||||||
| `examples/echo` | 0.0% (example app; intentionally not covered) |
|
|
||||||
| `internal/model` | no test files (pure data structs) |
|
|
||||||
|
|
||||||
Total statements in module: **77.9%**.
|
|
||||||
|
|
||||||
## Notable lower-coverage targets (production code)
|
|
||||||
|
|
||||||
- `internal/proxy/handlers.go:handleConnect` — 75.9%
|
|
||||||
- `internal/proxy/certs.go:writeFileAtomically` — 76.2%
|
|
||||||
- `internal/proxy/certs.go:load` — 90.5%
|
|
||||||
- `internal/proxy/certs.go:create` — 80.8%
|
|
||||||
- `internal/proxy/certs.go:IsTrustedBySystem` — 76.9%
|
|
||||||
- `internal/cli/run.go:runCert` — 87.5%
|
|
||||||
|
|
||||||
## Interpretation
|
|
||||||
|
|
||||||
- Core runtime paths (`internal/app`, `internal/process`, `internal/tui`) are high confidence.
|
|
||||||
- Proxy package has broad behavior coverage including HTTP and HTTPS MITM integration flow, and now clears 90% package coverage.
|
|
||||||
- CLI command routing and fatal/stdout/stderr seams are covered, including `Run` success/error branches.
|
|
||||||
- TUI pane rendering coverage now includes error/PID branch behavior.
|
|
||||||
|
|
||||||
## Next optional improvements
|
|
||||||
|
|
||||||
1. Add deeper CONNECT tunnel write/flush/read failure permutations inside `handleConnect` loop.
|
|
||||||
2. Add additional deterministic edge cases for `IsTrustedBySystem` non-unknown-authority verify failures.
|
|
||||||
3. Optionally add non-unix signal file coverage in CI matrix (currently unix-focused).
|
|
||||||
@ -1,49 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMain_SmokeInvokesCLIRun(t *testing.T) {
|
|
||||||
origArgs := os.Args
|
|
||||||
t.Cleanup(func() { os.Args = origArgs })
|
|
||||||
os.Args = []string{"tap", "invalid"}
|
|
||||||
|
|
||||||
origStderr := os.Stderr
|
|
||||||
r, w, err := os.Pipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("stderr pipe error: %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() {
|
|
||||||
os.Stderr = origStderr
|
|
||||||
_ = r.Close()
|
|
||||||
})
|
|
||||||
os.Stderr = w
|
|
||||||
|
|
||||||
outCh := make(chan string, 1)
|
|
||||||
go func() {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
_, _ = io.Copy(&buf, r)
|
|
||||||
outCh <- buf.String()
|
|
||||||
}()
|
|
||||||
|
|
||||||
main()
|
|
||||||
|
|
||||||
_ = w.Close()
|
|
||||||
os.Stderr = origStderr
|
|
||||||
var got string
|
|
||||||
select {
|
|
||||||
case got = <-outCh:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for stderr capture")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(got, "usage:") {
|
|
||||||
t.Fatalf("stderr missing usage output, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,109 +0,0 @@
|
|||||||
package app
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NOTE: Run with -race; this validates cross-component concurrency.
|
|
||||||
|
|
||||||
func TestSessionIntegration_LifecycleAndRequestEvents(t *testing.T) {
|
|
||||||
addr := freeTCPAddr(t)
|
|
||||||
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("StartSession() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = w.Write([]byte("ok"))
|
|
||||||
}))
|
|
||||||
t.Cleanup(upstream.Close)
|
|
||||||
|
|
||||||
startupEvents := collectUntilTypes(t, s.Events, []model.EventType{
|
|
||||||
model.EventTypeProxyStarting,
|
|
||||||
model.EventTypeProcessStarting,
|
|
||||||
model.EventTypeProcessStarted,
|
|
||||||
}, 3*time.Second)
|
|
||||||
|
|
||||||
proxyURL, err := url.Parse(s.proxy.Url)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("url.Parse(proxy) error = %v", err)
|
|
||||||
}
|
|
||||||
client := &http.Client{
|
|
||||||
Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)},
|
|
||||||
Timeout: 3 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Get(upstream.URL + "/session")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("proxy request error = %v", err)
|
|
||||||
}
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
requestEvents := collectUntilTypes(t, s.Events, []model.EventType{
|
|
||||||
model.EventTypeRequestStarted,
|
|
||||||
model.EventTypeRequestFinished,
|
|
||||||
}, 3*time.Second)
|
|
||||||
|
|
||||||
s.Stop()
|
|
||||||
select {
|
|
||||||
case <-s.proc.Done:
|
|
||||||
case <-time.After(4 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for process stop")
|
|
||||||
}
|
|
||||||
|
|
||||||
shutdownEvents := collectUntilTypes(t, s.Events, []model.EventType{
|
|
||||||
model.EventTypeProxyStopped,
|
|
||||||
model.EventTypeProcessSignaled,
|
|
||||||
model.EventTypeProcessExited,
|
|
||||||
}, 4*time.Second)
|
|
||||||
|
|
||||||
if !isBefore(startupEvents, model.EventTypeProcessStarting, model.EventTypeProcessStarted) {
|
|
||||||
t.Fatalf("expected %s before %s in startup events: %#v", model.EventTypeProcessStarting, model.EventTypeProcessStarted, startupEvents)
|
|
||||||
}
|
|
||||||
if !isBefore(requestEvents, model.EventTypeRequestStarted, model.EventTypeRequestFinished) {
|
|
||||||
t.Fatalf("expected %s before %s in request events: %#v", model.EventTypeRequestStarted, model.EventTypeRequestFinished, requestEvents)
|
|
||||||
}
|
|
||||||
if !isBefore(shutdownEvents, model.EventTypeProcessSignaled, model.EventTypeProcessExited) {
|
|
||||||
t.Fatalf("expected %s before %s in shutdown events: %#v", model.EventTypeProcessSignaled, model.EventTypeProcessExited, shutdownEvents)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSessionIntegration_RestartProcessEmitsLifecycleEvents(t *testing.T) {
|
|
||||||
addr := freeTCPAddr(t)
|
|
||||||
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("StartSession() error = %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() { s.Stop() })
|
|
||||||
|
|
||||||
_ = collectUntilTypes(t, s.Events, []model.EventType{
|
|
||||||
model.EventTypeProcessStarted,
|
|
||||||
model.EventTypeProxyStarting,
|
|
||||||
}, 3*time.Second)
|
|
||||||
|
|
||||||
if err := s.RestartProcess(); err != nil {
|
|
||||||
t.Fatalf("RestartProcess() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
events := collectUntilTypes(t, s.Events, []model.EventType{
|
|
||||||
model.EventTypeProcessRestarting,
|
|
||||||
model.EventTypeProcessSignaled,
|
|
||||||
model.EventTypeProcessExited,
|
|
||||||
model.EventTypeProcessStarting,
|
|
||||||
model.EventTypeProcessStarted,
|
|
||||||
}, 4*time.Second)
|
|
||||||
|
|
||||||
if !isBefore(events, model.EventTypeProcessRestarting, model.EventTypeProcessStarted) {
|
|
||||||
t.Fatalf("expected restarting before process started, got %#v", events)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -4,7 +4,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"sync"
|
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -12,10 +11,6 @@ import (
|
|||||||
"termtap.dev/internal/process"
|
"termtap.dev/internal/process"
|
||||||
)
|
)
|
||||||
|
|
||||||
var killEscalationDelay = 1500 * time.Millisecond
|
|
||||||
var scheduleKillEscalation = time.AfterFunc
|
|
||||||
var killEscalationMu sync.RWMutex
|
|
||||||
|
|
||||||
func StartProcess(cmd model.Command, addr string, ch chan<- model.Event) (*model.Process, error) {
|
func StartProcess(cmd model.Command, addr string, ch chan<- model.Event) (*model.Process, error) {
|
||||||
ch <- model.Event{
|
ch <- model.Event{
|
||||||
Time: time.Now().Local(),
|
Time: time.Now().Local(),
|
||||||
@ -49,16 +44,12 @@ func StopProcess(proc *model.Process, ch chan<- model.Event, sig syscall.Signal)
|
|||||||
|
|
||||||
_ = process.SignalProcess(proc.Exec, sig)
|
_ = process.SignalProcess(proc.Exec, sig)
|
||||||
|
|
||||||
killEscalationMu.RLock()
|
go func() {
|
||||||
delay := killEscalationDelay
|
time.Sleep(1500 * time.Millisecond)
|
||||||
scheduler := scheduleKillEscalation
|
|
||||||
killEscalationMu.RUnlock()
|
|
||||||
|
|
||||||
scheduler(delay, func() {
|
|
||||||
if process.ProcessAlive(proc.Exec) {
|
if process.ProcessAlive(proc.Exec) {
|
||||||
_ = process.SignalProcess(proc.Exec, syscall.SIGKILL)
|
_ = process.SignalProcess(proc.Exec, syscall.SIGKILL)
|
||||||
}
|
}
|
||||||
})
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitForProcessExit(proc *model.Process, ch chan<- model.Event) {
|
func waitForProcessExit(proc *model.Process, ch chan<- model.Event) {
|
||||||
|
|||||||
@ -1,223 +0,0 @@
|
|||||||
package app
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os/exec"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStartProcess(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("starts process and marks running", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 32)
|
|
||||||
proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 0.2"}}, "127.0.0.1:8080", ch)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("StartProcess() error = %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() {
|
|
||||||
StopProcess(proc, ch, syscall.SIGTERM)
|
|
||||||
select {
|
|
||||||
case <-proc.Done:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("timeout waiting process done in cleanup")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if proc == nil || proc.Exec == nil {
|
|
||||||
t.Fatal("StartProcess() returned nil process/exec")
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, time.Second)
|
|
||||||
if !hasType(events, model.EventTypeProcessStarting) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeProcessStarting)
|
|
||||||
}
|
|
||||||
if !hasType(events, model.EventTypeProcessStarted) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeProcessStarted)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns error when exec start fails", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
proc, err := StartProcess(model.Command{Name: "definitely-not-a-real-command"}, "127.0.0.1:8080", ch)
|
|
||||||
if err == nil {
|
|
||||||
if proc != nil && proc.Exec != nil && proc.Exec.Process != nil {
|
|
||||||
_ = proc.Exec.Process.Kill()
|
|
||||||
}
|
|
||||||
t.Fatal("StartProcess() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 1, time.Second)
|
|
||||||
if !hasType(events, model.EventTypeProcessStarting) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeProcessStarting)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStopProcess(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("nil guards", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
StopProcess(nil, make(chan model.Event, 1), syscall.SIGTERM)
|
|
||||||
StopProcess(&model.Process{}, make(chan model.Event, 1), syscall.SIGTERM)
|
|
||||||
StopProcess(&model.Process{Exec: &exec.Cmd{}}, make(chan model.Event, 1), syscall.SIGTERM)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("emits signaled event", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 32)
|
|
||||||
proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, "127.0.0.1:8080", ch)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("StartProcess() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
StopProcess(proc, ch, syscall.SIGTERM)
|
|
||||||
if _, ok := waitForEventType(t, ch, model.EventTypeProcessSignaled, 2*time.Second); !ok {
|
|
||||||
t.Fatalf("did not receive %s event", model.EventTypeProcessSignaled)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-proc.Done:
|
|
||||||
case <-time.After(3 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for process to exit after signal")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("schedules deterministic kill escalation hook", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
origDelay := killEscalationDelay
|
|
||||||
origScheduler := scheduleKillEscalation
|
|
||||||
defer func() {
|
|
||||||
killEscalationMu.Lock()
|
|
||||||
killEscalationDelay = origDelay
|
|
||||||
scheduleKillEscalation = origScheduler
|
|
||||||
killEscalationMu.Unlock()
|
|
||||||
}()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 32)
|
|
||||||
proc, err := StartProcess(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, "127.0.0.1:8080", ch)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("StartProcess() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
killEscalationMu.Lock()
|
|
||||||
killEscalationDelay = 25 * time.Millisecond
|
|
||||||
scheduled := make(chan time.Duration, 1)
|
|
||||||
scheduleKillEscalation = func(d time.Duration, fn func()) *time.Timer {
|
|
||||||
scheduled <- d
|
|
||||||
go fn()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
killEscalationMu.Unlock()
|
|
||||||
|
|
||||||
StopProcess(proc, ch, syscall.SIGTERM)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case d := <-scheduled:
|
|
||||||
if d != killEscalationDelay {
|
|
||||||
t.Fatalf("scheduled delay = %v, want %v", d, killEscalationDelay)
|
|
||||||
}
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("kill escalation was not scheduled")
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-proc.Done:
|
|
||||||
case <-time.After(3 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for process to exit after deterministic escalation")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWaitForProcessExit(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("nil guards are no-op", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
waitForProcessExit(nil, make(chan model.Event, 1))
|
|
||||||
waitForProcessExit(&model.Process{}, make(chan model.Event, 1))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("normal exit emits process exited and closes done", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cmd := exec.Command("sh", "-c", "exit 0")
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
t.Fatalf("cmd.Start() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
proc := &model.Process{Exec: cmd, Running: true, Done: make(chan struct{})}
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
|
|
||||||
waitForProcessExit(proc, ch)
|
|
||||||
|
|
||||||
if _, ok := waitForEventType(t, ch, model.EventTypeProcessExited, time.Second); !ok {
|
|
||||||
t.Fatalf("did not receive %s event", model.EventTypeProcessExited)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-proc.Done:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("Done channel was not closed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("exit error path carries exit code", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cmd := exec.Command("sh", "-c", "exit 7")
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
t.Fatalf("cmd.Start() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
proc := &model.Process{Exec: cmd, Running: true, Done: make(chan struct{})}
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
|
|
||||||
waitForProcessExit(proc, ch)
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, time.Second)
|
|
||||||
found := false
|
|
||||||
for _, ev := range events {
|
|
||||||
if ev.Type == model.EventTypeProcessExited && ev.ExitCode == 7 {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
t.Fatalf("expected ProcessExited with exit code 7, got %#v", events)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-proc.Done:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("Done channel was not closed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("unexpected wait failure emits fatal", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
proc := &model.Process{Exec: &exec.Cmd{}, Done: make(chan struct{})}
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
|
|
||||||
waitForProcessExit(proc, ch)
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 1, time.Second)
|
|
||||||
if events[0].Type != model.EventTypeFatal {
|
|
||||||
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeFatal)
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-proc.Done:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("Done channel was not closed")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -1,183 +0,0 @@
|
|||||||
package app
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type staticErrListener struct {
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *staticErrListener) Accept() (net.Conn, error) { return nil, l.err }
|
|
||||||
func (l *staticErrListener) Close() error { return nil }
|
|
||||||
func (l *staticErrListener) Addr() net.Addr {
|
|
||||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 8080}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStartProxy_NilGuards(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
StartProxy(nil, make(chan model.Event, 1))
|
|
||||||
StartProxy(&model.ProxyServer{}, make(chan model.Event, 1))
|
|
||||||
StartProxy(&model.ProxyServer{Server: &http.Server{}}, make(chan model.Event, 1))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStartProxy_EmitsStartingAndWarnWhenUntrustedCA(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
listenErr := errors.New("accept failed")
|
|
||||||
var ln net.Listener = &staticErrListener{err: listenErr}
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
ps := &model.ProxyServer{
|
|
||||||
Listener: &ln,
|
|
||||||
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
|
|
||||||
CAReady: true,
|
|
||||||
CATrusted: false,
|
|
||||||
CACreated: true,
|
|
||||||
CACertPath: "/tmp/test-ca.pem",
|
|
||||||
}
|
|
||||||
|
|
||||||
StartProxy(ps, ch)
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 3, time.Second)
|
|
||||||
if !hasType(events, model.EventTypeProxyStarting) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
|
|
||||||
}
|
|
||||||
if !hasType(events, model.EventTypeWarn) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeWarn)
|
|
||||||
}
|
|
||||||
if !containsBody(events, "generated HTTPS interception CA") {
|
|
||||||
t.Fatalf("expected generated-CA warning body, got events: %#v", events)
|
|
||||||
}
|
|
||||||
if !hasType(events, model.EventTypeFatal) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeFatal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStartProxy_WarnBodyForExistingUntrustedCA(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
listenErr := errors.New("accept failed")
|
|
||||||
var ln net.Listener = &staticErrListener{err: listenErr}
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
ps := &model.ProxyServer{
|
|
||||||
Listener: &ln,
|
|
||||||
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
|
|
||||||
CAReady: true,
|
|
||||||
CATrusted: false,
|
|
||||||
CACreated: false,
|
|
||||||
CACertPath: "/tmp/test-ca.pem",
|
|
||||||
}
|
|
||||||
|
|
||||||
StartProxy(ps, ch)
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 3, time.Second)
|
|
||||||
if !containsBody(events, "HTTPS interception CA available at") {
|
|
||||||
t.Fatalf("expected existing-CA warning body, got events: %#v", events)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStartProxy_NoCAWarningWhenNotReady(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
listenErr := errors.New("accept failed")
|
|
||||||
var ln net.Listener = &staticErrListener{err: listenErr}
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
ps := &model.ProxyServer{
|
|
||||||
Listener: &ln,
|
|
||||||
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
|
|
||||||
CAReady: false,
|
|
||||||
CATrusted: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
StartProxy(ps, ch)
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, time.Second)
|
|
||||||
if !hasType(events, model.EventTypeProxyStarting) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
|
|
||||||
}
|
|
||||||
if hasType(events, model.EventTypeWarn) {
|
|
||||||
t.Fatalf("unexpected %s event when CA is not ready: %#v", model.EventTypeWarn, events)
|
|
||||||
}
|
|
||||||
if !hasType(events, model.EventTypeFatal) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeFatal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStartProxy_NoCAWarningWhenTrusted(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
listenErr := errors.New("accept failed")
|
|
||||||
var ln net.Listener = &staticErrListener{err: listenErr}
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
ps := &model.ProxyServer{
|
|
||||||
Listener: &ln,
|
|
||||||
Server: &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})},
|
|
||||||
CAReady: true,
|
|
||||||
CATrusted: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
StartProxy(ps, ch)
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, time.Second)
|
|
||||||
if !hasType(events, model.EventTypeProxyStarting) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
|
|
||||||
}
|
|
||||||
if hasType(events, model.EventTypeWarn) {
|
|
||||||
t.Fatalf("unexpected %s event when CA is already trusted: %#v", model.EventTypeWarn, events)
|
|
||||||
}
|
|
||||||
if !hasType(events, model.EventTypeFatal) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeFatal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStartProxy_SwallowsErrServerClosed(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("listen error = %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() { _ = ln.Close() })
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
server := &http.Server{Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})}
|
|
||||||
ps := &model.ProxyServer{Listener: &ln, Server: server}
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
StartProxy(ps, ch)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
|
||||||
t.Fatalf("shutdown error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("timeout waiting for StartProxy to return")
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 1, time.Second)
|
|
||||||
if !hasType(events, model.EventTypeProxyStarting) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeProxyStarting)
|
|
||||||
}
|
|
||||||
if hasType(events, model.EventTypeFatal) {
|
|
||||||
t.Fatalf("unexpected %s event", model.EventTypeFatal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,247 +0,0 @@
|
|||||||
package app
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStartSession(t *testing.T) {
|
|
||||||
t.Run("happy path creates proxy and process", func(t *testing.T) {
|
|
||||||
addr := freeTCPAddr(t)
|
|
||||||
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("StartSession() error = %v", err)
|
|
||||||
}
|
|
||||||
if s == nil {
|
|
||||||
t.Fatal("StartSession() returned nil session")
|
|
||||||
}
|
|
||||||
if s.proxy == nil {
|
|
||||||
t.Fatal("session.proxy is nil")
|
|
||||||
}
|
|
||||||
if s.proc == nil {
|
|
||||||
t.Fatal("session.proc is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
s.Stop()
|
|
||||||
if s.proc != nil && s.proc.Done != nil {
|
|
||||||
select {
|
|
||||||
case <-s.proc.Done:
|
|
||||||
case <-time.After(4 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for process stop")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("error when proxy creation fails", func(t *testing.T) {
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", "")
|
|
||||||
t.Setenv("HOME", "")
|
|
||||||
|
|
||||||
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "true"}}, "127.0.0.1:0")
|
|
||||||
if err == nil {
|
|
||||||
if s != nil {
|
|
||||||
s.Stop()
|
|
||||||
}
|
|
||||||
t.Fatal("StartSession() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
if s != nil {
|
|
||||||
t.Fatalf("session = %#v, want nil", s)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("destroys proxy when process startup fails", func(t *testing.T) {
|
|
||||||
addr := freeTCPAddr(t)
|
|
||||||
|
|
||||||
s, err := StartSession(model.Command{Name: "definitely-not-a-real-command"}, addr)
|
|
||||||
if err == nil {
|
|
||||||
if s != nil {
|
|
||||||
s.Stop()
|
|
||||||
}
|
|
||||||
t.Fatal("StartSession() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
deadline := time.After(3 * time.Second)
|
|
||||||
ticker := time.NewTicker(25 * time.Millisecond)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
ln, listenErr := net.Listen("tcp", addr)
|
|
||||||
if listenErr == nil {
|
|
||||||
_ = ln.Close()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-deadline:
|
|
||||||
t.Fatalf("address %s did not become reusable, got err: %v", addr, listenErr)
|
|
||||||
case <-ticker.C:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("stop during restart stops new process and returns ErrSessionStopped", func(t *testing.T) {
|
|
||||||
s := &Session{
|
|
||||||
ch: make(chan model.Event, 64),
|
|
||||||
cmd: model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}},
|
|
||||||
addr: "127.0.0.1:0",
|
|
||||||
proc: &model.Process{Done: make(chan struct{})},
|
|
||||||
}
|
|
||||||
close(s.proc.Done)
|
|
||||||
|
|
||||||
s.restartMu.Lock()
|
|
||||||
errCh := make(chan error, 1)
|
|
||||||
stopDone := make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
errCh <- s.RestartProcess()
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
s.Stop()
|
|
||||||
close(stopDone)
|
|
||||||
}()
|
|
||||||
|
|
||||||
s.restartMu.Unlock()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-errCh:
|
|
||||||
if !errors.Is(err, ErrSessionStopped) {
|
|
||||||
t.Fatalf("error = %v, want %v", err, ErrSessionStopped)
|
|
||||||
}
|
|
||||||
case <-time.After(6 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for RestartProcess result")
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-stopDone:
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("timeout waiting for Stop completion")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSessionStop_Idempotent(t *testing.T) {
|
|
||||||
addr := freeTCPAddr(t)
|
|
||||||
s, err := StartSession(model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}}, addr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("StartSession() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.Stop()
|
|
||||||
s.Stop()
|
|
||||||
|
|
||||||
if s.proc != nil && s.proc.Done != nil {
|
|
||||||
select {
|
|
||||||
case <-s.proc.Done:
|
|
||||||
case <-time.After(4 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for process to stop")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRestartProcess(t *testing.T) {
|
|
||||||
t.Run("nil session returns error", func(t *testing.T) {
|
|
||||||
var s *Session
|
|
||||||
err := s.RestartProcess()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("RestartProcess() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("stopped session returns ErrSessionStopped", func(t *testing.T) {
|
|
||||||
s := &Session{stopped: true}
|
|
||||||
err := s.RestartProcess()
|
|
||||||
if !errors.Is(err, ErrSessionStopped) {
|
|
||||||
t.Fatalf("error = %v, want %v", err, ErrSessionStopped)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("concurrent restart returns ErrRestartInProgress", func(t *testing.T) {
|
|
||||||
s := &Session{restarting: true}
|
|
||||||
err := s.RestartProcess()
|
|
||||||
if !errors.Is(err, ErrRestartInProgress) {
|
|
||||||
t.Fatalf("error = %v, want %v", err, ErrRestartInProgress)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("timeout waiting for stop returns error", func(t *testing.T) {
|
|
||||||
s := &Session{
|
|
||||||
ch: make(chan model.Event, 8),
|
|
||||||
cmd: model.Command{Name: "sh", Args: []string{"-c", "true"}},
|
|
||||||
addr: "127.0.0.1:0",
|
|
||||||
proc: &model.Process{Done: make(chan struct{})},
|
|
||||||
}
|
|
||||||
|
|
||||||
err := s.RestartProcess()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("RestartProcess() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("successful restart swaps process", func(t *testing.T) {
|
|
||||||
addr := freeTCPAddr(t)
|
|
||||||
oldProc := &model.Process{Done: make(chan struct{})}
|
|
||||||
close(oldProc.Done)
|
|
||||||
|
|
||||||
s := &Session{
|
|
||||||
ch: make(chan model.Event, 32),
|
|
||||||
cmd: model.Command{Name: "sh", Args: []string{"-c", "sleep 5"}},
|
|
||||||
addr: addr,
|
|
||||||
proc: oldProc,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := s.RestartProcess()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RestartProcess() error = %v", err)
|
|
||||||
}
|
|
||||||
if _, ok := waitForEventType(t, s.ch, model.EventTypeProcessRestarting, time.Second); !ok {
|
|
||||||
t.Fatalf("did not receive %s event", model.EventTypeProcessRestarting)
|
|
||||||
}
|
|
||||||
if s.proc == nil {
|
|
||||||
t.Fatal("session process is nil after restart")
|
|
||||||
}
|
|
||||||
if s.proc == oldProc {
|
|
||||||
t.Fatal("session process pointer was not swapped")
|
|
||||||
}
|
|
||||||
|
|
||||||
StopProcess(s.proc, s.ch, syscall.SIGTERM)
|
|
||||||
select {
|
|
||||||
case <-s.proc.Done:
|
|
||||||
case <-time.After(4 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for restarted process to stop")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWaitForProcessStop(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("nil process and nil done are true", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
if !waitForProcessStop(nil, time.Millisecond) {
|
|
||||||
t.Fatal("waitForProcessStop(nil) = false, want true")
|
|
||||||
}
|
|
||||||
if !waitForProcessStop(&model.Process{}, time.Millisecond) {
|
|
||||||
t.Fatal("waitForProcessStop(process with nil done) = false, want true")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("done channel closes before timeout", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
proc := &model.Process{Done: make(chan struct{})}
|
|
||||||
close(proc.Done)
|
|
||||||
|
|
||||||
if !waitForProcessStop(proc, time.Second) {
|
|
||||||
t.Fatal("waitForProcessStop() = false, want true")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("timeout returns false", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
proc := &model.Process{Done: make(chan struct{})}
|
|
||||||
if waitForProcessStop(proc, 20*time.Millisecond) {
|
|
||||||
t.Fatal("waitForProcessStop() = true, want false")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -1,111 +0,0 @@
|
|||||||
package app
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
out := make([]model.Event, 0, n)
|
|
||||||
deadline := time.After(timeout)
|
|
||||||
for len(out) < n {
|
|
||||||
select {
|
|
||||||
case ev := <-ch:
|
|
||||||
out = append(out, ev)
|
|
||||||
case <-deadline:
|
|
||||||
t.Fatalf("timeout waiting for %d events, got %d", n, len(out))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasType(events []model.Event, typ model.EventType) bool {
|
|
||||||
for _, ev := range events {
|
|
||||||
if ev.Type == typ {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func containsBody(events []model.Event, part string) bool {
|
|
||||||
for _, ev := range events {
|
|
||||||
if strings.Contains(ev.Body, part) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitForEventType(t *testing.T, ch <-chan model.Event, typ model.EventType, timeout time.Duration) (model.Event, bool) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
deadline := time.After(timeout)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case ev := <-ch:
|
|
||||||
if ev.Type == typ {
|
|
||||||
return ev, true
|
|
||||||
}
|
|
||||||
case <-deadline:
|
|
||||||
return model.Event{}, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func collectUntilTypes(t *testing.T, ch <-chan model.Event, required []model.EventType, timeout time.Duration) []model.Event {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
need := make(map[model.EventType]bool, len(required))
|
|
||||||
for _, typ := range required {
|
|
||||||
need[typ] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
events := make([]model.Event, 0, len(required)+8)
|
|
||||||
deadline := time.After(timeout)
|
|
||||||
for len(need) > 0 {
|
|
||||||
select {
|
|
||||||
case ev := <-ch:
|
|
||||||
events = append(events, ev)
|
|
||||||
delete(need, ev.Type)
|
|
||||||
case <-deadline:
|
|
||||||
t.Fatalf("timeout waiting for required events: remaining=%v, collected=%#v", need, events)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return events
|
|
||||||
}
|
|
||||||
|
|
||||||
func isBefore(events []model.Event, first, second model.EventType) bool {
|
|
||||||
firstIdx := -1
|
|
||||||
secondIdx := -1
|
|
||||||
for i, ev := range events {
|
|
||||||
if ev.Type == first && firstIdx == -1 {
|
|
||||||
firstIdx = i
|
|
||||||
}
|
|
||||||
if ev.Type == second && secondIdx == -1 {
|
|
||||||
secondIdx = i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return firstIdx >= 0 && secondIdx >= 0 && firstIdx < secondIdx
|
|
||||||
}
|
|
||||||
|
|
||||||
func freeTCPAddr(t *testing.T) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("listen error = %v", err)
|
|
||||||
}
|
|
||||||
addr := ln.Addr().String()
|
|
||||||
_ = ln.Close()
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
@ -2,7 +2,6 @@ package cli
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
@ -17,23 +16,6 @@ import (
|
|||||||
// This should be configurable at some point, just in case they build on 8080
|
// This should be configurable at some point, just in case they build on 8080
|
||||||
const proxy_addr = "127.0.0.1:8080"
|
const proxy_addr = "127.0.0.1:8080"
|
||||||
|
|
||||||
var fatalExit = log.Fatalln
|
|
||||||
var stdoutWriter io.Writer = stdioRef{isErr: false}
|
|
||||||
var stderrWriter io.Writer = stdioRef{isErr: true}
|
|
||||||
var startSessionFn = app.StartSession
|
|
||||||
var runTUIFn = tui.Run
|
|
||||||
|
|
||||||
type stdioRef struct {
|
|
||||||
isErr bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w stdioRef) Write(p []byte) (int, error) {
|
|
||||||
if w.isErr {
|
|
||||||
return os.Stderr.Write(p)
|
|
||||||
}
|
|
||||||
return os.Stdout.Write(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Run(args []string) {
|
func Run(args []string) {
|
||||||
if len(args) >= 2 && args[1] == "cert" {
|
if len(args) >= 2 && args[1] == "cert" {
|
||||||
runCert()
|
runCert()
|
||||||
@ -46,10 +28,9 @@ func Run(args []string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := startSessionFn(cmd, proxy_addr)
|
session, err := app.StartSession(cmd, proxy_addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fatalExit(err)
|
log.Fatalln(err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
defer session.Stop()
|
defer session.Stop()
|
||||||
|
|
||||||
@ -57,9 +38,8 @@ func Run(args []string) {
|
|||||||
Restart: session.RestartProcess,
|
Restart: session.RestartProcess,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := runTUIFn(session.Events, controls); err != nil {
|
if err := tui.Run(session.Events, controls); err != nil {
|
||||||
fatalExit(err)
|
log.Fatalln(err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,50 +67,49 @@ usage:
|
|||||||
tap run -- <command> [args...]
|
tap run -- <command> [args...]
|
||||||
`
|
`
|
||||||
|
|
||||||
fmt.Fprintln(stderrWriter, helpText)
|
fmt.Fprintln(os.Stderr, helpText)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runCert() {
|
func runCert() {
|
||||||
ca, err := proxy.EnsureCertificateAuthority()
|
ca, err := proxy.EnsureCertificateAuthority()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fatalExit(err)
|
log.Fatalln(err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
certPath := ca.CertPath()
|
certPath := ca.CertPath()
|
||||||
quotedCertPath := strconv.Quote(certPath)
|
quotedCertPath := strconv.Quote(certPath)
|
||||||
fmt.Fprintf(stdoutWriter, "Certificate path: %s\n", certPath)
|
fmt.Printf("Certificate path: %s\n", certPath)
|
||||||
if ca.WasCreated() {
|
if ca.WasCreated() {
|
||||||
fmt.Fprintln(stdoutWriter, "Created a new local HTTPS interception CA.")
|
fmt.Println("Created a new local HTTPS interception CA.")
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintln(stdoutWriter, "Using existing local HTTPS interception CA.")
|
fmt.Println("Using existing local HTTPS interception CA.")
|
||||||
}
|
}
|
||||||
|
|
||||||
trusted, err := ca.IsTrustedBySystem()
|
trusted, err := ca.IsTrustedBySystem()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(stdoutWriter, "System trust check failed: %v\n", err)
|
fmt.Printf("System trust check failed: %v\n", err)
|
||||||
} else if trusted {
|
} else if trusted {
|
||||||
fmt.Fprintln(stdoutWriter, "System trust store: trusted")
|
fmt.Println("System trust store: trusted")
|
||||||
} else {
|
} else {
|
||||||
fmt.Fprintln(stdoutWriter, "System trust store: not trusted")
|
fmt.Println("System trust store: not trusted")
|
||||||
}
|
}
|
||||||
|
|
||||||
if runtime.GOOS != "linux" {
|
if runtime.GOOS != "linux" {
|
||||||
fmt.Fprintln(stdoutWriter, "Install this certificate into your OS or client trust store to inspect HTTPS traffic.")
|
fmt.Println("Install this certificate into your OS or client trust store to inspect HTTPS traffic.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintln(stdoutWriter)
|
fmt.Println()
|
||||||
fmt.Fprintln(stdoutWriter, "Trust instructions (Linux):")
|
fmt.Println("Trust instructions (Linux):")
|
||||||
fmt.Fprintln(stdoutWriter, "Debian/Ubuntu:")
|
fmt.Println("Debian/Ubuntu:")
|
||||||
fmt.Fprintf(stdoutWriter, " sudo cp %s /usr/local/share/ca-certificates/termtap.crt\n", quotedCertPath)
|
fmt.Printf(" sudo cp %s /usr/local/share/ca-certificates/termtap.crt\n", quotedCertPath)
|
||||||
fmt.Fprintln(stdoutWriter, " sudo update-ca-certificates")
|
fmt.Println(" sudo update-ca-certificates")
|
||||||
fmt.Fprintln(stdoutWriter, "Fedora/RHEL/CentOS:")
|
fmt.Println("Fedora/RHEL/CentOS:")
|
||||||
fmt.Fprintf(stdoutWriter, " sudo cp %s /etc/pki/ca-trust/source/anchors/termtap.crt\n", quotedCertPath)
|
fmt.Printf(" sudo cp %s /etc/pki/ca-trust/source/anchors/termtap.crt\n", quotedCertPath)
|
||||||
fmt.Fprintln(stdoutWriter, " sudo update-ca-trust")
|
fmt.Println(" sudo update-ca-trust")
|
||||||
fmt.Fprintln(stdoutWriter, "Arch:")
|
fmt.Println("Arch:")
|
||||||
fmt.Fprintf(stdoutWriter, " sudo trust anchor %s\n", quotedCertPath)
|
fmt.Printf(" sudo trust anchor %s\n", quotedCertPath)
|
||||||
fmt.Fprintln(stdoutWriter)
|
fmt.Println()
|
||||||
fmt.Fprintln(stdoutWriter, "Quick curl test:")
|
fmt.Println("Quick curl test:")
|
||||||
fmt.Fprintf(stdoutWriter, " curl --proxy http://%s --cacert %s https://example.com\n", proxy_addr, quotedCertPath)
|
fmt.Printf(" curl --proxy http://%s --cacert %s https://example.com\n", proxy_addr, quotedCertPath)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,356 +0,0 @@
|
|||||||
package cli
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"termtap.dev/internal/app"
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
"termtap.dev/internal/tui"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestParseCommand(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args []string
|
|
||||||
ok bool
|
|
||||||
nameWant string
|
|
||||||
argsWant []string
|
|
||||||
}{
|
|
||||||
{name: "too few args", args: []string{"tap"}, ok: false},
|
|
||||||
{name: "missing run token", args: []string{"tap", "oops", "--", "echo"}, ok: false},
|
|
||||||
{name: "missing separator", args: []string{"tap", "run", "echo"}, ok: false},
|
|
||||||
{name: "single command", args: []string{"tap", "run", "--", "echo"}, ok: true, nameWant: "echo", argsWant: []string{}},
|
|
||||||
{name: "command with args", args: []string{"tap", "run", "--", "curl", "-s", "https://example.com"}, ok: true, nameWant: "curl", argsWant: []string{"-s", "https://example.com"}},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cmd, ok := parseCommand(tt.args)
|
|
||||||
if ok != tt.ok {
|
|
||||||
t.Fatalf("ok = %v, want %v", ok, tt.ok)
|
|
||||||
}
|
|
||||||
if !tt.ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if cmd.Name != tt.nameWant {
|
|
||||||
t.Fatalf("cmd.Name = %q, want %q", cmd.Name, tt.nameWant)
|
|
||||||
}
|
|
||||||
if strings.Join(cmd.Args, "|") != strings.Join(tt.argsWant, "|") {
|
|
||||||
t.Fatalf("cmd.Args = %#v, want %#v", cmd.Args, tt.argsWant)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDisplayHelpWritesToStderr(t *testing.T) {
|
|
||||||
_, stderr := captureOutput(t, func() {
|
|
||||||
displayHelp()
|
|
||||||
})
|
|
||||||
|
|
||||||
if !strings.Contains(stderr, "tap cert") || !strings.Contains(stderr, "tap run --") {
|
|
||||||
t.Fatalf("stderr missing usage text: %q", stderr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun_InvalidCommandShowsHelp(t *testing.T) {
|
|
||||||
_, stderr := captureOutput(t, func() {
|
|
||||||
Run([]string{"tap", "wat"})
|
|
||||||
})
|
|
||||||
|
|
||||||
if !strings.Contains(stderr, "usage:") {
|
|
||||||
t.Fatalf("stderr missing usage output: %q", stderr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun_RoutesCertCommand(t *testing.T) {
|
|
||||||
configRoot := t.TempDir()
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
||||||
|
|
||||||
stdout, _ := captureOutput(t, func() {
|
|
||||||
Run([]string{"tap", "cert"})
|
|
||||||
})
|
|
||||||
|
|
||||||
if !strings.Contains(stdout, "Certificate path:") {
|
|
||||||
t.Fatalf("stdout missing certificate path output: %q", stdout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun_StartSessionFailureCallsFatalExit(t *testing.T) {
|
|
||||||
restore := installRunSeams(t)
|
|
||||||
defer restore()
|
|
||||||
|
|
||||||
startSessionFn = func(model.Command, string) (*app.Session, error) {
|
|
||||||
return nil, errors.New("boom")
|
|
||||||
}
|
|
||||||
called := installFatalSpy(t)
|
|
||||||
|
|
||||||
Run([]string{"tap", "run", "--", "definitely-not-a-real-command"})
|
|
||||||
if !*called {
|
|
||||||
t.Fatal("expected fatalExit to be called on StartSession failure")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun_TUIFailureCallsFatalExit(t *testing.T) {
|
|
||||||
restore := installRunSeams(t)
|
|
||||||
defer restore()
|
|
||||||
|
|
||||||
startSessionFn = func(model.Command, string) (*app.Session, error) {
|
|
||||||
return &app.Session{Events: make(chan model.Event)}, nil
|
|
||||||
}
|
|
||||||
runTUIFn = func(<-chan model.Event, tui.Controls) error {
|
|
||||||
return errors.New("tui failed")
|
|
||||||
}
|
|
||||||
called := installFatalSpy(t)
|
|
||||||
|
|
||||||
Run([]string{"tap", "run", "--", "echo"})
|
|
||||||
if !*called {
|
|
||||||
t.Fatal("expected fatalExit to be called on tui failure")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRun_SuccessPathDoesNotCallFatal(t *testing.T) {
|
|
||||||
restore := installRunSeams(t)
|
|
||||||
defer restore()
|
|
||||||
|
|
||||||
startSessionFn = func(model.Command, string) (*app.Session, error) {
|
|
||||||
return &app.Session{Events: make(chan model.Event)}, nil
|
|
||||||
}
|
|
||||||
runTUIFn = func(<-chan model.Event, tui.Controls) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
called := installFatalSpy(t)
|
|
||||||
|
|
||||||
Run([]string{"tap", "run", "--", "echo"})
|
|
||||||
if *called {
|
|
||||||
t.Fatal("fatalExit should not be called on success path")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunCert_EnsureCAFailureCallsFatalExit(t *testing.T) {
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", "")
|
|
||||||
t.Setenv("HOME", "")
|
|
||||||
called := installFatalSpy(t)
|
|
||||||
|
|
||||||
runCert()
|
|
||||||
if !*called {
|
|
||||||
t.Fatal("expected fatalExit to be called when EnsureCertificateAuthority fails")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunCertOutputContract(t *testing.T) {
|
|
||||||
configRoot := t.TempDir()
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
||||||
|
|
||||||
stdout, _ := captureOutput(t, func() {
|
|
||||||
runCert()
|
|
||||||
})
|
|
||||||
|
|
||||||
if !strings.Contains(stdout, "Certificate path:") {
|
|
||||||
t.Fatalf("stdout missing certificate path line: %q", stdout)
|
|
||||||
}
|
|
||||||
if !strings.Contains(stdout, "local HTTPS interception CA") {
|
|
||||||
t.Fatalf("stdout missing CA create/existing line: %q", stdout)
|
|
||||||
}
|
|
||||||
if !strings.Contains(stdout, "System trust store:") && !strings.Contains(stdout, "System trust check failed:") {
|
|
||||||
t.Fatalf("stdout missing trust check line: %q", stdout)
|
|
||||||
}
|
|
||||||
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
if !strings.Contains(stdout, "Trust instructions (Linux):") {
|
|
||||||
t.Fatalf("stdout missing linux trust instructions: %q", stdout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunCert_CreatedThenExistingMessage(t *testing.T) {
|
|
||||||
configRoot := t.TempDir()
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
||||||
|
|
||||||
firstOut, _ := captureOutput(t, func() {
|
|
||||||
runCert()
|
|
||||||
})
|
|
||||||
if !strings.Contains(firstOut, "Created a new local HTTPS interception CA.") {
|
|
||||||
t.Fatalf("first run should indicate created CA, got: %q", firstOut)
|
|
||||||
}
|
|
||||||
|
|
||||||
secondOut, _ := captureOutput(t, func() {
|
|
||||||
runCert()
|
|
||||||
})
|
|
||||||
if !strings.Contains(secondOut, "Using existing local HTTPS interception CA.") {
|
|
||||||
t.Fatalf("second run should indicate existing CA, got: %q", secondOut)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func captureOutput(t *testing.T, fn func()) (stdout string, stderr string) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
origStdoutWriter := stdoutWriter
|
|
||||||
origStderrWriter := stderrWriter
|
|
||||||
t.Cleanup(func() {
|
|
||||||
stdoutWriter = origStdoutWriter
|
|
||||||
stderrWriter = origStderrWriter
|
|
||||||
})
|
|
||||||
|
|
||||||
origStdout := os.Stdout
|
|
||||||
origStderr := os.Stderr
|
|
||||||
|
|
||||||
outR, outW, err := os.Pipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("stdout pipe error: %v", err)
|
|
||||||
}
|
|
||||||
errR, errW, err := os.Pipe()
|
|
||||||
if err != nil {
|
|
||||||
_ = outR.Close()
|
|
||||||
_ = outW.Close()
|
|
||||||
t.Fatalf("stderr pipe error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
os.Stdout = outW
|
|
||||||
os.Stderr = errW
|
|
||||||
stdoutWriter = outW
|
|
||||||
stderrWriter = errW
|
|
||||||
|
|
||||||
outCh := make(chan string, 1)
|
|
||||||
errCh := make(chan string, 1)
|
|
||||||
go func() {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
_, _ = io.Copy(&buf, outR)
|
|
||||||
outCh <- buf.String()
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
_, _ = io.Copy(&buf, errR)
|
|
||||||
errCh <- buf.String()
|
|
||||||
}()
|
|
||||||
|
|
||||||
fn()
|
|
||||||
|
|
||||||
_ = outW.Close()
|
|
||||||
_ = errW.Close()
|
|
||||||
stdoutWriter = origStdoutWriter
|
|
||||||
stderrWriter = origStderrWriter
|
|
||||||
os.Stdout = origStdout
|
|
||||||
os.Stderr = origStderr
|
|
||||||
|
|
||||||
select {
|
|
||||||
case stdout = <-outCh:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for stdout capture")
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case stderr = <-errCh:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for stderr capture")
|
|
||||||
}
|
|
||||||
_ = outR.Close()
|
|
||||||
_ = errR.Close()
|
|
||||||
|
|
||||||
return stdout, stderr
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDisplayHelp_UsesInjectedStderrWriter(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
orig := stderrWriter
|
|
||||||
t.Cleanup(func() { stderrWriter = orig })
|
|
||||||
stderrWriter = &buf
|
|
||||||
|
|
||||||
displayHelp()
|
|
||||||
|
|
||||||
if got := buf.String(); !strings.Contains(got, "usage:") {
|
|
||||||
t.Fatalf("help output missing usage, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRunCert_UsesInjectedStdoutWriter(t *testing.T) {
|
|
||||||
configRoot := t.TempDir()
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
orig := stdoutWriter
|
|
||||||
t.Cleanup(func() { stdoutWriter = orig })
|
|
||||||
stdoutWriter = &buf
|
|
||||||
|
|
||||||
runCert()
|
|
||||||
|
|
||||||
if got := buf.String(); !strings.Contains(got, "Certificate path:") {
|
|
||||||
t.Fatalf("cert output missing path line, got: %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func installRunSeams(t *testing.T) func() {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
origStartSession := startSessionFn
|
|
||||||
origRunTUI := runTUIFn
|
|
||||||
return func() {
|
|
||||||
startSessionFn = origStartSession
|
|
||||||
runTUIFn = origRunTUI
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func installFatalSpy(t *testing.T) *bool {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
origFatal := fatalExit
|
|
||||||
called := false
|
|
||||||
fatalExit = func(v ...any) {
|
|
||||||
called = true
|
|
||||||
}
|
|
||||||
t.Cleanup(func() {
|
|
||||||
fatalExit = origFatal
|
|
||||||
})
|
|
||||||
|
|
||||||
return &called
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStdioRefWrite(t *testing.T) {
|
|
||||||
t.Run("writes to stdout", func(t *testing.T) {
|
|
||||||
assertStdioRefWrite(t, false, "hello")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("writes to stderr", func(t *testing.T) {
|
|
||||||
assertStdioRefWrite(t, true, "boom")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertStdioRefWrite(t *testing.T, isErr bool, payload string) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
r, w, err := os.Pipe()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("pipe error: %v", err)
|
|
||||||
}
|
|
||||||
defer func() { _ = r.Close() }()
|
|
||||||
|
|
||||||
if isErr {
|
|
||||||
orig := os.Stderr
|
|
||||||
os.Stderr = w
|
|
||||||
defer func() { os.Stderr = orig }()
|
|
||||||
} else {
|
|
||||||
orig := os.Stdout
|
|
||||||
os.Stdout = w
|
|
||||||
defer func() { os.Stdout = orig }()
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := (stdioRef{isErr: isErr}).Write([]byte(payload)); err != nil {
|
|
||||||
_ = w.Close()
|
|
||||||
t.Fatalf("stdioRef write error: %v", err)
|
|
||||||
}
|
|
||||||
_ = w.Close()
|
|
||||||
|
|
||||||
data, err := io.ReadAll(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ReadAll(pipe) error: %v", err)
|
|
||||||
}
|
|
||||||
if got := string(data); got != payload {
|
|
||||||
t.Fatalf("pipe write = %q, want %q", got, payload)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,219 +0,0 @@
|
|||||||
package process
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCommandString(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
cmd model.Command
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty args",
|
|
||||||
cmd: model.Command{Name: "go", Args: []string{}},
|
|
||||||
want: "go ",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple args",
|
|
||||||
cmd: model.Command{Name: "curl", Args: []string{"-s", "https://example.com"}},
|
|
||||||
want: "curl -s https://example.com",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
if got := CommandString(tt.cmd); got != tt.want {
|
|
||||||
t.Fatalf("CommandString() = %q, want %q", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInjectEnv(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cmd := exec.Command("sh", "-c", "true")
|
|
||||||
injectEnv(cmd, "127.0.0.1:8080")
|
|
||||||
|
|
||||||
mustContain := []string{
|
|
||||||
"HTTP_PROXY=http://127.0.0.1:8080",
|
|
||||||
"http_proxy=http://127.0.0.1:8080",
|
|
||||||
"HTTPS_PROXY=http://127.0.0.1:8080",
|
|
||||||
"https_proxy=http://127.0.0.1:8080",
|
|
||||||
"NO_PROXY=",
|
|
||||||
"no_proxy=",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, kv := range mustContain {
|
|
||||||
if !containsEnvEntry(cmd.Env, kv) {
|
|
||||||
t.Fatalf("injectEnv() missing env entry %q", kv)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadPipe(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("stdout lines emit stdout events", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 4)
|
|
||||||
input := strings.NewReader("line1\nline2\n")
|
|
||||||
|
|
||||||
readPipe(input, model.EventTypeProcessStdout, ch)
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, time.Second)
|
|
||||||
if events[0].Type != model.EventTypeProcessStdout || events[0].Body != "line1" {
|
|
||||||
t.Fatalf("event[0] = %#v, want stdout line1", events[0])
|
|
||||||
}
|
|
||||||
if events[1].Type != model.EventTypeProcessStdout || events[1].Body != "line2" {
|
|
||||||
t.Fatalf("event[1] = %#v, want stdout line2", events[1])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("stderr lines emit stderr events", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 4)
|
|
||||||
input := strings.NewReader("err1\nerr2\n")
|
|
||||||
|
|
||||||
readPipe(input, model.EventTypeProcessStderr, ch)
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, time.Second)
|
|
||||||
if events[0].Type != model.EventTypeProcessStderr || events[0].Body != "err1" {
|
|
||||||
t.Fatalf("event[0] = %#v, want stderr err1", events[0])
|
|
||||||
}
|
|
||||||
if events[1].Type != model.EventTypeProcessStderr || events[1].Body != "err2" {
|
|
||||||
t.Fatalf("event[1] = %#v, want stderr err2", events[1])
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateStatus(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("nil process is no-op", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
UpdateStatus(nil, true, make(chan model.Event, 1))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("state unchanged is no-op", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 1)
|
|
||||||
proc := &model.Process{Running: true}
|
|
||||||
UpdateStatus(proc, true, ch)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case ev := <-ch:
|
|
||||||
t.Fatalf("unexpected event for unchanged state: %#v", ev)
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("emits started and stopped events with pid", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 2)
|
|
||||||
proc := &model.Process{
|
|
||||||
Exec: &exec.Cmd{Process: &os.Process{Pid: 4321}},
|
|
||||||
}
|
|
||||||
|
|
||||||
UpdateStatus(proc, true, ch)
|
|
||||||
events := drainEvents(t, ch, 1, time.Second)
|
|
||||||
started := events[0]
|
|
||||||
if started.Type != model.EventTypeProcessStarted {
|
|
||||||
t.Fatalf("started type = %s, want %s", started.Type, model.EventTypeProcessStarted)
|
|
||||||
}
|
|
||||||
if started.PID != 4321 {
|
|
||||||
t.Fatalf("started PID = %d, want %d", started.PID, 4321)
|
|
||||||
}
|
|
||||||
if !proc.Running {
|
|
||||||
t.Fatal("proc.Running = false, want true")
|
|
||||||
}
|
|
||||||
|
|
||||||
UpdateStatus(proc, false, ch)
|
|
||||||
events = drainEvents(t, ch, 1, time.Second)
|
|
||||||
stopped := events[0]
|
|
||||||
if stopped.Type != model.EventTypeProcessExited {
|
|
||||||
t.Fatalf("stopped type = %s, want %s", stopped.Type, model.EventTypeProcessExited)
|
|
||||||
}
|
|
||||||
if stopped.PID != 4321 {
|
|
||||||
t.Fatalf("stopped PID = %d, want %d", stopped.PID, 4321)
|
|
||||||
}
|
|
||||||
if proc.Running {
|
|
||||||
t.Fatal("proc.Running = true, want false")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewProcess(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
cmd := model.Command{Name: "sh", Args: []string{"-c", "printf test"}}
|
|
||||||
|
|
||||||
proc := NewProcess(cmd, "127.0.0.1:8080", ch)
|
|
||||||
if proc == nil {
|
|
||||||
t.Fatal("NewProcess() returned nil")
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(proc.Command, cmd) {
|
|
||||||
t.Fatalf("process command = %#v, want %#v", proc.Command, cmd)
|
|
||||||
}
|
|
||||||
if proc.Exec == nil {
|
|
||||||
t.Fatal("process Exec is nil")
|
|
||||||
}
|
|
||||||
if proc.Running {
|
|
||||||
t.Fatal("new process should not be running")
|
|
||||||
}
|
|
||||||
if proc.Done == nil {
|
|
||||||
t.Fatal("Done channel is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := proc.Exec.Args[0], "sh"; got != want {
|
|
||||||
t.Fatalf("Exec.Args[0] = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
if !containsEnvEntry(proc.Exec.Env, "HTTP_PROXY=http://127.0.0.1:8080") {
|
|
||||||
t.Fatal("process env missing injected HTTP_PROXY")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func containsEnvEntry(env []string, want string) bool {
|
|
||||||
for _, entry := range env {
|
|
||||||
if entry == want {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
events := make([]model.Event, 0, n)
|
|
||||||
deadline := time.After(timeout)
|
|
||||||
for len(events) < n {
|
|
||||||
select {
|
|
||||||
case ev := <-ch:
|
|
||||||
events = append(events, ev)
|
|
||||||
case <-deadline:
|
|
||||||
t.Fatalf("timeout waiting for %d events; got %d", n, len(events))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return events
|
|
||||||
}
|
|
||||||
@ -1,147 +0,0 @@
|
|||||||
//go:build unix
|
|
||||||
|
|
||||||
package process
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type customSignal struct{}
|
|
||||||
|
|
||||||
func (customSignal) String() string { return "custom" }
|
|
||||||
func (customSignal) Signal() {}
|
|
||||||
|
|
||||||
// NOTE: Run these tests with -race in CI for signal/process safety.
|
|
||||||
|
|
||||||
func TestConfigureProcessForSignals(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cmd := exec.Command("sh", "-c", "sleep 0.1")
|
|
||||||
configureProcessForSignals(cmd)
|
|
||||||
|
|
||||||
if cmd.SysProcAttr == nil {
|
|
||||||
t.Fatal("SysProcAttr is nil")
|
|
||||||
}
|
|
||||||
if !cmd.SysProcAttr.Setpgid {
|
|
||||||
t.Fatal("Setpgid = false, want true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSignalProcess_NilSafe(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
if err := SignalProcess(nil, syscall.SIGTERM); err != nil {
|
|
||||||
t.Fatalf("SignalProcess(nil) error = %v, want nil", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := &exec.Cmd{}
|
|
||||||
if err := SignalProcess(cmd, syscall.SIGTERM); err != nil {
|
|
||||||
t.Fatalf("SignalProcess(cmd without process) error = %v, want nil", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.Process = &os.Process{Pid: 0}
|
|
||||||
if err := SignalProcess(cmd, syscall.SIGTERM); err != nil {
|
|
||||||
t.Fatalf("SignalProcess(pid<=0) error = %v, want nil", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSignalProcess_ESRCHIsTreatedAsSuccess(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cmd := &exec.Cmd{Process: &os.Process{Pid: 999999}}
|
|
||||||
if err := SignalProcess(cmd, syscall.SIGTERM); err != nil {
|
|
||||||
t.Fatalf("SignalProcess() error = %v, want nil when process group not found", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSignalProcess_UsesFallbackOnKillError(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cmd := exec.Command("sh", "-c", "sleep 5")
|
|
||||||
configureProcessForSignals(cmd)
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
t.Fatalf("start command error = %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() {
|
|
||||||
_ = cmd.Process.Kill()
|
|
||||||
_, _ = cmd.Process.Wait()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Invalid signal causes syscall.Kill to fail with EINVAL, then fallback to cmd.Process.Signal.
|
|
||||||
err := SignalProcess(cmd, syscall.Signal(9999))
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("SignalProcess() error = nil, want non-nil for invalid signal")
|
|
||||||
}
|
|
||||||
if !(errors.Is(err, syscall.EINVAL) || errors.Is(err, os.ErrProcessDone)) {
|
|
||||||
// OS/process timing can vary; ensure we at least failed predictably.
|
|
||||||
t.Fatalf("SignalProcess() unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSignalProcess_NonSyscallSignalUsesProcessSignal(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
cmd := exec.Command("sh", "-c", "sleep 1")
|
|
||||||
configureProcessForSignals(cmd)
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
t.Fatalf("start command error = %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() {
|
|
||||||
_ = cmd.Process.Kill()
|
|
||||||
_, _ = cmd.Process.Wait()
|
|
||||||
})
|
|
||||||
|
|
||||||
err := SignalProcess(cmd, customSignal{})
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("SignalProcess(custom signal) error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
if errors.Is(err, os.ErrProcessDone) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if msg := err.Error(); msg == "" {
|
|
||||||
t.Fatalf("unexpected empty error for custom signal: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProcessAlive(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
if ProcessAlive(nil) {
|
|
||||||
t.Fatal("ProcessAlive(nil) = true, want false")
|
|
||||||
}
|
|
||||||
if ProcessAlive(&exec.Cmd{}) {
|
|
||||||
t.Fatal("ProcessAlive(cmd without process) = true, want false")
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("sh", "-c", "sleep 0.2")
|
|
||||||
configureProcessForSignals(cmd)
|
|
||||||
if err := cmd.Start(); err != nil {
|
|
||||||
t.Fatalf("start command error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ProcessAlive(cmd) {
|
|
||||||
_ = cmd.Process.Kill()
|
|
||||||
_, _ = cmd.Process.Wait()
|
|
||||||
t.Fatal("ProcessAlive(running) = false, want true")
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = cmd.Process.Kill()
|
|
||||||
_, _ = cmd.Process.Wait()
|
|
||||||
|
|
||||||
deadline := time.After(time.Second)
|
|
||||||
for {
|
|
||||||
if !ProcessAlive(cmd) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case <-deadline:
|
|
||||||
t.Fatal("ProcessAlive(exited) stayed true")
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,242 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"encoding/pem"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLoadOrCreateCertificateAuthority_RecreatesWhenKeyMissing(t *testing.T) {
|
|
||||||
configRoot := t.TempDir()
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
||||||
|
|
||||||
baseDir := filepath.Join(configRoot, caDirName)
|
|
||||||
if err := os.MkdirAll(baseDir, 0o700); err != nil {
|
|
||||||
t.Fatalf("MkdirAll() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
certPath := filepath.Join(baseDir, caCertName)
|
|
||||||
if err := os.WriteFile(certPath, []byte("stale-cert"), 0o600); err != nil {
|
|
||||||
t.Fatalf("WriteFile(cert) error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ca, err := loadOrCreateCertificateAuthority()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("loadOrCreateCertificateAuthority() error = %v", err)
|
|
||||||
}
|
|
||||||
if !ca.WasCreated() {
|
|
||||||
t.Fatal("WasCreated = false, want true when key is missing")
|
|
||||||
}
|
|
||||||
if _, err := os.Stat(filepath.Join(baseDir, caKeyName)); err != nil {
|
|
||||||
t.Fatalf("expected key file to be created, stat error = %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLoadOrCreateCertificateAuthority_LoadErrorOnCorruptFiles(t *testing.T) {
|
|
||||||
configRoot := t.TempDir()
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
||||||
|
|
||||||
baseDir := filepath.Join(configRoot, caDirName)
|
|
||||||
if err := os.MkdirAll(baseDir, 0o700); err != nil {
|
|
||||||
t.Fatalf("MkdirAll() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
certPath := filepath.Join(baseDir, caCertName)
|
|
||||||
keyPath := filepath.Join(baseDir, caKeyName)
|
|
||||||
if err := os.WriteFile(certPath, []byte("not-a-pem"), 0o600); err != nil {
|
|
||||||
t.Fatalf("WriteFile(cert) error = %v", err)
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(keyPath, []byte("not-a-pem"), 0o600); err != nil {
|
|
||||||
t.Fatalf("WriteFile(key) error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := loadOrCreateCertificateAuthority()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("loadOrCreateCertificateAuthority() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertificateAuthorityLoad_ErrorPaths(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
certBytes []byte
|
|
||||||
keyBytes []byte
|
|
||||||
wantPart string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "invalid cert pem",
|
|
||||||
certBytes: []byte("bad-cert"),
|
|
||||||
keyBytes: []byte("bad-key"),
|
|
||||||
wantPart: "decode ca cert pem",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "parse cert fails",
|
|
||||||
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("bogus")}),
|
|
||||||
keyBytes: []byte("bad-key"),
|
|
||||||
wantPart: "parse ca cert",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid key pem",
|
|
||||||
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}),
|
|
||||||
keyBytes: []byte("bad-key"),
|
|
||||||
wantPart: "decode ca key pem",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "parse key fails",
|
|
||||||
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}),
|
|
||||||
keyBytes: pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: []byte("bogus")}),
|
|
||||||
wantPart: "parse ca key",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
ca := &CertificateAuthority{
|
|
||||||
certPath: filepath.Join(dir, caCertName),
|
|
||||||
keyPath: filepath.Join(dir, caKeyName),
|
|
||||||
leafCert: make(map[string]*tls.Certificate),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := os.WriteFile(ca.certPath, tt.certBytes, 0o600); err != nil {
|
|
||||||
t.Fatalf("write cert file error = %v", err)
|
|
||||||
}
|
|
||||||
if err := os.WriteFile(ca.keyPath, tt.keyBytes, 0o600); err != nil {
|
|
||||||
t.Fatalf("write key file error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ca.load()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("load() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), tt.wantPart) {
|
|
||||||
t.Fatalf("load() error = %q, want contains %q", err.Error(), tt.wantPart)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertificateAuthorityCreate_ErrorWhenWritePathInvalid(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ca := &CertificateAuthority{
|
|
||||||
certPath: filepath.Join("/nope", "missing", "ca-cert.pem"),
|
|
||||||
keyPath: filepath.Join("/nope", "missing", "ca-key.pem"),
|
|
||||||
leafCert: make(map[string]*tls.Certificate),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ca.create()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("create() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertificateAuthorityCreate_WriteErrorPaths(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("write ca cert wraps error", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
badCertPath := filepath.Join(dir, "cert-as-dir")
|
|
||||||
if err := os.MkdirAll(badCertPath, 0o700); err != nil {
|
|
||||||
t.Fatalf("MkdirAll(cert dir) error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ca := &CertificateAuthority{
|
|
||||||
certPath: badCertPath,
|
|
||||||
keyPath: filepath.Join(dir, "ca-key.pem"),
|
|
||||||
leafCert: make(map[string]*tls.Certificate),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ca.create()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("create() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "write ca cert") {
|
|
||||||
t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca cert")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("write ca key wraps error", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
badKeyPath := filepath.Join(dir, "key-as-dir")
|
|
||||||
if err := os.MkdirAll(badKeyPath, 0o700); err != nil {
|
|
||||||
t.Fatalf("MkdirAll(key dir) error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ca := &CertificateAuthority{
|
|
||||||
certPath: filepath.Join(dir, "ca-cert.pem"),
|
|
||||||
keyPath: badKeyPath,
|
|
||||||
leafCert: make(map[string]*tls.Certificate),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ca.create()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("create() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "write ca key") {
|
|
||||||
t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca key")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertificateAuthorityLoad_ReadErrorPaths(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("read cert failure", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
ca := &CertificateAuthority{
|
|
||||||
certPath: filepath.Join(dir, "missing-cert.pem"),
|
|
||||||
keyPath: filepath.Join(dir, "missing-key.pem"),
|
|
||||||
leafCert: make(map[string]*tls.Certificate),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ca.load()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("load() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "read ca cert") {
|
|
||||||
t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca cert")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("read key failure", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
goodCA := newTestCA(t)
|
|
||||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: goodCA.cert.Raw})
|
|
||||||
|
|
||||||
certPath := filepath.Join(dir, "cert.pem")
|
|
||||||
if err := os.WriteFile(certPath, certPEM, 0o600); err != nil {
|
|
||||||
t.Fatalf("WriteFile(cert) error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ca := &CertificateAuthority{
|
|
||||||
certPath: certPath,
|
|
||||||
keyPath: filepath.Join(dir, "missing-key.pem"),
|
|
||||||
leafCert: make(map[string]*tls.Certificate),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ca.load()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("load() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
if !strings.Contains(err.Error(), "read ca key") {
|
|
||||||
t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca key")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -1,45 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestLoadOrCreateCertificateAuthority_CreateThenLoad(t *testing.T) {
|
|
||||||
configRoot := t.TempDir()
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
||||||
|
|
||||||
ca1, err := loadOrCreateCertificateAuthority()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("first loadOrCreateCertificateAuthority() error = %v", err)
|
|
||||||
}
|
|
||||||
if ca1 == nil {
|
|
||||||
t.Fatal("first CA is nil")
|
|
||||||
}
|
|
||||||
if !ca1.WasCreated() {
|
|
||||||
t.Fatal("first CA should report WasCreated=true")
|
|
||||||
}
|
|
||||||
if ca1.CertPath() == "" {
|
|
||||||
t.Fatal("first CA CertPath is empty")
|
|
||||||
}
|
|
||||||
if _, err := os.Stat(ca1.CertPath()); err != nil {
|
|
||||||
t.Fatalf("first CA cert file missing: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ca2, err := loadOrCreateCertificateAuthority()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("second loadOrCreateCertificateAuthority() error = %v", err)
|
|
||||||
}
|
|
||||||
if ca2 == nil {
|
|
||||||
t.Fatal("second CA is nil")
|
|
||||||
}
|
|
||||||
if ca2.WasCreated() {
|
|
||||||
t.Fatal("second CA should report WasCreated=false (loaded existing)")
|
|
||||||
}
|
|
||||||
if ca2.CertPath() != ca1.CertPath() {
|
|
||||||
t.Fatalf("cert path mismatch: first=%q second=%q", ca1.CertPath(), ca2.CertPath())
|
|
||||||
}
|
|
||||||
if ca2.cert == nil || ca2.key == nil {
|
|
||||||
t.Fatal("loaded CA should include cert and key")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,266 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"math/big"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNormalizeCertHost(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
in string
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{name: "host and port", in: "example.com:443", want: "example.com"},
|
|
||||||
{name: "plain host", in: "example.com", want: "example.com"},
|
|
||||||
{name: "whitespace trims", in: " example.com:8443 ", want: "example.com"},
|
|
||||||
{name: "empty", in: " ", want: ""},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
if got := normalizeCertHost(tt.in); got != tt.want {
|
|
||||||
t.Fatalf("normalizeCertHost(%q) = %q, want %q", tt.in, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRandSerialNumber(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
serial, err := randSerialNumber()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("randSerialNumber() error = %v", err)
|
|
||||||
}
|
|
||||||
if serial == nil {
|
|
||||||
t.Fatal("serial is nil")
|
|
||||||
}
|
|
||||||
if serial.Sign() < 0 {
|
|
||||||
t.Fatalf("serial must be non-negative, got %v", serial)
|
|
||||||
}
|
|
||||||
|
|
||||||
limit := new(big.Int).Lsh(big.NewInt(1), 128)
|
|
||||||
if serial.Cmp(limit) >= 0 {
|
|
||||||
t.Fatalf("serial must be < 2^128, got %v", serial)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriteFileAtomically(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
path := filepath.Join(dir, "cert.pem")
|
|
||||||
|
|
||||||
if err := writeFileAtomically(path, []byte("first"), 0o600); err != nil {
|
|
||||||
t.Fatalf("first writeFileAtomically() error = %v", err)
|
|
||||||
}
|
|
||||||
if err := writeFileAtomically(path, []byte("second"), 0o600); err != nil {
|
|
||||||
t.Fatalf("second writeFileAtomically() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ReadFile() error = %v", err)
|
|
||||||
}
|
|
||||||
if got, want := string(data), "second"; got != want {
|
|
||||||
t.Fatalf("file contents = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
info, err := os.Stat(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Stat() error = %v", err)
|
|
||||||
}
|
|
||||||
if got := info.Mode().Perm(); got != 0o600 {
|
|
||||||
t.Fatalf("file permissions = %#o, want %#o", got, 0o600)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertificateAuthority_Basics(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var nilCA *CertificateAuthority
|
|
||||||
if got := nilCA.CertPath(); got != "" {
|
|
||||||
t.Fatalf("nil CertPath() = %q, want empty", got)
|
|
||||||
}
|
|
||||||
if got := nilCA.WasCreated(); got {
|
|
||||||
t.Fatalf("nil WasCreated() = %v, want false", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
ca := newTestCA(t)
|
|
||||||
ca.certPath = "/tmp/test-ca.pem"
|
|
||||||
ca.wasCreated = true
|
|
||||||
if got, want := ca.CertPath(), "/tmp/test-ca.pem"; got != want {
|
|
||||||
t.Fatalf("CertPath() = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
if !ca.WasCreated() {
|
|
||||||
t.Fatal("WasCreated() = false, want true")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCertificateForHost(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ca := newTestCA(t)
|
|
||||||
|
|
||||||
t.Run("empty host returns error", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
cert, err := ca.CertificateForHost(" ")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("CertificateForHost() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
if cert != nil {
|
|
||||||
t.Fatalf("cert = %#v, want nil", cert)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("cache hit returns same pointer", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
c1, err := ca.CertificateForHost("example.com:443")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("first CertificateForHost() error = %v", err)
|
|
||||||
}
|
|
||||||
c2, err := ca.CertificateForHost("example.com")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("second CertificateForHost() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c1 != c2 {
|
|
||||||
t.Fatal("expected same certificate pointer from cache")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("ip and dns SAN selection", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ipCert, err := ca.CertificateForHost("127.0.0.1:443")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ip CertificateForHost() error = %v", err)
|
|
||||||
}
|
|
||||||
if ipCert.Leaf == nil {
|
|
||||||
t.Fatal("ip cert leaf is nil")
|
|
||||||
}
|
|
||||||
if len(ipCert.Leaf.IPAddresses) == 0 {
|
|
||||||
t.Fatal("ip cert should contain IP SAN")
|
|
||||||
}
|
|
||||||
if len(ipCert.Leaf.DNSNames) != 0 {
|
|
||||||
t.Fatalf("ip cert DNSNames = %v, want empty", ipCert.Leaf.DNSNames)
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsCert, err := ca.CertificateForHost("service.local")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("dns CertificateForHost() error = %v", err)
|
|
||||||
}
|
|
||||||
if dnsCert.Leaf == nil {
|
|
||||||
t.Fatal("dns cert leaf is nil")
|
|
||||||
}
|
|
||||||
if len(dnsCert.Leaf.DNSNames) == 0 {
|
|
||||||
t.Fatal("dns cert should contain DNS SAN")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("evicts oldest entry over maxLeafCerts", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ca2 := newTestCA(t)
|
|
||||||
for i := 0; i < maxLeafCerts+1; i++ {
|
|
||||||
host := filepath.Base(filepath.Join("h", big.NewInt(int64(i)).String()+".example"))
|
|
||||||
if _, err := ca2.CertificateForHost(host); err != nil {
|
|
||||||
t.Fatalf("CertificateForHost(%q) error = %v", host, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(ca2.leafOrder) != maxLeafCerts {
|
|
||||||
t.Fatalf("leafOrder len = %d, want %d", len(ca2.leafOrder), maxLeafCerts)
|
|
||||||
}
|
|
||||||
if _, ok := ca2.leafCert["0.example"]; ok {
|
|
||||||
t.Fatal("expected oldest cert to be evicted")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsTrustedBySystem(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var nilCA *CertificateAuthority
|
|
||||||
_, err := nilCA.IsTrustedBySystem()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("nil IsTrustedBySystem() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
ca := &CertificateAuthority{}
|
|
||||||
_, err = ca.IsTrustedBySystem()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("missing-cert IsTrustedBySystem() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("untrusted generated CA returns false without error", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
ca := newTestCA(t)
|
|
||||||
|
|
||||||
trusted, err := ca.IsTrustedBySystem()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("IsTrustedBySystem() error = %v, want nil for unknown authority", err)
|
|
||||||
}
|
|
||||||
if trusted {
|
|
||||||
t.Fatal("trusted = true, want false for generated test CA")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEnsureCertificateAuthority(t *testing.T) {
|
|
||||||
configRoot := t.TempDir()
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
||||||
|
|
||||||
ca, err := EnsureCertificateAuthority()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("EnsureCertificateAuthority() error = %v", err)
|
|
||||||
}
|
|
||||||
if ca == nil {
|
|
||||||
t.Fatal("EnsureCertificateAuthority() returned nil CA")
|
|
||||||
}
|
|
||||||
if ca.CertPath() == "" {
|
|
||||||
t.Fatal("EnsureCertificateAuthority() returned empty cert path")
|
|
||||||
}
|
|
||||||
if _, statErr := os.Stat(ca.CertPath()); statErr != nil {
|
|
||||||
t.Fatalf("expected cert on disk, stat error = %v", statErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriteFileAtomically_ErrorPath(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
err := writeFileAtomically(filepath.Join("/nope", "bad", "path.pem"), []byte("x"), 0o600)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("writeFileAtomically() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Accept platform-dependent fs errors as long as function fails.
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriteFileAtomically_RenameErrorWhenTargetIsDirectory(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
dir := t.TempDir()
|
|
||||||
targetDir := filepath.Join(dir, "target-as-dir")
|
|
||||||
if err := os.MkdirAll(targetDir, 0o700); err != nil {
|
|
||||||
t.Fatalf("MkdirAll(targetDir) error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := writeFileAtomically(targetDir, []byte("x"), 0o600)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("writeFileAtomically() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Add deterministic tests for loadOrCreateCertificateAuthority trust-store interactions.
|
|
||||||
@ -1,29 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/pem"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIsTrustedBySystem_TrustedViaCertEnv(t *testing.T) {
|
|
||||||
ca := newTestCA(t)
|
|
||||||
|
|
||||||
rootFile := filepath.Join(t.TempDir(), "roots.pem")
|
|
||||||
pemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.cert.Raw})
|
|
||||||
if err := os.WriteFile(rootFile, pemBytes, 0o600); err != nil {
|
|
||||||
t.Fatalf("WriteFile(root cert) error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Setenv("SSL_CERT_FILE", rootFile)
|
|
||||||
t.Setenv("SSL_CERT_DIR", t.TempDir())
|
|
||||||
|
|
||||||
trusted, err := ca.IsTrustedBySystem()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("IsTrustedBySystem() error = %v", err)
|
|
||||||
}
|
|
||||||
if !trusted {
|
|
||||||
t.Fatal("trusted = false, want true when CA is in configured cert file")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,330 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type failingResponseWriter struct {
|
|
||||||
header http.Header
|
|
||||||
code int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *failingResponseWriter) Header() http.Header {
|
|
||||||
if w.header == nil {
|
|
||||||
w.header = make(http.Header)
|
|
||||||
}
|
|
||||||
return w.header
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *failingResponseWriter) WriteHeader(statusCode int) {
|
|
||||||
w.code = statusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *failingResponseWriter) Write(_ []byte) (int, error) {
|
|
||||||
return 0, io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
|
|
||||||
type hijackFailWriter struct {
|
|
||||||
header http.Header
|
|
||||||
code int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *hijackFailWriter) Header() http.Header {
|
|
||||||
if w.header == nil {
|
|
||||||
w.header = make(http.Header)
|
|
||||||
}
|
|
||||||
return w.header
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *hijackFailWriter) WriteHeader(statusCode int) {
|
|
||||||
w.code = statusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *hijackFailWriter) Write(p []byte) (int, error) {
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *hijackFailWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
||||||
return nil, nil, fmt.Errorf("hijack failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
type dummyConn struct{}
|
|
||||||
|
|
||||||
func (dummyConn) Read(_ []byte) (int, error) { return 0, io.EOF }
|
|
||||||
func (dummyConn) Write(p []byte) (int, error) { return len(p), nil }
|
|
||||||
func (dummyConn) Close() error { return nil }
|
|
||||||
func (dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
|
||||||
func (dummyConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
|
||||||
func (dummyConn) SetDeadline(_ time.Time) error { return nil }
|
|
||||||
func (dummyConn) SetReadDeadline(_ time.Time) error { return nil }
|
|
||||||
func (dummyConn) SetWriteDeadline(_ time.Time) error { return nil }
|
|
||||||
|
|
||||||
type writeConnectFailWriter struct {
|
|
||||||
header http.Header
|
|
||||||
code int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *writeConnectFailWriter) Header() http.Header {
|
|
||||||
if w.header == nil {
|
|
||||||
w.header = make(http.Header)
|
|
||||||
}
|
|
||||||
return w.header
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *writeConnectFailWriter) WriteHeader(statusCode int) {
|
|
||||||
w.code = statusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *writeConnectFailWriter) Write(p []byte) (int, error) {
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *writeConnectFailWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
||||||
rw := bufio.NewReadWriter(
|
|
||||||
bufio.NewReader(strings.NewReader("")),
|
|
||||||
bufio.NewWriter(errWriter{}),
|
|
||||||
)
|
|
||||||
return dummyConn{}, rw, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyHandler_NonConnectRejectsNonAbsoluteURL(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
|
||||||
h := proxyHandler(ch, nil, ps)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/not-proxy-form", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
h.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadRequest {
|
|
||||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 1, time.Second)
|
|
||||||
if !hasEventType(events, model.EventTypeWarn) {
|
|
||||||
t.Fatalf("expected %s event, got %#v", model.EventTypeWarn, events)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyHandler_NonConnectSuccess(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Header().Set("X-Upstream", "yes")
|
|
||||||
w.WriteHeader(http.StatusAccepted)
|
|
||||||
_, _ = w.Write([]byte("pong"))
|
|
||||||
}))
|
|
||||||
t.Cleanup(upstream.Close)
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
|
||||||
h := proxyHandler(ch, nil, ps)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, upstream.URL+"/ping", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
h.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusAccepted {
|
|
||||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusAccepted)
|
|
||||||
}
|
|
||||||
if got, want := w.Body.String(), "pong"; got != want {
|
|
||||||
t.Fatalf("body = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
if got := w.Header().Get("X-Upstream"); got != "yes" {
|
|
||||||
t.Fatalf("X-Upstream header = %q, want yes", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, time.Second)
|
|
||||||
if !hasEventType(events, model.EventTypeRequestStarted) {
|
|
||||||
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
|
|
||||||
}
|
|
||||||
if !hasEventType(events, model.EventTypeRequestFinished) {
|
|
||||||
t.Fatalf("expected %s event", model.EventTypeRequestFinished)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyHandler_NonConnectUpstreamError(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
|
||||||
h := proxyHandler(ch, nil, ps)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1:1/fail", nil)
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
|
|
||||||
h.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
if w.Code != http.StatusBadGateway {
|
|
||||||
t.Fatalf("status = %d, want %d", w.Code, http.StatusBadGateway)
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, time.Second)
|
|
||||||
if !hasEventType(events, model.EventTypeRequestStarted) {
|
|
||||||
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
|
|
||||||
}
|
|
||||||
if !hasEventType(events, model.EventTypeRequestFailed) {
|
|
||||||
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyHandler_NonConnectWriteFailureEmitsFailedEvent(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = io.Copy(w, bytes.NewBufferString("response-body"))
|
|
||||||
}))
|
|
||||||
t.Cleanup(upstream.Close)
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
|
||||||
h := proxyHandler(ch, nil, ps)
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, upstream.URL+"/copy-fail", nil)
|
|
||||||
w := &failingResponseWriter{}
|
|
||||||
|
|
||||||
h.ServeHTTP(w, req)
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, time.Second)
|
|
||||||
if !hasEventType(events, model.EventTypeRequestStarted) {
|
|
||||||
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
|
|
||||||
}
|
|
||||||
if !hasEventType(events, model.EventTypeRequestFailed) {
|
|
||||||
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandleConnect_EarlyFailures(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
transport := &mockTransport{fn: func(req *http.Request) (*http.Response, error) {
|
|
||||||
return nil, context.Canceled
|
|
||||||
}}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
writer http.ResponseWriter
|
|
||||||
req *http.Request
|
|
||||||
ca *CertificateAuthority
|
|
||||||
wantCode int
|
|
||||||
wantBody string
|
|
||||||
wantStat int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "nil CA returns 502 and failed event",
|
|
||||||
writer: httptest.NewRecorder(),
|
|
||||||
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
|
|
||||||
ca: nil,
|
|
||||||
wantCode: http.StatusBadGateway,
|
|
||||||
wantBody: "certificate authority is unavailable",
|
|
||||||
wantStat: http.StatusBadGateway,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nil CA with host without port still fails predictably",
|
|
||||||
writer: httptest.NewRecorder(),
|
|
||||||
req: &http.Request{Method: http.MethodConnect, Host: "example.com"},
|
|
||||||
ca: nil,
|
|
||||||
wantCode: http.StatusBadGateway,
|
|
||||||
wantBody: "certificate authority is unavailable",
|
|
||||||
wantStat: http.StatusBadGateway,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "cert mint failure returns 502 and failed event",
|
|
||||||
writer: httptest.NewRecorder(),
|
|
||||||
req: &http.Request{Method: http.MethodConnect, Host: ""},
|
|
||||||
ca: &CertificateAuthority{},
|
|
||||||
wantCode: http.StatusBadGateway,
|
|
||||||
wantBody: "failed to mint interception certificate",
|
|
||||||
wantStat: http.StatusBadGateway,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-hijacker writer returns 500 and failed event",
|
|
||||||
writer: httptest.NewRecorder(),
|
|
||||||
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
|
|
||||||
ca: newTestCA(t),
|
|
||||||
wantCode: http.StatusInternalServerError,
|
|
||||||
wantBody: "hijack is unavailable",
|
|
||||||
wantStat: http.StatusInternalServerError,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hijack failure returns 500 and failed event",
|
|
||||||
writer: &hijackFailWriter{},
|
|
||||||
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
|
|
||||||
ca: newTestCA(t),
|
|
||||||
wantCode: http.StatusInternalServerError,
|
|
||||||
wantBody: "CONNECT hijack failed",
|
|
||||||
wantStat: http.StatusInternalServerError,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "write connect established failure emits failed event",
|
|
||||||
writer: &writeConnectFailWriter{},
|
|
||||||
req: httptest.NewRequest(http.MethodConnect, "http://example.com:443", nil),
|
|
||||||
ca: newTestCA(t),
|
|
||||||
wantCode: 0,
|
|
||||||
wantBody: "CONNECT setup failed",
|
|
||||||
wantStat: http.StatusBadGateway,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
|
||||||
|
|
||||||
handleConnect(tt.writer, tt.req, ch, transport, tt.ca, ps)
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, time.Second)
|
|
||||||
if events[0].Type != model.EventTypeRequestStarted {
|
|
||||||
t.Fatalf("expected %s event", model.EventTypeRequestStarted)
|
|
||||||
}
|
|
||||||
if events[1].Type != model.EventTypeRequestFailed {
|
|
||||||
t.Fatalf("expected %s event", model.EventTypeRequestFailed)
|
|
||||||
}
|
|
||||||
if events[1].Request.Method != http.MethodConnect {
|
|
||||||
t.Fatalf("failed event request method = %q, want CONNECT", events[1].Request.Method)
|
|
||||||
}
|
|
||||||
if events[1].Request.Status != tt.wantStat {
|
|
||||||
t.Fatalf("failed event request status = %d, want %d", events[1].Request.Status, tt.wantStat)
|
|
||||||
}
|
|
||||||
if !events[1].Request.Failed || events[1].Request.Pending {
|
|
||||||
t.Fatalf("failed event request flags = pending:%v failed:%v, want pending:false failed:true", events[1].Request.Pending, events[1].Request.Failed)
|
|
||||||
}
|
|
||||||
if !strings.Contains(events[1].Body, tt.wantBody) {
|
|
||||||
t.Fatalf("failed event body %q does not contain %q", events[1].Body, tt.wantBody)
|
|
||||||
}
|
|
||||||
|
|
||||||
if recorder, ok := tt.writer.(*httptest.ResponseRecorder); ok {
|
|
||||||
if recorder.Code != tt.wantCode {
|
|
||||||
t.Fatalf("status = %d, want %d", recorder.Code, tt.wantCode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if w, ok := tt.writer.(*hijackFailWriter); ok {
|
|
||||||
if w.code != tt.wantCode {
|
|
||||||
t.Fatalf("status = %d, want %d", w.code, tt.wantCode)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Add full TLS MITM loop integration test.
|
|
||||||
@ -1,157 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStripHopByHopHeaders(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
in http.Header
|
|
||||||
want http.Header
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "removes static hop-by-hop headers",
|
|
||||||
in: http.Header{
|
|
||||||
"Connection": {"keep-alive"},
|
|
||||||
"Keep-Alive": {"timeout=5"},
|
|
||||||
"Proxy-Connection": {"keep-alive"},
|
|
||||||
"Transfer-Encoding": {"chunked"},
|
|
||||||
"Upgrade": {"websocket"},
|
|
||||||
"X-Custom": {"ok"},
|
|
||||||
},
|
|
||||||
want: http.Header{
|
|
||||||
"X-Custom": {"ok"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "removes headers listed in connection with spaces and commas",
|
|
||||||
in: http.Header{
|
|
||||||
"Connection": {" keep-alive, X-Foo ,X-Bar"},
|
|
||||||
"Keep-Alive": {"timeout=5"},
|
|
||||||
"X-Foo": {"foo"},
|
|
||||||
"X-Bar": {"bar"},
|
|
||||||
"Content-Type": {"application/json"},
|
|
||||||
"Content-Length": {"12"},
|
|
||||||
},
|
|
||||||
want: http.Header{
|
|
||||||
"Content-Type": {"application/json"},
|
|
||||||
"Content-Length": {"12"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "keeps unrelated headers when no connection header",
|
|
||||||
in: http.Header{
|
|
||||||
"Accept": {"*/*"},
|
|
||||||
"Authorization": {"Bearer abc"},
|
|
||||||
},
|
|
||||||
want: http.Header{
|
|
||||||
"Accept": {"*/*"},
|
|
||||||
"Authorization": {"Bearer abc"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
stripHopByHopHeaders(tt.in)
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(tt.in, tt.want) {
|
|
||||||
t.Fatalf("stripHopByHopHeaders() = %#v, want %#v", tt.in, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStripHopByHopHeaders_NilHeader(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
stripHopByHopHeaders(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRedactHeaders(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input http.Header
|
|
||||||
want http.Header
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "redacts sensitive canonical headers",
|
|
||||||
input: http.Header{
|
|
||||||
"Authorization": {"Bearer token123"},
|
|
||||||
"Cookie": {"session=abc"},
|
|
||||||
"Proxy-Authorization": {"Basic abc"},
|
|
||||||
"Set-Cookie": {"a=1"},
|
|
||||||
"X-Api-Key": {"secret"},
|
|
||||||
},
|
|
||||||
want: http.Header{
|
|
||||||
"Authorization": {"[REDACTED]"},
|
|
||||||
"Cookie": {"[REDACTED]"},
|
|
||||||
"Proxy-Authorization": {"[REDACTED]"},
|
|
||||||
"Set-Cookie": {"[REDACTED]"},
|
|
||||||
"X-Api-Key": {"[REDACTED]"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "leaves non-sensitive headers untouched",
|
|
||||||
input: http.Header{
|
|
||||||
"Content-Type": {"application/json"},
|
|
||||||
"X-Trace-ID": {"trace-1"},
|
|
||||||
},
|
|
||||||
want: http.Header{
|
|
||||||
"Content-Type": {"application/json"},
|
|
||||||
"X-Trace-ID": {"trace-1"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
got := redactHeaders(tt.input)
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Fatalf("redactHeaders() = %#v, want %#v", got, tt.want)
|
|
||||||
}
|
|
||||||
|
|
||||||
got.Set("Content-Type", "modified")
|
|
||||||
if reflect.DeepEqual(tt.input, got) {
|
|
||||||
t.Fatal("redactHeaders() appears to mutate input or return aliased map")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCopyHeaders(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
src := http.Header{
|
|
||||||
"X-Multi": {"a", "b", "c"},
|
|
||||||
"Content-Type": {"application/json"},
|
|
||||||
}
|
|
||||||
dest := http.Header{
|
|
||||||
"Existing": {"keep"},
|
|
||||||
}
|
|
||||||
|
|
||||||
copyHeaders(src, dest)
|
|
||||||
|
|
||||||
want := http.Header{
|
|
||||||
"Existing": {"keep"},
|
|
||||||
"X-Multi": {"a", "b", "c"},
|
|
||||||
"Content-Type": {"application/json"},
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(dest, want) {
|
|
||||||
t.Fatalf("copyHeaders() dest = %#v, want %#v", dest, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,129 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"net/url"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestHTTPProxyE2E_WithClientProxyConfig(t *testing.T) {
|
|
||||||
configRoot := t.TempDir()
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
||||||
|
|
||||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
_, _ = w.Write([]byte("pong"))
|
|
||||||
}))
|
|
||||||
t.Cleanup(upstream.Close)
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 64)
|
|
||||||
ps, err := NewProxyServer("127.0.0.1:0", ch)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewProxyServer() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
serveDone := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
serveDone <- ps.Server.Serve(*ps.Listener)
|
|
||||||
}()
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
Destroy(ps, ch)
|
|
||||||
select {
|
|
||||||
case <-serveDone:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for proxy server to stop")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
proxyURL, err := url.Parse(ps.Url)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("url.Parse(proxy) error = %v", err)
|
|
||||||
}
|
|
||||||
client := &http.Client{
|
|
||||||
Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)},
|
|
||||||
Timeout: 3 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Get(upstream.URL + "/ping")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("client.Get() error = %v", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ReadAll(response) error = %v", err)
|
|
||||||
}
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusOK)
|
|
||||||
}
|
|
||||||
if got, want := string(body), "pong"; got != want {
|
|
||||||
t.Fatalf("body = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, 2*time.Second)
|
|
||||||
if !hasEventType(events, model.EventTypeRequestStarted) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeRequestStarted)
|
|
||||||
}
|
|
||||||
if !hasEventType(events, model.EventTypeRequestFinished) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeRequestFinished)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPProxyE2E_UpstreamFailureEmitsRequestFailed(t *testing.T) {
|
|
||||||
configRoot := t.TempDir()
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 64)
|
|
||||||
ps, err := NewProxyServer("127.0.0.1:0", ch)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewProxyServer() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
serveDone := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
serveDone <- ps.Server.Serve(*ps.Listener)
|
|
||||||
}()
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
Destroy(ps, ch)
|
|
||||||
select {
|
|
||||||
case <-serveDone:
|
|
||||||
case <-time.After(2 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for proxy server to stop")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
proxyURL, err := url.Parse(ps.Url)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("url.Parse(proxy) error = %v", err)
|
|
||||||
}
|
|
||||||
client := &http.Client{
|
|
||||||
Transport: &http.Transport{Proxy: http.ProxyURL(proxyURL)},
|
|
||||||
Timeout: 3 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, reqErr := client.Get("http://127.0.0.1:1/unreachable")
|
|
||||||
if reqErr != nil {
|
|
||||||
t.Fatalf("client.Get() error = %v; proxy should reply with mapped HTTP status", reqErr)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusBadGateway {
|
|
||||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusBadGateway)
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, 3*time.Second)
|
|
||||||
if !hasEventType(events, model.EventTypeRequestStarted) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeRequestStarted)
|
|
||||||
}
|
|
||||||
if !hasEventType(events, model.EventTypeRequestFailed) {
|
|
||||||
t.Fatalf("missing %s event", model.EventTypeRequestFailed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,301 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type hijackPipeWriter struct {
|
|
||||||
conn net.Conn
|
|
||||||
|
|
||||||
header http.Header
|
|
||||||
code int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *hijackPipeWriter) Header() http.Header {
|
|
||||||
if w.header == nil {
|
|
||||||
w.header = make(http.Header)
|
|
||||||
}
|
|
||||||
return w.header
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *hijackPipeWriter) WriteHeader(statusCode int) {
|
|
||||||
w.code = statusCode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *hijackPipeWriter) Write(p []byte) (int, error) {
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *hijackPipeWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
|
||||||
return w.conn, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func startMITMConnect(t *testing.T, transport http.RoundTripper) (net.Conn, *tls.Conn, chan model.Event, chan struct{}) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
ca := newTestCA(t)
|
|
||||||
clientConn, serverConn := net.Pipe()
|
|
||||||
|
|
||||||
writer := &hijackPipeWriter{conn: serverConn}
|
|
||||||
req, err := http.NewRequest(http.MethodConnect, "http://example.com:443", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewRequest(CONNECT) error = %v", err)
|
|
||||||
}
|
|
||||||
req.Host = "example.com:443"
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 16)
|
|
||||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
|
||||||
|
|
||||||
handleDone := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
handleConnect(writer, req, ch, transport, ca, ps)
|
|
||||||
close(handleDone)
|
|
||||||
}()
|
|
||||||
|
|
||||||
reader := bufio.NewReader(clientConn)
|
|
||||||
connectResp, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
|
|
||||||
if err != nil {
|
|
||||||
_ = clientConn.Close()
|
|
||||||
_ = serverConn.Close()
|
|
||||||
t.Fatalf("ReadResponse(CONNECT established) error = %v", err)
|
|
||||||
}
|
|
||||||
if connectResp.StatusCode != http.StatusOK {
|
|
||||||
_ = clientConn.Close()
|
|
||||||
_ = serverConn.Close()
|
|
||||||
t.Fatalf("CONNECT established status = %d, want %d", connectResp.StatusCode, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
pool := x509.NewCertPool()
|
|
||||||
pool.AddCert(ca.cert)
|
|
||||||
tlsClient := tls.Client(clientConn, &tls.Config{
|
|
||||||
ServerName: "example.com",
|
|
||||||
RootCAs: pool,
|
|
||||||
MinVersion: tls.VersionTLS12,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := tlsClient.Handshake(); err != nil {
|
|
||||||
_ = tlsClient.Close()
|
|
||||||
_ = serverConn.Close()
|
|
||||||
t.Fatalf("tls handshake error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
_ = tlsClient.Close()
|
|
||||||
_ = serverConn.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
return clientConn, tlsClient, ch, handleDone
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPSE2E_MITMHandleConnectFlow(t *testing.T) {
|
|
||||||
transport := &mockTransport{fn: func(r *http.Request) (*http.Response, error) {
|
|
||||||
if r.URL.Scheme != "https" {
|
|
||||||
t.Fatalf("inner request scheme = %q, want https", r.URL.Scheme)
|
|
||||||
}
|
|
||||||
if r.URL.Host == "" {
|
|
||||||
t.Fatal("inner request host should be populated")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: http.StatusCreated,
|
|
||||||
ProtoMajor: 1,
|
|
||||||
ProtoMinor: 1,
|
|
||||||
Header: http.Header{"Content-Type": {"text/plain"}},
|
|
||||||
Body: io.NopCloser(strings.NewReader("mitm-ok")),
|
|
||||||
}, nil
|
|
||||||
}}
|
|
||||||
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
|
|
||||||
|
|
||||||
// 3) Send decrypted inner request over tunnel and read MITM response
|
|
||||||
innerReq, err := http.NewRequest(http.MethodGet, "https://example.com/inside", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewRequest(inner) error = %v", err)
|
|
||||||
}
|
|
||||||
innerReq.Host = "example.com"
|
|
||||||
innerReq.Close = true
|
|
||||||
if err := innerReq.Write(tlsClient); err != nil {
|
|
||||||
t.Fatalf("innerReq.Write() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
innerResp, err := http.ReadResponse(bufio.NewReader(tlsClient), innerReq)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ReadResponse(inner) error = %v", err)
|
|
||||||
}
|
|
||||||
defer innerResp.Body.Close()
|
|
||||||
|
|
||||||
body, err := io.ReadAll(innerResp.Body)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ReadAll(inner body) error = %v", err)
|
|
||||||
}
|
|
||||||
if innerResp.StatusCode != http.StatusCreated {
|
|
||||||
t.Fatalf("inner status = %d, want %d", innerResp.StatusCode, http.StatusCreated)
|
|
||||||
}
|
|
||||||
if got, want := string(body), "mitm-ok"; got != want {
|
|
||||||
t.Fatalf("inner body = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = tlsClient.Close()
|
|
||||||
_ = clientConn.Close()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-handleDone:
|
|
||||||
case <-time.After(3 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for handleConnect to return")
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 4, 2*time.Second)
|
|
||||||
if events[0].Type != model.EventTypeRequestStarted {
|
|
||||||
t.Fatalf("event[0] = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
|
|
||||||
}
|
|
||||||
if events[1].Type != model.EventTypeRequestStarted {
|
|
||||||
t.Fatalf("event[1] = %s, want %s", events[1].Type, model.EventTypeRequestStarted)
|
|
||||||
}
|
|
||||||
if events[2].Type != model.EventTypeRequestFinished {
|
|
||||||
t.Fatalf("event[2] = %s, want %s", events[2].Type, model.EventTypeRequestFinished)
|
|
||||||
}
|
|
||||||
if events[3].Type != model.EventTypeRequestFinished {
|
|
||||||
t.Fatalf("event[3] = %s, want %s", events[3].Type, model.EventTypeRequestFinished)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPSE2E_MITMUpstreamErrorReturnsHTTPErrorInsideTunnel(t *testing.T) {
|
|
||||||
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
|
|
||||||
return nil, errors.New("upstream exploded")
|
|
||||||
}}
|
|
||||||
|
|
||||||
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
|
|
||||||
|
|
||||||
innerReq, err := http.NewRequest(http.MethodGet, "https://example.com/fail", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewRequest(inner) error = %v", err)
|
|
||||||
}
|
|
||||||
innerReq.Host = "example.com"
|
|
||||||
if err := innerReq.Write(tlsClient); err != nil {
|
|
||||||
t.Fatalf("innerReq.Write() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
innerResp, err := http.ReadResponse(bufio.NewReader(tlsClient), innerReq)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ReadResponse(inner) error = %v", err)
|
|
||||||
}
|
|
||||||
defer innerResp.Body.Close()
|
|
||||||
|
|
||||||
if innerResp.StatusCode != http.StatusBadGateway {
|
|
||||||
t.Fatalf("inner status = %d, want %d", innerResp.StatusCode, http.StatusBadGateway)
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = tlsClient.Close()
|
|
||||||
_ = clientConn.Close()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-handleDone:
|
|
||||||
case <-time.After(3 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for handleConnect to return")
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 4, 2*time.Second)
|
|
||||||
if events[0].Type != model.EventTypeRequestStarted ||
|
|
||||||
events[1].Type != model.EventTypeRequestStarted ||
|
|
||||||
events[2].Type != model.EventTypeRequestFailed ||
|
|
||||||
events[3].Type != model.EventTypeRequestFailed {
|
|
||||||
t.Fatalf("unexpected event sequence: %#v", events)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPSE2E_MITMHandshakeFailureEmitsConnectFailed(t *testing.T) {
|
|
||||||
ca := newTestCA(t)
|
|
||||||
|
|
||||||
clientConn, serverConn := net.Pipe()
|
|
||||||
t.Cleanup(func() {
|
|
||||||
_ = clientConn.Close()
|
|
||||||
_ = serverConn.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
writer := &hijackPipeWriter{conn: serverConn}
|
|
||||||
req, err := http.NewRequest(http.MethodConnect, "http://example.com:443", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewRequest(CONNECT) error = %v", err)
|
|
||||||
}
|
|
||||||
req.Host = "example.com:443"
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 16)
|
|
||||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
|
||||||
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
|
|
||||||
return nil, errors.New("should not be called")
|
|
||||||
}}
|
|
||||||
|
|
||||||
handleDone := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
handleConnect(writer, req, ch, transport, ca, ps)
|
|
||||||
close(handleDone)
|
|
||||||
}()
|
|
||||||
|
|
||||||
reader := bufio.NewReader(clientConn)
|
|
||||||
connectResp, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ReadResponse(CONNECT established) error = %v", err)
|
|
||||||
}
|
|
||||||
if connectResp.StatusCode != http.StatusOK {
|
|
||||||
t.Fatalf("CONNECT established status = %d, want %d", connectResp.StatusCode, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write plaintext (not TLS handshake) to force handshake failure.
|
|
||||||
if _, err := clientConn.Write([]byte("not tls handshake")); err != nil {
|
|
||||||
t.Fatalf("client write error = %v", err)
|
|
||||||
}
|
|
||||||
_ = clientConn.Close()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-handleDone:
|
|
||||||
case <-time.After(3 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for handleConnect to return on handshake failure")
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, 2*time.Second)
|
|
||||||
if events[0].Type != model.EventTypeRequestStarted || events[1].Type != model.EventTypeRequestFailed {
|
|
||||||
t.Fatalf("unexpected event sequence: %#v", events)
|
|
||||||
}
|
|
||||||
if !strings.Contains(events[1].Body, "TLS handshake with client failed") {
|
|
||||||
t.Fatalf("failed event body = %q, want handshake failure details", events[1].Body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPSE2E_MITMDecryptedReadFailureEmitsConnectFailed(t *testing.T) {
|
|
||||||
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
|
|
||||||
return nil, errors.New("should not be called")
|
|
||||||
}}
|
|
||||||
|
|
||||||
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
|
|
||||||
|
|
||||||
// Send malformed decrypted payload so ReadRequest fails (non-EOF branch).
|
|
||||||
if _, err := tlsClient.Write([]byte("BAD REQUEST\r\n\r\n")); err != nil {
|
|
||||||
t.Fatalf("tlsClient.Write() error = %v", err)
|
|
||||||
}
|
|
||||||
_ = tlsClient.Close()
|
|
||||||
_ = clientConn.Close()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-handleDone:
|
|
||||||
case <-time.After(3 * time.Second):
|
|
||||||
t.Fatal("timeout waiting for handleConnect to return on decrypted read failure")
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 2, 2*time.Second)
|
|
||||||
if events[0].Type != model.EventTypeRequestStarted || events[1].Type != model.EventTypeRequestFailed {
|
|
||||||
t.Fatalf("unexpected event sequence: %#v", events)
|
|
||||||
}
|
|
||||||
if !strings.Contains(events[1].Body, "failed to read decrypted HTTPS request") {
|
|
||||||
t.Fatalf("failed event body = %q, want read failure details", events[1].Body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,142 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewBodyPreview(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
contentType string
|
|
||||||
wantEnabled bool
|
|
||||||
}{
|
|
||||||
{name: "text content enabled", contentType: "text/plain", wantEnabled: true},
|
|
||||||
{name: "json content enabled", contentType: "application/json", wantEnabled: true},
|
|
||||||
{name: "binary content disabled", contentType: "application/octet-stream", wantEnabled: false},
|
|
||||||
{name: "empty content type disabled", contentType: "", wantEnabled: false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
p := newBodyPreview(tt.contentType)
|
|
||||||
if p.enabled != tt.wantEnabled {
|
|
||||||
t.Fatalf("newBodyPreview(%q).enabled = %v, want %v", tt.contentType, p.enabled, tt.wantEnabled)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBodyPreviewWriteAndPreview(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("nil receiver is safe", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
var p *bodyPreview
|
|
||||||
p.Write([]byte("abc"))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("disabled preview ignores data", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
p := &bodyPreview{enabled: false}
|
|
||||||
p.Write([]byte("abc"))
|
|
||||||
if got := string(p.Preview()); got != "" {
|
|
||||||
t.Fatalf("Preview() = %q, want empty", got)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("empty write does nothing", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
p := &bodyPreview{enabled: true}
|
|
||||||
p.Write(nil)
|
|
||||||
if got := string(p.Preview()); got != "" {
|
|
||||||
t.Fatalf("Preview() = %q, want empty", got)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("escapes newlines", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
p := &bodyPreview{enabled: true}
|
|
||||||
p.Write([]byte("a\nb"))
|
|
||||||
if got, want := string(p.Preview()), `a\nb`; got != want {
|
|
||||||
t.Fatalf("Preview() = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("truncates at max preview bytes and appends ellipsis", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
p := &bodyPreview{enabled: true}
|
|
||||||
p.Write([]byte(strings.Repeat("a", maxPreviewBytes)))
|
|
||||||
p.Write([]byte("b"))
|
|
||||||
|
|
||||||
got := string(p.Preview())
|
|
||||||
if !strings.HasSuffix(got, "...") {
|
|
||||||
t.Fatalf("Preview() must end with ellipsis when truncated: %q", got[len(got)-10:])
|
|
||||||
}
|
|
||||||
if len(got) != maxPreviewBytes+3 {
|
|
||||||
t.Fatalf("len(Preview()) = %d, want %d", len(got), maxPreviewBytes+3)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWrapBufferedConn(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
client, server := net.Pipe()
|
|
||||||
t.Cleanup(func() {
|
|
||||||
_ = client.Close()
|
|
||||||
_ = server.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns original conn when readWriter nil", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
got := wrapBufferedConn(client, nil)
|
|
||||||
if got != client {
|
|
||||||
t.Fatal("wrapBufferedConn should return original conn when readWriter is nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("read uses buffered readWriter", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("xyz")), bufio.NewWriter(io.Discard))
|
|
||||||
got := wrapBufferedConn(client, rw)
|
|
||||||
|
|
||||||
buf := make([]byte, 3)
|
|
||||||
n, err := got.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Read() error = %v", err)
|
|
||||||
}
|
|
||||||
if n != 3 || string(buf) != "xyz" {
|
|
||||||
t.Fatalf("Read() = (%d, %q), want (3, %q)", n, string(buf), "xyz")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPreviewReadCloserRead(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
preview := newBodyPreview("text/plain")
|
|
||||||
rc := &previewReadCloser{
|
|
||||||
ReadCloser: io.NopCloser(strings.NewReader("hello\nworld")),
|
|
||||||
preview: preview,
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := io.ReadAll(rc)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ReadAll() error = %v", err)
|
|
||||||
}
|
|
||||||
if got, want := string(data), "hello\nworld"; got != want {
|
|
||||||
t.Fatalf("read content = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
if got, want := string(preview.Preview()), `hello\nworld`; got != want {
|
|
||||||
t.Fatalf("preview = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,275 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
type mockTransport struct {
|
|
||||||
fn func(*http.Request) (*http.Response, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
||||||
return m.fn(req)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRoundTripCapturedRequest_Success(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
|
|
||||||
reqURL, err := url.Parse("http://example.com/path?q=1")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("url.Parse() error = %v", err)
|
|
||||||
}
|
|
||||||
req := &http.Request{
|
|
||||||
Method: http.MethodPost,
|
|
||||||
URL: reqURL,
|
|
||||||
Host: "example.com",
|
|
||||||
Header: http.Header{
|
|
||||||
"Connection": {"X-Hop"},
|
|
||||||
"X-Hop": {"drop"},
|
|
||||||
"Authorization": {"Bearer token"},
|
|
||||||
"Content-Type": {"text/plain"},
|
|
||||||
},
|
|
||||||
Body: io.NopCloser(strings.NewReader("req\nbody")),
|
|
||||||
}
|
|
||||||
|
|
||||||
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
|
|
||||||
if got := outReq.Header.Get("Connection"); got != "" {
|
|
||||||
t.Fatalf("Connection header should be stripped, got %q", got)
|
|
||||||
}
|
|
||||||
if got := outReq.Header.Get("X-Hop"); got != "" {
|
|
||||||
t.Fatalf("header listed in Connection should be stripped, got %q", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
data, readErr := io.ReadAll(outReq.Body)
|
|
||||||
if readErr != nil {
|
|
||||||
t.Fatalf("ReadAll(outReq.Body) error = %v", readErr)
|
|
||||||
}
|
|
||||||
if got, want := string(data), "req\nbody"; got != want {
|
|
||||||
t.Fatalf("request body = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: http.StatusCreated,
|
|
||||||
Header: http.Header{
|
|
||||||
"Set-Cookie": {"session=top-secret"},
|
|
||||||
"Connection": {"close"},
|
|
||||||
"Content-Type": {"text/plain"},
|
|
||||||
},
|
|
||||||
Body: io.NopCloser(strings.NewReader("resp\nbody")),
|
|
||||||
}, nil
|
|
||||||
}}
|
|
||||||
|
|
||||||
resp, captured, responsePreview, err := roundTripCapturedRequest(req, transport, ch, "", false)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("roundTripCapturedRequest() error = %v", err)
|
|
||||||
}
|
|
||||||
if resp == nil {
|
|
||||||
t.Fatal("roundTripCapturedRequest() returned nil response")
|
|
||||||
}
|
|
||||||
if responsePreview == nil {
|
|
||||||
t.Fatal("roundTripCapturedRequest() returned nil response preview")
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, readErr := io.ReadAll(resp.Body); readErr != nil {
|
|
||||||
t.Fatalf("ReadAll(resp.Body) error = %v", readErr)
|
|
||||||
}
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
|
|
||||||
if got, want := string(captured.RequestData), `req\nbody`; got != want {
|
|
||||||
t.Fatalf("captured.RequestData = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
if got, want := string(responsePreview.Preview()), `resp\nbody`; got != want {
|
|
||||||
t.Fatalf("responsePreview = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
if got := captured.RequestHeaders.Get("Authorization"); got != "[REDACTED]" {
|
|
||||||
t.Fatalf("Authorization header = %q, want [REDACTED]", got)
|
|
||||||
}
|
|
||||||
if got := captured.ResponseHeaders.Get("Set-Cookie"); got != "[REDACTED]" {
|
|
||||||
t.Fatalf("Set-Cookie header = %q, want [REDACTED]", got)
|
|
||||||
}
|
|
||||||
if got := captured.ResponseHeaders.Get("Connection"); got != "" {
|
|
||||||
t.Fatalf("Connection should be stripped from response headers, got %q", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 1, time.Second)
|
|
||||||
if events[0].Type != model.EventTypeRequestStarted {
|
|
||||||
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRoundTripCapturedRequest_ErrorPath(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
|
|
||||||
reqURL, err := url.Parse("http://example.com/fail")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("url.Parse() error = %v", err)
|
|
||||||
}
|
|
||||||
req := &http.Request{
|
|
||||||
Method: http.MethodPost,
|
|
||||||
URL: reqURL,
|
|
||||||
Host: "example.com",
|
|
||||||
Header: http.Header{"Content-Type": {"text/plain"}},
|
|
||||||
Body: io.NopCloser(strings.NewReader("boom\nbody")),
|
|
||||||
}
|
|
||||||
|
|
||||||
wantErr := errors.New("upstream failed")
|
|
||||||
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
|
|
||||||
_, _ = io.ReadAll(outReq.Body)
|
|
||||||
return nil, wantErr
|
|
||||||
}}
|
|
||||||
|
|
||||||
resp, captured, responsePreview, gotErr := roundTripCapturedRequest(req, transport, ch, "", false)
|
|
||||||
if !errors.Is(gotErr, wantErr) {
|
|
||||||
t.Fatalf("error = %v, want %v", gotErr, wantErr)
|
|
||||||
}
|
|
||||||
if resp != nil {
|
|
||||||
t.Fatalf("response = %#v, want nil", resp)
|
|
||||||
}
|
|
||||||
if responsePreview != nil {
|
|
||||||
t.Fatalf("responsePreview = %#v, want nil", responsePreview)
|
|
||||||
}
|
|
||||||
if got, want := string(captured.RequestData), `boom\nbody`; got != want {
|
|
||||||
t.Fatalf("captured.RequestData = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
|
|
||||||
events := drainEvents(t, ch, 1, time.Second)
|
|
||||||
if events[0].Type != model.EventTypeRequestStarted {
|
|
||||||
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRoundTripCapturedRequest_InterceptedTLSDefaults(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
|
|
||||||
u, err := url.Parse("/secure?p=1")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("url.Parse() error = %v", err)
|
|
||||||
}
|
|
||||||
req := &http.Request{
|
|
||||||
Method: http.MethodGet,
|
|
||||||
URL: u,
|
|
||||||
Header: http.Header{},
|
|
||||||
}
|
|
||||||
|
|
||||||
const defaultHost = "api.example.com:443"
|
|
||||||
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
|
|
||||||
if got := outReq.URL.Scheme; got != "https" {
|
|
||||||
t.Fatalf("URL.Scheme = %q, want https", got)
|
|
||||||
}
|
|
||||||
if got := outReq.URL.Host; got != defaultHost {
|
|
||||||
t.Fatalf("URL.Host = %q, want %q", got, defaultHost)
|
|
||||||
}
|
|
||||||
if got := outReq.Host; got != defaultHost {
|
|
||||||
t.Fatalf("Host = %q, want %q", got, defaultHost)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &http.Response{
|
|
||||||
StatusCode: http.StatusNoContent,
|
|
||||||
Header: http.Header{"Content-Type": {"text/plain"}},
|
|
||||||
Body: io.NopCloser(strings.NewReader("")),
|
|
||||||
}, nil
|
|
||||||
}}
|
|
||||||
|
|
||||||
_, _, _, gotErr := roundTripCapturedRequest(req, transport, ch, defaultHost, true)
|
|
||||||
if gotErr != nil {
|
|
||||||
t.Fatalf("roundTripCapturedRequest() error = %v", gotErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewConnectRequest(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
now := time.Now().Add(-time.Second)
|
|
||||||
req := &http.Request{Method: http.MethodConnect, Host: "example.com:443"}
|
|
||||||
|
|
||||||
got := newConnectRequest(req, now)
|
|
||||||
|
|
||||||
if got.Method != http.MethodConnect {
|
|
||||||
t.Fatalf("Method = %q, want CONNECT", got.Method)
|
|
||||||
}
|
|
||||||
if got.Host != req.Host || got.URL != req.Host || got.RawURL != req.Host {
|
|
||||||
t.Fatalf("connect request host/url/raw mismatch: %#v", got)
|
|
||||||
}
|
|
||||||
if !got.Pending || got.Failed {
|
|
||||||
t.Fatalf("Pending/Failed = (%v,%v), want (true,false)", got.Pending, got.Failed)
|
|
||||||
}
|
|
||||||
if got.Status != -1 {
|
|
||||||
t.Fatalf("Status = %d, want -1", got.Status)
|
|
||||||
}
|
|
||||||
if got.StartTime != now {
|
|
||||||
t.Fatalf("StartTime = %v, want %v", got.StartTime, now)
|
|
||||||
}
|
|
||||||
if got.ID == uuid.Nil {
|
|
||||||
t.Fatal("ID must be non-zero UUID")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStartFinishFailRequestEvents(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 4)
|
|
||||||
req := model.Request{
|
|
||||||
ID: uuid.New(),
|
|
||||||
Method: http.MethodGet,
|
|
||||||
RawURL: "http://example.com/a",
|
|
||||||
StartTime: time.Now().Add(-3 * time.Millisecond),
|
|
||||||
Pending: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
startRequest(ch, req)
|
|
||||||
events := drainEvents(t, ch, 1, time.Second)
|
|
||||||
startEv := events[0]
|
|
||||||
if startEv.Type != model.EventTypeRequestStarted {
|
|
||||||
t.Fatalf("start event type = %s, want %s", startEv.Type, model.EventTypeRequestStarted)
|
|
||||||
}
|
|
||||||
if startEv.Request.Pending != true {
|
|
||||||
t.Fatalf("start request pending = %v, want true", startEv.Request.Pending)
|
|
||||||
}
|
|
||||||
|
|
||||||
finishRequest(ch, req, http.StatusOK)
|
|
||||||
events = drainEvents(t, ch, 1, time.Second)
|
|
||||||
finishEv := events[0]
|
|
||||||
if finishEv.Type != model.EventTypeRequestFinished {
|
|
||||||
t.Fatalf("finish event type = %s, want %s", finishEv.Type, model.EventTypeRequestFinished)
|
|
||||||
}
|
|
||||||
if finishEv.Request.Pending {
|
|
||||||
t.Fatal("finished request should not be pending")
|
|
||||||
}
|
|
||||||
if finishEv.Request.Failed {
|
|
||||||
t.Fatal("finished request should not be failed")
|
|
||||||
}
|
|
||||||
if finishEv.Request.Status != http.StatusOK {
|
|
||||||
t.Fatalf("finished status = %d, want %d", finishEv.Request.Status, http.StatusOK)
|
|
||||||
}
|
|
||||||
|
|
||||||
failRequest(ch, req, http.StatusBadGateway, "upstream error")
|
|
||||||
events = drainEvents(t, ch, 1, time.Second)
|
|
||||||
failEv := events[0]
|
|
||||||
if failEv.Type != model.EventTypeRequestFailed {
|
|
||||||
t.Fatalf("fail event type = %s, want %s", failEv.Type, model.EventTypeRequestFailed)
|
|
||||||
}
|
|
||||||
if failEv.Request.Pending {
|
|
||||||
t.Fatal("failed request should not be pending")
|
|
||||||
}
|
|
||||||
if !failEv.Request.Failed {
|
|
||||||
t.Fatal("failed request should be marked failed")
|
|
||||||
}
|
|
||||||
if failEv.Request.Status != http.StatusBadGateway {
|
|
||||||
t.Fatalf("failed status = %d, want %d", failEv.Request.Status, http.StatusBadGateway)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,138 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type captureConn struct {
|
|
||||||
bytes.Buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *captureConn) Read(_ []byte) (int, error) { return 0, io.EOF }
|
|
||||||
func (c *captureConn) Close() error { return nil }
|
|
||||||
func (c *captureConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
|
||||||
func (c *captureConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
|
||||||
func (c *captureConn) SetDeadline(_ time.Time) error { return nil }
|
|
||||||
func (c *captureConn) SetReadDeadline(_ time.Time) error { return nil }
|
|
||||||
func (c *captureConn) SetWriteDeadline(_ time.Time) error { return nil }
|
|
||||||
|
|
||||||
type failWriteConn struct{}
|
|
||||||
|
|
||||||
func (failWriteConn) Read(_ []byte) (int, error) { return 0, io.EOF }
|
|
||||||
func (failWriteConn) Write(_ []byte) (int, error) { return 0, io.ErrClosedPipe }
|
|
||||||
func (failWriteConn) Close() error { return nil }
|
|
||||||
func (failWriteConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
|
||||||
func (failWriteConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
|
||||||
func (failWriteConn) SetDeadline(_ time.Time) error { return nil }
|
|
||||||
func (failWriteConn) SetReadDeadline(_ time.Time) error { return nil }
|
|
||||||
func (failWriteConn) SetWriteDeadline(_ time.Time) error { return nil }
|
|
||||||
|
|
||||||
type trackingBody struct {
|
|
||||||
data *bytes.Reader
|
|
||||||
readN int
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *trackingBody) Read(p []byte) (int, error) {
|
|
||||||
n, err := b.data.Read(p)
|
|
||||||
b.readN += n
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *trackingBody) Close() error {
|
|
||||||
b.closed = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriteConnectEstablished(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("writes directly to raw conn", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
conn := &captureConn{}
|
|
||||||
if err := writeConnectEstablished(conn, nil); err != nil {
|
|
||||||
t.Fatalf("writeConnectEstablished() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := conn.String(), "HTTP/1.1 200 Connection Established\r\n\r\n"; got != want {
|
|
||||||
t.Fatalf("raw write = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("writes and flushes with buffered readWriter", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
conn := &captureConn{}
|
|
||||||
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bufio.NewWriter(conn))
|
|
||||||
if err := writeConnectEstablished(conn, rw); err != nil {
|
|
||||||
t.Fatalf("writeConnectEstablished() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := conn.String(), "HTTP/1.1 200 Connection Established\r\n\r\n"; got != want {
|
|
||||||
t.Fatalf("buffered write = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns flush error", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bufio.NewWriter(errWriter{}))
|
|
||||||
err := writeConnectEstablished(&captureConn{}, rw)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns buffered write error when writer already failed", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
bw := bufio.NewWriter(errWriter{})
|
|
||||||
_ = bw.Flush() // set sticky error to force WriteString error path
|
|
||||||
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bw)
|
|
||||||
|
|
||||||
err := writeConnectEstablished(&captureConn{}, rw)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns raw conn write error", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
err := writeConnectEstablished(failWriteConn{}, nil)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDiscardAndCloseBody(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("nil body is safe", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
discardAndCloseBody(nil)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("closes body and discards at most limit", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
payload := bytes.Repeat([]byte("x"), maxDiscardBodyBytes+128)
|
|
||||||
body := &trackingBody{data: bytes.NewReader(payload)}
|
|
||||||
|
|
||||||
discardAndCloseBody(body)
|
|
||||||
|
|
||||||
if !body.closed {
|
|
||||||
t.Fatal("body was not closed")
|
|
||||||
}
|
|
||||||
if body.readN != maxDiscardBodyBytes {
|
|
||||||
t.Fatalf("bytes read = %d, want %d", body.readN, maxDiscardBodyBytes)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -1,235 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NOTE: Run these tests with -race; they cover concurrent state.
|
|
||||||
|
|
||||||
type closeTrackingConn struct {
|
|
||||||
closed bool
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *closeTrackingConn) Read(_ []byte) (int, error) { return 0, net.ErrClosed }
|
|
||||||
func (c *closeTrackingConn) Write(b []byte) (int, error) { return len(b), nil }
|
|
||||||
func (c *closeTrackingConn) Close() error { c.mu.Lock(); c.closed = true; c.mu.Unlock(); return nil }
|
|
||||||
func (c *closeTrackingConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
|
||||||
func (c *closeTrackingConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
|
||||||
func (c *closeTrackingConn) SetDeadline(_ time.Time) error { return nil }
|
|
||||||
func (c *closeTrackingConn) SetReadDeadline(_ time.Time) error { return nil }
|
|
||||||
func (c *closeTrackingConn) SetWriteDeadline(_ time.Time) error { return nil }
|
|
||||||
|
|
||||||
func (c *closeTrackingConn) isClosed() bool {
|
|
||||||
c.mu.Lock()
|
|
||||||
defer c.mu.Unlock()
|
|
||||||
return c.closed
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewProxyServer_Success(t *testing.T) {
|
|
||||||
configRoot := t.TempDir()
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 16)
|
|
||||||
ps, err := NewProxyServer("127.0.0.1:0", ch)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewProxyServer() error = %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() { Destroy(ps, ch) })
|
|
||||||
|
|
||||||
if ps.Listener == nil || *ps.Listener == nil {
|
|
||||||
t.Fatal("Listener is nil")
|
|
||||||
}
|
|
||||||
if ps.Server == nil {
|
|
||||||
t.Fatal("Server is nil")
|
|
||||||
}
|
|
||||||
if ps.Server.Handler == nil {
|
|
||||||
t.Fatal("Server.Handler is nil")
|
|
||||||
}
|
|
||||||
if !strings.HasPrefix(ps.Url, "http://") {
|
|
||||||
t.Fatalf("Url = %q, want http:// prefix", ps.Url)
|
|
||||||
}
|
|
||||||
if !ps.CAReady {
|
|
||||||
t.Fatal("CAReady = false, want true")
|
|
||||||
}
|
|
||||||
if ps.CACertPath == "" {
|
|
||||||
t.Fatal("CACertPath should not be empty")
|
|
||||||
}
|
|
||||||
if got, want := filepath.Dir(ps.CACertPath), filepath.Join(configRoot, caDirName); got != want {
|
|
||||||
t.Fatalf("CACertPath dir = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
if _, statErr := os.Stat(ps.CACertPath); statErr != nil {
|
|
||||||
t.Fatalf("expected CA cert on disk, stat error = %v", statErr)
|
|
||||||
}
|
|
||||||
if ps.Conns == nil {
|
|
||||||
t.Fatal("Conns map is nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewProxyServer_CACreatedFlagAcrossRuns(t *testing.T) {
|
|
||||||
configRoot := t.TempDir()
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 16)
|
|
||||||
ps1, err := NewProxyServer("127.0.0.1:0", ch)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("first NewProxyServer() error = %v", err)
|
|
||||||
}
|
|
||||||
Destroy(ps1, ch)
|
|
||||||
|
|
||||||
ps2, err := NewProxyServer("127.0.0.1:0", ch)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("second NewProxyServer() error = %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() { Destroy(ps2, ch) })
|
|
||||||
|
|
||||||
if !ps1.CACreated {
|
|
||||||
t.Fatal("first NewProxyServer should report CACreated=true")
|
|
||||||
}
|
|
||||||
if ps2.CACreated {
|
|
||||||
t.Fatal("second NewProxyServer should report CACreated=false when CA already exists")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewProxyServer_ErrorWhenCASetupFails(t *testing.T) {
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", "")
|
|
||||||
t.Setenv("HOME", "")
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
ps, err := NewProxyServer("127.0.0.1:0", ch)
|
|
||||||
if err == nil {
|
|
||||||
Destroy(ps, ch)
|
|
||||||
t.Fatal("NewProxyServer() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
if ps != nil {
|
|
||||||
t.Fatalf("proxy server = %#v, want nil", ps)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewProxyServer_ErrorWhenListenFails(t *testing.T) {
|
|
||||||
configRoot := t.TempDir()
|
|
||||||
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
||||||
|
|
||||||
occupied, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("net.Listen() error = %v", err)
|
|
||||||
}
|
|
||||||
t.Cleanup(func() { _ = occupied.Close() })
|
|
||||||
|
|
||||||
addr := occupied.Addr().String()
|
|
||||||
ch := make(chan model.Event, 8)
|
|
||||||
ps, gotErr := NewProxyServer(addr, ch)
|
|
||||||
if gotErr == nil {
|
|
||||||
Destroy(ps, ch)
|
|
||||||
t.Fatal("NewProxyServer() error = nil, want non-nil")
|
|
||||||
}
|
|
||||||
if ps != nil {
|
|
||||||
t.Fatalf("proxy server = %#v, want nil", ps)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDestroy(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("nil-safe", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
Destroy(nil, make(chan model.Event, 1))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("emits ProxyStopped when server exists", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 2)
|
|
||||||
ps := &model.ProxyServer{Server: &http.Server{}, Conns: make(map[net.Conn]struct{})}
|
|
||||||
|
|
||||||
Destroy(ps, ch)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case ev := <-ch:
|
|
||||||
if ev.Type != model.EventTypeProxyStopped {
|
|
||||||
t.Fatalf("event type = %s, want %s", ev.Type, model.EventTypeProxyStopped)
|
|
||||||
}
|
|
||||||
case <-time.After(time.Second):
|
|
||||||
t.Fatal("timeout waiting for ProxyStopped event")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectionTrackingHelpers(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("track and untrack are nil-safe", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
trackConnection(nil, nil)
|
|
||||||
untrackConnection(nil, nil)
|
|
||||||
closeTrackedConnections(nil)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("track/untrack mutate connection set", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
|
||||||
c := &closeTrackingConn{}
|
|
||||||
|
|
||||||
trackConnection(ps, c)
|
|
||||||
if _, ok := ps.Conns[c]; !ok {
|
|
||||||
t.Fatal("connection not tracked")
|
|
||||||
}
|
|
||||||
|
|
||||||
untrackConnection(ps, c)
|
|
||||||
if _, ok := ps.Conns[c]; ok {
|
|
||||||
t.Fatal("connection still tracked after untrack")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("closeTrackedConnections closes all tracked conns", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
|
|
||||||
c1 := &closeTrackingConn{}
|
|
||||||
c2 := &closeTrackingConn{}
|
|
||||||
ps.Conns[c1] = struct{}{}
|
|
||||||
ps.Conns[c2] = struct{}{}
|
|
||||||
|
|
||||||
closeTrackedConnections(ps)
|
|
||||||
|
|
||||||
if !c1.isClosed() || !c2.isClosed() {
|
|
||||||
t.Fatalf("expected all tracked conns closed, got c1=%v c2=%v", c1.isClosed(), c2.isClosed())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDestroy_ShutdownsServerContext(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
// Basic smoke test: ensure a real server can be shut down via Destroy.
|
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("listen error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("ok")) })}
|
|
||||||
go func() {
|
|
||||||
_ = srv.Serve(ln)
|
|
||||||
}()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 2)
|
|
||||||
ps := &model.ProxyServer{Server: srv, Conns: make(map[net.Conn]struct{})}
|
|
||||||
Destroy(ps, ch)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := srv.Shutdown(ctx); err != nil && err != http.ErrServerClosed {
|
|
||||||
t.Fatalf("server should be closed after Destroy, got shutdown error %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,77 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/ecdsa"
|
|
||||||
"crypto/elliptic"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"crypto/x509/pkix"
|
|
||||||
"math/big"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func drainEvents(t *testing.T, ch <-chan model.Event, n int, timeout time.Duration) []model.Event {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
events := make([]model.Event, 0, n)
|
|
||||||
deadline := time.After(timeout)
|
|
||||||
for len(events) < n {
|
|
||||||
select {
|
|
||||||
case ev := <-ch:
|
|
||||||
events = append(events, ev)
|
|
||||||
case <-deadline:
|
|
||||||
t.Fatalf("timeout waiting for %d events, got %d", n, len(events))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return events
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasEventType(events []model.Event, typ model.EventType) bool {
|
|
||||||
for _, ev := range events {
|
|
||||||
if ev.Type == typ {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func newTestCA(t *testing.T) *CertificateAuthority {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GenerateKey() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
tmpl := &x509.Certificate{
|
|
||||||
SerialNumber: big.NewInt(1),
|
|
||||||
Subject: pkix.Name{CommonName: "test-ca"},
|
|
||||||
NotBefore: now.Add(-time.Minute),
|
|
||||||
NotAfter: now.Add(time.Hour),
|
|
||||||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
|
|
||||||
IsCA: true,
|
|
||||||
BasicConstraintsValid: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("CreateCertificate() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cert, err := x509.ParseCertificate(der)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ParseCertificate() error = %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &CertificateAuthority{
|
|
||||||
cert: cert,
|
|
||||||
key: key,
|
|
||||||
leafCert: make(map[string]*tls.Certificate),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,256 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
type timeoutErr struct{}
|
|
||||||
|
|
||||||
func (timeoutErr) Error() string { return "timeout" }
|
|
||||||
func (timeoutErr) Timeout() bool { return true }
|
|
||||||
func (timeoutErr) Temporary() bool { return true }
|
|
||||||
|
|
||||||
type errWriter struct{}
|
|
||||||
|
|
||||||
func (errWriter) Write(_ []byte) (int, error) {
|
|
||||||
return 0, errors.New("write failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCanDisplayContent(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
contentType string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{name: "empty content type", contentType: "", want: false},
|
|
||||||
{name: "text type", contentType: "text/plain", want: true},
|
|
||||||
{name: "json type", contentType: "application/json", want: true},
|
|
||||||
{name: "xml suffix", contentType: "application/problem+xml", want: true},
|
|
||||||
{name: "unknown binary", contentType: "application/octet-stream", want: false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
if got := canDisplayContent(tt.contentType); got != tt.want {
|
|
||||||
t.Fatalf("canDisplayContent(%q) = %v, want %v", tt.contentType, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFormatHeaders(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
headers http.Header
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty returns none",
|
|
||||||
headers: http.Header{},
|
|
||||||
want: "<none>",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "sorts keys stably",
|
|
||||||
headers: http.Header{
|
|
||||||
"B-Key": {"b1"},
|
|
||||||
"A-Key": {"a1", "a2"},
|
|
||||||
},
|
|
||||||
want: `A-Key="a1,a2", B-Key="b1"`,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
if got := formatHeaders(tt.headers); got != tt.want {
|
|
||||||
t.Fatalf("formatHeaders() = %q, want %q", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetEndOfUUID(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
id := uuid.MustParse("123e4567-e89b-12d3-a456-426614174000")
|
|
||||||
if got, want := getEndOfUUID(id), "426614174000"; got != want {
|
|
||||||
t.Fatalf("getEndOfUUID() = %q, want %q", got, want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStatusFromUpstreamError(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
newReq := func(ctx context.Context) *http.Request {
|
|
||||||
reqURL, err := url.Parse("http://example.com")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("url parse failed: %v", err)
|
|
||||||
}
|
|
||||||
return (&http.Request{Method: http.MethodGet, URL: reqURL}).WithContext(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
req *http.Request
|
|
||||||
resp *http.Response
|
|
||||||
err error
|
|
||||||
want int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "prefers upstream response status",
|
|
||||||
req: newReq(context.Background()),
|
|
||||||
resp: &http.Response{StatusCode: http.StatusTeapot},
|
|
||||||
err: errors.New("ignored"),
|
|
||||||
want: http.StatusTeapot,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "context canceled maps to bad gateway",
|
|
||||||
req: func() *http.Request {
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cancel()
|
|
||||||
return newReq(ctx)
|
|
||||||
}(),
|
|
||||||
err: context.Canceled,
|
|
||||||
want: http.StatusBadGateway,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "deadline exceeded maps to gateway timeout",
|
|
||||||
req: newReq(context.Background()),
|
|
||||||
err: context.DeadlineExceeded,
|
|
||||||
want: http.StatusGatewayTimeout,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "net timeout maps to gateway timeout",
|
|
||||||
req: newReq(context.Background()),
|
|
||||||
err: timeoutErr{},
|
|
||||||
want: http.StatusGatewayTimeout,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "default maps to bad gateway",
|
|
||||||
req: newReq(context.Background()),
|
|
||||||
err: errors.New("dial failed"),
|
|
||||||
want: http.StatusBadGateway,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
if got := statusFromUpstreamError(tt.req, tt.resp, tt.err); got != tt.want {
|
|
||||||
t.Fatalf("statusFromUpstreamError() = %d, want %d", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewUpstreamTransport(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
got := newUpstreamTransport()
|
|
||||||
transport, ok := got.(*http.Transport)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("newUpstreamTransport() type = %T, want *http.Transport", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
if transport.Proxy != nil {
|
|
||||||
t.Fatal("newUpstreamTransport() Proxy must be nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWritePlainHTTPError(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
status int
|
|
||||||
writer *bufio.Writer
|
|
||||||
wantErr bool
|
|
||||||
wantStatus int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "writes valid response and flushes",
|
|
||||||
status: http.StatusBadGateway,
|
|
||||||
writer: bufio.NewWriter(&strings.Builder{}),
|
|
||||||
wantErr: false,
|
|
||||||
wantStatus: http.StatusBadGateway,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "returns write error",
|
|
||||||
status: http.StatusBadGateway,
|
|
||||||
writer: bufio.NewWriter(errWriter{}),
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "returns response write error when writer already failed",
|
|
||||||
status: http.StatusBadGateway,
|
|
||||||
writer: bufio.NewWriter(errWriter{}),
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
w := tt.writer
|
|
||||||
if !tt.wantErr {
|
|
||||||
w = bufio.NewWriter(&sb)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := writePlainHTTPError(w, tt.status)
|
|
||||||
if tt.name == "returns response write error when writer already failed" {
|
|
||||||
_ = w.Flush() // set sticky writer error so resp.Write fails immediately
|
|
||||||
err = writePlainHTTPError(w, tt.status)
|
|
||||||
}
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Fatalf("writePlainHTTPError() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
}
|
|
||||||
if tt.wantErr {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, readErr := http.ReadResponse(bufio.NewReader(strings.NewReader(sb.String())), &http.Request{Method: http.MethodGet})
|
|
||||||
if readErr != nil {
|
|
||||||
t.Fatalf("ReadResponse() error = %v", readErr)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != tt.wantStatus {
|
|
||||||
t.Fatalf("status = %d, want %d", resp.StatusCode, tt.wantStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
if gotCT := resp.Header.Get("Content-Type"); gotCT != "text/plain; charset=utf-8" {
|
|
||||||
t.Fatalf("Content-Type = %q, want %q", gotCT, "text/plain; charset=utf-8")
|
|
||||||
}
|
|
||||||
|
|
||||||
wantBody := http.StatusText(tt.status)
|
|
||||||
if gotCL := resp.Header.Get("Content-Length"); gotCL != strconv.Itoa(len(wantBody)) {
|
|
||||||
t.Fatalf("Content-Length = %q, want %q", gotCL, strconv.Itoa(len(wantBody)))
|
|
||||||
}
|
|
||||||
|
|
||||||
body, bodyErr := io.ReadAll(resp.Body)
|
|
||||||
if bodyErr != nil {
|
|
||||||
t.Fatalf("ReadAll(body) error = %v", bodyErr)
|
|
||||||
}
|
|
||||||
if string(body) != wantBody {
|
|
||||||
t.Fatalf("body = %q, want %q", string(body), wantBody)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -14,26 +14,12 @@ const (
|
|||||||
maxRequests = 256
|
maxRequests = 256
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
focusPaneRequests = iota
|
|
||||||
focusPaneDetails
|
|
||||||
focusPaneEvents
|
|
||||||
focusPaneStd
|
|
||||||
)
|
|
||||||
|
|
||||||
type Model struct {
|
type Model struct {
|
||||||
channel <-chan model.Event
|
channel <-chan model.Event
|
||||||
controls Controls
|
controls Controls
|
||||||
|
|
||||||
events []model.Event
|
events []model.Event
|
||||||
requests []model.Request
|
requests []model.Request
|
||||||
requestCursor int
|
|
||||||
requestScroll int
|
|
||||||
detailsTab int
|
|
||||||
detailsScroll int
|
|
||||||
eventsScroll int
|
|
||||||
stdScroll int
|
|
||||||
focusedPane int
|
|
||||||
|
|
||||||
width int
|
width int
|
||||||
height int
|
height int
|
||||||
@ -57,13 +43,6 @@ func NewModel(ch <-chan model.Event, controls Controls) Model {
|
|||||||
controls: controls,
|
controls: controls,
|
||||||
events: make([]model.Event, 0, maxEvents),
|
events: make([]model.Event, 0, maxEvents),
|
||||||
requests: make([]model.Request, 0, maxRequests),
|
requests: make([]model.Request, 0, maxRequests),
|
||||||
requestCursor: 0,
|
|
||||||
requestScroll: 0,
|
|
||||||
detailsTab: detailsTabOverview,
|
|
||||||
detailsScroll: 0,
|
|
||||||
eventsScroll: 0,
|
|
||||||
stdScroll: 0,
|
|
||||||
focusedPane: focusPaneRequests,
|
|
||||||
width: 0,
|
width: 0,
|
||||||
height: 0,
|
height: 0,
|
||||||
showEvents: false,
|
showEvents: false,
|
||||||
|
|||||||
@ -1,361 +0,0 @@
|
|||||||
package tui
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
tea "github.com/charmbracelet/bubbletea"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewModelDefaults(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event)
|
|
||||||
m := NewModel(ch, Controls{})
|
|
||||||
|
|
||||||
if m.channel != ch {
|
|
||||||
t.Fatal("channel not set")
|
|
||||||
}
|
|
||||||
if len(m.events) != 0 || len(m.requests) != 0 {
|
|
||||||
t.Fatal("events/requests should initialize empty")
|
|
||||||
}
|
|
||||||
if m.width != 0 || m.height != 0 {
|
|
||||||
t.Fatal("width/height should initialize zero")
|
|
||||||
}
|
|
||||||
if m.showEvents || m.showStd || m.showSearch || m.restarting {
|
|
||||||
t.Fatal("toggle flags should initialize false")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInitBatchesEventAndTick(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event)
|
|
||||||
m := NewModel(ch, Controls{})
|
|
||||||
|
|
||||||
cmd := m.Init()
|
|
||||||
if cmd == nil {
|
|
||||||
t.Fatal("Init() returned nil cmd")
|
|
||||||
}
|
|
||||||
if _, ok := cmd().(tea.BatchMsg); !ok {
|
|
||||||
t.Fatalf("Init cmd message type = %T, want tea.BatchMsg", cmd())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWaitForEvent(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("returns EventMsg when channel has value", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event, 1)
|
|
||||||
ch <- model.Event{Type: model.EventTypeWarn, Body: "hello"}
|
|
||||||
|
|
||||||
msg := waitForEvent(ch)()
|
|
||||||
ev, ok := msg.(EventMsg)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("msg type = %T, want EventMsg", msg)
|
|
||||||
}
|
|
||||||
if ev.value.Body != "hello" {
|
|
||||||
t.Fatalf("event body = %q, want %q", ev.value.Body, "hello")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("returns ErrMsg when channel closed", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ch := make(chan model.Event)
|
|
||||||
close(ch)
|
|
||||||
|
|
||||||
msg := waitForEvent(ch)()
|
|
||||||
if _, ok := msg.(ErrMsg); !ok {
|
|
||||||
t.Fatalf("msg type = %T, want ErrMsg", msg)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMessagesCommands(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("restartCmd nil restart returns nil", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
if cmd := restartCmd(nil); cmd != nil {
|
|
||||||
t.Fatal("restartCmd(nil) should return nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("restartCmd wraps restart result", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
wantErr := errors.New("boom")
|
|
||||||
msg := restartCmd(func() error { return wantErr })()
|
|
||||||
rm, ok := msg.(RestartResultMsg)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("msg type = %T, want RestartResultMsg", msg)
|
|
||||||
}
|
|
||||||
if !errors.Is(rm.err, wantErr) {
|
|
||||||
t.Fatalf("restart result error = %v, want %v", rm.err, wantErr)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("tickCmd emits TickMsg", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
cmd := tickCmd()
|
|
||||||
if cmd == nil {
|
|
||||||
t.Fatal("tickCmd returned nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
msgCh := make(chan tea.Msg, 1)
|
|
||||||
go func() { msgCh <- cmd() }()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case msg := <-msgCh:
|
|
||||||
if _, ok := msg.(TickMsg); !ok {
|
|
||||||
t.Fatalf("msg type = %T, want TickMsg", msg)
|
|
||||||
}
|
|
||||||
case <-time.After(200 * time.Millisecond):
|
|
||||||
t.Fatal("timeout waiting for tick message")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdate(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("window size updates dimensions", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
m := NewModel(make(chan model.Event), Controls{})
|
|
||||||
next, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
|
|
||||||
got := next.(Model)
|
|
||||||
if got.width != 120 || got.height != 40 {
|
|
||||||
t.Fatalf("dimensions = (%d,%d), want (120,40)", got.width, got.height)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("tick updates now and reschedules only with pending requests", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
m1 := NewModel(make(chan model.Event), Controls{})
|
|
||||||
next1, cmd1 := m1.Update(TickMsg{Now: now})
|
|
||||||
got1 := next1.(Model)
|
|
||||||
if !got1.now.Equal(now) {
|
|
||||||
t.Fatal("tick should update now")
|
|
||||||
}
|
|
||||||
if cmd1 != nil {
|
|
||||||
t.Fatal("tick without pending requests should not reschedule")
|
|
||||||
}
|
|
||||||
|
|
||||||
m2 := NewModel(make(chan model.Event), Controls{})
|
|
||||||
m2.requests = append(m2.requests, model.Request{ID: uuid.New(), Pending: true})
|
|
||||||
_, cmd2 := m2.Update(TickMsg{Now: now})
|
|
||||||
if cmd2 == nil {
|
|
||||||
t.Fatal("tick with pending requests should reschedule")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("key handling toggles and quit", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
m := NewModel(make(chan model.Event), Controls{})
|
|
||||||
next, quitCmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("q")})
|
|
||||||
_ = next
|
|
||||||
if quitCmd == nil {
|
|
||||||
t.Fatal("q should return quit cmd")
|
|
||||||
}
|
|
||||||
|
|
||||||
nextCtrlC, quitCtrlC := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
|
|
||||||
_ = nextCtrlC
|
|
||||||
if quitCtrlC == nil {
|
|
||||||
t.Fatal("ctrl+c should return quit cmd")
|
|
||||||
}
|
|
||||||
|
|
||||||
next2, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("e")})
|
|
||||||
if !next2.(Model).showEvents {
|
|
||||||
t.Fatal("e should toggle showEvents")
|
|
||||||
}
|
|
||||||
|
|
||||||
next3, _ := next2.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("o")})
|
|
||||||
if !next3.(Model).showStd {
|
|
||||||
t.Fatal("o should toggle showStd")
|
|
||||||
}
|
|
||||||
|
|
||||||
next4, _ := next3.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("/")})
|
|
||||||
if !next4.(Model).showSearch {
|
|
||||||
t.Fatal("/ should enable search")
|
|
||||||
}
|
|
||||||
|
|
||||||
next5, _ := next4.(Model).Update(tea.KeyMsg{Type: tea.KeyEsc})
|
|
||||||
if next5.(Model).showSearch {
|
|
||||||
t.Fatal("esc should disable search")
|
|
||||||
}
|
|
||||||
|
|
||||||
next6, cmd6 := next5.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("x")})
|
|
||||||
if cmd6 != nil {
|
|
||||||
t.Fatal("unknown key should not return command")
|
|
||||||
}
|
|
||||||
if next6.(Model).showEvents != next5.(Model).showEvents ||
|
|
||||||
next6.(Model).showStd != next5.(Model).showStd ||
|
|
||||||
next6.(Model).showSearch != next5.(Model).showSearch {
|
|
||||||
t.Fatal("unknown key should not alter toggle state")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("restart key guarded by state and control fn", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
m := NewModel(make(chan model.Event), Controls{})
|
|
||||||
next, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlR})
|
|
||||||
if cmd != nil {
|
|
||||||
t.Fatal("ctrl+r with nil restart control should return nil cmd")
|
|
||||||
}
|
|
||||||
if next.(Model).restarting {
|
|
||||||
t.Fatal("restarting should remain false when restart control missing")
|
|
||||||
}
|
|
||||||
|
|
||||||
m2 := NewModel(make(chan model.Event), Controls{Restart: func() error { return nil }})
|
|
||||||
next2, cmd2 := m2.Update(tea.KeyMsg{Type: tea.KeyCtrlR})
|
|
||||||
if cmd2 == nil {
|
|
||||||
t.Fatal("ctrl+r with restart control should return cmd")
|
|
||||||
}
|
|
||||||
if !next2.(Model).restarting {
|
|
||||||
t.Fatal("restarting should be true after ctrl+r")
|
|
||||||
}
|
|
||||||
|
|
||||||
next3, cmd3 := next2.(Model).Update(tea.KeyMsg{Type: tea.KeyCtrlR})
|
|
||||||
if cmd3 != nil {
|
|
||||||
t.Fatal("ctrl+r while restarting should return nil cmd")
|
|
||||||
}
|
|
||||||
if !next3.(Model).restarting {
|
|
||||||
t.Fatal("restarting should stay true while guarded")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("ErrMsg pushes warning event", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
m := NewModel(make(chan model.Event), Controls{})
|
|
||||||
next, _ := m.Update(ErrMsg{err: errors.New("closed")})
|
|
||||||
got := next.(Model)
|
|
||||||
if len(got.events) != 1 {
|
|
||||||
t.Fatalf("event len = %d, want 1", len(got.events))
|
|
||||||
}
|
|
||||||
if got.events[0].Type != model.EventTypeWarn {
|
|
||||||
t.Fatalf("event type = %s, want %s", got.events[0].Type, model.EventTypeWarn)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("RestartResultMsg clears restarting and warns on error", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
m := NewModel(make(chan model.Event), Controls{})
|
|
||||||
m.restarting = true
|
|
||||||
|
|
||||||
next1, _ := m.Update(RestartResultMsg{err: nil})
|
|
||||||
if next1.(Model).restarting {
|
|
||||||
t.Fatal("restarting should clear on restart result")
|
|
||||||
}
|
|
||||||
|
|
||||||
next2, _ := m.Update(RestartResultMsg{err: errors.New("fail")})
|
|
||||||
got2 := next2.(Model)
|
|
||||||
if len(got2.events) != 1 || got2.events[0].Type != model.EventTypeWarn {
|
|
||||||
t.Fatalf("expected warn event on restart error, got %#v", got2.events)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("EventMsg updates events and requests", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
ch := make(chan model.Event, 2)
|
|
||||||
m := NewModel(ch, Controls{})
|
|
||||||
|
|
||||||
reqID := uuid.New()
|
|
||||||
startEv := EventMsg{value: model.Event{Type: model.EventTypeRequestStarted, Request: model.Request{ID: reqID, Method: http.MethodGet, Pending: true}}}
|
|
||||||
next1, cmd1 := m.Update(startEv)
|
|
||||||
if cmd1 == nil {
|
|
||||||
t.Fatal("EventMsg should return wait/tick cmd")
|
|
||||||
}
|
|
||||||
got1 := next1.(Model)
|
|
||||||
if len(got1.events) != 1 || len(got1.requests) != 1 {
|
|
||||||
t.Fatalf("expected one event and one request, got events=%d requests=%d", len(got1.events), len(got1.requests))
|
|
||||||
}
|
|
||||||
|
|
||||||
finishReq := got1.requests[0]
|
|
||||||
finishReq.Pending = false
|
|
||||||
finishReq.Status = 200
|
|
||||||
finishEv := EventMsg{value: model.Event{Type: model.EventTypeRequestFinished, Request: finishReq}}
|
|
||||||
next2, cmd2 := got1.Update(finishEv)
|
|
||||||
got2 := next2.(Model)
|
|
||||||
if got2.requests[0].Pending {
|
|
||||||
t.Fatal("request should be updated to non-pending")
|
|
||||||
}
|
|
||||||
if got2.requests[0].Status != 200 {
|
|
||||||
t.Fatalf("request status = %d, want 200", got2.requests[0].Status)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cmd2 == nil {
|
|
||||||
t.Fatal("expected waitForEvent cmd after finished request")
|
|
||||||
}
|
|
||||||
ch <- model.Event{Type: model.EventTypeWarn, Body: "next"}
|
|
||||||
msg := cmd2()
|
|
||||||
if _, ok := msg.(EventMsg); !ok {
|
|
||||||
t.Fatalf("cmd2 message type = %T, want EventMsg", msg)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestModelHelpers(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("pushEvent trims to maxEvents", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
m := NewModel(make(chan model.Event), Controls{})
|
|
||||||
for i := 0; i < maxEvents+5; i++ {
|
|
||||||
m.pushEvent(model.Event{Body: "x"})
|
|
||||||
}
|
|
||||||
if len(m.events) != maxEvents {
|
|
||||||
t.Fatalf("events len = %d, want %d", len(m.events), maxEvents)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("createRequest ignores CONNECT and trims", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
m := NewModel(make(chan model.Event), Controls{})
|
|
||||||
m.createRequest(model.Request{Method: http.MethodConnect})
|
|
||||||
if len(m.requests) != 0 {
|
|
||||||
t.Fatal("CONNECT request should be ignored")
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < maxRequests+3; i++ {
|
|
||||||
m.createRequest(model.Request{ID: uuid.New(), Method: http.MethodGet})
|
|
||||||
}
|
|
||||||
if len(m.requests) != maxRequests {
|
|
||||||
t.Fatalf("requests len = %d, want %d", len(m.requests), maxRequests)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("updateRequest updates only matching request", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
id1 := uuid.New()
|
|
||||||
id2 := uuid.New()
|
|
||||||
m := NewModel(make(chan model.Event), Controls{})
|
|
||||||
m.requests = []model.Request{{ID: id1, Status: 100}, {ID: id2, Status: 101}}
|
|
||||||
|
|
||||||
m.updateRequest(model.Request{ID: id2, Status: 202})
|
|
||||||
if m.requests[0].Status != 100 || m.requests[1].Status != 202 {
|
|
||||||
t.Fatalf("unexpected statuses after update: %#v", m.requests)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("hasPendingRequests true and false", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
m := NewModel(make(chan model.Event), Controls{})
|
|
||||||
if m.hasPendingRequests() {
|
|
||||||
t.Fatal("empty model should not have pending requests")
|
|
||||||
}
|
|
||||||
m.requests = []model.Request{{ID: uuid.New(), Pending: false}, {ID: uuid.New(), Pending: true}}
|
|
||||||
if !m.hasPendingRequests() {
|
|
||||||
t.Fatal("expected pending requests to be true")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@ -1,31 +1,16 @@
|
|||||||
package tui
|
package tui
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/hex"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
|
||||||
|
|
||||||
"github.com/charmbracelet/lipgloss"
|
"github.com/charmbracelet/lipgloss"
|
||||||
"github.com/charmbracelet/x/ansi"
|
|
||||||
"termtap.dev/internal/model"
|
"termtap.dev/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: LOTS OF THIS SUCKS BUT IT WORKS
|
// TODO: LOTS OF THIS SUCKS BUT IT WORKS
|
||||||
|
|
||||||
const (
|
|
||||||
detailsTabOverview = iota
|
|
||||||
detailsTabRequest
|
|
||||||
detailsTabResponse
|
|
||||||
detailsTabHeaders
|
|
||||||
)
|
|
||||||
|
|
||||||
var detailsTabNames = []string{"Overview", "Request", "Response", "Headers"}
|
|
||||||
|
|
||||||
func (m Model) renderStatusBar(w int) string {
|
func (m Model) renderStatusBar(w int) string {
|
||||||
var errCount int
|
var errCount int
|
||||||
var msSum int64
|
var msSum int64
|
||||||
@ -38,15 +23,7 @@ func (m Model) renderStatusBar(w int) string {
|
|||||||
|
|
||||||
avg := int(msSum) / max(1, len(m.requests))
|
avg := int(msSum) / max(1, len(m.requests))
|
||||||
left := fmt.Sprintf(" tap %3d reqs | %d err | avg %dms", len(m.requests), errCount, avg)
|
left := fmt.Sprintf(" tap %3d reqs | %d err | avg %dms", len(m.requests), errCount, avg)
|
||||||
logState := "logs off"
|
right := "j/k nav / search tab panel e events o output ^r restart q quit "
|
||||||
if m.showEvents && m.showStd {
|
|
||||||
logState = "events+output"
|
|
||||||
} else if m.showEvents {
|
|
||||||
logState = "events"
|
|
||||||
} else if m.showStd {
|
|
||||||
logState = "output"
|
|
||||||
}
|
|
||||||
right := " " + logState + " "
|
|
||||||
|
|
||||||
spaceSize := max(w-(len(left)+len(right)), 0)
|
spaceSize := max(w-(len(left)+len(right)), 0)
|
||||||
space := strings.Repeat(" ", spaceSize)
|
space := strings.Repeat(" ", spaceSize)
|
||||||
@ -54,91 +31,6 @@ func (m Model) renderStatusBar(w int) string {
|
|||||||
return m.theme.Header.Render(left + space + right)
|
return m.theme.Header.Render(left + space + right)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Model) renderBottomStatusBar(w int) string {
|
|
||||||
if w <= 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
modeLabel := "REQ"
|
|
||||||
modeColor := blue
|
|
||||||
switch m.focusedPane {
|
|
||||||
case focusPaneDetails:
|
|
||||||
modeLabel = "DETAIL"
|
|
||||||
modeColor = blue
|
|
||||||
case focusPaneEvents:
|
|
||||||
modeLabel = "EVENT"
|
|
||||||
modeColor = green
|
|
||||||
case focusPaneStd:
|
|
||||||
modeLabel = "OUT"
|
|
||||||
modeColor = orange
|
|
||||||
}
|
|
||||||
|
|
||||||
modeStyle := lipgloss.NewStyle().Foreground(background).Background(modeColor).Bold(true)
|
|
||||||
left := modeStyle.Render(" " + modeLabel + " ")
|
|
||||||
if m.restarting {
|
|
||||||
left += m.theme.Text.Render(" ") + m.theme.EventWarn.Render(" RESTARTING ")
|
|
||||||
}
|
|
||||||
|
|
||||||
right := m.theme.TextMuted.Render(" " + m.bottomStatusRight() + " ")
|
|
||||||
if lipgloss.Width(right) >= w {
|
|
||||||
return clampRendered(right, w)
|
|
||||||
}
|
|
||||||
|
|
||||||
maxLeft := max(0, w-lipgloss.Width(right)-1)
|
|
||||||
left = clampRendered(left, maxLeft)
|
|
||||||
spaceSize := max(1, w-lipgloss.Width(left)-lipgloss.Width(right))
|
|
||||||
space := m.theme.Text.Render(strings.Repeat(" ", spaceSize))
|
|
||||||
|
|
||||||
return clampRendered(left+space+right, w)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Model) bottomStatusRight() string {
|
|
||||||
selected, total := m.requestSelectionStats()
|
|
||||||
|
|
||||||
switch m.focusedPane {
|
|
||||||
case focusPaneRequests:
|
|
||||||
return fmt.Sprintf("req %d/%d", selected, total)
|
|
||||||
|
|
||||||
case focusPaneDetails:
|
|
||||||
tab := "overview"
|
|
||||||
if m.detailsTab >= 0 && m.detailsTab < len(detailsTabNames) {
|
|
||||||
tab = strings.ToLower(detailsTabNames[m.detailsTab])
|
|
||||||
}
|
|
||||||
linesTotal := m.detailsContentLineCount(m.detailsPaneWidth())
|
|
||||||
linesPos := 0
|
|
||||||
if linesTotal > 0 {
|
|
||||||
linesPos = min(linesTotal, m.detailsScroll+1)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("req %d/%d | tab %s | %d/%d", selected, total, tab, linesPos, linesTotal)
|
|
||||||
|
|
||||||
case focusPaneEvents:
|
|
||||||
count := m.eventCount()
|
|
||||||
if m.eventsScroll == 0 {
|
|
||||||
return fmt.Sprintf("events %d | LIVE", count)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("events %d | PAUSED", count)
|
|
||||||
|
|
||||||
case focusPaneStd:
|
|
||||||
count := m.stdLogCount()
|
|
||||||
if m.stdScroll == 0 {
|
|
||||||
return fmt.Sprintf("lines %d | LIVE", count)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("lines %d | PAUSED", count)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Model) requestSelectionStats() (selected int, total int) {
|
|
||||||
total = len(m.requests)
|
|
||||||
if total == 0 {
|
|
||||||
return 0, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
selected = min(total, max(1, m.requestCursor+1))
|
|
||||||
return selected, total
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Implement
|
// TODO: Implement
|
||||||
func (m Model) renderSearchPane(w, h int) []string {
|
func (m Model) renderSearchPane(w, h int) []string {
|
||||||
lines := make([]string, h)
|
lines := make([]string, h)
|
||||||
@ -151,79 +43,39 @@ func (m Model) renderSearchPane(w, h int) []string {
|
|||||||
func (m Model) renderRequestPane(w, h int) []string {
|
func (m Model) renderRequestPane(w, h int) []string {
|
||||||
var lines []string
|
var lines []string
|
||||||
|
|
||||||
titleStyle := m.theme.TextMuted
|
|
||||||
if m.focusedPane == focusPaneRequests {
|
|
||||||
titleStyle = m.theme.EventHeader
|
|
||||||
}
|
|
||||||
title := titleStyle.Render(padToWidth("[1] REQUESTS", w))
|
|
||||||
lines = append(lines, title)
|
|
||||||
|
|
||||||
// Render header
|
// Render header
|
||||||
headerLeft := fmt.Sprintf(" %-7s %-24s %s", "METHOD", "HOST", "PATH")
|
headerLeft := fmt.Sprintf(" %-7s %-24s %s", "METHOD", "HOST", "PATH")
|
||||||
headerRight := fmt.Sprintf("%4s %8s ", "CODE", "TIME")
|
headerRight := fmt.Sprintf("%4s %8s ", "CODE", "TIME")
|
||||||
headerSpace := strings.Repeat(" ", max(0, w-len(headerLeft+headerRight)))
|
headerSpace := strings.Repeat(" ", max(0, w-len(headerLeft+headerRight)))
|
||||||
header := m.theme.TextMuted.Render(headerLeft + headerSpace + headerRight)
|
header := headerLeft + headerSpace + headerRight
|
||||||
lines = append(lines, header)
|
lines = append(lines, header)
|
||||||
|
|
||||||
bodyLines := make([]string, 0, len(m.requests))
|
for i := len(m.requests) - 1; i >= 0; i-- {
|
||||||
for i, row := len(m.requests)-1, 0; i >= 0; i, row = i-1, row+1 {
|
|
||||||
req := m.requests[i]
|
req := m.requests[i]
|
||||||
duration := req.Duration
|
|
||||||
|
// Some formatting magic here maybe
|
||||||
|
left := fmt.Sprintf(
|
||||||
|
" %-7s %-24s %s",
|
||||||
|
strings.ToUpper(req.Method),
|
||||||
|
truncate(req.Host, 24),
|
||||||
|
req.URL,
|
||||||
|
)
|
||||||
|
right := fmt.Sprintf(
|
||||||
|
"%4d %8s ",
|
||||||
|
req.Status,
|
||||||
|
formatDuration(req.Duration),
|
||||||
|
)
|
||||||
if req.Pending && !req.StartTime.IsZero() {
|
if req.Pending && !req.StartTime.IsZero() {
|
||||||
duration = time.Since(req.StartTime)
|
right = fmt.Sprintf(
|
||||||
|
"%4s %8s ",
|
||||||
|
"",
|
||||||
|
formatDuration(time.Since(req.StartTime)),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
space := strings.Repeat(" ", max(0, w-len(left+right)))
|
||||||
|
|
||||||
statusStyle := lipgloss.NewStyle().Foreground(green).Background(background).Bold(true)
|
line := left + space + right
|
||||||
if req.Failed || req.Status >= 500 {
|
lines = append(lines, line)
|
||||||
statusStyle = lipgloss.NewStyle().Foreground(red).Background(background).Bold(true)
|
|
||||||
} else if req.Status >= 400 {
|
|
||||||
statusStyle = lipgloss.NewStyle().Foreground(orange).Background(background).Bold(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
latencyStyle := m.theme.Text
|
|
||||||
if duration >= 2*time.Second {
|
|
||||||
latencyStyle = lipgloss.NewStyle().Foreground(orange).Background(background).Bold(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
methodCol := statusStyle.Render(fmt.Sprintf("%-7s", truncate(strings.ToUpper(req.Method), 7)))
|
|
||||||
hostCol := m.theme.Text.Render(fmt.Sprintf("%-24s", truncate(req.Host, 24)))
|
|
||||||
pathCol := m.theme.Text.Render(req.URL)
|
|
||||||
|
|
||||||
statusText := ""
|
|
||||||
if !req.Pending && req.Status > 0 {
|
|
||||||
statusText = fmt.Sprintf("%d", req.Status)
|
|
||||||
}
|
|
||||||
statusCol := statusStyle.Render(fmt.Sprintf("%4s", statusText))
|
|
||||||
timeCol := latencyStyle.Render(fmt.Sprintf("%8s", formatDuration(duration)))
|
|
||||||
|
|
||||||
sep := m.theme.Text.Render(" ")
|
|
||||||
left := sep + methodCol + sep + hostCol + sep + pathCol
|
|
||||||
right := statusCol + sep + timeCol + sep
|
|
||||||
space := strings.Repeat(" ", max(0, w-lipgloss.Width(left+right)))
|
|
||||||
|
|
||||||
line := left + m.theme.Text.Render(space) + right
|
|
||||||
line = clampRendered(line, w)
|
|
||||||
if row == m.requestCursor {
|
|
||||||
line = m.theme.RequestSelected.Render(ansi.Strip(line))
|
|
||||||
}
|
|
||||||
bodyLines = append(bodyLines, line)
|
|
||||||
}
|
|
||||||
|
|
||||||
bodyHeight := max(0, h-len(lines))
|
|
||||||
scroll := m.requestScroll
|
|
||||||
maxScroll := max(0, len(bodyLines)-bodyHeight)
|
|
||||||
if scroll < 0 {
|
|
||||||
scroll = 0
|
|
||||||
}
|
|
||||||
if scroll > maxScroll {
|
|
||||||
scroll = maxScroll
|
|
||||||
}
|
|
||||||
|
|
||||||
if bodyHeight > 0 {
|
|
||||||
end := min(len(bodyLines), scroll+bodyHeight)
|
|
||||||
if scroll < end {
|
|
||||||
lines = append(lines, bodyLines[scroll:end]...)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cleanup
|
// Cleanup
|
||||||
@ -240,396 +92,16 @@ func (m Model) renderRequestPane(w, h int) []string {
|
|||||||
return lines
|
return lines
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Implement
|
||||||
func (m Model) renderDetailsPane(w, h int) []string {
|
func (m Model) renderDetailsPane(w, h int) []string {
|
||||||
if h <= 0 {
|
lines := make([]string, h)
|
||||||
return nil
|
for y := range lines {
|
||||||
}
|
lines[y] = m.theme.Text.Render(strings.Repeat(" ", w))
|
||||||
|
|
||||||
formatLine := func(content string) string {
|
|
||||||
line := truncate(content, w)
|
|
||||||
if len(line) < w {
|
|
||||||
line += strings.Repeat(" ", w-len(line))
|
|
||||||
}
|
|
||||||
return m.theme.Text.Render(line)
|
|
||||||
}
|
|
||||||
|
|
||||||
formatMutedLine := func(content string) string {
|
|
||||||
line := truncate(content, w)
|
|
||||||
if len(line) < w {
|
|
||||||
line += strings.Repeat(" ", w-len(line))
|
|
||||||
}
|
|
||||||
return m.theme.TextMuted.Render(line)
|
|
||||||
}
|
|
||||||
|
|
||||||
formatMutedItalicLine := func(content string) string {
|
|
||||||
line := truncate(content, w)
|
|
||||||
if len(line) < w {
|
|
||||||
line += strings.Repeat(" ", w-len(line))
|
|
||||||
}
|
|
||||||
return lipgloss.NewStyle().
|
|
||||||
Foreground(textMuted).
|
|
||||||
Background(background).
|
|
||||||
Italic(true).
|
|
||||||
Render(line)
|
|
||||||
}
|
|
||||||
|
|
||||||
renderTabRow := func() string {
|
|
||||||
if w <= 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
activeTabStyle := lipgloss.NewStyle().
|
|
||||||
Foreground(background).
|
|
||||||
Background(blue).
|
|
||||||
Bold(true)
|
|
||||||
|
|
||||||
var parts []string
|
|
||||||
for i, name := range detailsTabNames {
|
|
||||||
label := " " + name + " "
|
|
||||||
if i == m.detailsTab {
|
|
||||||
label = " [" + name + "] "
|
|
||||||
parts = append(parts, activeTabStyle.Render(label))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
parts = append(parts, m.theme.TextMuted.Render(label))
|
|
||||||
}
|
|
||||||
|
|
||||||
sep := m.theme.Text.Render(" ")
|
|
||||||
line := strings.Join(parts, sep)
|
|
||||||
if lipgloss.Width(line) < w {
|
|
||||||
pad := strings.Repeat(" ", w-lipgloss.Width(line))
|
|
||||||
line += m.theme.Text.Render(pad)
|
|
||||||
}
|
|
||||||
|
|
||||||
return clampRendered(line, w)
|
|
||||||
}
|
|
||||||
|
|
||||||
lines := make([]string, 0, h)
|
|
||||||
detailsTitleStyle := m.theme.TextMuted
|
|
||||||
if m.focusedPane == focusPaneDetails {
|
|
||||||
detailsTitleStyle = m.theme.EventHeader
|
|
||||||
}
|
|
||||||
lines = append(lines, detailsTitleStyle.Render(padToWidth("[2] DETAIL", w)))
|
|
||||||
if len(lines) >= h {
|
|
||||||
return lines[:h]
|
|
||||||
}
|
|
||||||
lines = append(lines, renderTabRow())
|
|
||||||
if len(lines) >= h {
|
|
||||||
return lines[:h]
|
|
||||||
}
|
|
||||||
|
|
||||||
contentLines := m.detailsContentLines(w, formatLine, formatMutedLine, formatMutedItalicLine)
|
|
||||||
contentHeight := max(0, h-len(lines))
|
|
||||||
scroll := m.detailsScroll
|
|
||||||
maxScroll := max(0, len(contentLines)-contentHeight)
|
|
||||||
if scroll < 0 {
|
|
||||||
scroll = 0
|
|
||||||
}
|
|
||||||
if scroll > maxScroll {
|
|
||||||
scroll = maxScroll
|
|
||||||
}
|
|
||||||
|
|
||||||
if contentHeight > 0 {
|
|
||||||
end := min(len(contentLines), scroll+contentHeight)
|
|
||||||
if scroll < end {
|
|
||||||
lines = append(lines, contentLines[scroll:end]...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for len(lines) < h {
|
|
||||||
lines = append(lines, formatLine(""))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(lines) > h {
|
|
||||||
return lines[:h]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return lines
|
return lines
|
||||||
}
|
}
|
||||||
|
|
||||||
func padToWidth(s string, w int) string {
|
|
||||||
if w <= 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
s = truncate(s, w)
|
|
||||||
if len(s) < w {
|
|
||||||
s += strings.Repeat(" ", w-len(s))
|
|
||||||
}
|
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Model) detailsContentLines(
|
|
||||||
w int,
|
|
||||||
formatLine func(string) string,
|
|
||||||
formatMutedLine func(string) string,
|
|
||||||
formatMutedItalicLine func(string) string,
|
|
||||||
) []string {
|
|
||||||
selectedReq, ok := m.selectedRequest()
|
|
||||||
if !ok {
|
|
||||||
return []string{formatLine(" No requests yet. Use j/k once requests arrive.")}
|
|
||||||
}
|
|
||||||
|
|
||||||
contentLines := make([]string, 0, 64)
|
|
||||||
|
|
||||||
switch m.detailsTab {
|
|
||||||
case detailsTabOverview:
|
|
||||||
duration := selectedReq.Duration
|
|
||||||
if selectedReq.Pending && !selectedReq.StartTime.IsZero() {
|
|
||||||
duration = time.Since(selectedReq.StartTime)
|
|
||||||
}
|
|
||||||
|
|
||||||
renderKVLine := func(key, value string, valueStyle lipgloss.Style) string {
|
|
||||||
keyPart := m.theme.TextMuted.Render(fmt.Sprintf(" %-8s", key))
|
|
||||||
valuePart := valueStyle.Render(value)
|
|
||||||
sep := m.theme.Text.Render(" ")
|
|
||||||
line := keyPart + sep + valuePart
|
|
||||||
if lipgloss.Width(line) < w {
|
|
||||||
line += m.theme.Text.Render(strings.Repeat(" ", w-lipgloss.Width(line)))
|
|
||||||
}
|
|
||||||
return clampRendered(line, w)
|
|
||||||
}
|
|
||||||
|
|
||||||
statusText := "-"
|
|
||||||
statusStyle := m.theme.TextMuted
|
|
||||||
if selectedReq.Status > 0 {
|
|
||||||
statusText = fmt.Sprintf("%d", selectedReq.Status)
|
|
||||||
statusStyle = m.theme.EventSuccess
|
|
||||||
if selectedReq.Status >= 400 {
|
|
||||||
statusStyle = m.theme.EventWarn
|
|
||||||
}
|
|
||||||
if selectedReq.Status >= 500 {
|
|
||||||
statusStyle = m.theme.EventError
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
timeText := "-"
|
|
||||||
if !selectedReq.StartTime.IsZero() {
|
|
||||||
timeText = selectedReq.StartTime.Format("3:04:05 PM")
|
|
||||||
}
|
|
||||||
|
|
||||||
contentLines = append(contentLines, formatLine(""))
|
|
||||||
contentLines = append(contentLines, renderKVLine("Method", strings.ToUpper(selectedReq.Method), m.theme.Text))
|
|
||||||
contentLines = append(contentLines, renderKVLine("URL", selectedReq.RawURL, m.theme.TextMuted))
|
|
||||||
queryText := selectedReq.QueryString
|
|
||||||
if queryText == "" {
|
|
||||||
queryText = "-"
|
|
||||||
}
|
|
||||||
contentLines = append(contentLines, renderKVLine("Query", queryText, m.theme.TextMuted))
|
|
||||||
contentLines = append(contentLines, renderKVLine("Status", statusText, statusStyle))
|
|
||||||
contentLines = append(contentLines, renderKVLine("Latency", formatDuration(duration), m.theme.Text))
|
|
||||||
contentLines = append(contentLines, renderKVLine("Time", timeText, m.theme.Text))
|
|
||||||
|
|
||||||
contentLines = append(contentLines, formatLine(""))
|
|
||||||
contentLines = append(contentLines, formatMutedLine(" Timing"))
|
|
||||||
|
|
||||||
barValue := formatDuration(duration)
|
|
||||||
barPrefix := " "
|
|
||||||
barSuffix := " " + barValue
|
|
||||||
maxBarWidth := max(1, w/2)
|
|
||||||
barWidth := min(maxBarWidth, max(0, w-len(barPrefix)-len(barSuffix)))
|
|
||||||
if barWidth == 0 {
|
|
||||||
contentLines = append(contentLines, formatLine(" "+barValue))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
const maxDurationForScale = 2 * time.Second
|
|
||||||
ratio := float64(duration) / float64(maxDurationForScale)
|
|
||||||
if ratio < 0 {
|
|
||||||
ratio = 0
|
|
||||||
}
|
|
||||||
if ratio > 1 {
|
|
||||||
ratio = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
filled := int(ratio * float64(barWidth))
|
|
||||||
if duration > 0 && filled == 0 {
|
|
||||||
filled = 1
|
|
||||||
}
|
|
||||||
if filled > barWidth {
|
|
||||||
filled = barWidth
|
|
||||||
}
|
|
||||||
empty := max(0, barWidth-filled)
|
|
||||||
|
|
||||||
filledPart := lipgloss.NewStyle().Foreground(blue).Render(strings.Repeat("█", filled))
|
|
||||||
emptyPart := lipgloss.NewStyle().Foreground(cyan).Render(strings.Repeat("░", empty))
|
|
||||||
barLine := barPrefix + filledPart + emptyPart + m.theme.TextMuted.Render(barSuffix)
|
|
||||||
contentLines = append(contentLines, clampRendered(barLine, w))
|
|
||||||
|
|
||||||
case detailsTabRequest:
|
|
||||||
contentLines = append(contentLines, formatLine(""))
|
|
||||||
contentLines = append(contentLines, formatMutedLine(" -- Request Body --"))
|
|
||||||
contentLines = append(contentLines, formatLine(""))
|
|
||||||
if len(selectedReq.RequestData) == 0 {
|
|
||||||
contentLines = append(contentLines, formatMutedItalicLine(" empty"))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
for _, line := range formatBodyLines(prettyBody(selectedReq.RequestData, selectedReq.RequestHeaders), w) {
|
|
||||||
contentLines = append(contentLines, formatLine(" "+line))
|
|
||||||
}
|
|
||||||
|
|
||||||
case detailsTabResponse:
|
|
||||||
contentLines = append(contentLines, formatLine(""))
|
|
||||||
contentLines = append(contentLines, formatMutedLine(" -- Response Body --"))
|
|
||||||
contentLines = append(contentLines, formatLine(""))
|
|
||||||
if len(selectedReq.ResponseData) == 0 {
|
|
||||||
contentLines = append(contentLines, formatMutedItalicLine(" empty"))
|
|
||||||
break
|
|
||||||
}
|
|
||||||
for _, line := range formatBodyLines(prettyBody(selectedReq.ResponseData, selectedReq.ResponseHeaders), w) {
|
|
||||||
contentLines = append(contentLines, formatLine(" "+line))
|
|
||||||
}
|
|
||||||
|
|
||||||
case detailsTabHeaders:
|
|
||||||
contentLines = append(contentLines, formatLine(""))
|
|
||||||
renderHeaderLine := func(key, value string) string {
|
|
||||||
left := m.theme.HeaderKey.Render(" " + key + ": ")
|
|
||||||
right := m.theme.TextMuted.Render(value)
|
|
||||||
line := left + right
|
|
||||||
if lipgloss.Width(line) < w {
|
|
||||||
line += m.theme.Text.Render(strings.Repeat(" ", w-lipgloss.Width(line)))
|
|
||||||
}
|
|
||||||
return clampRendered(line, w)
|
|
||||||
}
|
|
||||||
|
|
||||||
appendHeaders := func(title string, headers map[string][]string) {
|
|
||||||
contentLines = append(contentLines, formatMutedLine(" -- "+title+" --"))
|
|
||||||
contentLines = append(contentLines, formatLine(""))
|
|
||||||
|
|
||||||
if len(headers) == 0 {
|
|
||||||
contentLines = append(contentLines, formatMutedItalicLine(" empty"))
|
|
||||||
contentLines = append(contentLines, formatLine(""))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
keys := make([]string, 0, len(headers))
|
|
||||||
for key := range headers {
|
|
||||||
keys = append(keys, key)
|
|
||||||
}
|
|
||||||
sort.Strings(keys)
|
|
||||||
|
|
||||||
for _, key := range keys {
|
|
||||||
contentLines = append(contentLines, renderHeaderLine(key, strings.Join(headers[key], ", ")))
|
|
||||||
}
|
|
||||||
|
|
||||||
contentLines = append(contentLines, formatLine(""))
|
|
||||||
}
|
|
||||||
|
|
||||||
appendHeaders("Request Headers", selectedReq.RequestHeaders)
|
|
||||||
appendHeaders("Response Headers", selectedReq.ResponseHeaders)
|
|
||||||
|
|
||||||
default:
|
|
||||||
contentLines = append(contentLines, formatLine(" Unknown details tab"))
|
|
||||||
}
|
|
||||||
|
|
||||||
return contentLines
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Model) detailsContentLineCount(w int) int {
|
|
||||||
plain := func(s string) string { return s }
|
|
||||||
return len(m.detailsContentLines(w, plain, plain, plain))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Model) selectedRequest() (model.Request, bool) {
|
|
||||||
if len(m.requests) == 0 {
|
|
||||||
return model.Request{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
cursor := m.requestCursor
|
|
||||||
if cursor < 0 {
|
|
||||||
cursor = 0
|
|
||||||
}
|
|
||||||
if cursor >= len(m.requests) {
|
|
||||||
cursor = len(m.requests) - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
idx := len(m.requests) - cursor - 1
|
|
||||||
if idx < 0 || idx >= len(m.requests) {
|
|
||||||
return model.Request{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.requests[idx], true
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatBodyLines(body []byte, width int) []string {
|
|
||||||
if len(body) == 0 {
|
|
||||||
return []string{"(empty)"}
|
|
||||||
}
|
|
||||||
|
|
||||||
text := string(body)
|
|
||||||
if !utf8.Valid(body) {
|
|
||||||
previewSize := min(len(body), 64)
|
|
||||||
hexPreview := hex.EncodeToString(body[:previewSize])
|
|
||||||
suffix := ""
|
|
||||||
if previewSize < len(body) {
|
|
||||||
suffix = "..."
|
|
||||||
}
|
|
||||||
text = fmt.Sprintf("(binary payload, %d bytes, hex preview: %s%s)", len(body), hexPreview, suffix)
|
|
||||||
}
|
|
||||||
|
|
||||||
src := strings.ReplaceAll(text, "\t", " ")
|
|
||||||
rawLines := strings.Split(src, "\n")
|
|
||||||
if width <= 4 {
|
|
||||||
return rawLines
|
|
||||||
}
|
|
||||||
|
|
||||||
maxWidth := max(1, width-2)
|
|
||||||
out := make([]string, 0, len(rawLines))
|
|
||||||
for _, line := range rawLines {
|
|
||||||
if line == "" {
|
|
||||||
out = append(out, "")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for len(line) > maxWidth {
|
|
||||||
out = append(out, line[:maxWidth])
|
|
||||||
line = line[maxWidth:]
|
|
||||||
}
|
|
||||||
out = append(out, line)
|
|
||||||
}
|
|
||||||
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func prettyBody(body []byte, headers map[string][]string) []byte {
|
|
||||||
if len(body) == 0 {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
if !looksLikeJSON(body, headers) {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
var out bytes.Buffer
|
|
||||||
if err := json.Indent(&out, body, "", " "); err != nil {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
|
|
||||||
return out.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
func looksLikeJSON(body []byte, headers map[string][]string) bool {
|
|
||||||
if json.Valid(body) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, values := range headers {
|
|
||||||
if !strings.EqualFold(key, "Content-Type") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, value := range values {
|
|
||||||
contentType := strings.ToLower(value)
|
|
||||||
if strings.Contains(contentType, "application/json") || strings.Contains(contentType, "+json") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: This can be done better
|
// TODO: This can be done better
|
||||||
// TODO: Should h be max or defined?
|
// TODO: Should h be max or defined?
|
||||||
func (m Model) renderEventsPane(w, h int) []string {
|
func (m Model) renderEventsPane(w, h int) []string {
|
||||||
@ -642,30 +114,18 @@ func (m Model) renderEventsPane(w, h int) []string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
left := fmt.Sprintf("[3] EVENT LOG - %d EVENTS", len(events))
|
displayCount := max(h-1, 0)
|
||||||
right := "E: TOGGLE"
|
|
||||||
headerStyle := m.theme.TextMuted
|
if displayCount < len(events) {
|
||||||
if m.focusedPane == focusPaneEvents {
|
events = events[len(events)-displayCount:]
|
||||||
headerStyle = m.theme.EventPaneHeader
|
|
||||||
}
|
}
|
||||||
status := headerStyle.Render(left + strings.Repeat(" ", max(0, w-len(left+right))) + right)
|
|
||||||
|
left := fmt.Sprintf("EVENT LOG - %d EVENTS", len(events))
|
||||||
|
right := "E: TOGGLE"
|
||||||
|
status := m.theme.EventHeader.Render(left + strings.Repeat(" ", w-len(left+right)) + right)
|
||||||
lines := []string{status}
|
lines := []string{status}
|
||||||
|
|
||||||
bodyHeight := max(0, h-1)
|
for _, event := range events {
|
||||||
maxScroll := max(0, len(events)-bodyHeight)
|
|
||||||
scroll := m.eventsScroll
|
|
||||||
if scroll < 0 {
|
|
||||||
scroll = 0
|
|
||||||
}
|
|
||||||
if scroll > maxScroll {
|
|
||||||
scroll = maxScroll
|
|
||||||
}
|
|
||||||
|
|
||||||
start := max(0, len(events)-bodyHeight-scroll)
|
|
||||||
end := min(len(events), start+bodyHeight)
|
|
||||||
visible := events[start:end]
|
|
||||||
|
|
||||||
for _, event := range visible {
|
|
||||||
var (
|
var (
|
||||||
eTime string = m.theme.TextMuted.Render(event.Time.Format("15:04:05") + " ")
|
eTime string = m.theme.TextMuted.Render(event.Time.Format("15:04:05") + " ")
|
||||||
eType string = getEventColor(m.theme, event.Type).Render(fmt.Sprintf("%-17s ", event.Type))
|
eType string = getEventColor(m.theme, event.Type).Render(fmt.Sprintf("%-17s ", event.Type))
|
||||||
@ -705,16 +165,6 @@ func (m Model) renderEventsPane(w, h int) []string {
|
|||||||
return lines
|
return lines
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Model) eventCount() int {
|
|
||||||
count := 0
|
|
||||||
for _, ev := range m.events {
|
|
||||||
if ev.Type != model.EventTypeProcessStderr && ev.Type != model.EventTypeProcessStdout {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return count
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Should h be max or defined?
|
// TODO: Should h be max or defined?
|
||||||
func (m Model) renderStdPane(w, h int) []string {
|
func (m Model) renderStdPane(w, h int) []string {
|
||||||
// Only the stdout or stderr logs
|
// Only the stdout or stderr logs
|
||||||
@ -726,30 +176,18 @@ func (m Model) renderStdPane(w, h int) []string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
left := fmt.Sprintf("[4] STDOUT/STDERR LOG - %d LINES", len(logs))
|
displayCount := max(h-1, 0)
|
||||||
right := "O: TOGGLE"
|
|
||||||
headerStyle := m.theme.TextMuted
|
if displayCount < len(logs) {
|
||||||
if m.focusedPane == focusPaneStd {
|
logs = logs[len(logs)-displayCount:]
|
||||||
headerStyle = m.theme.StdHeader
|
|
||||||
}
|
}
|
||||||
status := headerStyle.Render(left + strings.Repeat(" ", max(0, w-len(left+right))) + right)
|
|
||||||
|
left := fmt.Sprintf("STDOUT/STDERR LOG - %d LINES", len(logs))
|
||||||
|
right := "O: TOGGLE"
|
||||||
|
status := m.theme.StdHeader.Render(left + strings.Repeat(" ", w-len(left+right)) + right)
|
||||||
lines := []string{status}
|
lines := []string{status}
|
||||||
|
|
||||||
bodyHeight := max(0, h-1)
|
for _, log := range logs {
|
||||||
maxScroll := max(0, len(logs)-bodyHeight)
|
|
||||||
scroll := m.stdScroll
|
|
||||||
if scroll < 0 {
|
|
||||||
scroll = 0
|
|
||||||
}
|
|
||||||
if scroll > maxScroll {
|
|
||||||
scroll = maxScroll
|
|
||||||
}
|
|
||||||
|
|
||||||
start := max(0, len(logs)-bodyHeight-scroll)
|
|
||||||
end := min(len(logs), start+bodyHeight)
|
|
||||||
visible := logs[start:end]
|
|
||||||
|
|
||||||
for _, log := range visible {
|
|
||||||
var (
|
var (
|
||||||
tag string
|
tag string
|
||||||
body string
|
body string
|
||||||
@ -787,13 +225,3 @@ func (m Model) renderStdPane(w, h int) []string {
|
|||||||
|
|
||||||
return lines
|
return lines
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m Model) stdLogCount() int {
|
|
||||||
count := 0
|
|
||||||
for _, ev := range m.events {
|
|
||||||
if ev.Type == model.EventTypeProcessStderr || ev.Type == model.EventTypeProcessStdout {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return count
|
|
||||||
}
|
|
||||||
|
|||||||
@ -4,13 +4,13 @@ import "strings"
|
|||||||
|
|
||||||
func (m Model) renderAppPane() string {
|
func (m Model) renderAppPane() string {
|
||||||
// Constant height offset
|
// Constant height offset
|
||||||
constHeightOffset := 2
|
constHeightOffset := 1
|
||||||
|
|
||||||
var (
|
var (
|
||||||
searchW int = max(0, m.width)
|
searchW int = max(0, m.width)
|
||||||
searchH int = 1
|
searchH int = 1
|
||||||
|
|
||||||
reqW int = max(0, int(float64(m.width)*0.5))
|
reqW int = max(0, int(float64(m.width)*0.55))
|
||||||
detW int = max(0, m.width-reqW)
|
detW int = max(0, m.width-reqW)
|
||||||
|
|
||||||
reqH int = max(0, m.height-constHeightOffset)
|
reqH int = max(0, m.height-constHeightOffset)
|
||||||
@ -68,9 +68,6 @@ func (m Model) renderAppPane() string {
|
|||||||
screen = append(screen, stdPane...)
|
screen = append(screen, stdPane...)
|
||||||
}
|
}
|
||||||
|
|
||||||
statusBottom := m.renderBottomStatusBar(m.width)
|
|
||||||
screen = append(screen, statusBottom)
|
|
||||||
|
|
||||||
if len(screen) != m.height {
|
if len(screen) != m.height {
|
||||||
return "height of screen does not match terminal height"
|
return "height of screen does not match terminal height"
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,15 +11,12 @@ type Theme struct {
|
|||||||
|
|
||||||
Header lipgloss.Style
|
Header lipgloss.Style
|
||||||
EventHeader lipgloss.Style
|
EventHeader lipgloss.Style
|
||||||
EventPaneHeader lipgloss.Style
|
|
||||||
StdHeader lipgloss.Style
|
StdHeader lipgloss.Style
|
||||||
|
|
||||||
Text lipgloss.Style
|
Text lipgloss.Style
|
||||||
TextMuted lipgloss.Style
|
TextMuted lipgloss.Style
|
||||||
TextError lipgloss.Style
|
TextError lipgloss.Style
|
||||||
TextMutedError lipgloss.Style
|
TextMutedError lipgloss.Style
|
||||||
RequestSelected lipgloss.Style
|
|
||||||
HeaderKey lipgloss.Style
|
|
||||||
|
|
||||||
EventDefault lipgloss.Style
|
EventDefault lipgloss.Style
|
||||||
EventSession lipgloss.Style
|
EventSession lipgloss.Style
|
||||||
@ -59,10 +56,6 @@ func newTheme() Theme {
|
|||||||
Bold(true).
|
Bold(true).
|
||||||
Foreground(background).
|
Foreground(background).
|
||||||
Background(blue),
|
Background(blue),
|
||||||
EventPaneHeader: lipgloss.NewStyle().
|
|
||||||
Bold(true).
|
|
||||||
Foreground(background).
|
|
||||||
Background(green),
|
|
||||||
StdHeader: lipgloss.NewStyle().
|
StdHeader: lipgloss.NewStyle().
|
||||||
Bold(true).
|
Bold(true).
|
||||||
Foreground(background).
|
Foreground(background).
|
||||||
@ -80,13 +73,6 @@ func newTheme() Theme {
|
|||||||
TextMutedError: lipgloss.NewStyle().
|
TextMutedError: lipgloss.NewStyle().
|
||||||
Foreground(textMuted).
|
Foreground(textMuted).
|
||||||
Background(backgroundError),
|
Background(backgroundError),
|
||||||
RequestSelected: lipgloss.NewStyle().
|
|
||||||
Foreground(background).
|
|
||||||
Background(blue).
|
|
||||||
Bold(true),
|
|
||||||
HeaderKey: lipgloss.NewStyle().
|
|
||||||
Foreground(cyan).
|
|
||||||
Background(background),
|
|
||||||
|
|
||||||
EventDefault: lipgloss.NewStyle().
|
EventDefault: lipgloss.NewStyle().
|
||||||
Foreground(text).
|
Foreground(text).
|
||||||
|
|||||||
@ -13,7 +13,6 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
case tea.WindowSizeMsg:
|
case tea.WindowSizeMsg:
|
||||||
m.width = msg.Width
|
m.width = msg.Width
|
||||||
m.height = msg.Height
|
m.height = msg.Height
|
||||||
m.clampPaneScrolls()
|
|
||||||
return m, nil
|
return m, nil
|
||||||
|
|
||||||
case TickMsg:
|
case TickMsg:
|
||||||
@ -28,22 +27,6 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
switch msg.String() {
|
switch msg.String() {
|
||||||
case "ctrl+c", "q":
|
case "ctrl+c", "q":
|
||||||
return m, tea.Quit
|
return m, tea.Quit
|
||||||
case "j", "down":
|
|
||||||
m.moveFocusedVertical(1)
|
|
||||||
case "k", "up":
|
|
||||||
m.moveFocusedVertical(-1)
|
|
||||||
case "tab":
|
|
||||||
m.moveDetailsTab(1)
|
|
||||||
case "shift+tab", "backtab":
|
|
||||||
m.moveDetailsTab(-1)
|
|
||||||
case "1":
|
|
||||||
m.setFocusedPane(focusPaneRequests)
|
|
||||||
case "2":
|
|
||||||
m.setFocusedPane(focusPaneDetails)
|
|
||||||
case "3":
|
|
||||||
m.setFocusedPane(focusPaneEvents)
|
|
||||||
case "4":
|
|
||||||
m.setFocusedPane(focusPaneStd)
|
|
||||||
case tea.KeyCtrlR.String():
|
case tea.KeyCtrlR.String():
|
||||||
if m.restarting {
|
if m.restarting {
|
||||||
return m, nil
|
return m, nil
|
||||||
@ -55,18 +38,12 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
return m, restartCmd(m.controls.Restart)
|
return m, restartCmd(m.controls.Restart)
|
||||||
case "e":
|
case "e":
|
||||||
m.showEvents = !m.showEvents
|
m.showEvents = !m.showEvents
|
||||||
m.ensureFocusedPaneVisible()
|
|
||||||
m.clampPaneScrolls()
|
|
||||||
case "o":
|
case "o":
|
||||||
m.showStd = !m.showStd
|
m.showStd = !m.showStd
|
||||||
m.ensureFocusedPaneVisible()
|
|
||||||
m.clampPaneScrolls()
|
|
||||||
case "/":
|
case "/":
|
||||||
m.showSearch = true
|
m.showSearch = true
|
||||||
m.clampPaneScrolls()
|
|
||||||
case "esc":
|
case "esc":
|
||||||
m.showSearch = false
|
m.showSearch = false
|
||||||
m.clampPaneScrolls()
|
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
|
|
||||||
@ -92,7 +69,6 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
|||||||
case EventMsg:
|
case EventMsg:
|
||||||
m.pushEvent(msg.value)
|
m.pushEvent(msg.value)
|
||||||
m.applyMessage(msg.value)
|
m.applyMessage(msg.value)
|
||||||
m.clampPaneScrolls()
|
|
||||||
if m.hasPendingRequests() {
|
if m.hasPendingRequests() {
|
||||||
return m, tea.Batch(waitForEvent(m.channel), tickCmd())
|
return m, tea.Batch(waitForEvent(m.channel), tickCmd())
|
||||||
}
|
}
|
||||||
@ -123,10 +99,6 @@ func (m *Model) createRequest(req model.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.requests) > 0 && m.requestCursor > 0 {
|
|
||||||
m.requestCursor++
|
|
||||||
}
|
|
||||||
|
|
||||||
m.requests = append(m.requests, req)
|
m.requests = append(m.requests, req)
|
||||||
|
|
||||||
// If we passed the max, delete the first one
|
// If we passed the max, delete the first one
|
||||||
@ -134,8 +106,6 @@ func (m *Model) createRequest(req model.Request) {
|
|||||||
if len(m.requests) > maxRequests {
|
if len(m.requests) > maxRequests {
|
||||||
m.requests = m.requests[1:]
|
m.requests = m.requests[1:]
|
||||||
}
|
}
|
||||||
|
|
||||||
m.clampRequestCursor()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) updateRequest(req model.Request) {
|
func (m *Model) updateRequest(req model.Request) {
|
||||||
@ -149,248 +119,6 @@ func (m *Model) updateRequest(req model.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) moveRequestCursor(delta int) {
|
|
||||||
if len(m.requests) == 0 {
|
|
||||||
m.requestCursor = 0
|
|
||||||
m.requestScroll = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m.requestCursor += delta
|
|
||||||
m.clampRequestCursor()
|
|
||||||
m.ensureRequestCursorVisible()
|
|
||||||
m.detailsScroll = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) clampRequestCursor() {
|
|
||||||
if len(m.requests) == 0 {
|
|
||||||
m.requestCursor = 0
|
|
||||||
m.requestScroll = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.requestCursor < 0 {
|
|
||||||
m.requestCursor = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.requestCursor >= len(m.requests) {
|
|
||||||
m.requestCursor = len(m.requests) - 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) ensureRequestCursorVisible() {
|
|
||||||
viewHeight := m.requestBodyHeight()
|
|
||||||
if viewHeight <= 0 {
|
|
||||||
m.requestScroll = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
maxScroll := max(0, len(m.requests)-viewHeight)
|
|
||||||
if m.requestScroll < 0 {
|
|
||||||
m.requestScroll = 0
|
|
||||||
}
|
|
||||||
if m.requestScroll > maxScroll {
|
|
||||||
m.requestScroll = maxScroll
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.requestCursor < m.requestScroll {
|
|
||||||
m.requestScroll = m.requestCursor
|
|
||||||
}
|
|
||||||
if m.requestCursor >= m.requestScroll+viewHeight {
|
|
||||||
m.requestScroll = m.requestCursor - viewHeight + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.requestScroll < 0 {
|
|
||||||
m.requestScroll = 0
|
|
||||||
}
|
|
||||||
if m.requestScroll > maxScroll {
|
|
||||||
m.requestScroll = maxScroll
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) moveDetailsTab(delta int) {
|
|
||||||
count := len(detailsTabNames)
|
|
||||||
if count == 0 {
|
|
||||||
m.detailsTab = 0
|
|
||||||
m.detailsScroll = 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m.detailsTab = (m.detailsTab + delta) % count
|
|
||||||
if m.detailsTab < 0 {
|
|
||||||
m.detailsTab += count
|
|
||||||
}
|
|
||||||
m.detailsScroll = 0
|
|
||||||
m.clampPaneScrolls()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) moveFocusedVertical(delta int) {
|
|
||||||
switch m.focusedPane {
|
|
||||||
case focusPaneRequests:
|
|
||||||
m.moveRequestCursor(delta)
|
|
||||||
case focusPaneDetails:
|
|
||||||
m.detailsScroll += delta
|
|
||||||
case focusPaneEvents:
|
|
||||||
if delta > 0 {
|
|
||||||
m.eventsScroll = max(0, m.eventsScroll-delta)
|
|
||||||
} else {
|
|
||||||
m.eventsScroll += -delta
|
|
||||||
}
|
|
||||||
case focusPaneStd:
|
|
||||||
if delta > 0 {
|
|
||||||
m.stdScroll = max(0, m.stdScroll-delta)
|
|
||||||
} else {
|
|
||||||
m.stdScroll += -delta
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.clampPaneScrolls()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) setFocusedPane(pane int) {
|
|
||||||
if !m.canFocusPane(pane) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.focusedPane = pane
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) canFocusPane(pane int) bool {
|
|
||||||
switch pane {
|
|
||||||
case focusPaneRequests, focusPaneDetails:
|
|
||||||
return true
|
|
||||||
case focusPaneEvents:
|
|
||||||
return m.showEvents
|
|
||||||
case focusPaneStd:
|
|
||||||
return m.showStd
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) ensureFocusedPaneVisible() {
|
|
||||||
if m.canFocusPane(m.focusedPane) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.canFocusPane(focusPaneDetails) {
|
|
||||||
m.focusedPane = focusPaneDetails
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m.focusedPane = focusPaneRequests
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Model) clampPaneScrolls() {
|
|
||||||
if m.requestScroll < 0 {
|
|
||||||
m.requestScroll = 0
|
|
||||||
}
|
|
||||||
if m.detailsScroll < 0 {
|
|
||||||
m.detailsScroll = 0
|
|
||||||
}
|
|
||||||
if m.eventsScroll < 0 {
|
|
||||||
m.eventsScroll = 0
|
|
||||||
}
|
|
||||||
if m.stdScroll < 0 {
|
|
||||||
m.stdScroll = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
m.ensureRequestCursorVisible()
|
|
||||||
|
|
||||||
maxDetails := m.maxDetailsScroll()
|
|
||||||
if m.detailsScroll > maxDetails {
|
|
||||||
m.detailsScroll = maxDetails
|
|
||||||
}
|
|
||||||
|
|
||||||
maxEvents := m.maxEventsScroll()
|
|
||||||
if m.eventsScroll > maxEvents {
|
|
||||||
m.eventsScroll = maxEvents
|
|
||||||
}
|
|
||||||
|
|
||||||
maxStd := m.maxStdScroll()
|
|
||||||
if m.stdScroll > maxStd {
|
|
||||||
m.stdScroll = maxStd
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Model) panelHeights() (requestHeight, detailsHeight, eventsHeight, stdHeight int) {
|
|
||||||
const constHeightOffset = 1
|
|
||||||
|
|
||||||
requestHeight = max(0, m.height-constHeightOffset)
|
|
||||||
detailsHeight = max(0, m.height-constHeightOffset)
|
|
||||||
eventsHeight = max(0, int(float64(m.height)*0.2))
|
|
||||||
stdHeight = max(0, int(float64(m.height)*0.2))
|
|
||||||
|
|
||||||
if m.showSearch {
|
|
||||||
requestHeight = max(0, requestHeight-1)
|
|
||||||
detailsHeight = max(0, detailsHeight-1)
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.showEvents {
|
|
||||||
requestHeight = max(0, requestHeight-eventsHeight)
|
|
||||||
detailsHeight = max(0, detailsHeight-eventsHeight)
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.showStd {
|
|
||||||
requestHeight = max(0, requestHeight-stdHeight)
|
|
||||||
detailsHeight = max(0, detailsHeight-stdHeight)
|
|
||||||
}
|
|
||||||
|
|
||||||
return requestHeight, detailsHeight, eventsHeight, stdHeight
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Model) requestBodyHeight() int {
|
|
||||||
requestHeight, _, _, _ := m.panelHeights()
|
|
||||||
return max(0, requestHeight-2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Model) detailsBodyHeight() int {
|
|
||||||
_, detailsHeight, _, _ := m.panelHeights()
|
|
||||||
return max(0, detailsHeight-2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Model) detailsPaneWidth() int {
|
|
||||||
requestWidth := max(0, int(float64(m.width)*0.5))
|
|
||||||
return max(0, m.width-requestWidth)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Model) maxDetailsScroll() int {
|
|
||||||
bodyHeight := m.detailsBodyHeight()
|
|
||||||
if bodyHeight <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
total := m.detailsContentLineCount(m.detailsPaneWidth())
|
|
||||||
return max(0, total-bodyHeight)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Model) maxEventsScroll() int {
|
|
||||||
if !m.showEvents {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, eventsHeight, _ := m.panelHeights()
|
|
||||||
bodyHeight := max(0, eventsHeight-1)
|
|
||||||
if bodyHeight <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return max(0, m.eventCount()-bodyHeight)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Model) maxStdScroll() int {
|
|
||||||
if !m.showStd {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, _, stdHeight := m.panelHeights()
|
|
||||||
bodyHeight := max(0, stdHeight-1)
|
|
||||||
if bodyHeight <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return max(0, m.stdLogCount()-bodyHeight)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Model) hasPendingRequests() bool {
|
func (m Model) hasPendingRequests() bool {
|
||||||
// Traverse backward to be a bit more efficient, the most recent requests are more
|
// Traverse backward to be a bit more efficient, the most recent requests are more
|
||||||
// like to be pending.
|
// like to be pending.
|
||||||
|
|||||||
@ -1,245 +0,0 @@
|
|||||||
package tui
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"termtap.dev/internal/model"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestFormatDuration(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
in time.Duration
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{name: "pending zero", in: 0, want: "PENDING"},
|
|
||||||
{name: "microseconds", in: 750 * time.Microsecond, want: "750us"},
|
|
||||||
{name: "milliseconds", in: 20 * time.Millisecond, want: "20ms"},
|
|
||||||
{name: "seconds", in: 11 * time.Second, want: "11.00s"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
if got := formatDuration(tt.in); got != tt.want {
|
|
||||||
t.Fatalf("formatDuration(%v) = %q, want %q", tt.in, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTruncate(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
s string
|
|
||||||
max int
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{name: "short unchanged", s: "abc", max: 3, want: "abc"},
|
|
||||||
{name: "max small no ellipsis", s: "abcdef", max: 3, want: "abc"},
|
|
||||||
{name: "ellipsis", s: "abcdef", max: 5, want: "ab..."},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
if got := truncate(tt.s, tt.max); got != tt.want {
|
|
||||||
t.Fatalf("truncate(%q,%d) = %q, want %q", tt.s, tt.max, got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClampRendered(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
if got := clampRendered("abcdef", 0); got != "" {
|
|
||||||
t.Fatalf("clampRendered max=0 = %q, want empty", got)
|
|
||||||
}
|
|
||||||
if got := clampRendered("abc", 10); got != "abc" {
|
|
||||||
t.Fatalf("clampRendered no truncation = %q, want %q", got, "abc")
|
|
||||||
}
|
|
||||||
if got := clampRendered("abcdef", 4); !strings.Contains(got, "...") {
|
|
||||||
t.Fatalf("clampRendered truncation should include ellipsis, got %q", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetEventColor(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
theme := newTheme()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
typ model.EventType
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
{name: "session", typ: model.EventTypeSessionStarted, want: theme.EventSession.Render("x")},
|
|
||||||
{name: "proxy", typ: model.EventTypeProxyStarted, want: theme.EventProxy.Render("x")},
|
|
||||||
{name: "request in flight", typ: model.EventTypeRequestStarted, want: theme.EventRequestInFlight.Render("x")},
|
|
||||||
{name: "request success", typ: model.EventTypeRequestFinished, want: theme.EventSuccess.Render("x")},
|
|
||||||
{name: "warn", typ: model.EventTypeWarn, want: theme.EventWarn.Render("x")},
|
|
||||||
{name: "error", typ: model.EventTypeRequestFailed, want: theme.EventError.Render("x")},
|
|
||||||
{name: "fatal", typ: model.EventTypeFatal, want: theme.EventFatal.Render("x")},
|
|
||||||
{name: "default", typ: model.EventType("UnknownType"), want: theme.EventDefault.Render("x")},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
tt := tt
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
got := getEventColor(theme, tt.typ).Render("x")
|
|
||||||
if got != tt.want {
|
|
||||||
t.Fatalf("unexpected style for %s", tt.typ)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestViewAndPaneStructure(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
m := NewModel(make(chan model.Event), Controls{})
|
|
||||||
|
|
||||||
t.Run("view returns raw pane when unset size", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
got := m.View()
|
|
||||||
if got != m.renderAppPane() {
|
|
||||||
t.Fatal("View should return raw pane when width/height are unset")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("renderAppPane line count matches height", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
m2 := m
|
|
||||||
m2.width = 80
|
|
||||||
m2.height = 12
|
|
||||||
got := m2.renderAppPane()
|
|
||||||
if got == "height of request and details did not match" || got == "height of screen does not match terminal height" {
|
|
||||||
t.Fatalf("unexpected renderAppPane invariant error: %q", got)
|
|
||||||
}
|
|
||||||
if lines := strings.Count(got, "\n") + 1; lines != m2.height {
|
|
||||||
t.Fatalf("line count = %d, want %d", lines, m2.height)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("renderAppPane supports toggles", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
m2 := m
|
|
||||||
m2.width = 90
|
|
||||||
m2.height = 14
|
|
||||||
m2.showEvents = true
|
|
||||||
m2.showStd = true
|
|
||||||
m2.showSearch = true
|
|
||||||
got := m2.renderAppPane()
|
|
||||||
if got == "height of request and details did not match" || got == "height of screen does not match terminal height" {
|
|
||||||
t.Fatalf("unexpected renderAppPane invariant error with toggles: %q", got)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("view applies configured terminal height", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
m2 := m
|
|
||||||
m2.width = 70
|
|
||||||
m2.height = 10
|
|
||||||
|
|
||||||
got := m2.View()
|
|
||||||
if lines := strings.Count(got, "\n") + 1; lines < m2.height {
|
|
||||||
t.Fatalf("View line count = %d, want at least %d", lines, m2.height)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPaneRenderersAndStatusBar(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
m := NewModel(make(chan model.Event), Controls{})
|
|
||||||
m.width = 100
|
|
||||||
m.height = 12
|
|
||||||
m.requests = []model.Request{
|
|
||||||
{ID: uuid.New(), Method: "GET", Host: "a", URL: "/a", Status: 200, Duration: 5 * time.Millisecond},
|
|
||||||
{ID: uuid.New(), Method: "POST", Host: "b", URL: "/b", Status: 500, Duration: 10 * time.Millisecond, Failed: true},
|
|
||||||
}
|
|
||||||
m.events = []model.Event{
|
|
||||||
{Type: model.EventTypeWarn, Body: "warn"},
|
|
||||||
{Type: model.EventTypeProcessStdout, Body: "out"},
|
|
||||||
{Type: model.EventTypeProcessStderr, Body: "err"},
|
|
||||||
}
|
|
||||||
|
|
||||||
status := m.renderStatusBar(100)
|
|
||||||
if !strings.Contains(status, "2 reqs") || !strings.Contains(status, "1 err") {
|
|
||||||
t.Fatalf("status bar missing expected counters: %q", status)
|
|
||||||
}
|
|
||||||
|
|
||||||
search := m.renderSearchPane(20, 3)
|
|
||||||
if len(search) != 3 {
|
|
||||||
t.Fatalf("search pane len = %d, want 3", len(search))
|
|
||||||
}
|
|
||||||
for i, line := range search {
|
|
||||||
if len(line) != 20 {
|
|
||||||
t.Fatalf("search pane line %d len = %d, want %d", i, len(line), 20)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
requests := m.renderRequestPane(50, 4)
|
|
||||||
if len(requests) != 4 {
|
|
||||||
t.Fatalf("request pane len = %d, want 4", len(requests))
|
|
||||||
}
|
|
||||||
|
|
||||||
details := m.renderDetailsPane(30, 4)
|
|
||||||
if len(details) != 4 {
|
|
||||||
t.Fatalf("details pane len = %d, want 4", len(details))
|
|
||||||
}
|
|
||||||
|
|
||||||
events := m.renderEventsPane(60, 4)
|
|
||||||
if len(events) != 4 {
|
|
||||||
t.Fatalf("events pane len = %d, want 4", len(events))
|
|
||||||
}
|
|
||||||
if strings.Contains(strings.Join(events, "\n"), "out") || strings.Contains(strings.Join(events, "\n"), "err") {
|
|
||||||
t.Fatal("events pane should filter stdout/stderr events")
|
|
||||||
}
|
|
||||||
|
|
||||||
std := m.renderStdPane(60, 4)
|
|
||||||
if len(std) != 4 {
|
|
||||||
t.Fatalf("std pane len = %d, want 4", len(std))
|
|
||||||
}
|
|
||||||
joined := strings.Join(std, "\n")
|
|
||||||
if !strings.Contains(joined, "out") || !strings.Contains(joined, "err") {
|
|
||||||
t.Fatal("std pane should include stdout/stderr logs")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRenderEventsPane_ErrorAndPIDBranches(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
m := NewModel(make(chan model.Event), Controls{})
|
|
||||||
m.events = []model.Event{
|
|
||||||
{Type: model.EventTypeWarn, Body: "old"},
|
|
||||||
{Type: model.EventTypeRequestFailed, Body: "failed body", PID: 123, Time: time.Now()},
|
|
||||||
{Type: model.EventTypeFatal, Body: "fatal body", Time: time.Now()},
|
|
||||||
}
|
|
||||||
|
|
||||||
lines := m.renderEventsPane(60, 3)
|
|
||||||
if len(lines) != 3 {
|
|
||||||
t.Fatalf("events pane len = %d, want 3", len(lines))
|
|
||||||
}
|
|
||||||
|
|
||||||
joined := strings.Join(lines, "\n")
|
|
||||||
if !strings.Contains(joined, "123") {
|
|
||||||
t.Fatalf("expected PID to appear in events pane, got: %q", joined)
|
|
||||||
}
|
|
||||||
if !strings.Contains(joined, "failed body") {
|
|
||||||
t.Fatalf("expected failed body to appear in events pane, got: %q", joined)
|
|
||||||
}
|
|
||||||
if !strings.Contains(joined, "fatal body") {
|
|
||||||
t.Fatalf("expected fatal body to appear in events pane, got: %q", joined)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Loading…
x
Reference in New Issue
Block a user