diff --git a/.opencode/agents/review.md b/.opencode/agents/review.md index 02a64af..04321fb 100644 --- a/.opencode/agents/review.md +++ b/.opencode/agents/review.md @@ -11,6 +11,7 @@ permission: "git log*": allow "git *": allow "grep *": allow + "go *": allow webfetch: deny color: "#e01da6" --- diff --git a/doc/event-pressure-notes.md b/doc/event-pressure-notes.md new file mode 100644 index 0000000..02585b4 --- /dev/null +++ b/doc/event-pressure-notes.md @@ -0,0 +1,42 @@ +# Event Pressure Notes + +This is a quick note on potential event-channel pressure in the current proxy architecture. + +## Why this matters + +Proxy request handling currently emits events synchronously into a shared channel. +If producers are faster than the consumer (TUI/event loop), the channel can fill and block producers. +When that happens, request handling can stall even if upstream/downstream network paths are healthy. + +## Where pressure comes from + +- Every request can produce multiple lifecycle events (`started`, `finished`, `failed`, warnings). +- CONNECT + MITM flow can emit both tunnel-level and inner-request events. +- Bursty traffic (many small requests, retries, connection churn) amplifies event rate quickly. + +## User-visible symptoms + +- Request latency spikes that do not match upstream timings. +- Intermittent pauses during high traffic. +- Shutdown/restart feeling delayed when many events are in flight. + +## Current risk profile + +- Channel buffer size helps absorb bursts, but only up to a point. +- Backpressure is currently coupled to request path execution, so stalls propagate into proxy behavior. + +## Mitigation options + +1. Introduce non-blocking event enqueue for low-priority events. + - Keep critical events blocking (fatal/start/stop), but drop or coalesce high-volume request events under load. +2. Add an internal event relay. + - Proxy handlers write to a local buffered queue; a dedicated goroutine forwards to the main channel. +3. Coalesce repetitive events. + - Aggregate similar warnings or per-interval request counters instead of per-request chatter. +4. Add lightweight metrics. + - Track dropped/coalesced events and queue depth so pressure is visible during development. + +## Practical near-term suggestion + +Start with a small event relay + drop policy for non-critical request events when queue depth is high. +This contains proxy-path stalls without changing the external event model too much. diff --git a/internal/app/proxy.go b/internal/app/proxy.go index 51674fc..7da3ce8 100644 --- a/internal/app/proxy.go +++ b/internal/app/proxy.go @@ -20,6 +20,20 @@ func StartProxy(ps *model.ProxyServer, ch chan<- model.Event) { Body: fmt.Sprintf("proxy server started on %s", (*ps.Listener).Addr().String()), } + if ps.CAReady && !ps.CATrusted { + body := fmt.Sprintf("HTTPS interception CA available at %s; trust this certificate to inspect HTTPS traffic", ps.CACertPath) + eventType := model.EventTypeWarn + if ps.CACreated { + body = fmt.Sprintf("generated HTTPS interception CA at %s; trust this certificate to inspect HTTPS traffic", ps.CACertPath) + } + + ch <- model.Event{ + Time: time.Now().Local(), + Type: eventType, + Body: body, + } + } + if err := ps.Server.Serve(*ps.Listener); err != nil { if errors.Is(err, http.ErrServerClosed) { return diff --git a/internal/cli/run.go b/internal/cli/run.go index 8fef507..0f6fcbd 100644 --- a/internal/cli/run.go +++ b/internal/cli/run.go @@ -4,9 +4,12 @@ import ( "fmt" "log" "os" + "runtime" + "strconv" "termtap.dev/internal/app" "termtap.dev/internal/model" + "termtap.dev/internal/proxy" "termtap.dev/internal/tui" ) @@ -14,6 +17,11 @@ import ( const proxy_addr = "127.0.0.1:8080" func Run(args []string) { + if len(args) >= 2 && args[1] == "cert" { + runCert() + return + } + cmd, ok := parseCommand(args) if !ok { displayHelp() @@ -55,8 +63,53 @@ func parseCommand(args []string) (model.Command, bool) { func displayHelp() { helpText := ` usage: + tap cert tap run -- [args...] ` fmt.Fprintln(os.Stderr, helpText) } + +func runCert() { + ca, err := proxy.EnsureCertificateAuthority() + if err != nil { + log.Fatalln(err) + } + + certPath := ca.CertPath() + quotedCertPath := strconv.Quote(certPath) + fmt.Printf("Certificate path: %s\n", certPath) + if ca.WasCreated() { + fmt.Println("Created a new local HTTPS interception CA.") + } else { + fmt.Println("Using existing local HTTPS interception CA.") + } + + trusted, err := ca.IsTrustedBySystem() + if err != nil { + fmt.Printf("System trust check failed: %v\n", err) + } else if trusted { + fmt.Println("System trust store: trusted") + } else { + fmt.Println("System trust store: not trusted") + } + + if runtime.GOOS != "linux" { + fmt.Println("Install this certificate into your OS or client trust store to inspect HTTPS traffic.") + return + } + + fmt.Println() + fmt.Println("Trust instructions (Linux):") + fmt.Println("Debian/Ubuntu:") + fmt.Printf(" sudo cp %s /usr/local/share/ca-certificates/termtap.crt\n", quotedCertPath) + fmt.Println(" sudo update-ca-certificates") + fmt.Println("Fedora/RHEL/CentOS:") + fmt.Printf(" sudo cp %s /etc/pki/ca-trust/source/anchors/termtap.crt\n", quotedCertPath) + fmt.Println(" sudo update-ca-trust") + fmt.Println("Arch:") + fmt.Printf(" sudo trust anchor %s\n", quotedCertPath) + fmt.Println() + fmt.Println("Quick curl test:") + fmt.Printf(" curl --proxy http://%s --cacert %s https://example.com\n", proxy_addr, quotedCertPath) +} diff --git a/internal/model/proxy.go b/internal/model/proxy.go index ebb1e4e..974796f 100644 --- a/internal/model/proxy.go +++ b/internal/model/proxy.go @@ -4,6 +4,7 @@ import ( "net" "net/http" "net/url" + "sync" "time" "github.com/google/uuid" @@ -13,6 +14,14 @@ type ProxyServer struct { Listener *net.Listener Server *http.Server Url string + + CACertPath string + CAReady bool + CACreated bool + CATrusted bool + + ConnMu sync.Mutex + Conns map[net.Conn]struct{} } type Request struct { diff --git a/internal/process/runner.go b/internal/process/runner.go index 18b7ad7..95e400f 100644 --- a/internal/process/runner.go +++ b/internal/process/runner.go @@ -59,7 +59,7 @@ func injectEnv(proc *exec.Cmd, addr string) { injected := []string{ "HTTP_PROXY=" + proxyAddr, "http_proxy=" + proxyAddr, - "HTTPS_PROXY=" + proxyAddr, // TODO: HTTP NOT SUPPORTED + "HTTPS_PROXY=" + proxyAddr, "https_proxy=" + proxyAddr, // "ALL_PROXY=" + proxyAddr, // "all_proxy=" + proxyAddr, diff --git a/internal/proxy/buffer.go b/internal/proxy/buffer.go new file mode 100644 index 0000000..c713953 --- /dev/null +++ b/internal/proxy/buffer.go @@ -0,0 +1,37 @@ +package proxy + +import ( + "bufio" + "io" + "net" +) + +type bufferedConn struct { + net.Conn + reader io.Reader +} + +func (c *bufferedConn) Read(p []byte) (int, error) { + return c.reader.Read(p) +} + +func wrapBufferedConn(conn net.Conn, readWriter *bufio.ReadWriter) net.Conn { + if readWriter == nil { + return conn + } + + return &bufferedConn{Conn: conn, reader: readWriter} +} + +type previewReadCloser struct { + io.ReadCloser + preview *bodyPreview +} + +func (p *previewReadCloser) Read(data []byte) (int, error) { + n, err := p.ReadCloser.Read(data) + if n > 0 { + p.preview.Write(data[:n]) + } + return n, err +} diff --git a/internal/proxy/certs.go b/internal/proxy/certs.go new file mode 100644 index 0000000..524074a --- /dev/null +++ b/internal/proxy/certs.go @@ -0,0 +1,335 @@ +package proxy + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "math/big" + "net" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +const ( + caDirName = "termtap" + caCertName = "mitm-ca-cert.pem" + caKeyName = "mitm-ca-key.pem" + caValidFor = 10 * 365 * 24 * time.Hour + leafValidFor = 7 * 24 * time.Hour + maxLeafCerts = 256 +) + +type CertificateAuthority struct { + cert *x509.Certificate + key *ecdsa.PrivateKey + certPath string + keyPath string + wasCreated bool + + mu sync.Mutex + leafCert map[string]*tls.Certificate + leafOrder []string +} + +func loadOrCreateCertificateAuthority() (*CertificateAuthority, error) { + configDir, err := os.UserConfigDir() + if err != nil { + return nil, fmt.Errorf("resolve user config dir: %w", err) + } + + baseDir := filepath.Join(configDir, caDirName) + if err := os.MkdirAll(baseDir, 0o700); err != nil { + return nil, fmt.Errorf("create cert dir: %w", err) + } + + ca := &CertificateAuthority{ + certPath: filepath.Join(baseDir, caCertName), + keyPath: filepath.Join(baseDir, caKeyName), + leafCert: make(map[string]*tls.Certificate), + } + + if _, err := os.Stat(ca.certPath); err == nil { + if _, err := os.Stat(ca.keyPath); err == nil { + if err := ca.load(); err != nil { + return nil, err + } + return ca, nil + } + } + + if err := ca.create(); err != nil { + return nil, err + } + + ca.wasCreated = true + return ca, nil +} + +func (ca *CertificateAuthority) load() error { + certPEM, err := os.ReadFile(ca.certPath) + if err != nil { + return fmt.Errorf("read ca cert: %w", err) + } + + keyPEM, err := os.ReadFile(ca.keyPath) + if err != nil { + return fmt.Errorf("read ca key: %w", err) + } + + certBlock, _ := pem.Decode(certPEM) + if certBlock == nil { + return fmt.Errorf("decode ca cert pem") + } + + cert, err := x509.ParseCertificate(certBlock.Bytes) + if err != nil { + return fmt.Errorf("parse ca cert: %w", err) + } + + keyBlock, _ := pem.Decode(keyPEM) + if keyBlock == nil { + return fmt.Errorf("decode ca key pem") + } + + key, err := x509.ParseECPrivateKey(keyBlock.Bytes) + if err != nil { + return fmt.Errorf("parse ca key: %w", err) + } + + ca.cert = cert + ca.key = key + return nil +} + +func (ca *CertificateAuthority) create() error { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return fmt.Errorf("generate ca key: %w", err) + } + + serial, err := randSerialNumber() + if err != nil { + return err + } + + now := time.Now() + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: "termtap Local MITM CA", + Organization: []string{"termtap"}, + }, + NotBefore: now.Add(-1 * time.Hour), + NotAfter: now.Add(caValidFor), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 1, + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + if err != nil { + return fmt.Errorf("create ca cert: %w", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return fmt.Errorf("marshal ca key: %w", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + + if err := writeFileAtomically(ca.certPath, certPEM, 0o600); err != nil { + return fmt.Errorf("write ca cert: %w", err) + } + if err := writeFileAtomically(ca.keyPath, keyPEM, 0o600); err != nil { + return fmt.Errorf("write ca key: %w", err) + } + + cert, err := x509.ParseCertificate(der) + if err != nil { + return fmt.Errorf("parse created ca cert: %w", err) + } + + ca.cert = cert + ca.key = key + return nil +} + +func (ca *CertificateAuthority) CertificateForHost(host string) (*tls.Certificate, error) { + host = normalizeCertHost(host) + if host == "" { + return nil, fmt.Errorf("empty host for certificate") + } + + ca.mu.Lock() + defer ca.mu.Unlock() + + if cert, ok := ca.leafCert[host]; ok { + return cert, nil + } + + serial, err := randSerialNumber() + if err != nil { + return nil, err + } + + leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, fmt.Errorf("generate leaf key: %w", err) + } + + now := time.Now() + tmpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + CommonName: host, + }, + NotBefore: now.Add(-1 * time.Hour), + NotAfter: now.Add(leafValidFor), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + if ip := net.ParseIP(host); ip != nil { + tmpl.IPAddresses = []net.IP{ip} + } else { + tmpl.DNSNames = []string{host} + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, ca.cert, &leafKey.PublicKey, ca.key) + if err != nil { + return nil, fmt.Errorf("create leaf cert: %w", err) + } + + tlsCert := &tls.Certificate{ + Certificate: [][]byte{der, ca.cert.Raw}, + PrivateKey: leafKey, + } + leafParsed, err := x509.ParseCertificate(der) + if err != nil { + return nil, fmt.Errorf("parse leaf cert: %w", err) + } + tlsCert.Leaf = leafParsed + + ca.leafCert[host] = tlsCert + ca.leafOrder = append(ca.leafOrder, host) + if len(ca.leafOrder) > maxLeafCerts { + evicted := ca.leafOrder[0] + ca.leafOrder = ca.leafOrder[1:] + delete(ca.leafCert, evicted) + } + + return tlsCert, nil +} + +func (ca *CertificateAuthority) CertPath() string { + if ca == nil { + return "" + } + return ca.certPath +} + +func (ca *CertificateAuthority) WasCreated() bool { + if ca == nil { + return false + } + return ca.wasCreated +} + +func (ca *CertificateAuthority) IsTrustedBySystem() (bool, error) { + if ca == nil || ca.cert == nil { + return false, fmt.Errorf("certificate authority is unavailable") + } + + roots, err := x509.SystemCertPool() + if err != nil { + return false, fmt.Errorf("load system cert pool: %w", err) + } + if roots == nil { + return false, nil + } + + _, err = ca.cert.Verify(x509.VerifyOptions{Roots: roots}) + if err == nil { + return true, nil + } + + if _, ok := errors.AsType[x509.UnknownAuthorityError](err); ok { + return false, nil + } + + return false, err +} + +func EnsureCertificateAuthority() (*CertificateAuthority, error) { + return loadOrCreateCertificateAuthority() +} + +func randSerialNumber() (*big.Int, error) { + limit := new(big.Int).Lsh(big.NewInt(1), 128) + serial, err := rand.Int(rand.Reader, limit) + if err != nil { + return nil, fmt.Errorf("generate serial number: %w", err) + } + return serial, nil +} + +func normalizeCertHost(hostport string) string { + host := strings.TrimSpace(hostport) + if host == "" { + return "" + } + + if parsedHost, _, err := net.SplitHostPort(host); err == nil { + return parsedHost + } + + return host +} + +func writeFileAtomically(path string, data []byte, perm os.FileMode) error { + dir := filepath.Dir(path) + tmpFile, err := os.CreateTemp(dir, ".termtap-tmp-*") + if err != nil { + return err + } + + tmpPath := tmpFile.Name() + cleanup := true + defer func() { + if cleanup { + _ = os.Remove(tmpPath) + } + }() + + if _, err := tmpFile.Write(data); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Chmod(perm); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Close(); err != nil { + return err + } + + if err := os.Rename(tmpPath, path); err != nil { + return err + } + + cleanup = false + return nil +} diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go deleted file mode 100644 index 2100eeb..0000000 --- a/internal/proxy/handler.go +++ /dev/null @@ -1,226 +0,0 @@ -package proxy - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "net" - "net/http" - "sort" - "strings" - "time" - - "github.com/google/uuid" - "termtap.dev/internal/model" -) - -// NOTE: Much of this code is AI generated, and is not expected to make it into production - -const maxPreviewBytes = 1024 - -func proxyHandler(ch chan<- model.Event) http.Handler { - transport := http.DefaultTransport - - // TODO: This should be wired into the main channel, but that will require a model package - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.Method == http.MethodConnect { - http.Error(w, "CONNECT is not supported yet", http.StatusNotImplemented) - ch <- model.Event{ - Time: time.Now().Local(), - Type: model.EventTypeWarn, - Body: fmt.Sprintf("CONNECT is not supported: %s", req.Host), - } - return - } - - if req.URL.Scheme == "" || req.URL.Host == "" { - http.Error(w, "request must use absolute-form URLs through the proxy", http.StatusBadRequest) - ch <- model.Event{ - Time: time.Now().Local(), - Type: model.EventTypeWarn, - Body: fmt.Sprintf("rejected non-proxy request %s %s", req.Method, req.URL.String()), - } - return - } - - start := time.Now() - - request := model.Request{ - ID: uuid.New(), - ResponseData: []byte{}, - RequestData: []byte{}, - URL: "", - Status: -1, - Method: "", - Duration: 0, - Pending: true, - Failed: false, - StartTime: start, - } - - requestPreview, err := readAndRestoreBody(&req.Body) - if err != nil { - ch <- model.Event{ - Time: time.Now().Local(), - Type: model.EventTypeWarn, - Body: fmt.Sprintf("(%s) failed to read request body: %v", getEndOfUUID(request.ID), err), - Request: request, - } - } else { - request.RequestData = []byte(requestPreview) - } - - outReq := req.Clone(req.Context()) - outReq.RequestURI = "" - - request.URL = outReq.URL.Path - request.QueryString = outReq.URL.RawQuery - request.QueryMap = outReq.URL.Query() - request.Host = outReq.Host - request.Method = outReq.Method - request.RequestHeaders = outReq.Header - request.RawURL = outReq.URL.String() - - ch <- model.Event{ - Time: time.Now().Local(), - Type: model.EventTypeRequestStarted, - Body: fmt.Sprintf("(%s) %s %s", getEndOfUUID(request.ID), request.Method, request.RawURL), - Request: request, - } - - resp, err := transport.RoundTrip(outReq) - if err != nil { - status := statusFromUpstreamError(req, resp, err) - - http.Error(w, http.StatusText(status), status) - request.Pending = false - request.Failed = true - request.Duration = time.Since(start).Round(time.Microsecond) - request.Status = status - - ch <- model.Event{ - Time: time.Now().Local(), - Type: model.EventTypeRequestFailed, - Body: fmt.Sprintf("(%s) upstream error: %v", getEndOfUUID(request.ID), err), - Request: request, - } - return - } - defer resp.Body.Close() - - responsePreview, err := readAndRestoreBody(&resp.Body) - if err != nil { - ch <- model.Event{ - Time: time.Now().Local(), - Type: model.EventTypeWarn, - Body: fmt.Sprintf("(%s) failed to read response body: %v", getEndOfUUID(request.ID), err), - Request: request, - } - } else { - request.ResponseData = []byte(responsePreview) - } - - copyHeader(w.Header(), resp.Header) - w.WriteHeader(resp.StatusCode) - if _, err := io.Copy(w, resp.Body); err != nil { - request.Pending = false - request.Failed = true - request.Duration = time.Since(start).Round(time.Microsecond) - request.Status = resp.StatusCode - - ch <- model.Event{ - Time: time.Now().Local(), - Type: model.EventTypeRequestFailed, - Body: fmt.Sprintf("(%s) failed to write response body: %v", getEndOfUUID(request.ID), err), - } - return - } - - request.Duration = time.Since(start).Round(time.Microsecond) - request.Status = resp.StatusCode - request.ResponseHeaders = resp.Header - request.Pending = false - - ch <- model.Event{ - Time: time.Now().Local(), - Type: model.EventTypeRequestFinished, - Body: fmt.Sprintf("(%s) %s %s %d %dms", getEndOfUUID(request.ID), request.Method, request.RawURL, request.Status, request.Duration.Milliseconds()), - Request: request, - } - }) -} - -func copyHeader(dst, src http.Header) { - for key, values := range src { - for _, value := range values { - dst.Add(key, value) - } - } -} - -func readAndRestoreBody(body *io.ReadCloser) (string, error) { - if body == nil || *body == nil { - return "", nil - } - - payload, err := io.ReadAll(*body) - if err != nil { - return "", err - } - - *body = io.NopCloser(bytes.NewReader(payload)) - - preview := payload - if len(preview) > maxPreviewBytes { - preview = preview[:maxPreviewBytes] - } - - text := strings.ReplaceAll(string(preview), "\n", "\\n") - if len(payload) > maxPreviewBytes { - text += "..." - } - - return text, nil -} - -func formatHeaders(headers http.Header) string { - if len(headers) == 0 { - return "" - } - - parts := make([]string, 0, len(headers)) - for key, values := range headers { - parts = append(parts, fmt.Sprintf("%s=%q", key, strings.Join(values, ","))) - } - sort.Strings(parts) - - return strings.Join(parts, ", ") -} - -func getEndOfUUID(id uuid.UUID) string { - return id.String()[24:] -} - -// BUG: Not sure if this actually works, seems to favor the 502 -func statusFromUpstreamError(req *http.Request, resp *http.Response, err error) int { - if resp != nil { - return resp.StatusCode - } - - if errors.Is(req.Context().Err(), context.Canceled) { - return http.StatusBadGateway - } - - if errors.Is(err, context.DeadlineExceeded) { - return http.StatusGatewayTimeout - } - - var netErr net.Error - if errors.As(err, &netErr) && netErr.Timeout() { - return http.StatusGatewayTimeout - } - - return http.StatusBadGateway -} diff --git a/internal/proxy/handlers.go b/internal/proxy/handlers.go new file mode 100644 index 0000000..a85ceeb --- /dev/null +++ b/internal/proxy/handlers.go @@ -0,0 +1,183 @@ +package proxy + +import ( + "bufio" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "time" + + "termtap.dev/internal/model" +) + +const connectIdleTimeout = 30 * time.Second + +func proxyHandler(ch chan<- model.Event, ca *CertificateAuthority, ps *model.ProxyServer) http.Handler { + transport := newUpstreamTransport() + + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.Method == http.MethodConnect { + handleConnect(w, req, ch, transport, ca, ps) + return + } + + if req.URL.Scheme == "" || req.URL.Host == "" { + http.Error(w, "request must use absolute-form URLs through the proxy", http.StatusBadRequest) + ch <- model.Event{ + Time: time.Now().Local(), + Type: model.EventTypeWarn, + Body: fmt.Sprintf("rejected non-proxy request %s %s", req.Method, req.URL.String()), + } + return + } + + resp, request, responsePreview, err := roundTripCapturedRequest(req, transport, ch, "", false) + if err != nil { + status := statusFromUpstreamError(req, resp, err) + + http.Error(w, http.StatusText(status), status) + failRequest(ch, request, status, fmt.Sprintf("upstream error: %v", err)) + return + } + defer resp.Body.Close() + + copyHeaders(resp.Header, w.Header()) + w.WriteHeader(resp.StatusCode) + if _, err := io.Copy(w, resp.Body); err != nil { + request.ResponseData = responsePreview.Preview() + failRequest(ch, request, resp.StatusCode, fmt.Sprintf("failed to write response body: %v", err)) + return + } + + request.ResponseData = responsePreview.Preview() + finishRequest(ch, request, resp.StatusCode) + }) +} + +func handleConnect(w http.ResponseWriter, req *http.Request, ch chan<- model.Event, transport http.RoundTripper, ca *CertificateAuthority, ps *model.ProxyServer) { + start := time.Now() + + request := newConnectRequest(req, start) + startRequest(ch, request) + + target := req.Host + if !strings.Contains(target, ":") { + target = net.JoinHostPort(target, "443") + } + + if ca == nil { + http.Error(w, "HTTPS interception unavailable", http.StatusBadGateway) + failRequest(ch, request, http.StatusBadGateway, "HTTPS interception certificate authority is unavailable") + return + } + + leafCert, err := ca.CertificateForHost(target) + if err != nil { + http.Error(w, "failed to prepare interception certificate", http.StatusBadGateway) + failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("failed to mint interception certificate for %s: %v", target, err)) + return + } + + hijacker, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "proxy does not support hijacking", http.StatusInternalServerError) + failRequest(ch, request, http.StatusInternalServerError, "CONNECT hijack is unavailable") + return + } + + clientConn, readWriter, err := hijacker.Hijack() + if err != nil { + http.Error(w, "failed to hijack connection", http.StatusInternalServerError) + failRequest(ch, request, http.StatusInternalServerError, fmt.Sprintf("CONNECT hijack failed: %v", err)) + return + } + trackConnection(ps, clientConn) + defer func() { + untrackConnection(ps, clientConn) + _ = clientConn.Close() + }() + + if err := writeConnectEstablished(clientConn, readWriter); err != nil { + failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("CONNECT setup failed: %v", err)) + return + } + + mitmConn := wrapBufferedConn(clientConn, readWriter) + tlsConn := tls.Server(mitmConn, &tls.Config{ + Certificates: []tls.Certificate{*leafCert}, + MinVersion: tls.VersionTLS12, + }) + defer tlsConn.Close() + + _ = clientConn.SetDeadline(time.Now().Add(connectIdleTimeout)) + if err := tlsConn.Handshake(); err != nil { + failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("TLS handshake with client failed: %v", err)) + return + } + _ = clientConn.SetDeadline(time.Time{}) + + reader := bufio.NewReader(tlsConn) + writer := bufio.NewWriter(tlsConn) + + for { + _ = clientConn.SetReadDeadline(time.Now().Add(connectIdleTimeout)) + innerReq, err := http.ReadRequest(reader) + if err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + finishRequest(ch, request, http.StatusOK) + return + } + failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("failed to read decrypted HTTPS request: %v", err)) + return + } + _ = clientConn.SetReadDeadline(time.Time{}) + + resp, captured, responsePreview, err := roundTripCapturedRequest(innerReq, transport, ch, target, true) + if err != nil { + discardAndCloseBody(innerReq.Body) + status := statusFromUpstreamError(innerReq, resp, err) + _ = clientConn.SetWriteDeadline(time.Now().Add(connectIdleTimeout)) + if writeErr := writePlainHTTPError(writer, status); writeErr != nil { + failRequest(ch, captured, status, fmt.Sprintf("upstream error: %v", err)) + failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("failed to write HTTPS error response: %v", writeErr)) + return + } + _ = clientConn.SetWriteDeadline(time.Time{}) + failRequest(ch, captured, status, fmt.Sprintf("upstream error: %v", err)) + failRequest(ch, request, status, fmt.Sprintf("closing CONNECT tunnel after upstream error: %v", err)) + return + } + + _ = clientConn.SetWriteDeadline(time.Now().Add(connectIdleTimeout)) + if err := resp.Write(writer); err != nil { + resp.Body.Close() + captured.ResponseData = responsePreview.Preview() + failRequest(ch, captured, resp.StatusCode, fmt.Sprintf("failed to write HTTPS response: %v", err)) + failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("failed to write HTTPS response: %v", err)) + return + } + + if err := writer.Flush(); err != nil { + _ = clientConn.SetWriteDeadline(time.Time{}) + resp.Body.Close() + captured.ResponseData = responsePreview.Preview() + failRequest(ch, captured, resp.StatusCode, fmt.Sprintf("failed to flush HTTPS response: %v", err)) + failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("failed to flush HTTPS response: %v", err)) + return + } + _ = clientConn.SetWriteDeadline(time.Time{}) + + captured.ResponseData = responsePreview.Preview() + finishRequest(ch, captured, resp.StatusCode) + shouldClose := innerReq.Close || resp.Close + resp.Body.Close() + if shouldClose { + finishRequest(ch, request, http.StatusOK) + return + } + } +} diff --git a/internal/proxy/headers.go b/internal/proxy/headers.go new file mode 100644 index 0000000..5a23a07 --- /dev/null +++ b/internal/proxy/headers.go @@ -0,0 +1,67 @@ +package proxy + +import ( + "net/http" + "strings" +) + +var sensitiveHeaders = map[string]struct{}{ + "Authorization": {}, + "Cookie": {}, + "Proxy-Authorization": {}, + "Set-Cookie": {}, + "X-Api-Key": {}, +} + +var hopByHopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Proxy-Connection", + "Te", + "Trailer", + "Transfer-Encoding", + "Upgrade", +} + +// Remove headers that are only required for client<->proxy and proxy<->server communication. +// Otherwise known as hop-by-hop headers. We do not want to show these to users since they are +// used only for internal functioning for the proxy server. +func stripHopByHopHeaders(headers http.Header) { + if headers == nil { + return + } + + connectionValues := append([]string(nil), headers.Values("Connection")...) + for _, key := range hopByHopHeaders { + headers.Del(key) + } + + for _, value := range connectionValues { + for key := range strings.SplitSeq(value, ",") { + headers.Del(strings.TrimSpace(key)) + } + } +} + +// Return a new set of headers that has sensitive headers redacted. +// +// TODO: Maybe use '***' length of header? +func redactHeaders(headers http.Header) http.Header { + clone := headers.Clone() + for key := range clone { + if _, ok := sensitiveHeaders[http.CanonicalHeaderKey(key)]; ok { + clone.Set(key, "[REDACTED]") + } + } + return clone +} + +func copyHeaders(src, dest http.Header) { + for key, values := range src { + for _, value := range values { + dest.Add(key, value) + } + } +} diff --git a/internal/proxy/preview.go b/internal/proxy/preview.go new file mode 100644 index 0000000..20d4ae2 --- /dev/null +++ b/internal/proxy/preview.go @@ -0,0 +1,50 @@ +package proxy + +import ( + "bytes" + "strings" +) + +const maxPreviewBytes = 1024 * 64 // 64 kb (maybe we want 256kb) + +type bodyPreview struct { + enabled bool + truncated bool + buf bytes.Buffer +} + +func newBodyPreview(contentType string) *bodyPreview { + return &bodyPreview{enabled: canDisplayContent(contentType)} +} + +func (p *bodyPreview) Write(data []byte) { + if p == nil || !p.enabled || len(data) == 0 { + return + } + + remaining := maxPreviewBytes - p.buf.Len() + if remaining <= 0 { + p.truncated = true + return + } + + if len(data) > remaining { + data = data[:remaining] + p.truncated = true + } + + _, _ = p.buf.Write(data) +} + +func (p *bodyPreview) Preview() []byte { + if p == nil || !p.enabled || p.buf.Len() == 0 { + return []byte{} + } + + text := strings.ReplaceAll(p.buf.String(), "\n", "\\n") + if p.truncated { + text += "..." + } + + return []byte(text) +} diff --git a/internal/proxy/requests.go b/internal/proxy/requests.go new file mode 100644 index 0000000..e21f0c5 --- /dev/null +++ b/internal/proxy/requests.go @@ -0,0 +1,128 @@ +package proxy + +import ( + "fmt" + "net/http" + "time" + + "github.com/google/uuid" + "termtap.dev/internal/model" +) + +func roundTripCapturedRequest(req *http.Request, transport http.RoundTripper, ch chan<- model.Event, defaultHost string, interceptedTLS bool) (*http.Response, model.Request, *bodyPreview, error) { + start := time.Now() + request := model.Request{ + ID: uuid.New(), + ResponseData: []byte{}, + RequestData: []byte{}, + URL: "", + Status: -1, + Method: "", + Duration: 0, + Pending: true, + Failed: false, + StartTime: start, + } + + outReq := req.Clone(req.Context()) + outReq.RequestURI = "" + if interceptedTLS { + if outReq.URL.Scheme == "" { + outReq.URL.Scheme = "https" + } + if outReq.URL.Host == "" { + outReq.URL.Host = defaultHost + } + if outReq.Host == "" { + outReq.Host = defaultHost + } + } + stripHopByHopHeaders(outReq.Header) + requestPreview := newBodyPreview(outReq.Header.Get("Content-Type")) + if outReq.Body != nil { + outReq.Body = &previewReadCloser{ReadCloser: outReq.Body, preview: requestPreview} + } + + request.URL = outReq.URL.Path + request.QueryString = outReq.URL.RawQuery + request.QueryMap = outReq.URL.Query() + request.Host = outReq.Host + request.Method = outReq.Method + request.RequestHeaders = redactHeaders(outReq.Header) + request.RawURL = outReq.URL.String() + if request.RawURL == "" { + request.RawURL = outReq.Host + outReq.URL.RequestURI() + } + + startRequest(ch, request) + + resp, err := transport.RoundTrip(outReq) + request.RequestData = requestPreview.Preview() + if err != nil { + return resp, request, nil, err + } + + stripHopByHopHeaders(resp.Header) + responsePreview := newBodyPreview(resp.Header.Get("Content-Type")) + if resp.Body != nil { + resp.Body = &previewReadCloser{ReadCloser: resp.Body, preview: responsePreview} + } + + request.ResponseHeaders = redactHeaders(resp.Header) + return resp, request, responsePreview, nil +} + +func newConnectRequest(req *http.Request, start time.Time) model.Request { + // CONNECT requests do not have as much data, which is why we use Host for most of the pieces + return model.Request{ + ID: uuid.New(), + ResponseData: []byte{}, + RequestData: []byte{}, + URL: req.Host, + RawURL: req.Host, + Host: req.Host, + Status: -1, + Method: req.Method, + Duration: 0, + Pending: true, + Failed: false, + StartTime: start, + } +} + +func finishRequest(ch chan<- model.Event, request model.Request, status int) { + request.Pending = false + request.Failed = false + request.Status = status + request.Duration = time.Since(request.StartTime).Round(time.Microsecond) + + ch <- model.Event{ + Time: time.Now().Local(), + Type: model.EventTypeRequestFinished, + Body: fmt.Sprintf("(%s) %s %s %d %dms", getEndOfUUID(request.ID), request.Method, request.RawURL, request.Status, request.Duration.Milliseconds()), + Request: request, + } +} + +func failRequest(ch chan<- model.Event, request model.Request, status int, body string) { + request.Pending = false + request.Failed = true + request.Status = status + request.Duration = time.Since(request.StartTime).Round(time.Microsecond) + + ch <- model.Event{ + Time: time.Now().Local(), + Type: model.EventTypeRequestFailed, + Body: fmt.Sprintf("(%s) %s", getEndOfUUID(request.ID), body), + Request: request, + } +} + +func startRequest(ch chan<- model.Event, request model.Request) { + ch <- model.Event{ + Time: time.Now().Local(), + Type: model.EventTypeRequestStarted, + Body: fmt.Sprintf("(%s) %s %s", getEndOfUUID(request.ID), request.Method, request.RawURL), + Request: request, + } +} diff --git a/internal/proxy/secure_utils.go b/internal/proxy/secure_utils.go new file mode 100644 index 0000000..fead3da --- /dev/null +++ b/internal/proxy/secure_utils.go @@ -0,0 +1,30 @@ +package proxy + +import ( + "bufio" + "io" + "net" +) + +const maxDiscardBodyBytes = 1 << 20 + +func writeConnectEstablished(conn net.Conn, readWriter *bufio.ReadWriter) error { + if readWriter != nil { + if _, err := readWriter.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil { + return err + } + return readWriter.Flush() + } + + _, err := conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) + return err +} + +func discardAndCloseBody(body io.ReadCloser) { + if body == nil { + return + } + + _, _ = io.Copy(io.Discard, io.LimitReader(body, maxDiscardBodyBytes)) + _ = body.Close() +} diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 33b85e3..4508676 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -10,7 +10,22 @@ import ( "termtap.dev/internal/model" ) +const ( + proxyReadHeaderTimeout = 10 * time.Second + proxyIdleTimeout = 30 * time.Second +) + func NewProxyServer(addr string, ch chan<- model.Event) (*model.ProxyServer, error) { + ca, err := loadOrCreateCertificateAuthority() + if err != nil { + return nil, err + } + + trusted, err := ca.IsTrustedBySystem() + if err != nil { + trusted = false + } + listener, err := net.Listen("tcp", addr) if err != nil { return nil, err @@ -19,9 +34,18 @@ func NewProxyServer(addr string, ch chan<- model.Event) (*model.ProxyServer, err url := fmt.Sprintf("http://%s", listener.Addr().String()) ps := &model.ProxyServer{ - Listener: &listener, - Server: &http.Server{Handler: proxyHandler(ch)}, - Url: url, + Listener: &listener, + Url: url, + CACertPath: ca.CertPath(), + CAReady: true, + CACreated: ca.WasCreated(), + CATrusted: trusted, + Conns: make(map[net.Conn]struct{}), + } + ps.Server = &http.Server{ + Handler: proxyHandler(ch, ca, ps), + ReadHeaderTimeout: proxyReadHeaderTimeout, + IdleTimeout: proxyIdleTimeout, } return ps, nil @@ -33,11 +57,53 @@ func Destroy(ps *model.ProxyServer, ch chan<- model.Event) { defer cancel() if ps != nil && ps.Server != nil { + closeTrackedConnections(ps) _ = ps.Server.Shutdown(ctx) ch <- model.Event{ Time: time.Now().Local(), - Type: model.EventTypeProxyStarted, + Type: model.EventTypeProxyStopped, Body: "proxy server was destroyed", } } } + +func trackConnection(ps *model.ProxyServer, conn net.Conn) { + if ps == nil || conn == nil { + return + } + + ps.ConnMu.Lock() + defer ps.ConnMu.Unlock() + ps.Conns[conn] = struct{}{} +} + +func untrackConnection(ps *model.ProxyServer, conn net.Conn) { + if ps == nil || conn == nil { + return + } + + ps.ConnMu.Lock() + defer ps.ConnMu.Unlock() + delete(ps.Conns, conn) +} + +func closeTrackedConnections(ps *model.ProxyServer) { + if ps == nil { + return + } + + // Get all of the connections while claiming the mutex. + // Then close the mutex to allow access to the server object quicker. + // Then a loop can run to close the connections, without needing access + // to the server's mutex. + ps.ConnMu.Lock() + conns := make([]net.Conn, 0, len(ps.Conns)) + for conn := range ps.Conns { + conns = append(conns, conn) + } + ps.ConnMu.Unlock() + + for _, conn := range conns { + _ = conn.Close() + } +} diff --git a/internal/proxy/utils.go b/internal/proxy/utils.go new file mode 100644 index 0000000..376de02 --- /dev/null +++ b/internal/proxy/utils.go @@ -0,0 +1,111 @@ +package proxy + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "sort" + "strings" + + "github.com/google/uuid" +) + +var validContentTypes = []string{ + "application/graphql", + "application/javascript", + "application/json", + "application/x-www-form-urlencoded", + "application/xml", + "+json", + "+xml", +} + +func canDisplayContent(contentType string) bool { + if contentType == "" { + return false + } + + contentType = strings.ToLower(contentType) + if strings.HasPrefix(contentType, "text/") { + return true + } + + for _, t := range validContentTypes { + if strings.Contains(contentType, t) { + return true + } + } + + return false +} + +// NOTE: Currently unused, will be reference for the future header rendering +func formatHeaders(headers http.Header) string { + if len(headers) == 0 { + return "" + } + + parts := make([]string, 0, len(headers)) + for key, values := range headers { + parts = append(parts, fmt.Sprintf("%s=%q", key, strings.Join(values, ","))) + } + sort.Strings(parts) + + return strings.Join(parts, ", ") +} + +func getEndOfUUID(id uuid.UUID) string { + return id.String()[24:] +} + +// BUG: Not sure if this actually works, seems to favor the 502 +func statusFromUpstreamError(req *http.Request, resp *http.Response, err error) int { + if resp != nil { + return resp.StatusCode + } + + if errors.Is(req.Context().Err(), context.Canceled) { + return http.StatusBadGateway + } + + if errors.Is(err, context.DeadlineExceeded) { + return http.StatusGatewayTimeout + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return http.StatusGatewayTimeout + } + + return http.StatusBadGateway +} + +func newUpstreamTransport() http.RoundTripper { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Proxy = nil + return transport +} + +func writePlainHTTPError(w *bufio.Writer, status int) error { + resp := &http.Response{ + StatusCode: status, + Status: fmt.Sprintf("%d %s", status, http.StatusText(status)), + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(http.StatusText(status))), + ContentLength: int64(len(http.StatusText(status))), + Close: false, + } + resp.Header.Set("Content-Type", "text/plain; charset=utf-8") + resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(http.StatusText(status)))) + if err := resp.Write(w); err != nil { + return err + } + return w.Flush() +} diff --git a/internal/tui/panes.go b/internal/tui/panes.go index 3851419..3fa3c06 100644 --- a/internal/tui/panes.go +++ b/internal/tui/panes.go @@ -23,7 +23,7 @@ func (m Model) renderStatusBar(w int) string { avg := int(msSum) / max(1, len(m.requests)) left := fmt.Sprintf(" tap %3d reqs | %d err | avg %dms", len(m.requests), errCount, avg) - right := "j/k nav / search tab panel e events o output r/^r restart q quit " + right := "j/k nav / search tab panel e events o output ^r restart q quit " spaceSize := max(w-(len(left)+len(right)), 0) space := strings.Repeat(" ", spaceSize) @@ -57,7 +57,7 @@ func (m Model) renderRequestPane(w, h int) []string { left := fmt.Sprintf( " %-7s %-24s %s", strings.ToUpper(req.Method), - req.Host, + truncate(req.Host, 24), req.URL, ) right := fmt.Sprintf( @@ -98,6 +98,7 @@ func (m Model) renderDetailsPane(w, h int) []string { for y := range lines { lines[y] = m.theme.Text.Render(strings.Repeat(" ", w)) } + return lines } diff --git a/internal/tui/update.go b/internal/tui/update.go index ef57001..dd276b0 100644 --- a/internal/tui/update.go +++ b/internal/tui/update.go @@ -27,7 +27,7 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { switch msg.String() { case "ctrl+c", "q": return m, tea.Quit - case "r", tea.KeyCtrlR.String(): + case tea.KeyCtrlR.String(): if m.restarting { return m, nil } @@ -95,6 +95,10 @@ func (m *Model) applyMessage(msg model.Event) { } func (m *Model) createRequest(req model.Request) { + if req.Method == "CONNECT" { + return + } + m.requests = append(m.requests, req) // If we passed the max, delete the first one