termtap/internal/proxy/integration_https_mitm_test.go
Hayden Hargreaves 002773e77f test: AI generated all of these tests
Just for the MVP of course. Need to validate the idea.
2026-04-23 19:47:04 -07:00

302 lines
8.6 KiB
Go

package proxy
import (
"bufio"
"crypto/tls"
"crypto/x509"
"errors"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
"termtap.dev/internal/model"
)
type hijackPipeWriter struct {
conn net.Conn
header http.Header
code int
}
func (w *hijackPipeWriter) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *hijackPipeWriter) WriteHeader(statusCode int) {
w.code = statusCode
}
func (w *hijackPipeWriter) Write(p []byte) (int, error) {
return len(p), nil
}
func (w *hijackPipeWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.conn, nil, nil
}
func startMITMConnect(t *testing.T, transport http.RoundTripper) (net.Conn, *tls.Conn, chan model.Event, chan struct{}) {
t.Helper()
ca := newTestCA(t)
clientConn, serverConn := net.Pipe()
writer := &hijackPipeWriter{conn: serverConn}
req, err := http.NewRequest(http.MethodConnect, "http://example.com:443", nil)
if err != nil {
t.Fatalf("NewRequest(CONNECT) error = %v", err)
}
req.Host = "example.com:443"
ch := make(chan model.Event, 16)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
handleDone := make(chan struct{})
go func() {
handleConnect(writer, req, ch, transport, ca, ps)
close(handleDone)
}()
reader := bufio.NewReader(clientConn)
connectResp, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
if err != nil {
_ = clientConn.Close()
_ = serverConn.Close()
t.Fatalf("ReadResponse(CONNECT established) error = %v", err)
}
if connectResp.StatusCode != http.StatusOK {
_ = clientConn.Close()
_ = serverConn.Close()
t.Fatalf("CONNECT established status = %d, want %d", connectResp.StatusCode, http.StatusOK)
}
pool := x509.NewCertPool()
pool.AddCert(ca.cert)
tlsClient := tls.Client(clientConn, &tls.Config{
ServerName: "example.com",
RootCAs: pool,
MinVersion: tls.VersionTLS12,
})
if err := tlsClient.Handshake(); err != nil {
_ = tlsClient.Close()
_ = serverConn.Close()
t.Fatalf("tls handshake error = %v", err)
}
t.Cleanup(func() {
_ = tlsClient.Close()
_ = serverConn.Close()
})
return clientConn, tlsClient, ch, handleDone
}
func TestHTTPSE2E_MITMHandleConnectFlow(t *testing.T) {
transport := &mockTransport{fn: func(r *http.Request) (*http.Response, error) {
if r.URL.Scheme != "https" {
t.Fatalf("inner request scheme = %q, want https", r.URL.Scheme)
}
if r.URL.Host == "" {
t.Fatal("inner request host should be populated")
}
return &http.Response{
StatusCode: http.StatusCreated,
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{"Content-Type": {"text/plain"}},
Body: io.NopCloser(strings.NewReader("mitm-ok")),
}, nil
}}
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
// 3) Send decrypted inner request over tunnel and read MITM response
innerReq, err := http.NewRequest(http.MethodGet, "https://example.com/inside", nil)
if err != nil {
t.Fatalf("NewRequest(inner) error = %v", err)
}
innerReq.Host = "example.com"
innerReq.Close = true
if err := innerReq.Write(tlsClient); err != nil {
t.Fatalf("innerReq.Write() error = %v", err)
}
innerResp, err := http.ReadResponse(bufio.NewReader(tlsClient), innerReq)
if err != nil {
t.Fatalf("ReadResponse(inner) error = %v", err)
}
defer innerResp.Body.Close()
body, err := io.ReadAll(innerResp.Body)
if err != nil {
t.Fatalf("ReadAll(inner body) error = %v", err)
}
if innerResp.StatusCode != http.StatusCreated {
t.Fatalf("inner status = %d, want %d", innerResp.StatusCode, http.StatusCreated)
}
if got, want := string(body), "mitm-ok"; got != want {
t.Fatalf("inner body = %q, want %q", got, want)
}
_ = tlsClient.Close()
_ = clientConn.Close()
select {
case <-handleDone:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for handleConnect to return")
}
events := drainEvents(t, ch, 4, 2*time.Second)
if events[0].Type != model.EventTypeRequestStarted {
t.Fatalf("event[0] = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
}
if events[1].Type != model.EventTypeRequestStarted {
t.Fatalf("event[1] = %s, want %s", events[1].Type, model.EventTypeRequestStarted)
}
if events[2].Type != model.EventTypeRequestFinished {
t.Fatalf("event[2] = %s, want %s", events[2].Type, model.EventTypeRequestFinished)
}
if events[3].Type != model.EventTypeRequestFinished {
t.Fatalf("event[3] = %s, want %s", events[3].Type, model.EventTypeRequestFinished)
}
}
func TestHTTPSE2E_MITMUpstreamErrorReturnsHTTPErrorInsideTunnel(t *testing.T) {
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
return nil, errors.New("upstream exploded")
}}
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
innerReq, err := http.NewRequest(http.MethodGet, "https://example.com/fail", nil)
if err != nil {
t.Fatalf("NewRequest(inner) error = %v", err)
}
innerReq.Host = "example.com"
if err := innerReq.Write(tlsClient); err != nil {
t.Fatalf("innerReq.Write() error = %v", err)
}
innerResp, err := http.ReadResponse(bufio.NewReader(tlsClient), innerReq)
if err != nil {
t.Fatalf("ReadResponse(inner) error = %v", err)
}
defer innerResp.Body.Close()
if innerResp.StatusCode != http.StatusBadGateway {
t.Fatalf("inner status = %d, want %d", innerResp.StatusCode, http.StatusBadGateway)
}
_ = tlsClient.Close()
_ = clientConn.Close()
select {
case <-handleDone:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for handleConnect to return")
}
events := drainEvents(t, ch, 4, 2*time.Second)
if events[0].Type != model.EventTypeRequestStarted ||
events[1].Type != model.EventTypeRequestStarted ||
events[2].Type != model.EventTypeRequestFailed ||
events[3].Type != model.EventTypeRequestFailed {
t.Fatalf("unexpected event sequence: %#v", events)
}
}
func TestHTTPSE2E_MITMHandshakeFailureEmitsConnectFailed(t *testing.T) {
ca := newTestCA(t)
clientConn, serverConn := net.Pipe()
t.Cleanup(func() {
_ = clientConn.Close()
_ = serverConn.Close()
})
writer := &hijackPipeWriter{conn: serverConn}
req, err := http.NewRequest(http.MethodConnect, "http://example.com:443", nil)
if err != nil {
t.Fatalf("NewRequest(CONNECT) error = %v", err)
}
req.Host = "example.com:443"
ch := make(chan model.Event, 16)
ps := &model.ProxyServer{Conns: make(map[net.Conn]struct{})}
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
return nil, errors.New("should not be called")
}}
handleDone := make(chan struct{})
go func() {
handleConnect(writer, req, ch, transport, ca, ps)
close(handleDone)
}()
reader := bufio.NewReader(clientConn)
connectResp, err := http.ReadResponse(reader, &http.Request{Method: http.MethodConnect})
if err != nil {
t.Fatalf("ReadResponse(CONNECT established) error = %v", err)
}
if connectResp.StatusCode != http.StatusOK {
t.Fatalf("CONNECT established status = %d, want %d", connectResp.StatusCode, http.StatusOK)
}
// Write plaintext (not TLS handshake) to force handshake failure.
if _, err := clientConn.Write([]byte("not tls handshake")); err != nil {
t.Fatalf("client write error = %v", err)
}
_ = clientConn.Close()
select {
case <-handleDone:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for handleConnect to return on handshake failure")
}
events := drainEvents(t, ch, 2, 2*time.Second)
if events[0].Type != model.EventTypeRequestStarted || events[1].Type != model.EventTypeRequestFailed {
t.Fatalf("unexpected event sequence: %#v", events)
}
if !strings.Contains(events[1].Body, "TLS handshake with client failed") {
t.Fatalf("failed event body = %q, want handshake failure details", events[1].Body)
}
}
func TestHTTPSE2E_MITMDecryptedReadFailureEmitsConnectFailed(t *testing.T) {
transport := &mockTransport{fn: func(*http.Request) (*http.Response, error) {
return nil, errors.New("should not be called")
}}
clientConn, tlsClient, ch, handleDone := startMITMConnect(t, transport)
// Send malformed decrypted payload so ReadRequest fails (non-EOF branch).
if _, err := tlsClient.Write([]byte("BAD REQUEST\r\n\r\n")); err != nil {
t.Fatalf("tlsClient.Write() error = %v", err)
}
_ = tlsClient.Close()
_ = clientConn.Close()
select {
case <-handleDone:
case <-time.After(3 * time.Second):
t.Fatal("timeout waiting for handleConnect to return on decrypted read failure")
}
events := drainEvents(t, ch, 2, 2*time.Second)
if events[0].Type != model.EventTypeRequestStarted || events[1].Type != model.EventTypeRequestFailed {
t.Fatalf("unexpected event sequence: %#v", events)
}
if !strings.Contains(events[1].Body, "failed to read decrypted HTTPS request") {
t.Fatalf("failed event body = %q, want read failure details", events[1].Body)
}
}