fix: added a bit more headers

Still not perfect, but they're better.
This commit is contained in:
Hayden Hargreaves 2026-04-23 20:27:32 -07:00
parent 51d526c2fe
commit 30ead5d22c
4 changed files with 116 additions and 6 deletions

View File

@ -1,6 +1,7 @@
package proxy
import (
"fmt"
"net/http"
"strings"
)
@ -26,8 +27,7 @@ var hopByHopHeaders = []string{
}
// Remove headers that are only required for client<->proxy and proxy<->server communication.
// Otherwise known as hop-by-hop headers. We do not want to show these to users since they are
// used only for internal functioning for the proxy server.
// Otherwise known as hop-by-hop headers.
func stripHopByHopHeaders(headers http.Header) {
if headers == nil {
return
@ -45,6 +45,50 @@ func stripHopByHopHeaders(headers http.Header) {
}
}
func captureRequestHeaders(req *http.Request) http.Header {
if req == nil {
return http.Header{}
}
headers := req.Header.Clone()
host := strings.TrimSpace(req.Host)
if host == "" && req.URL != nil {
host = strings.TrimSpace(req.URL.Host)
}
if host != "" {
headers.Set("Host", host)
}
if req.ContentLength > 0 && headers.Get("Content-Length") == "" {
headers.Set("Content-Length", fmt.Sprintf("%d", req.ContentLength))
}
if len(req.TransferEncoding) > 0 && headers.Get("Transfer-Encoding") == "" {
headers.Set("Transfer-Encoding", strings.Join(req.TransferEncoding, ", "))
}
return headers
}
func captureResponseHeaders(resp *http.Response) http.Header {
if resp == nil {
return http.Header{}
}
headers := resp.Header.Clone()
if resp.ContentLength > 0 && headers.Get("Content-Length") == "" {
headers.Set("Content-Length", fmt.Sprintf("%d", resp.ContentLength))
}
if len(resp.TransferEncoding) > 0 && headers.Get("Transfer-Encoding") == "" {
headers.Set("Transfer-Encoding", strings.Join(resp.TransferEncoding, ", "))
}
return headers
}
// Return a new set of headers that has sensitive headers redacted.
//
// TODO: Maybe use '***' length of header?

View File

@ -2,6 +2,7 @@ package proxy
import (
"net/http"
"net/url"
"reflect"
"testing"
)
@ -155,3 +156,63 @@ func TestCopyHeaders(t *testing.T) {
t.Fatalf("copyHeaders() dest = %#v, want %#v", dest, want)
}
}
func TestCaptureRequestHeaders(t *testing.T) {
t.Parallel()
u, err := url.Parse("https://url-host.example/path")
if err != nil {
t.Fatalf("url.Parse() error = %v", err)
}
req := &http.Request{
URL: u,
Host: "host-header.example",
ContentLength: 42,
TransferEncoding: []string{"chunked"},
Header: http.Header{
"Content-Type": {"application/json"},
},
}
got := captureRequestHeaders(req)
if got.Get("Host") != "host-header.example" {
t.Fatalf("Host = %q, want host-header.example", got.Get("Host"))
}
if got.Get("Content-Length") != "42" {
t.Fatalf("Content-Length = %q, want 42", got.Get("Content-Length"))
}
if got.Get("Transfer-Encoding") != "chunked" {
t.Fatalf("Transfer-Encoding = %q, want chunked", got.Get("Transfer-Encoding"))
}
got.Set("Content-Type", "modified")
if req.Header.Get("Content-Type") == "modified" {
t.Fatal("captureRequestHeaders() should clone header map")
}
}
func TestCaptureResponseHeaders(t *testing.T) {
t.Parallel()
resp := &http.Response{
Header: http.Header{"Content-Type": {"application/json"}},
ContentLength: 128,
TransferEncoding: []string{"chunked"},
}
got := captureResponseHeaders(resp)
if got.Get("Content-Length") != "128" {
t.Fatalf("Content-Length = %q, want 128", got.Get("Content-Length"))
}
if got.Get("Transfer-Encoding") != "chunked" {
t.Fatalf("Transfer-Encoding = %q, want chunked", got.Get("Transfer-Encoding"))
}
got.Set("Content-Type", "modified")
if resp.Header.Get("Content-Type") == "modified" {
t.Fatal("captureResponseHeaders() should clone header map")
}
}

View File

@ -37,6 +37,7 @@ func roundTripCapturedRequest(req *http.Request, transport http.RoundTripper, ch
outReq.Host = defaultHost
}
}
capturedRequestHeaders := captureRequestHeaders(outReq)
stripHopByHopHeaders(outReq.Header)
requestPreview := newBodyPreview(outReq.Header.Get("Content-Type"))
if outReq.Body != nil {
@ -48,7 +49,7 @@ func roundTripCapturedRequest(req *http.Request, transport http.RoundTripper, ch
request.QueryMap = outReq.URL.Query()
request.Host = outReq.Host
request.Method = outReq.Method
request.RequestHeaders = redactHeaders(outReq.Header)
request.RequestHeaders = redactHeaders(capturedRequestHeaders)
request.RawURL = outReq.URL.String()
if request.RawURL == "" {
request.RawURL = outReq.Host + outReq.URL.RequestURI()
@ -62,13 +63,14 @@ func roundTripCapturedRequest(req *http.Request, transport http.RoundTripper, ch
return resp, request, nil, err
}
capturedResponseHeaders := captureResponseHeaders(resp)
stripHopByHopHeaders(resp.Header)
responsePreview := newBodyPreview(resp.Header.Get("Content-Type"))
if resp.Body != nil {
resp.Body = &previewReadCloser{ReadCloser: resp.Body, preview: responsePreview}
}
request.ResponseHeaders = redactHeaders(resp.Header)
request.ResponseHeaders = redactHeaders(capturedResponseHeaders)
return resp, request, responsePreview, nil
}

View File

@ -95,11 +95,14 @@ func TestRoundTripCapturedRequest_Success(t *testing.T) {
if got := captured.RequestHeaders.Get("Authorization"); got != "[REDACTED]" {
t.Fatalf("Authorization header = %q, want [REDACTED]", got)
}
if got := captured.RequestHeaders.Get("Host"); got != "example.com" {
t.Fatalf("Host header = %q, want example.com", got)
}
if got := captured.ResponseHeaders.Get("Set-Cookie"); got != "[REDACTED]" {
t.Fatalf("Set-Cookie header = %q, want [REDACTED]", got)
}
if got := captured.ResponseHeaders.Get("Connection"); got != "" {
t.Fatalf("Connection should be stripped from response headers, got %q", got)
if got := captured.ResponseHeaders.Get("Connection"); got != "close" {
t.Fatalf("captured Connection header = %q, want close", got)
}
events := drainEvents(t, ch, 1, time.Second)