139 lines
3.9 KiB
Go
139 lines
3.9 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type captureConn struct {
|
|
bytes.Buffer
|
|
}
|
|
|
|
func (c *captureConn) Read(_ []byte) (int, error) { return 0, io.EOF }
|
|
func (c *captureConn) Close() error { return nil }
|
|
func (c *captureConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
|
func (c *captureConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
|
func (c *captureConn) SetDeadline(_ time.Time) error { return nil }
|
|
func (c *captureConn) SetReadDeadline(_ time.Time) error { return nil }
|
|
func (c *captureConn) SetWriteDeadline(_ time.Time) error { return nil }
|
|
|
|
type failWriteConn struct{}
|
|
|
|
func (failWriteConn) Read(_ []byte) (int, error) { return 0, io.EOF }
|
|
func (failWriteConn) Write(_ []byte) (int, error) { return 0, io.ErrClosedPipe }
|
|
func (failWriteConn) Close() error { return nil }
|
|
func (failWriteConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
|
|
func (failWriteConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
|
|
func (failWriteConn) SetDeadline(_ time.Time) error { return nil }
|
|
func (failWriteConn) SetReadDeadline(_ time.Time) error { return nil }
|
|
func (failWriteConn) SetWriteDeadline(_ time.Time) error { return nil }
|
|
|
|
type trackingBody struct {
|
|
data *bytes.Reader
|
|
readN int
|
|
closed bool
|
|
}
|
|
|
|
func (b *trackingBody) Read(p []byte) (int, error) {
|
|
n, err := b.data.Read(p)
|
|
b.readN += n
|
|
return n, err
|
|
}
|
|
|
|
func (b *trackingBody) Close() error {
|
|
b.closed = true
|
|
return nil
|
|
}
|
|
|
|
func TestWriteConnectEstablished(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("writes directly to raw conn", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
conn := &captureConn{}
|
|
if err := writeConnectEstablished(conn, nil); err != nil {
|
|
t.Fatalf("writeConnectEstablished() error = %v", err)
|
|
}
|
|
|
|
if got, want := conn.String(), "HTTP/1.1 200 Connection Established\r\n\r\n"; got != want {
|
|
t.Fatalf("raw write = %q, want %q", got, want)
|
|
}
|
|
})
|
|
|
|
t.Run("writes and flushes with buffered readWriter", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
conn := &captureConn{}
|
|
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bufio.NewWriter(conn))
|
|
if err := writeConnectEstablished(conn, rw); err != nil {
|
|
t.Fatalf("writeConnectEstablished() error = %v", err)
|
|
}
|
|
|
|
if got, want := conn.String(), "HTTP/1.1 200 Connection Established\r\n\r\n"; got != want {
|
|
t.Fatalf("buffered write = %q, want %q", got, want)
|
|
}
|
|
})
|
|
|
|
t.Run("returns flush error", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bufio.NewWriter(errWriter{}))
|
|
err := writeConnectEstablished(&captureConn{}, rw)
|
|
if err == nil {
|
|
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
|
|
}
|
|
})
|
|
|
|
t.Run("returns buffered write error when writer already failed", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
bw := bufio.NewWriter(errWriter{})
|
|
_ = bw.Flush() // set sticky error to force WriteString error path
|
|
rw := bufio.NewReadWriter(bufio.NewReader(strings.NewReader("")), bw)
|
|
|
|
err := writeConnectEstablished(&captureConn{}, rw)
|
|
if err == nil {
|
|
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
|
|
}
|
|
})
|
|
|
|
t.Run("returns raw conn write error", func(t *testing.T) {
|
|
t.Parallel()
|
|
err := writeConnectEstablished(failWriteConn{}, nil)
|
|
if err == nil {
|
|
t.Fatal("writeConnectEstablished() error = nil, want non-nil")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestDiscardAndCloseBody(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("nil body is safe", func(t *testing.T) {
|
|
t.Parallel()
|
|
discardAndCloseBody(nil)
|
|
})
|
|
|
|
t.Run("closes body and discards at most limit", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
payload := bytes.Repeat([]byte("x"), maxDiscardBodyBytes+128)
|
|
body := &trackingBody{data: bytes.NewReader(payload)}
|
|
|
|
discardAndCloseBody(body)
|
|
|
|
if !body.closed {
|
|
t.Fatal("body was not closed")
|
|
}
|
|
if body.readN != maxDiscardBodyBytes {
|
|
t.Fatalf("bytes read = %d, want %d", body.readN, maxDiscardBodyBytes)
|
|
}
|
|
})
|
|
}
|