feat: FINALLY got HTTPS working :)
Going to work on UI now.
This commit is contained in:
parent
3e09987c2d
commit
365fb43eca
@ -11,6 +11,7 @@ permission:
|
||||
"git log*": allow
|
||||
"git *": allow
|
||||
"grep *": allow
|
||||
"go *": allow
|
||||
webfetch: deny
|
||||
color: "#e01da6"
|
||||
---
|
||||
|
||||
42
doc/event-pressure-notes.md
Normal file
42
doc/event-pressure-notes.md
Normal file
@ -0,0 +1,42 @@
|
||||
# Event Pressure Notes
|
||||
|
||||
This is a quick note on potential event-channel pressure in the current proxy architecture.
|
||||
|
||||
## Why this matters
|
||||
|
||||
Proxy request handling currently emits events synchronously into a shared channel.
|
||||
If producers are faster than the consumer (TUI/event loop), the channel can fill and block producers.
|
||||
When that happens, request handling can stall even if upstream/downstream network paths are healthy.
|
||||
|
||||
## Where pressure comes from
|
||||
|
||||
- Every request can produce multiple lifecycle events (`started`, `finished`, `failed`, warnings).
|
||||
- CONNECT + MITM flow can emit both tunnel-level and inner-request events.
|
||||
- Bursty traffic (many small requests, retries, connection churn) amplifies event rate quickly.
|
||||
|
||||
## User-visible symptoms
|
||||
|
||||
- Request latency spikes that do not match upstream timings.
|
||||
- Intermittent pauses during high traffic.
|
||||
- Shutdown/restart feeling delayed when many events are in flight.
|
||||
|
||||
## Current risk profile
|
||||
|
||||
- Channel buffer size helps absorb bursts, but only up to a point.
|
||||
- Backpressure is currently coupled to request path execution, so stalls propagate into proxy behavior.
|
||||
|
||||
## Mitigation options
|
||||
|
||||
1. Introduce non-blocking event enqueue for low-priority events.
|
||||
- Keep critical events blocking (fatal/start/stop), but drop or coalesce high-volume request events under load.
|
||||
2. Add an internal event relay.
|
||||
- Proxy handlers write to a local buffered queue; a dedicated goroutine forwards to the main channel.
|
||||
3. Coalesce repetitive events.
|
||||
- Aggregate similar warnings or per-interval request counters instead of per-request chatter.
|
||||
4. Add lightweight metrics.
|
||||
- Track dropped/coalesced events and queue depth so pressure is visible during development.
|
||||
|
||||
## Practical near-term suggestion
|
||||
|
||||
Start with a small event relay + drop policy for non-critical request events when queue depth is high.
|
||||
This contains proxy-path stalls without changing the external event model too much.
|
||||
@ -20,6 +20,20 @@ func StartProxy(ps *model.ProxyServer, ch chan<- model.Event) {
|
||||
Body: fmt.Sprintf("proxy server started on %s", (*ps.Listener).Addr().String()),
|
||||
}
|
||||
|
||||
if ps.CAReady && !ps.CATrusted {
|
||||
body := fmt.Sprintf("HTTPS interception CA available at %s; trust this certificate to inspect HTTPS traffic", ps.CACertPath)
|
||||
eventType := model.EventTypeWarn
|
||||
if ps.CACreated {
|
||||
body = fmt.Sprintf("generated HTTPS interception CA at %s; trust this certificate to inspect HTTPS traffic", ps.CACertPath)
|
||||
}
|
||||
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: eventType,
|
||||
Body: body,
|
||||
}
|
||||
}
|
||||
|
||||
if err := ps.Server.Serve(*ps.Listener); err != nil {
|
||||
if errors.Is(err, http.ErrServerClosed) {
|
||||
return
|
||||
|
||||
@ -4,9 +4,12 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"termtap.dev/internal/app"
|
||||
"termtap.dev/internal/model"
|
||||
"termtap.dev/internal/proxy"
|
||||
"termtap.dev/internal/tui"
|
||||
)
|
||||
|
||||
@ -14,6 +17,11 @@ import (
|
||||
const proxy_addr = "127.0.0.1:8080"
|
||||
|
||||
func Run(args []string) {
|
||||
if len(args) >= 2 && args[1] == "cert" {
|
||||
runCert()
|
||||
return
|
||||
}
|
||||
|
||||
cmd, ok := parseCommand(args)
|
||||
if !ok {
|
||||
displayHelp()
|
||||
@ -55,8 +63,53 @@ func parseCommand(args []string) (model.Command, bool) {
|
||||
func displayHelp() {
|
||||
helpText := `
|
||||
usage:
|
||||
tap cert
|
||||
tap run -- <command> [args...]
|
||||
`
|
||||
|
||||
fmt.Fprintln(os.Stderr, helpText)
|
||||
}
|
||||
|
||||
func runCert() {
|
||||
ca, err := proxy.EnsureCertificateAuthority()
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
certPath := ca.CertPath()
|
||||
quotedCertPath := strconv.Quote(certPath)
|
||||
fmt.Printf("Certificate path: %s\n", certPath)
|
||||
if ca.WasCreated() {
|
||||
fmt.Println("Created a new local HTTPS interception CA.")
|
||||
} else {
|
||||
fmt.Println("Using existing local HTTPS interception CA.")
|
||||
}
|
||||
|
||||
trusted, err := ca.IsTrustedBySystem()
|
||||
if err != nil {
|
||||
fmt.Printf("System trust check failed: %v\n", err)
|
||||
} else if trusted {
|
||||
fmt.Println("System trust store: trusted")
|
||||
} else {
|
||||
fmt.Println("System trust store: not trusted")
|
||||
}
|
||||
|
||||
if runtime.GOOS != "linux" {
|
||||
fmt.Println("Install this certificate into your OS or client trust store to inspect HTTPS traffic.")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("Trust instructions (Linux):")
|
||||
fmt.Println("Debian/Ubuntu:")
|
||||
fmt.Printf(" sudo cp %s /usr/local/share/ca-certificates/termtap.crt\n", quotedCertPath)
|
||||
fmt.Println(" sudo update-ca-certificates")
|
||||
fmt.Println("Fedora/RHEL/CentOS:")
|
||||
fmt.Printf(" sudo cp %s /etc/pki/ca-trust/source/anchors/termtap.crt\n", quotedCertPath)
|
||||
fmt.Println(" sudo update-ca-trust")
|
||||
fmt.Println("Arch:")
|
||||
fmt.Printf(" sudo trust anchor %s\n", quotedCertPath)
|
||||
fmt.Println()
|
||||
fmt.Println("Quick curl test:")
|
||||
fmt.Printf(" curl --proxy http://%s --cacert %s https://example.com\n", proxy_addr, quotedCertPath)
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@ -13,6 +14,14 @@ type ProxyServer struct {
|
||||
Listener *net.Listener
|
||||
Server *http.Server
|
||||
Url string
|
||||
|
||||
CACertPath string
|
||||
CAReady bool
|
||||
CACreated bool
|
||||
CATrusted bool
|
||||
|
||||
ConnMu sync.Mutex
|
||||
Conns map[net.Conn]struct{}
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
|
||||
@ -59,7 +59,7 @@ func injectEnv(proc *exec.Cmd, addr string) {
|
||||
injected := []string{
|
||||
"HTTP_PROXY=" + proxyAddr,
|
||||
"http_proxy=" + proxyAddr,
|
||||
"HTTPS_PROXY=" + proxyAddr, // TODO: HTTP NOT SUPPORTED
|
||||
"HTTPS_PROXY=" + proxyAddr,
|
||||
"https_proxy=" + proxyAddr,
|
||||
// "ALL_PROXY=" + proxyAddr,
|
||||
// "all_proxy=" + proxyAddr,
|
||||
|
||||
37
internal/proxy/buffer.go
Normal file
37
internal/proxy/buffer.go
Normal file
@ -0,0 +1,37 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
type bufferedConn struct {
|
||||
net.Conn
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (c *bufferedConn) Read(p []byte) (int, error) {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
|
||||
func wrapBufferedConn(conn net.Conn, readWriter *bufio.ReadWriter) net.Conn {
|
||||
if readWriter == nil {
|
||||
return conn
|
||||
}
|
||||
|
||||
return &bufferedConn{Conn: conn, reader: readWriter}
|
||||
}
|
||||
|
||||
type previewReadCloser struct {
|
||||
io.ReadCloser
|
||||
preview *bodyPreview
|
||||
}
|
||||
|
||||
func (p *previewReadCloser) Read(data []byte) (int, error) {
|
||||
n, err := p.ReadCloser.Read(data)
|
||||
if n > 0 {
|
||||
p.preview.Write(data[:n])
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
335
internal/proxy/certs.go
Normal file
335
internal/proxy/certs.go
Normal file
@ -0,0 +1,335 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
caDirName = "termtap"
|
||||
caCertName = "mitm-ca-cert.pem"
|
||||
caKeyName = "mitm-ca-key.pem"
|
||||
caValidFor = 10 * 365 * 24 * time.Hour
|
||||
leafValidFor = 7 * 24 * time.Hour
|
||||
maxLeafCerts = 256
|
||||
)
|
||||
|
||||
type CertificateAuthority struct {
|
||||
cert *x509.Certificate
|
||||
key *ecdsa.PrivateKey
|
||||
certPath string
|
||||
keyPath string
|
||||
wasCreated bool
|
||||
|
||||
mu sync.Mutex
|
||||
leafCert map[string]*tls.Certificate
|
||||
leafOrder []string
|
||||
}
|
||||
|
||||
func loadOrCreateCertificateAuthority() (*CertificateAuthority, error) {
|
||||
configDir, err := os.UserConfigDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve user config dir: %w", err)
|
||||
}
|
||||
|
||||
baseDir := filepath.Join(configDir, caDirName)
|
||||
if err := os.MkdirAll(baseDir, 0o700); err != nil {
|
||||
return nil, fmt.Errorf("create cert dir: %w", err)
|
||||
}
|
||||
|
||||
ca := &CertificateAuthority{
|
||||
certPath: filepath.Join(baseDir, caCertName),
|
||||
keyPath: filepath.Join(baseDir, caKeyName),
|
||||
leafCert: make(map[string]*tls.Certificate),
|
||||
}
|
||||
|
||||
if _, err := os.Stat(ca.certPath); err == nil {
|
||||
if _, err := os.Stat(ca.keyPath); err == nil {
|
||||
if err := ca.load(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ca, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := ca.create(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ca.wasCreated = true
|
||||
return ca, nil
|
||||
}
|
||||
|
||||
func (ca *CertificateAuthority) load() error {
|
||||
certPEM, err := os.ReadFile(ca.certPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read ca cert: %w", err)
|
||||
}
|
||||
|
||||
keyPEM, err := os.ReadFile(ca.keyPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read ca key: %w", err)
|
||||
}
|
||||
|
||||
certBlock, _ := pem.Decode(certPEM)
|
||||
if certBlock == nil {
|
||||
return fmt.Errorf("decode ca cert pem")
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(certBlock.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse ca cert: %w", err)
|
||||
}
|
||||
|
||||
keyBlock, _ := pem.Decode(keyPEM)
|
||||
if keyBlock == nil {
|
||||
return fmt.Errorf("decode ca key pem")
|
||||
}
|
||||
|
||||
key, err := x509.ParseECPrivateKey(keyBlock.Bytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse ca key: %w", err)
|
||||
}
|
||||
|
||||
ca.cert = cert
|
||||
ca.key = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ca *CertificateAuthority) create() error {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate ca key: %w", err)
|
||||
}
|
||||
|
||||
serial, err := randSerialNumber()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{
|
||||
CommonName: "termtap Local MITM CA",
|
||||
Organization: []string{"termtap"},
|
||||
},
|
||||
NotBefore: now.Add(-1 * time.Hour),
|
||||
NotAfter: now.Add(caValidFor),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: true,
|
||||
MaxPathLen: 1,
|
||||
}
|
||||
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ca cert: %w", err)
|
||||
}
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal ca key: %w", err)
|
||||
}
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||
|
||||
if err := writeFileAtomically(ca.certPath, certPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("write ca cert: %w", err)
|
||||
}
|
||||
if err := writeFileAtomically(ca.keyPath, keyPEM, 0o600); err != nil {
|
||||
return fmt.Errorf("write ca key: %w", err)
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse created ca cert: %w", err)
|
||||
}
|
||||
|
||||
ca.cert = cert
|
||||
ca.key = key
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ca *CertificateAuthority) CertificateForHost(host string) (*tls.Certificate, error) {
|
||||
host = normalizeCertHost(host)
|
||||
if host == "" {
|
||||
return nil, fmt.Errorf("empty host for certificate")
|
||||
}
|
||||
|
||||
ca.mu.Lock()
|
||||
defer ca.mu.Unlock()
|
||||
|
||||
if cert, ok := ca.leafCert[host]; ok {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
serial, err := randSerialNumber()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate leaf key: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: serial,
|
||||
Subject: pkix.Name{
|
||||
CommonName: host,
|
||||
},
|
||||
NotBefore: now.Add(-1 * time.Hour),
|
||||
NotAfter: now.Add(leafValidFor),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
tmpl.IPAddresses = []net.IP{ip}
|
||||
} else {
|
||||
tmpl.DNSNames = []string{host}
|
||||
}
|
||||
|
||||
der, err := x509.CreateCertificate(rand.Reader, tmpl, ca.cert, &leafKey.PublicKey, ca.key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create leaf cert: %w", err)
|
||||
}
|
||||
|
||||
tlsCert := &tls.Certificate{
|
||||
Certificate: [][]byte{der, ca.cert.Raw},
|
||||
PrivateKey: leafKey,
|
||||
}
|
||||
leafParsed, err := x509.ParseCertificate(der)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse leaf cert: %w", err)
|
||||
}
|
||||
tlsCert.Leaf = leafParsed
|
||||
|
||||
ca.leafCert[host] = tlsCert
|
||||
ca.leafOrder = append(ca.leafOrder, host)
|
||||
if len(ca.leafOrder) > maxLeafCerts {
|
||||
evicted := ca.leafOrder[0]
|
||||
ca.leafOrder = ca.leafOrder[1:]
|
||||
delete(ca.leafCert, evicted)
|
||||
}
|
||||
|
||||
return tlsCert, nil
|
||||
}
|
||||
|
||||
func (ca *CertificateAuthority) CertPath() string {
|
||||
if ca == nil {
|
||||
return ""
|
||||
}
|
||||
return ca.certPath
|
||||
}
|
||||
|
||||
func (ca *CertificateAuthority) WasCreated() bool {
|
||||
if ca == nil {
|
||||
return false
|
||||
}
|
||||
return ca.wasCreated
|
||||
}
|
||||
|
||||
func (ca *CertificateAuthority) IsTrustedBySystem() (bool, error) {
|
||||
if ca == nil || ca.cert == nil {
|
||||
return false, fmt.Errorf("certificate authority is unavailable")
|
||||
}
|
||||
|
||||
roots, err := x509.SystemCertPool()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("load system cert pool: %w", err)
|
||||
}
|
||||
if roots == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
_, err = ca.cert.Verify(x509.VerifyOptions{Roots: roots})
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if _, ok := errors.AsType[x509.UnknownAuthorityError](err); ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
func EnsureCertificateAuthority() (*CertificateAuthority, error) {
|
||||
return loadOrCreateCertificateAuthority()
|
||||
}
|
||||
|
||||
func randSerialNumber() (*big.Int, error) {
|
||||
limit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serial, err := rand.Int(rand.Reader, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate serial number: %w", err)
|
||||
}
|
||||
return serial, nil
|
||||
}
|
||||
|
||||
func normalizeCertHost(hostport string) string {
|
||||
host := strings.TrimSpace(hostport)
|
||||
if host == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if parsedHost, _, err := net.SplitHostPort(host); err == nil {
|
||||
return parsedHost
|
||||
}
|
||||
|
||||
return host
|
||||
}
|
||||
|
||||
func writeFileAtomically(path string, data []byte, perm os.FileMode) error {
|
||||
dir := filepath.Dir(path)
|
||||
tmpFile, err := os.CreateTemp(dir, ".termtap-tmp-*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmpPath := tmpFile.Name()
|
||||
cleanup := true
|
||||
defer func() {
|
||||
if cleanup {
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := tmpFile.Write(data); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Chmod(perm); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, path); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cleanup = false
|
||||
return nil
|
||||
}
|
||||
@ -1,226 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
// NOTE: Much of this code is AI generated, and is not expected to make it into production
|
||||
|
||||
const maxPreviewBytes = 1024
|
||||
|
||||
func proxyHandler(ch chan<- model.Event) http.Handler {
|
||||
transport := http.DefaultTransport
|
||||
|
||||
// TODO: This should be wired into the main channel, but that will require a model package
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method == http.MethodConnect {
|
||||
http.Error(w, "CONNECT is not supported yet", http.StatusNotImplemented)
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: model.EventTypeWarn,
|
||||
Body: fmt.Sprintf("CONNECT is not supported: %s", req.Host),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if req.URL.Scheme == "" || req.URL.Host == "" {
|
||||
http.Error(w, "request must use absolute-form URLs through the proxy", http.StatusBadRequest)
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: model.EventTypeWarn,
|
||||
Body: fmt.Sprintf("rejected non-proxy request %s %s", req.Method, req.URL.String()),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
request := model.Request{
|
||||
ID: uuid.New(),
|
||||
ResponseData: []byte{},
|
||||
RequestData: []byte{},
|
||||
URL: "",
|
||||
Status: -1,
|
||||
Method: "",
|
||||
Duration: 0,
|
||||
Pending: true,
|
||||
Failed: false,
|
||||
StartTime: start,
|
||||
}
|
||||
|
||||
requestPreview, err := readAndRestoreBody(&req.Body)
|
||||
if err != nil {
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: model.EventTypeWarn,
|
||||
Body: fmt.Sprintf("(%s) failed to read request body: %v", getEndOfUUID(request.ID), err),
|
||||
Request: request,
|
||||
}
|
||||
} else {
|
||||
request.RequestData = []byte(requestPreview)
|
||||
}
|
||||
|
||||
outReq := req.Clone(req.Context())
|
||||
outReq.RequestURI = ""
|
||||
|
||||
request.URL = outReq.URL.Path
|
||||
request.QueryString = outReq.URL.RawQuery
|
||||
request.QueryMap = outReq.URL.Query()
|
||||
request.Host = outReq.Host
|
||||
request.Method = outReq.Method
|
||||
request.RequestHeaders = outReq.Header
|
||||
request.RawURL = outReq.URL.String()
|
||||
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: model.EventTypeRequestStarted,
|
||||
Body: fmt.Sprintf("(%s) %s %s", getEndOfUUID(request.ID), request.Method, request.RawURL),
|
||||
Request: request,
|
||||
}
|
||||
|
||||
resp, err := transport.RoundTrip(outReq)
|
||||
if err != nil {
|
||||
status := statusFromUpstreamError(req, resp, err)
|
||||
|
||||
http.Error(w, http.StatusText(status), status)
|
||||
request.Pending = false
|
||||
request.Failed = true
|
||||
request.Duration = time.Since(start).Round(time.Microsecond)
|
||||
request.Status = status
|
||||
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: model.EventTypeRequestFailed,
|
||||
Body: fmt.Sprintf("(%s) upstream error: %v", getEndOfUUID(request.ID), err),
|
||||
Request: request,
|
||||
}
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
responsePreview, err := readAndRestoreBody(&resp.Body)
|
||||
if err != nil {
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: model.EventTypeWarn,
|
||||
Body: fmt.Sprintf("(%s) failed to read response body: %v", getEndOfUUID(request.ID), err),
|
||||
Request: request,
|
||||
}
|
||||
} else {
|
||||
request.ResponseData = []byte(responsePreview)
|
||||
}
|
||||
|
||||
copyHeader(w.Header(), resp.Header)
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
if _, err := io.Copy(w, resp.Body); err != nil {
|
||||
request.Pending = false
|
||||
request.Failed = true
|
||||
request.Duration = time.Since(start).Round(time.Microsecond)
|
||||
request.Status = resp.StatusCode
|
||||
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: model.EventTypeRequestFailed,
|
||||
Body: fmt.Sprintf("(%s) failed to write response body: %v", getEndOfUUID(request.ID), err),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
request.Duration = time.Since(start).Round(time.Microsecond)
|
||||
request.Status = resp.StatusCode
|
||||
request.ResponseHeaders = resp.Header
|
||||
request.Pending = false
|
||||
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: model.EventTypeRequestFinished,
|
||||
Body: fmt.Sprintf("(%s) %s %s %d %dms", getEndOfUUID(request.ID), request.Method, request.RawURL, request.Status, request.Duration.Milliseconds()),
|
||||
Request: request,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func copyHeader(dst, src http.Header) {
|
||||
for key, values := range src {
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readAndRestoreBody(body *io.ReadCloser) (string, error) {
|
||||
if body == nil || *body == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
payload, err := io.ReadAll(*body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
*body = io.NopCloser(bytes.NewReader(payload))
|
||||
|
||||
preview := payload
|
||||
if len(preview) > maxPreviewBytes {
|
||||
preview = preview[:maxPreviewBytes]
|
||||
}
|
||||
|
||||
text := strings.ReplaceAll(string(preview), "\n", "\\n")
|
||||
if len(payload) > maxPreviewBytes {
|
||||
text += "..."
|
||||
}
|
||||
|
||||
return text, nil
|
||||
}
|
||||
|
||||
func formatHeaders(headers http.Header) string {
|
||||
if len(headers) == 0 {
|
||||
return "<none>"
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(headers))
|
||||
for key, values := range headers {
|
||||
parts = append(parts, fmt.Sprintf("%s=%q", key, strings.Join(values, ",")))
|
||||
}
|
||||
sort.Strings(parts)
|
||||
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
func getEndOfUUID(id uuid.UUID) string {
|
||||
return id.String()[24:]
|
||||
}
|
||||
|
||||
// BUG: Not sure if this actually works, seems to favor the 502
|
||||
func statusFromUpstreamError(req *http.Request, resp *http.Response, err error) int {
|
||||
if resp != nil {
|
||||
return resp.StatusCode
|
||||
}
|
||||
|
||||
if errors.Is(req.Context().Err(), context.Canceled) {
|
||||
return http.StatusBadGateway
|
||||
}
|
||||
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return http.StatusGatewayTimeout
|
||||
}
|
||||
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return http.StatusGatewayTimeout
|
||||
}
|
||||
|
||||
return http.StatusBadGateway
|
||||
}
|
||||
183
internal/proxy/handlers.go
Normal file
183
internal/proxy/handlers.go
Normal file
@ -0,0 +1,183 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
const connectIdleTimeout = 30 * time.Second
|
||||
|
||||
func proxyHandler(ch chan<- model.Event, ca *CertificateAuthority, ps *model.ProxyServer) http.Handler {
|
||||
transport := newUpstreamTransport()
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
if req.Method == http.MethodConnect {
|
||||
handleConnect(w, req, ch, transport, ca, ps)
|
||||
return
|
||||
}
|
||||
|
||||
if req.URL.Scheme == "" || req.URL.Host == "" {
|
||||
http.Error(w, "request must use absolute-form URLs through the proxy", http.StatusBadRequest)
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: model.EventTypeWarn,
|
||||
Body: fmt.Sprintf("rejected non-proxy request %s %s", req.Method, req.URL.String()),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
resp, request, responsePreview, err := roundTripCapturedRequest(req, transport, ch, "", false)
|
||||
if err != nil {
|
||||
status := statusFromUpstreamError(req, resp, err)
|
||||
|
||||
http.Error(w, http.StatusText(status), status)
|
||||
failRequest(ch, request, status, fmt.Sprintf("upstream error: %v", err))
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
copyHeaders(resp.Header, w.Header())
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
if _, err := io.Copy(w, resp.Body); err != nil {
|
||||
request.ResponseData = responsePreview.Preview()
|
||||
failRequest(ch, request, resp.StatusCode, fmt.Sprintf("failed to write response body: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
request.ResponseData = responsePreview.Preview()
|
||||
finishRequest(ch, request, resp.StatusCode)
|
||||
})
|
||||
}
|
||||
|
||||
func handleConnect(w http.ResponseWriter, req *http.Request, ch chan<- model.Event, transport http.RoundTripper, ca *CertificateAuthority, ps *model.ProxyServer) {
|
||||
start := time.Now()
|
||||
|
||||
request := newConnectRequest(req, start)
|
||||
startRequest(ch, request)
|
||||
|
||||
target := req.Host
|
||||
if !strings.Contains(target, ":") {
|
||||
target = net.JoinHostPort(target, "443")
|
||||
}
|
||||
|
||||
if ca == nil {
|
||||
http.Error(w, "HTTPS interception unavailable", http.StatusBadGateway)
|
||||
failRequest(ch, request, http.StatusBadGateway, "HTTPS interception certificate authority is unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
leafCert, err := ca.CertificateForHost(target)
|
||||
if err != nil {
|
||||
http.Error(w, "failed to prepare interception certificate", http.StatusBadGateway)
|
||||
failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("failed to mint interception certificate for %s: %v", target, err))
|
||||
return
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
http.Error(w, "proxy does not support hijacking", http.StatusInternalServerError)
|
||||
failRequest(ch, request, http.StatusInternalServerError, "CONNECT hijack is unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
clientConn, readWriter, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
http.Error(w, "failed to hijack connection", http.StatusInternalServerError)
|
||||
failRequest(ch, request, http.StatusInternalServerError, fmt.Sprintf("CONNECT hijack failed: %v", err))
|
||||
return
|
||||
}
|
||||
trackConnection(ps, clientConn)
|
||||
defer func() {
|
||||
untrackConnection(ps, clientConn)
|
||||
_ = clientConn.Close()
|
||||
}()
|
||||
|
||||
if err := writeConnectEstablished(clientConn, readWriter); err != nil {
|
||||
failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("CONNECT setup failed: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
mitmConn := wrapBufferedConn(clientConn, readWriter)
|
||||
tlsConn := tls.Server(mitmConn, &tls.Config{
|
||||
Certificates: []tls.Certificate{*leafCert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
defer tlsConn.Close()
|
||||
|
||||
_ = clientConn.SetDeadline(time.Now().Add(connectIdleTimeout))
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("TLS handshake with client failed: %v", err))
|
||||
return
|
||||
}
|
||||
_ = clientConn.SetDeadline(time.Time{})
|
||||
|
||||
reader := bufio.NewReader(tlsConn)
|
||||
writer := bufio.NewWriter(tlsConn)
|
||||
|
||||
for {
|
||||
_ = clientConn.SetReadDeadline(time.Now().Add(connectIdleTimeout))
|
||||
innerReq, err := http.ReadRequest(reader)
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
||||
finishRequest(ch, request, http.StatusOK)
|
||||
return
|
||||
}
|
||||
failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("failed to read decrypted HTTPS request: %v", err))
|
||||
return
|
||||
}
|
||||
_ = clientConn.SetReadDeadline(time.Time{})
|
||||
|
||||
resp, captured, responsePreview, err := roundTripCapturedRequest(innerReq, transport, ch, target, true)
|
||||
if err != nil {
|
||||
discardAndCloseBody(innerReq.Body)
|
||||
status := statusFromUpstreamError(innerReq, resp, err)
|
||||
_ = clientConn.SetWriteDeadline(time.Now().Add(connectIdleTimeout))
|
||||
if writeErr := writePlainHTTPError(writer, status); writeErr != nil {
|
||||
failRequest(ch, captured, status, fmt.Sprintf("upstream error: %v", err))
|
||||
failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("failed to write HTTPS error response: %v", writeErr))
|
||||
return
|
||||
}
|
||||
_ = clientConn.SetWriteDeadline(time.Time{})
|
||||
failRequest(ch, captured, status, fmt.Sprintf("upstream error: %v", err))
|
||||
failRequest(ch, request, status, fmt.Sprintf("closing CONNECT tunnel after upstream error: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
_ = clientConn.SetWriteDeadline(time.Now().Add(connectIdleTimeout))
|
||||
if err := resp.Write(writer); err != nil {
|
||||
resp.Body.Close()
|
||||
captured.ResponseData = responsePreview.Preview()
|
||||
failRequest(ch, captured, resp.StatusCode, fmt.Sprintf("failed to write HTTPS response: %v", err))
|
||||
failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("failed to write HTTPS response: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := writer.Flush(); err != nil {
|
||||
_ = clientConn.SetWriteDeadline(time.Time{})
|
||||
resp.Body.Close()
|
||||
captured.ResponseData = responsePreview.Preview()
|
||||
failRequest(ch, captured, resp.StatusCode, fmt.Sprintf("failed to flush HTTPS response: %v", err))
|
||||
failRequest(ch, request, http.StatusBadGateway, fmt.Sprintf("failed to flush HTTPS response: %v", err))
|
||||
return
|
||||
}
|
||||
_ = clientConn.SetWriteDeadline(time.Time{})
|
||||
|
||||
captured.ResponseData = responsePreview.Preview()
|
||||
finishRequest(ch, captured, resp.StatusCode)
|
||||
shouldClose := innerReq.Close || resp.Close
|
||||
resp.Body.Close()
|
||||
if shouldClose {
|
||||
finishRequest(ch, request, http.StatusOK)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
67
internal/proxy/headers.go
Normal file
67
internal/proxy/headers.go
Normal file
@ -0,0 +1,67 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var sensitiveHeaders = map[string]struct{}{
|
||||
"Authorization": {},
|
||||
"Cookie": {},
|
||||
"Proxy-Authorization": {},
|
||||
"Set-Cookie": {},
|
||||
"X-Api-Key": {},
|
||||
}
|
||||
|
||||
var hopByHopHeaders = []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Proxy-Connection",
|
||||
"Te",
|
||||
"Trailer",
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
// Remove headers that are only required for client<->proxy and proxy<->server communication.
|
||||
// Otherwise known as hop-by-hop headers. We do not want to show these to users since they are
|
||||
// used only for internal functioning for the proxy server.
|
||||
func stripHopByHopHeaders(headers http.Header) {
|
||||
if headers == nil {
|
||||
return
|
||||
}
|
||||
|
||||
connectionValues := append([]string(nil), headers.Values("Connection")...)
|
||||
for _, key := range hopByHopHeaders {
|
||||
headers.Del(key)
|
||||
}
|
||||
|
||||
for _, value := range connectionValues {
|
||||
for key := range strings.SplitSeq(value, ",") {
|
||||
headers.Del(strings.TrimSpace(key))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return a new set of headers that has sensitive headers redacted.
|
||||
//
|
||||
// TODO: Maybe use '***' length of header?
|
||||
func redactHeaders(headers http.Header) http.Header {
|
||||
clone := headers.Clone()
|
||||
for key := range clone {
|
||||
if _, ok := sensitiveHeaders[http.CanonicalHeaderKey(key)]; ok {
|
||||
clone.Set(key, "[REDACTED]")
|
||||
}
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func copyHeaders(src, dest http.Header) {
|
||||
for key, values := range src {
|
||||
for _, value := range values {
|
||||
dest.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
50
internal/proxy/preview.go
Normal file
50
internal/proxy/preview.go
Normal file
@ -0,0 +1,50 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const maxPreviewBytes = 1024 * 64 // 64 kb (maybe we want 256kb)
|
||||
|
||||
type bodyPreview struct {
|
||||
enabled bool
|
||||
truncated bool
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func newBodyPreview(contentType string) *bodyPreview {
|
||||
return &bodyPreview{enabled: canDisplayContent(contentType)}
|
||||
}
|
||||
|
||||
func (p *bodyPreview) Write(data []byte) {
|
||||
if p == nil || !p.enabled || len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
remaining := maxPreviewBytes - p.buf.Len()
|
||||
if remaining <= 0 {
|
||||
p.truncated = true
|
||||
return
|
||||
}
|
||||
|
||||
if len(data) > remaining {
|
||||
data = data[:remaining]
|
||||
p.truncated = true
|
||||
}
|
||||
|
||||
_, _ = p.buf.Write(data)
|
||||
}
|
||||
|
||||
func (p *bodyPreview) Preview() []byte {
|
||||
if p == nil || !p.enabled || p.buf.Len() == 0 {
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
text := strings.ReplaceAll(p.buf.String(), "\n", "\\n")
|
||||
if p.truncated {
|
||||
text += "..."
|
||||
}
|
||||
|
||||
return []byte(text)
|
||||
}
|
||||
128
internal/proxy/requests.go
Normal file
128
internal/proxy/requests.go
Normal file
@ -0,0 +1,128 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
func roundTripCapturedRequest(req *http.Request, transport http.RoundTripper, ch chan<- model.Event, defaultHost string, interceptedTLS bool) (*http.Response, model.Request, *bodyPreview, error) {
|
||||
start := time.Now()
|
||||
request := model.Request{
|
||||
ID: uuid.New(),
|
||||
ResponseData: []byte{},
|
||||
RequestData: []byte{},
|
||||
URL: "",
|
||||
Status: -1,
|
||||
Method: "",
|
||||
Duration: 0,
|
||||
Pending: true,
|
||||
Failed: false,
|
||||
StartTime: start,
|
||||
}
|
||||
|
||||
outReq := req.Clone(req.Context())
|
||||
outReq.RequestURI = ""
|
||||
if interceptedTLS {
|
||||
if outReq.URL.Scheme == "" {
|
||||
outReq.URL.Scheme = "https"
|
||||
}
|
||||
if outReq.URL.Host == "" {
|
||||
outReq.URL.Host = defaultHost
|
||||
}
|
||||
if outReq.Host == "" {
|
||||
outReq.Host = defaultHost
|
||||
}
|
||||
}
|
||||
stripHopByHopHeaders(outReq.Header)
|
||||
requestPreview := newBodyPreview(outReq.Header.Get("Content-Type"))
|
||||
if outReq.Body != nil {
|
||||
outReq.Body = &previewReadCloser{ReadCloser: outReq.Body, preview: requestPreview}
|
||||
}
|
||||
|
||||
request.URL = outReq.URL.Path
|
||||
request.QueryString = outReq.URL.RawQuery
|
||||
request.QueryMap = outReq.URL.Query()
|
||||
request.Host = outReq.Host
|
||||
request.Method = outReq.Method
|
||||
request.RequestHeaders = redactHeaders(outReq.Header)
|
||||
request.RawURL = outReq.URL.String()
|
||||
if request.RawURL == "" {
|
||||
request.RawURL = outReq.Host + outReq.URL.RequestURI()
|
||||
}
|
||||
|
||||
startRequest(ch, request)
|
||||
|
||||
resp, err := transport.RoundTrip(outReq)
|
||||
request.RequestData = requestPreview.Preview()
|
||||
if err != nil {
|
||||
return resp, request, nil, err
|
||||
}
|
||||
|
||||
stripHopByHopHeaders(resp.Header)
|
||||
responsePreview := newBodyPreview(resp.Header.Get("Content-Type"))
|
||||
if resp.Body != nil {
|
||||
resp.Body = &previewReadCloser{ReadCloser: resp.Body, preview: responsePreview}
|
||||
}
|
||||
|
||||
request.ResponseHeaders = redactHeaders(resp.Header)
|
||||
return resp, request, responsePreview, nil
|
||||
}
|
||||
|
||||
func newConnectRequest(req *http.Request, start time.Time) model.Request {
|
||||
// CONNECT requests do not have as much data, which is why we use Host for most of the pieces
|
||||
return model.Request{
|
||||
ID: uuid.New(),
|
||||
ResponseData: []byte{},
|
||||
RequestData: []byte{},
|
||||
URL: req.Host,
|
||||
RawURL: req.Host,
|
||||
Host: req.Host,
|
||||
Status: -1,
|
||||
Method: req.Method,
|
||||
Duration: 0,
|
||||
Pending: true,
|
||||
Failed: false,
|
||||
StartTime: start,
|
||||
}
|
||||
}
|
||||
|
||||
func finishRequest(ch chan<- model.Event, request model.Request, status int) {
|
||||
request.Pending = false
|
||||
request.Failed = false
|
||||
request.Status = status
|
||||
request.Duration = time.Since(request.StartTime).Round(time.Microsecond)
|
||||
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: model.EventTypeRequestFinished,
|
||||
Body: fmt.Sprintf("(%s) %s %s %d %dms", getEndOfUUID(request.ID), request.Method, request.RawURL, request.Status, request.Duration.Milliseconds()),
|
||||
Request: request,
|
||||
}
|
||||
}
|
||||
|
||||
func failRequest(ch chan<- model.Event, request model.Request, status int, body string) {
|
||||
request.Pending = false
|
||||
request.Failed = true
|
||||
request.Status = status
|
||||
request.Duration = time.Since(request.StartTime).Round(time.Microsecond)
|
||||
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: model.EventTypeRequestFailed,
|
||||
Body: fmt.Sprintf("(%s) %s", getEndOfUUID(request.ID), body),
|
||||
Request: request,
|
||||
}
|
||||
}
|
||||
|
||||
func startRequest(ch chan<- model.Event, request model.Request) {
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: model.EventTypeRequestStarted,
|
||||
Body: fmt.Sprintf("(%s) %s %s", getEndOfUUID(request.ID), request.Method, request.RawURL),
|
||||
Request: request,
|
||||
}
|
||||
}
|
||||
30
internal/proxy/secure_utils.go
Normal file
30
internal/proxy/secure_utils.go
Normal file
@ -0,0 +1,30 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
const maxDiscardBodyBytes = 1 << 20
|
||||
|
||||
func writeConnectEstablished(conn net.Conn, readWriter *bufio.ReadWriter) error {
|
||||
if readWriter != nil {
|
||||
if _, err := readWriter.WriteString("HTTP/1.1 200 Connection Established\r\n\r\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
return readWriter.Flush()
|
||||
}
|
||||
|
||||
_, err := conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
|
||||
return err
|
||||
}
|
||||
|
||||
func discardAndCloseBody(body io.ReadCloser) {
|
||||
if body == nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, _ = io.Copy(io.Discard, io.LimitReader(body, maxDiscardBodyBytes))
|
||||
_ = body.Close()
|
||||
}
|
||||
@ -10,7 +10,22 @@ import (
|
||||
"termtap.dev/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
proxyReadHeaderTimeout = 10 * time.Second
|
||||
proxyIdleTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
func NewProxyServer(addr string, ch chan<- model.Event) (*model.ProxyServer, error) {
|
||||
ca, err := loadOrCreateCertificateAuthority()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
trusted, err := ca.IsTrustedBySystem()
|
||||
if err != nil {
|
||||
trusted = false
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -20,8 +35,17 @@ func NewProxyServer(addr string, ch chan<- model.Event) (*model.ProxyServer, err
|
||||
|
||||
ps := &model.ProxyServer{
|
||||
Listener: &listener,
|
||||
Server: &http.Server{Handler: proxyHandler(ch)},
|
||||
Url: url,
|
||||
CACertPath: ca.CertPath(),
|
||||
CAReady: true,
|
||||
CACreated: ca.WasCreated(),
|
||||
CATrusted: trusted,
|
||||
Conns: make(map[net.Conn]struct{}),
|
||||
}
|
||||
ps.Server = &http.Server{
|
||||
Handler: proxyHandler(ch, ca, ps),
|
||||
ReadHeaderTimeout: proxyReadHeaderTimeout,
|
||||
IdleTimeout: proxyIdleTimeout,
|
||||
}
|
||||
|
||||
return ps, nil
|
||||
@ -33,11 +57,53 @@ func Destroy(ps *model.ProxyServer, ch chan<- model.Event) {
|
||||
defer cancel()
|
||||
|
||||
if ps != nil && ps.Server != nil {
|
||||
closeTrackedConnections(ps)
|
||||
_ = ps.Server.Shutdown(ctx)
|
||||
ch <- model.Event{
|
||||
Time: time.Now().Local(),
|
||||
Type: model.EventTypeProxyStarted,
|
||||
Type: model.EventTypeProxyStopped,
|
||||
Body: "proxy server was destroyed",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func trackConnection(ps *model.ProxyServer, conn net.Conn) {
|
||||
if ps == nil || conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ps.ConnMu.Lock()
|
||||
defer ps.ConnMu.Unlock()
|
||||
ps.Conns[conn] = struct{}{}
|
||||
}
|
||||
|
||||
func untrackConnection(ps *model.ProxyServer, conn net.Conn) {
|
||||
if ps == nil || conn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ps.ConnMu.Lock()
|
||||
defer ps.ConnMu.Unlock()
|
||||
delete(ps.Conns, conn)
|
||||
}
|
||||
|
||||
func closeTrackedConnections(ps *model.ProxyServer) {
|
||||
if ps == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get all of the connections while claiming the mutex.
|
||||
// Then close the mutex to allow access to the server object quicker.
|
||||
// Then a loop can run to close the connections, without needing access
|
||||
// to the server's mutex.
|
||||
ps.ConnMu.Lock()
|
||||
conns := make([]net.Conn, 0, len(ps.Conns))
|
||||
for conn := range ps.Conns {
|
||||
conns = append(conns, conn)
|
||||
}
|
||||
ps.ConnMu.Unlock()
|
||||
|
||||
for _, conn := range conns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
111
internal/proxy/utils.go
Normal file
111
internal/proxy/utils.go
Normal file
@ -0,0 +1,111 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var validContentTypes = []string{
|
||||
"application/graphql",
|
||||
"application/javascript",
|
||||
"application/json",
|
||||
"application/x-www-form-urlencoded",
|
||||
"application/xml",
|
||||
"+json",
|
||||
"+xml",
|
||||
}
|
||||
|
||||
func canDisplayContent(contentType string) bool {
|
||||
if contentType == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
contentType = strings.ToLower(contentType)
|
||||
if strings.HasPrefix(contentType, "text/") {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, t := range validContentTypes {
|
||||
if strings.Contains(contentType, t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// NOTE: Currently unused, will be reference for the future header rendering
|
||||
func formatHeaders(headers http.Header) string {
|
||||
if len(headers) == 0 {
|
||||
return "<none>"
|
||||
}
|
||||
|
||||
parts := make([]string, 0, len(headers))
|
||||
for key, values := range headers {
|
||||
parts = append(parts, fmt.Sprintf("%s=%q", key, strings.Join(values, ",")))
|
||||
}
|
||||
sort.Strings(parts)
|
||||
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
func getEndOfUUID(id uuid.UUID) string {
|
||||
return id.String()[24:]
|
||||
}
|
||||
|
||||
// BUG: Not sure if this actually works, seems to favor the 502
|
||||
func statusFromUpstreamError(req *http.Request, resp *http.Response, err error) int {
|
||||
if resp != nil {
|
||||
return resp.StatusCode
|
||||
}
|
||||
|
||||
if errors.Is(req.Context().Err(), context.Canceled) {
|
||||
return http.StatusBadGateway
|
||||
}
|
||||
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return http.StatusGatewayTimeout
|
||||
}
|
||||
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||
return http.StatusGatewayTimeout
|
||||
}
|
||||
|
||||
return http.StatusBadGateway
|
||||
}
|
||||
|
||||
func newUpstreamTransport() http.RoundTripper {
|
||||
transport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
transport.Proxy = nil
|
||||
return transport
|
||||
}
|
||||
|
||||
func writePlainHTTPError(w *bufio.Writer, status int) error {
|
||||
resp := &http.Response{
|
||||
StatusCode: status,
|
||||
Status: fmt.Sprintf("%d %s", status, http.StatusText(status)),
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(http.StatusText(status))),
|
||||
ContentLength: int64(len(http.StatusText(status))),
|
||||
Close: false,
|
||||
}
|
||||
resp.Header.Set("Content-Type", "text/plain; charset=utf-8")
|
||||
resp.Header.Set("Content-Length", fmt.Sprintf("%d", len(http.StatusText(status))))
|
||||
if err := resp.Write(w); err != nil {
|
||||
return err
|
||||
}
|
||||
return w.Flush()
|
||||
}
|
||||
@ -23,7 +23,7 @@ func (m Model) renderStatusBar(w int) string {
|
||||
|
||||
avg := int(msSum) / max(1, len(m.requests))
|
||||
left := fmt.Sprintf(" tap %3d reqs | %d err | avg %dms", len(m.requests), errCount, avg)
|
||||
right := "j/k nav / search tab panel e events o output r/^r restart q quit "
|
||||
right := "j/k nav / search tab panel e events o output ^r restart q quit "
|
||||
|
||||
spaceSize := max(w-(len(left)+len(right)), 0)
|
||||
space := strings.Repeat(" ", spaceSize)
|
||||
@ -57,7 +57,7 @@ func (m Model) renderRequestPane(w, h int) []string {
|
||||
left := fmt.Sprintf(
|
||||
" %-7s %-24s %s",
|
||||
strings.ToUpper(req.Method),
|
||||
req.Host,
|
||||
truncate(req.Host, 24),
|
||||
req.URL,
|
||||
)
|
||||
right := fmt.Sprintf(
|
||||
@ -98,6 +98,7 @@ func (m Model) renderDetailsPane(w, h int) []string {
|
||||
for y := range lines {
|
||||
lines[y] = m.theme.Text.Render(strings.Repeat(" ", w))
|
||||
}
|
||||
|
||||
return lines
|
||||
}
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
|
||||
switch msg.String() {
|
||||
case "ctrl+c", "q":
|
||||
return m, tea.Quit
|
||||
case "r", tea.KeyCtrlR.String():
|
||||
case tea.KeyCtrlR.String():
|
||||
if m.restarting {
|
||||
return m, nil
|
||||
}
|
||||
@ -95,6 +95,10 @@ func (m *Model) applyMessage(msg model.Event) {
|
||||
}
|
||||
|
||||
func (m *Model) createRequest(req model.Request) {
|
||||
if req.Method == "CONNECT" {
|
||||
return
|
||||
}
|
||||
|
||||
m.requests = append(m.requests, req)
|
||||
|
||||
// If we passed the max, delete the first one
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user