Next step might actually be the TUI! Or maybe the raw proxy, it would be nice to be able to just run the proxy.
215 lines
5.2 KiB
Go
215 lines
5.2 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"termtap.dev/internal/model"
|
|
)
|
|
|
|
// NOTE: Much of this code is AI generated, and is not expected to make it into production
|
|
|
|
const maxPreviewBytes = 1024
|
|
|
|
func proxyHandler(ch chan<- model.Message) http.Handler {
|
|
transport := http.DefaultTransport
|
|
|
|
// TODO: This should be wired into the main channel, but that will require a model package
|
|
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
if req.Method == http.MethodConnect {
|
|
http.Error(w, "CONNECT is not supported yet", http.StatusNotImplemented)
|
|
ch <- model.Message{
|
|
Type: model.MessageTypeWarn,
|
|
Body: fmt.Sprintf("CONNECT is not supported: %s", req.Host),
|
|
}
|
|
return
|
|
}
|
|
|
|
if req.URL.Scheme == "" || req.URL.Host == "" {
|
|
http.Error(w, "request must use absolute-form URLs through the proxy", http.StatusBadRequest)
|
|
ch <- model.Message{
|
|
Type: model.MessageTypeWarn,
|
|
Body: fmt.Sprintf("rejected non-proxy request %s %s", req.Method, req.URL.String()),
|
|
}
|
|
return
|
|
}
|
|
|
|
start := time.Now()
|
|
|
|
request := model.Request{
|
|
ID: uuid.New(),
|
|
ResponseData: []byte{},
|
|
RequestData: []byte{},
|
|
URL: "",
|
|
Status: -1,
|
|
Method: "",
|
|
Duration: 0,
|
|
Pending: true,
|
|
Failed: false,
|
|
StartTime: start,
|
|
}
|
|
|
|
requestPreview, err := readAndRestoreBody(&req.Body)
|
|
if err != nil {
|
|
ch <- model.Message{
|
|
Type: model.MessageTypeWarn,
|
|
Body: fmt.Sprintf("(%s) failed to read request body", request.ID),
|
|
Request: request,
|
|
}
|
|
} else {
|
|
request.RequestData = []byte(requestPreview)
|
|
}
|
|
|
|
outReq := req.Clone(req.Context())
|
|
outReq.RequestURI = ""
|
|
|
|
request.URL = outReq.URL.Path
|
|
request.QueryString = outReq.URL.RawQuery
|
|
request.QueryMap = outReq.URL.Query()
|
|
request.Host = outReq.Host
|
|
request.Method = outReq.Method
|
|
request.RequestHeaders = outReq.Header
|
|
request.RawURL = outReq.URL.String()
|
|
|
|
ch <- model.Message{
|
|
Type: model.MessageTypeRequestStarted,
|
|
Body: fmt.Sprintf("-> %+v", request),
|
|
Request: request,
|
|
}
|
|
|
|
resp, err := transport.RoundTrip(outReq)
|
|
if err != nil {
|
|
status := statusFromUpstreamError(req, resp, err)
|
|
|
|
http.Error(w, http.StatusText(status), status)
|
|
request.Pending = false
|
|
request.Failed = true
|
|
request.Duration = time.Since(start).Round(time.Microsecond)
|
|
request.Status = status
|
|
|
|
ch <- model.Message{
|
|
Type: model.MessageTypeRequestFailed,
|
|
Body: fmt.Sprintf("upstream error for %s %s: %v", outReq.Method, outReq.URL.String(), err),
|
|
Request: request,
|
|
}
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
responsePreview, err := readAndRestoreBody(&resp.Body)
|
|
if err != nil {
|
|
ch <- model.Message{
|
|
Type: model.MessageTypeWarn,
|
|
Body: fmt.Sprintf("(%s) failed to read response body", request.ID),
|
|
Request: request,
|
|
}
|
|
} else {
|
|
request.ResponseData = []byte(responsePreview)
|
|
}
|
|
|
|
copyHeader(w.Header(), resp.Header)
|
|
w.WriteHeader(resp.StatusCode)
|
|
if _, err := io.Copy(w, resp.Body); err != nil {
|
|
request.Pending = false
|
|
request.Failed = true
|
|
request.Duration = time.Since(start).Round(time.Microsecond)
|
|
request.Status = resp.StatusCode
|
|
|
|
ch <- model.Message{
|
|
Type: model.MessageTypeRequestFailed,
|
|
Body: fmt.Sprintf("write response body %s %s: %v", outReq.Method, outReq.URL.String(), err),
|
|
}
|
|
return
|
|
}
|
|
|
|
request.Duration = time.Since(start).Round(time.Microsecond)
|
|
request.Status = resp.StatusCode
|
|
request.ResponseHeaders = resp.Header
|
|
request.Pending = false
|
|
|
|
ch <- model.Message{
|
|
Type: model.MessageTypeRequestFinished,
|
|
Body: fmt.Sprintf("<- %+v %s", request, formatHeaders(resp.Request.Header)),
|
|
Request: request,
|
|
}
|
|
})
|
|
}
|
|
|
|
func copyHeader(dst, src http.Header) {
|
|
for key, values := range src {
|
|
for _, value := range values {
|
|
dst.Add(key, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
func readAndRestoreBody(body *io.ReadCloser) (string, error) {
|
|
if body == nil || *body == nil {
|
|
return "", nil
|
|
}
|
|
|
|
payload, err := io.ReadAll(*body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
*body = io.NopCloser(bytes.NewReader(payload))
|
|
|
|
preview := payload
|
|
if len(preview) > maxPreviewBytes {
|
|
preview = preview[:maxPreviewBytes]
|
|
}
|
|
|
|
text := strings.ReplaceAll(string(preview), "\n", "\\n")
|
|
if len(payload) > maxPreviewBytes {
|
|
text += "..."
|
|
}
|
|
|
|
return text, nil
|
|
}
|
|
|
|
func formatHeaders(headers http.Header) string {
|
|
if len(headers) == 0 {
|
|
return "<none>"
|
|
}
|
|
|
|
parts := make([]string, 0, len(headers))
|
|
for key, values := range headers {
|
|
parts = append(parts, fmt.Sprintf("%s=%q", key, strings.Join(values, ",")))
|
|
}
|
|
sort.Strings(parts)
|
|
|
|
return strings.Join(parts, ", ")
|
|
}
|
|
|
|
// BUG: Not sure if this actually works, seems to favor the 502
|
|
func statusFromUpstreamError(req *http.Request, resp *http.Response, err error) int {
|
|
if resp != nil {
|
|
return resp.StatusCode
|
|
}
|
|
|
|
if errors.Is(req.Context().Err(), context.Canceled) {
|
|
return http.StatusBadGateway
|
|
}
|
|
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
return http.StatusGatewayTimeout
|
|
}
|
|
|
|
var netErr net.Error
|
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
|
return http.StatusGatewayTimeout
|
|
}
|
|
|
|
return http.StatusBadGateway
|
|
}
|