termtap/internal/proxy/server_test.go
Hayden Hargreaves 002773e77f test: AI generated all of these tests
Just for the MVP of course. Need to validate the idea.
2026-04-23 19:47:04 -07:00

236 lines
6.1 KiB
Go

package proxy
import (
"context"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"
"termtap.dev/internal/model"
)
// NOTE: Run these tests with -race; they cover concurrent state.
type closeTrackingConn struct {
closed bool
mu sync.Mutex
}
func (c *closeTrackingConn) Read(_ []byte) (int, error) { return 0, net.ErrClosed }
func (c *closeTrackingConn) Write(b []byte) (int, error) { return len(b), nil }
func (c *closeTrackingConn) Close() error { c.mu.Lock(); c.closed = true; c.mu.Unlock(); return nil }
func (c *closeTrackingConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (c *closeTrackingConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (c *closeTrackingConn) SetDeadline(_ time.Time) error { return nil }
func (c *closeTrackingConn) SetReadDeadline(_ time.Time) error { return nil }
func (c *closeTrackingConn) SetWriteDeadline(_ time.Time) error { return nil }
func (c *closeTrackingConn) isClosed() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.closed
}
func TestNewProxyServer_Success(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
ch := make(chan model.Event, 16)
ps, err := NewProxyServer("127.0.0.1:0", ch)
if err != nil {
t.Fatalf("NewProxyServer() error = %v", err)
}
t.Cleanup(func() { Destroy(ps, ch) })
if ps.Listener == nil || *ps.Listener == nil {
t.Fatal("Listener is nil")
}
if ps.Server == nil {
t.Fatal("Server is nil")
}
if ps.Server.Handler == nil {
t.Fatal("Server.Handler is nil")
}
if !strings.HasPrefix(ps.Url, "http://") {
t.Fatalf("Url = %q, want http:// prefix", ps.Url)
}
if !ps.CAReady {
t.Fatal("CAReady = false, want true")
}
if ps.CACertPath == "" {
t.Fatal("CACertPath should not be empty")
}
if got, want := filepath.Dir(ps.CACertPath), filepath.Join(configRoot, caDirName); got != want {
t.Fatalf("CACertPath dir = %q, want %q", got, want)
}
if _, statErr := os.Stat(ps.CACertPath); statErr != nil {
t.Fatalf("expected CA cert on disk, stat error = %v", statErr)
}
if ps.Conns == nil {
t.Fatal("Conns map is nil")
}
}
func TestNewProxyServer_CACreatedFlagAcrossRuns(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
ch := make(chan model.Event, 16)
ps1, err := NewProxyServer("127.0.0.1:0", ch)
if err != nil {
t.Fatalf("first NewProxyServer() error = %v", err)
}
Destroy(ps1, ch)
ps2, err := NewProxyServer("127.0.0.1:0", ch)
if err != nil {
t.Fatalf("second NewProxyServer() error = %v", err)
}
t.Cleanup(func() { Destroy(ps2, ch) })
if !ps1.CACreated {
t.Fatal("first NewProxyServer should report CACreated=true")
}
if ps2.CACreated {
t.Fatal("second NewProxyServer should report CACreated=false when CA already exists")
}
}
func TestNewProxyServer_ErrorWhenCASetupFails(t *testing.T) {
t.Setenv("XDG_CONFIG_HOME", "")
t.Setenv("HOME", "")
ch := make(chan model.Event, 8)
ps, err := NewProxyServer("127.0.0.1:0", ch)
if err == nil {
Destroy(ps, ch)
t.Fatal("NewProxyServer() error = nil, want non-nil")
}
if ps != nil {
t.Fatalf("proxy server = %#v, want nil", ps)
}
}
func TestNewProxyServer_ErrorWhenListenFails(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
occupied, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen() error = %v", err)
}
t.Cleanup(func() { _ = occupied.Close() })
addr := occupied.Addr().String()
ch := make(chan model.Event, 8)
ps, gotErr := NewProxyServer(addr, ch)
if gotErr == nil {
Destroy(ps, ch)
t.Fatal("NewProxyServer() error = nil, want non-nil")
}
if ps != nil {
t.Fatalf("proxy server = %#v, want nil", ps)
}
}
func TestDestroy(t *testing.T) {
t.Parallel()
t.Run("nil-safe", func(t *testing.T) {
t.Parallel()
Destroy(nil, make(chan model.Event, 1))
})
t.Run("emits ProxyStopped when server exists", func(t *testing.T) {
t.Parallel()
ch := make(chan model.Event, 2)
ps := &model.ProxyServer{Server: &http.Server{}, Conns: make(map[net.Conn]struct{})}
Destroy(ps, ch)
select {
case ev := <-ch:
if ev.Type != model.EventTypeProxyStopped {
t.Fatalf("event type = %s, want %s", ev.Type, model.EventTypeProxyStopped)
}
case <-time.After(time.Second):
t.Fatal("timeout waiting for ProxyStopped event")
}
})
}
func TestConnectionTrackingHelpers(t *testing.T) {
t.Parallel()
t.Run("track and untrack are nil-safe", func(t *testing.T) {
t.Parallel()
trackConnection(nil, nil)
untrackConnection(nil, nil)
closeTrackedConnections(nil)
})
t.Run("track/untrack mutate connection set", func(t *testing.T) {
t.Parallel()
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
c := &closeTrackingConn{}
trackConnection(ps, c)
if _, ok := ps.Conns[c]; !ok {
t.Fatal("connection not tracked")
}
untrackConnection(ps, c)
if _, ok := ps.Conns[c]; ok {
t.Fatal("connection still tracked after untrack")
}
})
t.Run("closeTrackedConnections closes all tracked conns", func(t *testing.T) {
t.Parallel()
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
c1 := &closeTrackingConn{}
c2 := &closeTrackingConn{}
ps.Conns[c1] = struct{}{}
ps.Conns[c2] = struct{}{}
closeTrackedConnections(ps)
if !c1.isClosed() || !c2.isClosed() {
t.Fatalf("expected all tracked conns closed, got c1=%v c2=%v", c1.isClosed(), c2.isClosed())
}
})
}
func TestDestroy_ShutdownsServerContext(t *testing.T) {
t.Parallel()
// Basic smoke test: ensure a real server can be shut down via Destroy.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen error = %v", err)
}
srv := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { _, _ = w.Write([]byte("ok")) })}
go func() {
_ = srv.Serve(ln)
}()
ch := make(chan model.Event, 2)
ps := &model.ProxyServer{Server: srv, Conns: make(map[net.Conn]struct{})}
Destroy(ps, ch)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil && err != http.ErrServerClosed {
t.Fatalf("server should be closed after Destroy, got shutdown error %v", err)
}
}