267 lines
6.6 KiB
Go
267 lines
6.6 KiB
Go
package proxy
|
|
|
|
import (
|
|
"errors"
|
|
"math/big"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
)
|
|
|
|
func TestNormalizeCertHost(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
in string
|
|
want string
|
|
}{
|
|
{name: "host and port", in: "example.com:443", want: "example.com"},
|
|
{name: "plain host", in: "example.com", want: "example.com"},
|
|
{name: "whitespace trims", in: " example.com:8443 ", want: "example.com"},
|
|
{name: "empty", in: " ", want: ""},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
if got := normalizeCertHost(tt.in); got != tt.want {
|
|
t.Fatalf("normalizeCertHost(%q) = %q, want %q", tt.in, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRandSerialNumber(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
serial, err := randSerialNumber()
|
|
if err != nil {
|
|
t.Fatalf("randSerialNumber() error = %v", err)
|
|
}
|
|
if serial == nil {
|
|
t.Fatal("serial is nil")
|
|
}
|
|
if serial.Sign() < 0 {
|
|
t.Fatalf("serial must be non-negative, got %v", serial)
|
|
}
|
|
|
|
limit := new(big.Int).Lsh(big.NewInt(1), 128)
|
|
if serial.Cmp(limit) >= 0 {
|
|
t.Fatalf("serial must be < 2^128, got %v", serial)
|
|
}
|
|
}
|
|
|
|
func TestWriteFileAtomically(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
path := filepath.Join(dir, "cert.pem")
|
|
|
|
if err := writeFileAtomically(path, []byte("first"), 0o600); err != nil {
|
|
t.Fatalf("first writeFileAtomically() error = %v", err)
|
|
}
|
|
if err := writeFileAtomically(path, []byte("second"), 0o600); err != nil {
|
|
t.Fatalf("second writeFileAtomically() error = %v", err)
|
|
}
|
|
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
t.Fatalf("ReadFile() error = %v", err)
|
|
}
|
|
if got, want := string(data), "second"; got != want {
|
|
t.Fatalf("file contents = %q, want %q", got, want)
|
|
}
|
|
|
|
info, err := os.Stat(path)
|
|
if err != nil {
|
|
t.Fatalf("Stat() error = %v", err)
|
|
}
|
|
if got := info.Mode().Perm(); got != 0o600 {
|
|
t.Fatalf("file permissions = %#o, want %#o", got, 0o600)
|
|
}
|
|
}
|
|
|
|
func TestCertificateAuthority_Basics(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var nilCA *CertificateAuthority
|
|
if got := nilCA.CertPath(); got != "" {
|
|
t.Fatalf("nil CertPath() = %q, want empty", got)
|
|
}
|
|
if got := nilCA.WasCreated(); got {
|
|
t.Fatalf("nil WasCreated() = %v, want false", got)
|
|
}
|
|
|
|
ca := newTestCA(t)
|
|
ca.certPath = "/tmp/test-ca.pem"
|
|
ca.wasCreated = true
|
|
if got, want := ca.CertPath(), "/tmp/test-ca.pem"; got != want {
|
|
t.Fatalf("CertPath() = %q, want %q", got, want)
|
|
}
|
|
if !ca.WasCreated() {
|
|
t.Fatal("WasCreated() = false, want true")
|
|
}
|
|
}
|
|
|
|
func TestCertificateForHost(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ca := newTestCA(t)
|
|
|
|
t.Run("empty host returns error", func(t *testing.T) {
|
|
t.Parallel()
|
|
cert, err := ca.CertificateForHost(" ")
|
|
if err == nil {
|
|
t.Fatal("CertificateForHost() error = nil, want non-nil")
|
|
}
|
|
if cert != nil {
|
|
t.Fatalf("cert = %#v, want nil", cert)
|
|
}
|
|
})
|
|
|
|
t.Run("cache hit returns same pointer", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
c1, err := ca.CertificateForHost("example.com:443")
|
|
if err != nil {
|
|
t.Fatalf("first CertificateForHost() error = %v", err)
|
|
}
|
|
c2, err := ca.CertificateForHost("example.com")
|
|
if err != nil {
|
|
t.Fatalf("second CertificateForHost() error = %v", err)
|
|
}
|
|
|
|
if c1 != c2 {
|
|
t.Fatal("expected same certificate pointer from cache")
|
|
}
|
|
})
|
|
|
|
t.Run("ip and dns SAN selection", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ipCert, err := ca.CertificateForHost("127.0.0.1:443")
|
|
if err != nil {
|
|
t.Fatalf("ip CertificateForHost() error = %v", err)
|
|
}
|
|
if ipCert.Leaf == nil {
|
|
t.Fatal("ip cert leaf is nil")
|
|
}
|
|
if len(ipCert.Leaf.IPAddresses) == 0 {
|
|
t.Fatal("ip cert should contain IP SAN")
|
|
}
|
|
if len(ipCert.Leaf.DNSNames) != 0 {
|
|
t.Fatalf("ip cert DNSNames = %v, want empty", ipCert.Leaf.DNSNames)
|
|
}
|
|
|
|
dnsCert, err := ca.CertificateForHost("service.local")
|
|
if err != nil {
|
|
t.Fatalf("dns CertificateForHost() error = %v", err)
|
|
}
|
|
if dnsCert.Leaf == nil {
|
|
t.Fatal("dns cert leaf is nil")
|
|
}
|
|
if len(dnsCert.Leaf.DNSNames) == 0 {
|
|
t.Fatal("dns cert should contain DNS SAN")
|
|
}
|
|
})
|
|
|
|
t.Run("evicts oldest entry over maxLeafCerts", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ca2 := newTestCA(t)
|
|
for i := 0; i < maxLeafCerts+1; i++ {
|
|
host := filepath.Base(filepath.Join("h", big.NewInt(int64(i)).String()+".example"))
|
|
if _, err := ca2.CertificateForHost(host); err != nil {
|
|
t.Fatalf("CertificateForHost(%q) error = %v", host, err)
|
|
}
|
|
}
|
|
|
|
if len(ca2.leafOrder) != maxLeafCerts {
|
|
t.Fatalf("leafOrder len = %d, want %d", len(ca2.leafOrder), maxLeafCerts)
|
|
}
|
|
if _, ok := ca2.leafCert["0.example"]; ok {
|
|
t.Fatal("expected oldest cert to be evicted")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestIsTrustedBySystem(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var nilCA *CertificateAuthority
|
|
_, err := nilCA.IsTrustedBySystem()
|
|
if err == nil {
|
|
t.Fatal("nil IsTrustedBySystem() error = nil, want non-nil")
|
|
}
|
|
|
|
ca := &CertificateAuthority{}
|
|
_, err = ca.IsTrustedBySystem()
|
|
if err == nil {
|
|
t.Fatal("missing-cert IsTrustedBySystem() error = nil, want non-nil")
|
|
}
|
|
|
|
t.Run("untrusted generated CA returns false without error", func(t *testing.T) {
|
|
t.Parallel()
|
|
ca := newTestCA(t)
|
|
|
|
trusted, err := ca.IsTrustedBySystem()
|
|
if err != nil {
|
|
t.Fatalf("IsTrustedBySystem() error = %v, want nil for unknown authority", err)
|
|
}
|
|
if trusted {
|
|
t.Fatal("trusted = true, want false for generated test CA")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestEnsureCertificateAuthority(t *testing.T) {
|
|
configRoot := t.TempDir()
|
|
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
|
|
ca, err := EnsureCertificateAuthority()
|
|
if err != nil {
|
|
t.Fatalf("EnsureCertificateAuthority() error = %v", err)
|
|
}
|
|
if ca == nil {
|
|
t.Fatal("EnsureCertificateAuthority() returned nil CA")
|
|
}
|
|
if ca.CertPath() == "" {
|
|
t.Fatal("EnsureCertificateAuthority() returned empty cert path")
|
|
}
|
|
if _, statErr := os.Stat(ca.CertPath()); statErr != nil {
|
|
t.Fatalf("expected cert on disk, stat error = %v", statErr)
|
|
}
|
|
}
|
|
|
|
func TestWriteFileAtomically_ErrorPath(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
err := writeFileAtomically(filepath.Join("/nope", "bad", "path.pem"), []byte("x"), 0o600)
|
|
if err == nil {
|
|
t.Fatal("writeFileAtomically() error = nil, want non-nil")
|
|
}
|
|
if errors.Is(err, os.ErrNotExist) {
|
|
return
|
|
}
|
|
// Accept platform-dependent fs errors as long as function fails.
|
|
}
|
|
|
|
func TestWriteFileAtomically_RenameErrorWhenTargetIsDirectory(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
targetDir := filepath.Join(dir, "target-as-dir")
|
|
if err := os.MkdirAll(targetDir, 0o700); err != nil {
|
|
t.Fatalf("MkdirAll(targetDir) error = %v", err)
|
|
}
|
|
|
|
err := writeFileAtomically(targetDir, []byte("x"), 0o600)
|
|
if err == nil {
|
|
t.Fatal("writeFileAtomically() error = nil, want non-nil")
|
|
}
|
|
}
|
|
|
|
// TODO: Add deterministic tests for loadOrCreateCertificateAuthority trust-store interactions.
|