279 lines
8.0 KiB
Go
279 lines
8.0 KiB
Go
package proxy
|
|
|
|
import (
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"termtap.dev/internal/model"
|
|
)
|
|
|
|
type mockTransport struct {
|
|
fn func(*http.Request) (*http.Response, error)
|
|
}
|
|
|
|
func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return m.fn(req)
|
|
}
|
|
|
|
func TestRoundTripCapturedRequest_Success(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ch := make(chan model.Event, 8)
|
|
|
|
reqURL, err := url.Parse("http://example.com/path?q=1")
|
|
if err != nil {
|
|
t.Fatalf("url.Parse() error = %v", err)
|
|
}
|
|
req := &http.Request{
|
|
Method: http.MethodPost,
|
|
URL: reqURL,
|
|
Host: "example.com",
|
|
Header: http.Header{
|
|
"Connection": {"X-Hop"},
|
|
"X-Hop": {"drop"},
|
|
"Authorization": {"Bearer token"},
|
|
"Content-Type": {"text/plain"},
|
|
},
|
|
Body: io.NopCloser(strings.NewReader("req\nbody")),
|
|
}
|
|
|
|
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
|
|
if got := outReq.Header.Get("Connection"); got != "" {
|
|
t.Fatalf("Connection header should be stripped, got %q", got)
|
|
}
|
|
if got := outReq.Header.Get("X-Hop"); got != "" {
|
|
t.Fatalf("header listed in Connection should be stripped, got %q", got)
|
|
}
|
|
|
|
data, readErr := io.ReadAll(outReq.Body)
|
|
if readErr != nil {
|
|
t.Fatalf("ReadAll(outReq.Body) error = %v", readErr)
|
|
}
|
|
if got, want := string(data), "req\nbody"; got != want {
|
|
t.Fatalf("request body = %q, want %q", got, want)
|
|
}
|
|
|
|
return &http.Response{
|
|
StatusCode: http.StatusCreated,
|
|
Header: http.Header{
|
|
"Set-Cookie": {"session=top-secret"},
|
|
"Connection": {"close"},
|
|
"Content-Type": {"text/plain"},
|
|
},
|
|
Body: io.NopCloser(strings.NewReader("resp\nbody")),
|
|
}, nil
|
|
}}
|
|
|
|
resp, captured, responsePreview, err := roundTripCapturedRequest(req, transport, ch, "", false)
|
|
if err != nil {
|
|
t.Fatalf("roundTripCapturedRequest() error = %v", err)
|
|
}
|
|
if resp == nil {
|
|
t.Fatal("roundTripCapturedRequest() returned nil response")
|
|
}
|
|
if responsePreview == nil {
|
|
t.Fatal("roundTripCapturedRequest() returned nil response preview")
|
|
}
|
|
|
|
if _, readErr := io.ReadAll(resp.Body); readErr != nil {
|
|
t.Fatalf("ReadAll(resp.Body) error = %v", readErr)
|
|
}
|
|
_ = resp.Body.Close()
|
|
|
|
if got, want := string(captured.RequestData), `req\nbody`; got != want {
|
|
t.Fatalf("captured.RequestData = %q, want %q", got, want)
|
|
}
|
|
if got, want := string(responsePreview.Preview()), `resp\nbody`; got != want {
|
|
t.Fatalf("responsePreview = %q, want %q", got, want)
|
|
}
|
|
if got := captured.RequestHeaders.Get("Authorization"); got != "[REDACTED]" {
|
|
t.Fatalf("Authorization header = %q, want [REDACTED]", got)
|
|
}
|
|
if got := captured.RequestHeaders.Get("Host"); got != "example.com" {
|
|
t.Fatalf("Host header = %q, want example.com", got)
|
|
}
|
|
if got := captured.ResponseHeaders.Get("Set-Cookie"); got != "[REDACTED]" {
|
|
t.Fatalf("Set-Cookie header = %q, want [REDACTED]", got)
|
|
}
|
|
if got := captured.ResponseHeaders.Get("Connection"); got != "close" {
|
|
t.Fatalf("captured Connection header = %q, want close", got)
|
|
}
|
|
|
|
events := drainEvents(t, ch, 1, time.Second)
|
|
if events[0].Type != model.EventTypeRequestStarted {
|
|
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
|
|
}
|
|
}
|
|
|
|
func TestRoundTripCapturedRequest_ErrorPath(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ch := make(chan model.Event, 8)
|
|
|
|
reqURL, err := url.Parse("http://example.com/fail")
|
|
if err != nil {
|
|
t.Fatalf("url.Parse() error = %v", err)
|
|
}
|
|
req := &http.Request{
|
|
Method: http.MethodPost,
|
|
URL: reqURL,
|
|
Host: "example.com",
|
|
Header: http.Header{"Content-Type": {"text/plain"}},
|
|
Body: io.NopCloser(strings.NewReader("boom\nbody")),
|
|
}
|
|
|
|
wantErr := errors.New("upstream failed")
|
|
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
|
|
_, _ = io.ReadAll(outReq.Body)
|
|
return nil, wantErr
|
|
}}
|
|
|
|
resp, captured, responsePreview, gotErr := roundTripCapturedRequest(req, transport, ch, "", false)
|
|
if !errors.Is(gotErr, wantErr) {
|
|
t.Fatalf("error = %v, want %v", gotErr, wantErr)
|
|
}
|
|
if resp != nil {
|
|
t.Fatalf("response = %#v, want nil", resp)
|
|
}
|
|
if responsePreview != nil {
|
|
t.Fatalf("responsePreview = %#v, want nil", responsePreview)
|
|
}
|
|
if got, want := string(captured.RequestData), `boom\nbody`; got != want {
|
|
t.Fatalf("captured.RequestData = %q, want %q", got, want)
|
|
}
|
|
|
|
events := drainEvents(t, ch, 1, time.Second)
|
|
if events[0].Type != model.EventTypeRequestStarted {
|
|
t.Fatalf("event type = %s, want %s", events[0].Type, model.EventTypeRequestStarted)
|
|
}
|
|
}
|
|
|
|
func TestRoundTripCapturedRequest_InterceptedTLSDefaults(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ch := make(chan model.Event, 8)
|
|
|
|
u, err := url.Parse("/secure?p=1")
|
|
if err != nil {
|
|
t.Fatalf("url.Parse() error = %v", err)
|
|
}
|
|
req := &http.Request{
|
|
Method: http.MethodGet,
|
|
URL: u,
|
|
Header: http.Header{},
|
|
}
|
|
|
|
const defaultHost = "api.example.com:443"
|
|
transport := &mockTransport{fn: func(outReq *http.Request) (*http.Response, error) {
|
|
if got := outReq.URL.Scheme; got != "https" {
|
|
t.Fatalf("URL.Scheme = %q, want https", got)
|
|
}
|
|
if got := outReq.URL.Host; got != defaultHost {
|
|
t.Fatalf("URL.Host = %q, want %q", got, defaultHost)
|
|
}
|
|
if got := outReq.Host; got != defaultHost {
|
|
t.Fatalf("Host = %q, want %q", got, defaultHost)
|
|
}
|
|
|
|
return &http.Response{
|
|
StatusCode: http.StatusNoContent,
|
|
Header: http.Header{"Content-Type": {"text/plain"}},
|
|
Body: io.NopCloser(strings.NewReader("")),
|
|
}, nil
|
|
}}
|
|
|
|
_, _, _, gotErr := roundTripCapturedRequest(req, transport, ch, defaultHost, true)
|
|
if gotErr != nil {
|
|
t.Fatalf("roundTripCapturedRequest() error = %v", gotErr)
|
|
}
|
|
}
|
|
|
|
func TestNewConnectRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
now := time.Now().Add(-time.Second)
|
|
req := &http.Request{Method: http.MethodConnect, Host: "example.com:443"}
|
|
|
|
got := newConnectRequest(req, now)
|
|
|
|
if got.Method != http.MethodConnect {
|
|
t.Fatalf("Method = %q, want CONNECT", got.Method)
|
|
}
|
|
if got.Host != req.Host || got.URL != req.Host || got.RawURL != req.Host {
|
|
t.Fatalf("connect request host/url/raw mismatch: %#v", got)
|
|
}
|
|
if !got.Pending || got.Failed {
|
|
t.Fatalf("Pending/Failed = (%v,%v), want (true,false)", got.Pending, got.Failed)
|
|
}
|
|
if got.Status != -1 {
|
|
t.Fatalf("Status = %d, want -1", got.Status)
|
|
}
|
|
if got.StartTime != now {
|
|
t.Fatalf("StartTime = %v, want %v", got.StartTime, now)
|
|
}
|
|
if got.ID == uuid.Nil {
|
|
t.Fatal("ID must be non-zero UUID")
|
|
}
|
|
}
|
|
|
|
func TestStartFinishFailRequestEvents(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ch := make(chan model.Event, 4)
|
|
req := model.Request{
|
|
ID: uuid.New(),
|
|
Method: http.MethodGet,
|
|
RawURL: "http://example.com/a",
|
|
StartTime: time.Now().Add(-3 * time.Millisecond),
|
|
Pending: true,
|
|
}
|
|
|
|
startRequest(ch, req)
|
|
events := drainEvents(t, ch, 1, time.Second)
|
|
startEv := events[0]
|
|
if startEv.Type != model.EventTypeRequestStarted {
|
|
t.Fatalf("start event type = %s, want %s", startEv.Type, model.EventTypeRequestStarted)
|
|
}
|
|
if startEv.Request.Pending != true {
|
|
t.Fatalf("start request pending = %v, want true", startEv.Request.Pending)
|
|
}
|
|
|
|
finishRequest(ch, req, http.StatusOK)
|
|
events = drainEvents(t, ch, 1, time.Second)
|
|
finishEv := events[0]
|
|
if finishEv.Type != model.EventTypeRequestFinished {
|
|
t.Fatalf("finish event type = %s, want %s", finishEv.Type, model.EventTypeRequestFinished)
|
|
}
|
|
if finishEv.Request.Pending {
|
|
t.Fatal("finished request should not be pending")
|
|
}
|
|
if finishEv.Request.Failed {
|
|
t.Fatal("finished request should not be failed")
|
|
}
|
|
if finishEv.Request.Status != http.StatusOK {
|
|
t.Fatalf("finished status = %d, want %d", finishEv.Request.Status, http.StatusOK)
|
|
}
|
|
|
|
failRequest(ch, req, http.StatusBadGateway, "upstream error")
|
|
events = drainEvents(t, ch, 1, time.Second)
|
|
failEv := events[0]
|
|
if failEv.Type != model.EventTypeRequestFailed {
|
|
t.Fatalf("fail event type = %s, want %s", failEv.Type, model.EventTypeRequestFailed)
|
|
}
|
|
if failEv.Request.Pending {
|
|
t.Fatal("failed request should not be pending")
|
|
}
|
|
if !failEv.Request.Failed {
|
|
t.Fatal("failed request should be marked failed")
|
|
}
|
|
if failEv.Request.Status != http.StatusBadGateway {
|
|
t.Fatalf("failed status = %d, want %d", failEv.Request.Status, http.StatusBadGateway)
|
|
}
|
|
}
|