diff --git a/examples/echo/main.go b/examples/echo/main.go index 3f1b675..436712e 100644 --- a/examples/echo/main.go +++ b/examples/echo/main.go @@ -4,13 +4,18 @@ package main // which hits the other and response with the data provided. import ( + "bytes" + "encoding/json" "fmt" "io" "log" "net" "net/http" "net/url" + "strconv" + "strings" "sync" + "time" ) func main() { @@ -41,34 +46,100 @@ func main() { func startFrontend(upstreamHost string) error { mux := http.NewServeMux() - mux.HandleFunc("/echo", func(w http.ResponseWriter, req *http.Request) { + mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - message := req.URL.Query().Get("message") - upstreamURL := fmt.Sprintf("http://%s:3001/echo?message=%s", upstreamHost, url.QueryEscape(message)) - - resp, err := http.Get(upstreamURL) - if err != nil { - http.Error(w, err.Error(), http.StatusBadGateway) - return - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusBadGateway) - return - } - - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.WriteHeader(resp.StatusCode) - _, _ = w.Write(body) + w.Header().Set("Content-Type", "text/html; charset=utf-8") + _, _ = w.Write([]byte(frontendHTML)) }) - log.Printf("frontend listening on http://127.0.0.1:3000/echo?message=hello") + mux.HandleFunc("/echo", func(w http.ResponseWriter, req *http.Request) { + client := &http.Client{Timeout: parseTimeout(req.URL.Query().Get("timeoutMs"))} + + switch req.Method { + case http.MethodGet: + message := req.URL.Query().Get("message") + upstreamURL := fmt.Sprintf( + "http://%s:3001/echo?message=%s&code=%s&fail=%s&sleepMs=%s", + upstreamHost, + url.QueryEscape(message), + url.QueryEscape(req.URL.Query().Get("code")), + url.QueryEscape(req.URL.Query().Get("fail")), + url.QueryEscape(req.URL.Query().Get("sleepMs")), + ) + + resp, err := client.Get(upstreamURL) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(resp.StatusCode) + _, _ = w.Write(body) + case http.MethodPost: + body, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if !json.Valid(body) { + http.Error(w, "invalid JSON payload", http.StatusBadRequest) + return + } + + upstreamURL := fmt.Sprintf( + "http://%s:3001/echo?code=%s&fail=%s&sleepMs=%s", + upstreamHost, + url.QueryEscape(req.URL.Query().Get("code")), + url.QueryEscape(req.URL.Query().Get("fail")), + url.QueryEscape(req.URL.Query().Get("sleepMs")), + ) + + upstreamReq, err := http.NewRequest(http.MethodPost, upstreamURL, bytes.NewReader(body)) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + upstreamReq.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(upstreamReq) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + + upstreamBody, err := io.ReadAll(resp.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(resp.StatusCode) + _, _ = w.Write(upstreamBody) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } + }) + + log.Printf("frontend UI on http://127.0.0.1:3000") + log.Printf("frontend GET example: http://127.0.0.1:3000/echo?message=hello&code=201&sleepMs=200") + log.Printf("frontend POST example: curl -i -X POST 'http://127.0.0.1:3000/echo?code=202&sleepMs=200' -H 'content-type: application/json' -d '{\"message\":\"hello\"}'") + log.Printf("frontend timeout example: http://127.0.0.1:3000/echo?message=late&sleepMs=4000&timeoutMs=1000") + log.Printf("frontend failure examples: fail=true, fail=drop, fail=timeout, fail=status") log.Printf("frontend calls upstream at http://%s:3001/echo", upstreamHost) return http.ListenAndServe("127.0.0.1:3000", mux) } @@ -76,20 +147,291 @@ func startFrontend(upstreamHost string) error { func startUpstream(upstreamHost string) error { mux := http.NewServeMux() mux.HandleFunc("/echo", func(w http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + code := parseStatusCode(req.URL.Query().Get("code")) + time.Sleep(parseSleep(req.URL.Query().Get("sleepMs"))) + if handleFailureMode(w, req, req.URL.Query().Get("fail"), code) { return } - message := req.URL.Query().Get("message") - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - _, _ = w.Write([]byte(message)) + switch req.Method { + case http.MethodGet: + message := req.URL.Query().Get("message") + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(code) + _, _ = w.Write([]byte(message)) + case http.MethodPost: + body, err := io.ReadAll(req.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(code) + _, _ = w.Write(body) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } }) - log.Printf("upstream listening on http://%s:3001/echo?message=hello", upstreamHost) + log.Printf("upstream listening on http://%s:3001/echo?message=hello&code=201", upstreamHost) + log.Printf("upstream POST example: curl -i -X POST 'http://%s:3001/echo?code=202&sleepMs=200' -H 'content-type: application/json' -d '{\"message\":\"hello\"}'", upstreamHost) return http.ListenAndServe(":3001", mux) } +const frontendHTML = ` + + + + + Echo JSON Demo + + + +
+

