From 30ead5d22ce6af07bed3ffbf4105f3f0422623bf Mon Sep 17 00:00:00 2001 From: Hayden Hargreaves Date: Thu, 23 Apr 2026 20:27:32 -0700 Subject: [PATCH] fix: added a bit more headers Still not perfect, but they're better. --- internal/proxy/headers.go | 48 ++++++++++++++++++++++++-- internal/proxy/headers_test.go | 61 +++++++++++++++++++++++++++++++++ internal/proxy/requests.go | 6 ++-- internal/proxy/requests_test.go | 7 ++-- 4 files changed, 116 insertions(+), 6 deletions(-) diff --git a/internal/proxy/headers.go b/internal/proxy/headers.go index 5a23a07..781925c 100644 --- a/internal/proxy/headers.go +++ b/internal/proxy/headers.go @@ -1,6 +1,7 @@ package proxy import ( + "fmt" "net/http" "strings" ) @@ -26,8 +27,7 @@ var hopByHopHeaders = []string{ } // Remove headers that are only required for client<->proxy and proxy<->server communication. -// Otherwise known as hop-by-hop headers. We do not want to show these to users since they are -// used only for internal functioning for the proxy server. +// Otherwise known as hop-by-hop headers. func stripHopByHopHeaders(headers http.Header) { if headers == nil { return @@ -45,6 +45,50 @@ func stripHopByHopHeaders(headers http.Header) { } } +func captureRequestHeaders(req *http.Request) http.Header { + if req == nil { + return http.Header{} + } + + headers := req.Header.Clone() + + host := strings.TrimSpace(req.Host) + if host == "" && req.URL != nil { + host = strings.TrimSpace(req.URL.Host) + } + if host != "" { + headers.Set("Host", host) + } + + if req.ContentLength > 0 && headers.Get("Content-Length") == "" { + headers.Set("Content-Length", fmt.Sprintf("%d", req.ContentLength)) + } + + if len(req.TransferEncoding) > 0 && headers.Get("Transfer-Encoding") == "" { + headers.Set("Transfer-Encoding", strings.Join(req.TransferEncoding, ", ")) + } + + return headers +} + +func captureResponseHeaders(resp *http.Response) http.Header { + if resp == nil { + return http.Header{} + } + + headers := resp.Header.Clone() + + if resp.ContentLength > 0 && headers.Get("Content-Length") == "" { + headers.Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength)) + } + + if len(resp.TransferEncoding) > 0 && headers.Get("Transfer-Encoding") == "" { + headers.Set("Transfer-Encoding", strings.Join(resp.TransferEncoding, ", ")) + } + + return headers +} + // Return a new set of headers that has sensitive headers redacted. // // TODO: Maybe use '***' length of header? diff --git a/internal/proxy/headers_test.go b/internal/proxy/headers_test.go index fd7a85f..bc72118 100644 --- a/internal/proxy/headers_test.go +++ b/internal/proxy/headers_test.go @@ -2,6 +2,7 @@ package proxy import ( "net/http" + "net/url" "reflect" "testing" ) @@ -155,3 +156,63 @@ func TestCopyHeaders(t *testing.T) { t.Fatalf("copyHeaders() dest = %#v, want %#v", dest, want) } } + +func TestCaptureRequestHeaders(t *testing.T) { + t.Parallel() + + u, err := url.Parse("https://url-host.example/path") + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + + req := &http.Request{ + URL: u, + Host: "host-header.example", + ContentLength: 42, + TransferEncoding: []string{"chunked"}, + Header: http.Header{ + "Content-Type": {"application/json"}, + }, + } + + got := captureRequestHeaders(req) + + if got.Get("Host") != "host-header.example" { + t.Fatalf("Host = %q, want host-header.example", got.Get("Host")) + } + if got.Get("Content-Length") != "42" { + t.Fatalf("Content-Length = %q, want 42", got.Get("Content-Length")) + } + if got.Get("Transfer-Encoding") != "chunked" { + t.Fatalf("Transfer-Encoding = %q, want chunked", got.Get("Transfer-Encoding")) + } + + got.Set("Content-Type", "modified") + if req.Header.Get("Content-Type") == "modified" { + t.Fatal("captureRequestHeaders() should clone header map") + } +} + +func TestCaptureResponseHeaders(t *testing.T) { + t.Parallel() + + resp := &http.Response{ + Header: http.Header{"Content-Type": {"application/json"}}, + ContentLength: 128, + TransferEncoding: []string{"chunked"}, + } + + got := captureResponseHeaders(resp) + + if got.Get("Content-Length") != "128" { + t.Fatalf("Content-Length = %q, want 128", got.Get("Content-Length")) + } + if got.Get("Transfer-Encoding") != "chunked" { + t.Fatalf("Transfer-Encoding = %q, want chunked", got.Get("Transfer-Encoding")) + } + + got.Set("Content-Type", "modified") + if resp.Header.Get("Content-Type") == "modified" { + t.Fatal("captureResponseHeaders() should clone header map") + } +} diff --git a/internal/proxy/requests.go b/internal/proxy/requests.go index e21f0c5..2ea8d11 100644 --- a/internal/proxy/requests.go +++ b/internal/proxy/requests.go @@ -37,6 +37,7 @@ func roundTripCapturedRequest(req *http.Request, transport http.RoundTripper, ch outReq.Host = defaultHost } } + capturedRequestHeaders := captureRequestHeaders(outReq) stripHopByHopHeaders(outReq.Header) requestPreview := newBodyPreview(outReq.Header.Get("Content-Type")) if outReq.Body != nil { @@ -48,7 +49,7 @@ func roundTripCapturedRequest(req *http.Request, transport http.RoundTripper, ch request.QueryMap = outReq.URL.Query() request.Host = outReq.Host request.Method = outReq.Method - request.RequestHeaders = redactHeaders(outReq.Header) + request.RequestHeaders = redactHeaders(capturedRequestHeaders) request.RawURL = outReq.URL.String() if request.RawURL == "" { request.RawURL = outReq.Host + outReq.URL.RequestURI() @@ -62,13 +63,14 @@ func roundTripCapturedRequest(req *http.Request, transport http.RoundTripper, ch return resp, request, nil, err } + capturedResponseHeaders := captureResponseHeaders(resp) stripHopByHopHeaders(resp.Header) responsePreview := newBodyPreview(resp.Header.Get("Content-Type")) if resp.Body != nil { resp.Body = &previewReadCloser{ReadCloser: resp.Body, preview: responsePreview} } - request.ResponseHeaders = redactHeaders(resp.Header) + request.ResponseHeaders = redactHeaders(capturedResponseHeaders) return resp, request, responsePreview, nil } diff --git a/internal/proxy/requests_test.go b/internal/proxy/requests_test.go index 8b7c21b..57f0442 100644 --- a/internal/proxy/requests_test.go +++ b/internal/proxy/requests_test.go @@ -95,11 +95,14 @@ func TestRoundTripCapturedRequest_Success(t *testing.T) { if got := captured.RequestHeaders.Get("Authorization"); got != "[REDACTED]" { t.Fatalf("Authorization header = %q, want [REDACTED]", got) } + if got := captured.RequestHeaders.Get("Host"); got != "example.com" { + t.Fatalf("Host header = %q, want example.com", got) + } if got := captured.ResponseHeaders.Get("Set-Cookie"); got != "[REDACTED]" { t.Fatalf("Set-Cookie header = %q, want [REDACTED]", got) } - if got := captured.ResponseHeaders.Get("Connection"); got != "" { - t.Fatalf("Connection should be stripped from response headers, got %q", got) + if got := captured.ResponseHeaders.Get("Connection"); got != "close" { + t.Fatalf("captured Connection header = %q, want close", got) } events := drainEvents(t, ch, 1, time.Second)