Echo JSON Through Frontend

+
+ + +
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+ +
+
Waiting for request...
+
+ + + +` + +func handleFailureMode(w http.ResponseWriter, req *http.Request, raw string, requestedCode int) bool { + mode := strings.ToLower(strings.TrimSpace(raw)) + if mode == "" || mode == "false" || mode == "0" || mode == "no" { + return false + } + + switch mode { + case "true", "drop": + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "drop mode not supported by server", http.StatusInternalServerError) + return true + } + + conn, _, err := hj.Hijack() + if err != nil { + http.Error(w, "failed to drop connection", http.StatusInternalServerError) + return true + } + _ = conn.Close() + return true + case "timeout", "hang": + <-req.Context().Done() + return true + case "status": + status := requestedCode + if status < 400 || status > 599 { + status = http.StatusInternalServerError + } + http.Error(w, fmt.Sprintf("forced failure (%d)", status), status) + return true + default: + if status, ok := parseFailureStatus(mode); ok { + http.Error(w, fmt.Sprintf("forced failure (%d)", status), status) + return true + } + + http.Error(w, "invalid fail mode", http.StatusBadRequest) + return true + } +} + +func parseFailureStatus(mode string) (int, bool) { + status, err := strconv.Atoi(mode) + if err != nil || status < 400 || status > 599 { + return 0, false + } + + return status, true +} + +func parseStatusCode(raw string) int { + if raw == "" { + return http.StatusOK + } + + code, err := strconv.Atoi(raw) + if err != nil || code < 100 || code > 999 { + return http.StatusOK + } + + return code +} + +func parseSleep(raw string) time.Duration { + ms, ok := parseMilliseconds(raw, 0) + if !ok { + return 0 + } + + return time.Duration(ms) * time.Millisecond +} + +func parseTimeout(raw string) time.Duration { + ms, ok := parseMilliseconds(raw, 5000) + if !ok { + return 5 * time.Second + } + if ms == 0 { + return 0 + } + + return time.Duration(ms) * time.Millisecond +} + +func parseMilliseconds(raw string, fallback int) (int, bool) { + if raw == "" { + return fallback, true + } + + ms, err := strconv.Atoi(raw) + if err != nil || ms < 0 { + return fallback, false + } + + return ms, true +} + func findNonLoopbackIPv4() (string, error) { addrs, err := net.InterfaceAddrs() if err != nil { diff --git a/internal/app/process.go b/internal/app/process.go index bdaae32..d049a96 100644 --- a/internal/app/process.go +++ b/internal/app/process.go @@ -18,13 +18,14 @@ func StartProcess(cmd model.Command, addr string, ch chan<- model.Message, sigCh proc := process.NewProcess(cmd, addr, ch) - if err := proc.Start(); err != nil { + if err := proc.Exec.Start(); err != nil { ch <- model.Message{ Type: model.MessageTypeProcessExited, Body: fmt.Sprintf("%q", err), } return } + process.UpdateStatus(proc, true, ch) // Listen for SIGTERM from main process go func() { @@ -32,23 +33,25 @@ func StartProcess(cmd model.Command, addr string, ch chan<- model.Message, sigCh ch <- model.Message{ Type: model.MessageTypeProcessSignaled, - Body: fmt.Sprintf("process with pid '%d' is being killed", proc.Process.Pid), - PID: proc.Process.Pid, + Body: fmt.Sprintf("process with pid '%d' is being killed", proc.Exec.Process.Pid), + PID: proc.Exec.Process.Pid, } - if proc.Process != nil { - _ = proc.Process.Signal(sig) + if proc.Exec != nil { + _ = proc.Exec.Process.Signal(sig) + process.UpdateStatus(proc, false, ch) } }() - if err := proc.Wait(); err != nil { + if err := proc.Exec.Wait(); err != nil { if exitErr, ok := errors.AsType[*exec.ExitError](err); ok { ch <- model.Message{ Type: model.MessageTypeProcessExited, - Body: "process killed itself", - PID: proc.Process.Pid, + Body: fmt.Sprintf("process pid '%d' exited by itself", proc.Exec.Process.Pid), + PID: proc.Exec.Process.Pid, ExitCode: exitErr.ExitCode(), } + process.UpdateStatus(proc, false, ch) return } @@ -56,6 +59,7 @@ func StartProcess(cmd model.Command, addr string, ch chan<- model.Message, sigCh Type: model.MessageTypeFatal, Body: fmt.Sprintf("%q", err), } + process.UpdateStatus(proc, false, ch) return } diff --git a/internal/app/proxy.go b/internal/app/proxy.go index b1d6856..002be92 100644 --- a/internal/app/proxy.go +++ b/internal/app/proxy.go @@ -16,7 +16,7 @@ func StartProxy(addr string, ch chan<- model.Message) { } return } - defer proxy.Destory(ps) + defer proxy.Destory(ps, ch) ch <- model.Message{ Type: model.MessageTypeProxyStarting, diff --git a/internal/app/session.go b/internal/app/session.go index f491af4..fddad2b 100644 --- a/internal/app/session.go +++ b/internal/app/session.go @@ -25,10 +25,15 @@ func StartSession(cmd model.Command, addr string) error { var events []model.Message + var requests []model.Request + for { select { case _ = <-sigCh: + fmt.Println("\n\nEVENTS") printEvents(events) + fmt.Println("\n\nREQUESTS") + printRequests(requests) return nil case msg := <-msgs: { @@ -36,6 +41,21 @@ func StartSession(cmd model.Command, addr string) error { switch msg.Type { case model.MessageTypeFatal: return fmt.Errorf("%s", msg.Body) + + case model.MessageTypeRequestStarted: + log.Printf("[%s] (%s) %s", msg.Type, msg.Request.ID.String(), msg.Body) + requests = append(requests, msg.Request) + + case model.MessageTypeRequestFinished, model.MessageTypeRequestFailed: + log.Printf("[%s] (%s) %s", msg.Type, msg.Request.ID.String(), msg.Body) + + for i := range requests { + if requests[i].ID == msg.Request.ID { + requests[i] = msg.Request + break + } + } + default: log.Printf("[%s] %s", msg.Type, msg.Body) } @@ -50,3 +70,12 @@ func printEvents(events []model.Message) { fmt.Printf("%+v\n", event) } } + +func printRequests(reqs []model.Request) { + for _, req := range reqs { + fmt.Printf("%+v\n", req) + for k, v := range req.QueryMap { + fmt.Printf("key: %s, vals: %+v\n", k, v) + } + } +} diff --git a/internal/model/command.go b/internal/model/command.go deleted file mode 100644 index 91becd5..0000000 --- a/internal/model/command.go +++ /dev/null @@ -1,6 +0,0 @@ -package model - -type Command struct { - Name string - Args []string -} diff --git a/internal/model/message.go b/internal/model/message.go index 3f5f957..207cd64 100644 --- a/internal/model/message.go +++ b/internal/model/message.go @@ -26,11 +26,9 @@ const ( ) type Message struct { - Type MessageType - Body string - PID int - RequestID string - URL string - Status int - ExitCode int + Type MessageType + Body string + PID int + ExitCode int + Request Request } diff --git a/internal/model/process.go b/internal/model/process.go new file mode 100644 index 0000000..59d600a --- /dev/null +++ b/internal/model/process.go @@ -0,0 +1,14 @@ +package model + +import "os/exec" + +type Command struct { + Name string + Args []string +} + +type Process struct { + Command Command + Exec *exec.Cmd + Running bool +} diff --git a/internal/model/proxy.go b/internal/model/proxy.go index 14b653f..ebb1e4e 100644 --- a/internal/model/proxy.go +++ b/internal/model/proxy.go @@ -3,6 +3,10 @@ package model import ( "net" "net/http" + "net/url" + "time" + + "github.com/google/uuid" ) type ProxyServer struct { @@ -10,3 +14,22 @@ type ProxyServer struct { Server *http.Server Url string } + +type Request struct { + ID uuid.UUID + Method string + ResponseData []byte + RequestData []byte + RawURL string + Host string + URL string + QueryString string + QueryMap url.Values + Status int + Duration time.Duration + Pending bool + Failed bool + StartTime time.Time + RequestHeaders http.Header + ResponseHeaders http.Header +} diff --git a/internal/process/runner.go b/internal/process/runner.go index 491ece4..4c192c9 100644 --- a/internal/process/runner.go +++ b/internal/process/runner.go @@ -15,7 +15,7 @@ func CommandString(c model.Command) string { return fmt.Sprintf("%s %s", c.Name, strings.Join(c.Args, " ")) } -func NewProcess(cmd model.Command, addr string, ch chan<- model.Message) *exec.Cmd { +func NewProcess(cmd model.Command, addr string, ch chan<- model.Message) *model.Process { proc := exec.Command(cmd.Name, cmd.Args...) injectEnv(proc, addr) @@ -42,7 +42,11 @@ func NewProcess(cmd model.Command, addr string, ch chan<- model.Message) *exec.C go readPipe(stderr, model.MessageTypeProcessStderr, ch) } - return proc + return &model.Process{ + Command: cmd, + Exec: proc, + Running: false, + } } func injectEnv(proc *exec.Cmd, addr string) { @@ -70,3 +74,33 @@ func readPipe(pipe io.Reader, t model.MessageType, ch chan<- model.Message) { } } } + +func UpdateStatus(proc *model.Process, running bool, ch chan<- model.Message) { + if proc == nil { + return + } + + if proc.Running == running { + return + } + + proc.Running = running + + var ( + t model.MessageType + status string + ) + if running { + t = model.MessageTypeProcessStarted + status = "running" + } else { + t = model.MessageTypeProcessExited + status = "stopped" + } + + ch <- model.Message{ + Type: t, + Body: fmt.Sprintf("Set process pid '%d' status to %s", proc.Exec.Process.Pid, status), + PID: proc.Exec.Process.Pid, + } +} diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index bc98600..3dc90fd 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -2,13 +2,17 @@ package proxy import ( "bytes" + "context" + "errors" "fmt" "io" + "net" "net/http" "sort" "strings" "time" + "github.com/google/uuid" "termtap.dev/internal/model" ) @@ -40,41 +44,86 @@ func proxyHandler(ch chan<- model.Message) http.Handler { } start := time.Now() - // requestPreview, err := readAndRestoreBody(&req.Body) - // if err != nil { - // http.Error(w, "failed to read request body", http.StatusBadRequest) - // log.Printf("!! read request body %s %s: %v", req.Method, req.URL.String(), err) - // return - // } + + request := model.Request{ + ID: uuid.New(), + ResponseData: []byte{}, + RequestData: []byte{}, + URL: "", + Status: -1, + Method: "", + Duration: 0, + Pending: true, + Failed: false, + StartTime: start, + } + + requestPreview, err := readAndRestoreBody(&req.Body) + if err != nil { + ch <- model.Message{ + Type: model.MessageTypeWarn, + Body: fmt.Sprintf("(%s) failed to read request body", request.ID), + Request: request, + } + } else { + request.RequestData = []byte(requestPreview) + } outReq := req.Clone(req.Context()) outReq.RequestURI = "" + + request.URL = outReq.URL.Path + request.QueryString = outReq.URL.RawQuery + request.QueryMap = outReq.URL.Query() + request.Host = outReq.Host + request.Method = outReq.Method + request.RequestHeaders = outReq.Header + request.RawURL = outReq.URL.String() + ch <- model.Message{ - Type: model.MessageTypeRequestStarted, - Body: fmt.Sprintf("-> %s %s", outReq.Method, outReq.URL.String()), + Type: model.MessageTypeRequestStarted, + Body: fmt.Sprintf("-> %+v", request), + Request: request, } resp, err := transport.RoundTrip(outReq) if err != nil { - http.Error(w, "bad gateway", http.StatusBadGateway) + status := statusFromUpstreamError(req, resp, err) + + http.Error(w, http.StatusText(status), status) + request.Pending = false + request.Failed = true + request.Duration = time.Since(start).Round(time.Microsecond) + request.Status = status + ch <- model.Message{ - Type: model.MessageTypeRequestFailed, - Body: fmt.Sprintf("upstream error for %s %s: %v", outReq.Method, outReq.URL.String(), err), + Type: model.MessageTypeRequestFailed, + Body: fmt.Sprintf("upstream error for %s %s: %v", outReq.Method, outReq.URL.String(), err), + Request: request, } return } defer resp.Body.Close() - // responsePreview, err := readAndRestoreBody(&resp.Body) - // if err != nil { - // http.Error(w, "bad gateway", http.StatusBadGateway) - // log.Printf("!! read response body %s %s: %v", outReq.Method, outReq.URL.String(), err) - // return - // } + responsePreview, err := readAndRestoreBody(&resp.Body) + if err != nil { + ch <- model.Message{ + Type: model.MessageTypeWarn, + Body: fmt.Sprintf("(%s) failed to read response body", request.ID), + Request: request, + } + } else { + request.ResponseData = []byte(responsePreview) + } copyHeader(w.Header(), resp.Header) w.WriteHeader(resp.StatusCode) if _, err := io.Copy(w, resp.Body); err != nil { + request.Pending = false + request.Failed = true + request.Duration = time.Since(start).Round(time.Microsecond) + request.Status = resp.StatusCode + ch <- model.Message{ Type: model.MessageTypeRequestFailed, Body: fmt.Sprintf("write response body %s %s: %v", outReq.Method, outReq.URL.String(), err), @@ -82,14 +131,15 @@ func proxyHandler(ch chan<- model.Message) http.Handler { return } + request.Duration = time.Since(start).Round(time.Microsecond) + request.Status = resp.StatusCode + request.ResponseHeaders = resp.Header + request.Pending = false + ch <- model.Message{ - Type: model.MessageTypeRequestFinished, - Body: fmt.Sprintf("<- %s %s %d %s", - outReq.Method, - outReq.URL.String(), - resp.StatusCode, - time.Since(start).Round(time.Millisecond), - ), + Type: model.MessageTypeRequestFinished, + Body: fmt.Sprintf("<- %+v %s", request, formatHeaders(resp.Request.Header)), + Request: request, } }) } @@ -140,3 +190,25 @@ func formatHeaders(headers http.Header) string { return strings.Join(parts, ", ") } + +// BUG: Not sure if this actually works, seems to favor the 502 +func statusFromUpstreamError(req *http.Request, resp *http.Response, err error) int { + if resp != nil { + return resp.StatusCode + } + + if errors.Is(req.Context().Err(), context.Canceled) { + return http.StatusBadGateway + } + + if errors.Is(err, context.DeadlineExceeded) { + return http.StatusGatewayTimeout + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return http.StatusGatewayTimeout + } + + return http.StatusBadGateway +} diff --git a/internal/proxy/server.go b/internal/proxy/server.go index 414c192..e4dd223 100644 --- a/internal/proxy/server.go +++ b/internal/proxy/server.go @@ -28,11 +28,15 @@ func NewProxyServer(addr string, ch chan<- model.Message) (*model.ProxyServer, e } // BUG: Not sure what all this does -func Destory(ps *model.ProxyServer) { +func Destory(ps *model.ProxyServer, ch chan<- model.Message) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() if ps != nil && ps.Server != nil { _ = ps.Server.Shutdown(ctx) + ch <- model.Message{ + Type: model.MessageTypeProxyStopped, + Body: "proxy server was destroyed", + } } } diff --git a/proto/main.go b/proto/main.go deleted file mode 100644 index 40a76f3..0000000 --- a/proto/main.go +++ /dev/null @@ -1,224 +0,0 @@ -package main - -import ( - "bytes" - "context" - "errors" - "fmt" - "io" - "log" - "net" - "net/http" - "os" - "os/exec" - "sort" - "strings" - "time" - - "github.com/google/uuid" -) - -func main() { - if err := parseArgs(); err != nil { - panic(err) - } -} - -func parseArgs() error { - if len(os.Args) < 3 { - return fmt.Errorf("Must use this right") - } - - if os.Args[1] != "run" || os.Args[2] != "--" { - return fmt.Errorf("Must use this right") - } - - cmd := os.Args[3:] - return run(cmd) -} - -func run(cmd []string) error { - fmt.Printf("%+v\n", cmd) - - server, url, err := proxy() - if err != nil { - return err - } - - defer func() { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - _ = server.Shutdown(ctx) - }() - - println(url) - - env := []string{ - "HTTP_PROXY=" + url, - "http_proxy=" + url, - "HTTPS_PROXY=" + url, - "https_proxy=" + url, - "ALL_PROXY=" + url, - "all_proxy=" + url, - "NO_PROXY=", - "no_proxy=", - } - - proc := exec.Command(cmd[0], cmd[1:]...) - proc.Stdin = os.Stdin - proc.Stdout = os.Stdout - proc.Stderr = os.Stderr - proc.Env = append(os.Environ(), env...) - - if err := proc.Start(); err != nil { - return err - } - - if err := proc.Wait(); err != nil { - var exitErr *exec.ExitError - if errors.As(err, &exitErr) { - os.Exit(exitErr.ExitCode()) - } - return fmt.Errorf("wait for command: %w", err) - } - - return nil -} - -func proxy() (*http.Server, string, error) { - addr := "127.0.0.1:8080" - listener, err := net.Listen("tcp", addr) - if err != nil { - return nil, "", err - } - - server := &http.Server{Handler: handler()} - - go func() { - if err := server.Serve(listener); err != nil { - fmt.Printf("%q", err) - } - }() - - url := "http://" + listener.Addr().String() - return server, url, nil -} - -func handler() http.Handler { - transport := http.DefaultTransport - - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.Method == http.MethodConnect { - http.Error(w, "CONNECT is not supported yet", http.StatusNotImplemented) - log.Printf("!! CONNECT %s not supported", req.Host) - return - } - - if req.URL.Scheme == "" || req.URL.Host == "" { - http.Error(w, "request must use absolute-form URLs through the proxy", http.StatusBadRequest) - log.Printf("!! rejected non-proxy request %s %s", req.Method, req.URL.String()) - return - } - - startedAt := time.Now() - id := uuid.New().String() - // requestPreview, err := readAndRestoreBody(&req.Body) - // if err != nil { - // http.Error(w, "failed to read request body", http.StatusBadRequest) - // log.Printf("!! read request body %s %s: %v", req.Method, req.URL.String(), err) - // return - // } - - outReq := req.Clone(req.Context()) - outReq.RequestURI = "" - - log.Printf( - "[%s] -> %s %s\n", - id, - outReq.Method, - outReq.URL.String(), - // formatHeaders(outReq.Header), - // requestPreview, - ) - - resp, err := transport.RoundTrip(outReq) - if err != nil { - http.Error(w, "bad gateway", http.StatusBadGateway) - log.Printf("!! upstream error for %s %s: %v", outReq.Method, outReq.URL.String(), err) - return - } - defer resp.Body.Close() - - // responsePreview, err := readAndRestoreBody(&resp.Body) - // if err != nil { - // http.Error(w, "bad gateway", http.StatusBadGateway) - // log.Printf("!! read response body %s %s: %v", outReq.Method, outReq.URL.String(), err) - // return - // } - - copyHeader(w.Header(), resp.Header) - w.WriteHeader(resp.StatusCode) - if _, err := io.Copy(w, resp.Body); err != nil { - log.Printf("!! write response body %s %s: %v", outReq.Method, outReq.URL.String(), err) - return - } - - log.Printf( - "[%s] <- %s %s %d %s\n", - id, - outReq.Method, - outReq.URL.String(), - resp.StatusCode, - time.Since(startedAt).Round(time.Millisecond), - ) - }) -} - -const maxPreviewBytes = 1024 - -func copyHeader(dst, src http.Header) { - for key, values := range src { - for _, value := range values { - dst.Add(key, value) - } - } -} - -func readAndRestoreBody(body *io.ReadCloser) (string, error) { - if body == nil || *body == nil { - return "", nil - } - - payload, err := io.ReadAll(*body) - if err != nil { - return "", err - } - - *body = io.NopCloser(bytes.NewReader(payload)) - - preview := payload - if len(preview) > maxPreviewBytes { - preview = preview[:maxPreviewBytes] - } - - text := strings.ReplaceAll(string(preview), "\n", "\\n") - if len(payload) > maxPreviewBytes { - text += "..." - } - - return text, nil -} - -func formatHeaders(headers http.Header) string { - if len(headers) == 0 { - return "" - } - - parts := make([]string, 0, len(headers)) - for key, values := range headers { - parts = append(parts, fmt.Sprintf("%s=%q", key, strings.Join(values, ","))) - } - sort.Strings(parts) - - return strings.Join(parts, ", ") -} diff --git a/proto/proxy.go b/proto/proxy.go deleted file mode 100644 index 0a60382..0000000 --- a/proto/proxy.go +++ /dev/null @@ -1,303 +0,0 @@ -package main - -import ( - "bytes" - "context" - "errors" - "flag" - "fmt" - "io" - "log" - "net" - "net/http" - "os" - "os/exec" - "os/signal" - "sort" - "strings" - "syscall" - "time" -) - -const maxPreviewBytes = 1024 - -func test() { - log.SetFlags(log.LstdFlags | log.Lmicroseconds) - - if len(os.Args) < 2 { - printUsage() - os.Exit(1) - } - - switch os.Args[1] { - case "run": - if err := runCommand(os.Args[2:]); err != nil { - log.Fatal(err) - } - case "proxy": - if err := runProxy(os.Args[2:]); err != nil { - log.Fatal(err) - } - default: - printUsage() - os.Exit(1) - } -} - -func printUsage() { - fmt.Fprintln(os.Stderr, "usage:") - fmt.Fprintln(os.Stderr, " tap run -- [args...]") - fmt.Fprintln(os.Stderr, " tap proxy [-listen 127.0.0.1:8080]") -} - -func runCommand(args []string) error { - runFlags := flag.NewFlagSet("run", flag.ExitOnError) - listenAddr := runFlags.String("listen", "127.0.0.1:0", "proxy listen address") - runFlags.SetOutput(io.Discard) - - if err := runFlags.Parse(args); err != nil { - return err - } - - commandArgs := runFlags.Args() - if len(commandArgs) == 0 { - return errors.New("run requires a command after `--`") - } - if commandArgs[0] == "--" { - commandArgs = commandArgs[1:] - } - if len(commandArgs) == 0 { - return errors.New("run requires a command after `--`") - } - - server, proxyURL, err := startProxy(*listenAddr) - if err != nil { - return err - } - defer shutdownServer(server) - - log.Printf("proxy listening on %s", proxyURL) - - cmd := exec.Command(commandArgs[0], commandArgs[1:]...) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Env = withProxyEnv(os.Environ(), proxyURL) - - if err := cmd.Start(); err != nil { - return fmt.Errorf("start command: %w", err) - } - - forwardSignals(cmd.Process) - - if err := cmd.Wait(); err != nil { - var exitErr *exec.ExitError - if errors.As(err, &exitErr) { - os.Exit(exitErr.ExitCode()) - } - return fmt.Errorf("wait for command: %w", err) - } - - return nil -} - -func runProxy(args []string) error { - proxyFlags := flag.NewFlagSet("proxy", flag.ExitOnError) - listenAddr := proxyFlags.String("listen", "127.0.0.1:8080", "proxy listen address") - if err := proxyFlags.Parse(args); err != nil { - return err - } - - server, proxyURL, err := startProxy(*listenAddr) - if err != nil { - return err - } - defer shutdownServer(server) - - log.Printf("proxy listening on %s", proxyURL) - - stop := make(chan os.Signal, 1) - signal.Notify(stop, os.Interrupt, syscall.SIGTERM) - defer signal.Stop(stop) - <-stop - - return nil -} - -func startProxy(listenAddr string) (*http.Server, string, error) { - listener, err := net.Listen("tcp", listenAddr) - if err != nil { - return nil, "", fmt.Errorf("listen on %s: %w", listenAddr, err) - } - - server := &http.Server{Handler: newForwardProxy()} - - go func() { - if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { - log.Printf("proxy server error: %v", err) - } - }() - - proxyURL := "http://" + listener.Addr().String() - return server, proxyURL, nil -} - -func shutdownServer(server *http.Server) { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - _ = server.Shutdown(ctx) -} - -func withProxyEnv(env []string, proxyURL string) []string { - filtered := make([]string, 0, len(env)+5) - for _, entry := range env { - if hasEnvKey(entry, "HTTP_PROXY") || hasEnvKey(entry, "http_proxy") || hasEnvKey(entry, "HTTPS_PROXY") || hasEnvKey(entry, "https_proxy") || hasEnvKey(entry, "ALL_PROXY") || hasEnvKey(entry, "all_proxy") || hasEnvKey(entry, "NO_PROXY") || hasEnvKey(entry, "no_proxy") { - continue - } - filtered = append(filtered, entry) - } - - filtered = append(filtered, - "HTTP_PROXY="+proxyURL, - "http_proxy="+proxyURL, - "HTTPS_PROXY="+proxyURL, - "https_proxy="+proxyURL, - "ALL_PROXY="+proxyURL, - "all_proxy="+proxyURL, - "NO_PROXY=", - "no_proxy=", - ) - - return filtered -} - -func hasEnvKey(entry, key string) bool { - return strings.HasPrefix(entry, key+"=") -} - -func forwardSignals(process *os.Process) { - ch := make(chan os.Signal, 1) - signal.Notify(ch, os.Interrupt, syscall.SIGTERM) - - go func() { - for sig := range ch { - _ = process.Signal(sig) - } - }() -} - -func newForwardProxy() http.Handler { - transport := http.DefaultTransport - - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.Method == http.MethodConnect { - http.Error(w, "CONNECT is not supported yet", http.StatusNotImplemented) - log.Printf("!! CONNECT %s not supported", req.Host) - return - } - - if req.URL.Scheme == "" || req.URL.Host == "" { - http.Error(w, "request must use absolute-form URLs through the proxy", http.StatusBadRequest) - log.Printf("!! rejected non-proxy request %s %s", req.Method, req.URL.String()) - return - } - - startedAt := time.Now() - requestPreview, err := readAndRestoreBody(&req.Body) - if err != nil { - http.Error(w, "failed to read request body", http.StatusBadRequest) - log.Printf("!! read request body %s %s: %v", req.Method, req.URL.String(), err) - return - } - - outReq := req.Clone(req.Context()) - outReq.RequestURI = "" - - log.Printf( - "-> %s %s\n request headers: %s\n request body: %q", - outReq.Method, - outReq.URL.String(), - formatHeaders(outReq.Header), - requestPreview, - ) - - resp, err := transport.RoundTrip(outReq) - if err != nil { - http.Error(w, "bad gateway", http.StatusBadGateway) - log.Printf("!! upstream error for %s %s: %v", outReq.Method, outReq.URL.String(), err) - return - } - defer resp.Body.Close() - - responsePreview, err := readAndRestoreBody(&resp.Body) - if err != nil { - http.Error(w, "bad gateway", http.StatusBadGateway) - log.Printf("!! read response body %s %s: %v", outReq.Method, outReq.URL.String(), err) - return - } - - copyHeader(w.Header(), resp.Header) - w.WriteHeader(resp.StatusCode) - if _, err := io.Copy(w, resp.Body); err != nil { - log.Printf("!! write response body %s %s: %v", outReq.Method, outReq.URL.String(), err) - return - } - - log.Printf( - "<- %s %s %d %s\n response headers: %s\n response body: %q", - outReq.Method, - outReq.URL.String(), - resp.StatusCode, - time.Since(startedAt).Round(time.Millisecond), - formatHeaders(resp.Header), - responsePreview, - ) - }) -} - -func copyHeader(dst, src http.Header) { - for key, values := range src { - for _, value := range values { - dst.Add(key, value) - } - } -} - -func readAndRestoreBody(body *io.ReadCloser) (string, error) { - if body == nil || *body == nil { - return "", nil - } - - payload, err := io.ReadAll(*body) - if err != nil { - return "", err - } - - *body = io.NopCloser(bytes.NewReader(payload)) - - preview := payload - if len(preview) > maxPreviewBytes { - preview = preview[:maxPreviewBytes] - } - - text := strings.ReplaceAll(string(preview), "\n", "\\n") - if len(payload) > maxPreviewBytes { - text += "..." - } - - return text, nil -} - -func formatHeaders(headers http.Header) string { - if len(headers) == 0 { - return "" - } - - parts := make([]string, 0, len(headers)) - for key, values := range headers { - parts = append(parts, fmt.Sprintf("%s=%q", key, strings.Join(values, ","))) - } - sort.Strings(parts) - - return strings.Join(parts, ", ") -} diff --git a/proto/server.go b/proto/server.go deleted file mode 100644 index 3b044b8..0000000 --- a/proto/server.go +++ /dev/null @@ -1,42 +0,0 @@ -package main - -import ( - "fmt" - "io" - "net/http" -) - -func main() { - if err := startDemoServer("127.0.0.1:3000"); err != nil { - panic(err) - } -} - -func startDemoServer(addr string) error { - mux := http.NewServeMux() - mux.HandleFunc("/send", func(w http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - - resp, err := http.Get("http://example.com") - if err != nil { - http.Error(w, err.Error(), http.StatusBadGateway) - return - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusBadGateway) - return - } - - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprintf(w, "sent request to http://example.com\nstatus: %s\nbytes: %d\n", resp.Status, len(body)) - }) - - fmt.Printf("demo server listening on http://%s/send\n", addr) - return http.ListenAndServe(addr, mux) -}