termtap/internal/proxy/certs_error_test.go
Hayden Hargreaves 002773e77f test: AI generated all of these tests
Just for the MVP of course. Need to validate the idea.
2026-04-23 19:47:04 -07:00

243 lines
6.5 KiB
Go

package proxy
import (
"crypto/tls"
"encoding/pem"
"os"
"path/filepath"
"strings"
"testing"
)
func TestLoadOrCreateCertificateAuthority_RecreatesWhenKeyMissing(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
baseDir := filepath.Join(configRoot, caDirName)
if err := os.MkdirAll(baseDir, 0o700); err != nil {
t.Fatalf("MkdirAll() error = %v", err)
}
certPath := filepath.Join(baseDir, caCertName)
if err := os.WriteFile(certPath, []byte("stale-cert"), 0o600); err != nil {
t.Fatalf("WriteFile(cert) error = %v", err)
}
ca, err := loadOrCreateCertificateAuthority()
if err != nil {
t.Fatalf("loadOrCreateCertificateAuthority() error = %v", err)
}
if !ca.WasCreated() {
t.Fatal("WasCreated = false, want true when key is missing")
}
if _, err := os.Stat(filepath.Join(baseDir, caKeyName)); err != nil {
t.Fatalf("expected key file to be created, stat error = %v", err)
}
}
func TestLoadOrCreateCertificateAuthority_LoadErrorOnCorruptFiles(t *testing.T) {
configRoot := t.TempDir()
t.Setenv("XDG_CONFIG_HOME", configRoot)
baseDir := filepath.Join(configRoot, caDirName)
if err := os.MkdirAll(baseDir, 0o700); err != nil {
t.Fatalf("MkdirAll() error = %v", err)
}
certPath := filepath.Join(baseDir, caCertName)
keyPath := filepath.Join(baseDir, caKeyName)
if err := os.WriteFile(certPath, []byte("not-a-pem"), 0o600); err != nil {
t.Fatalf("WriteFile(cert) error = %v", err)
}
if err := os.WriteFile(keyPath, []byte("not-a-pem"), 0o600); err != nil {
t.Fatalf("WriteFile(key) error = %v", err)
}
_, err := loadOrCreateCertificateAuthority()
if err == nil {
t.Fatal("loadOrCreateCertificateAuthority() error = nil, want non-nil")
}
}
func TestCertificateAuthorityLoad_ErrorPaths(t *testing.T) {
t.Parallel()
tests := []struct {
name string
certBytes []byte
keyBytes []byte
wantPart string
}{
{
name: "invalid cert pem",
certBytes: []byte("bad-cert"),
keyBytes: []byte("bad-key"),
wantPart: "decode ca cert pem",
},
{
name: "parse cert fails",
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("bogus")}),
keyBytes: []byte("bad-key"),
wantPart: "parse ca cert",
},
{
name: "invalid key pem",
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}),
keyBytes: []byte("bad-key"),
wantPart: "decode ca key pem",
},
{
name: "parse key fails",
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}),
keyBytes: pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: []byte("bogus")}),
wantPart: "parse ca key",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
ca := &CertificateAuthority{
certPath: filepath.Join(dir, caCertName),
keyPath: filepath.Join(dir, caKeyName),
leafCert: make(map[string]*tls.Certificate),
}
if err := os.WriteFile(ca.certPath, tt.certBytes, 0o600); err != nil {
t.Fatalf("write cert file error = %v", err)
}
if err := os.WriteFile(ca.keyPath, tt.keyBytes, 0o600); err != nil {
t.Fatalf("write key file error = %v", err)
}
err := ca.load()
if err == nil {
t.Fatal("load() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), tt.wantPart) {
t.Fatalf("load() error = %q, want contains %q", err.Error(), tt.wantPart)
}
})
}
}
func TestCertificateAuthorityCreate_ErrorWhenWritePathInvalid(t *testing.T) {
t.Parallel()
ca := &CertificateAuthority{
certPath: filepath.Join("/nope", "missing", "ca-cert.pem"),
keyPath: filepath.Join("/nope", "missing", "ca-key.pem"),
leafCert: make(map[string]*tls.Certificate),
}
err := ca.create()
if err == nil {
t.Fatal("create() error = nil, want non-nil")
}
}
func TestCertificateAuthorityCreate_WriteErrorPaths(t *testing.T) {
t.Parallel()
t.Run("write ca cert wraps error", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
badCertPath := filepath.Join(dir, "cert-as-dir")
if err := os.MkdirAll(badCertPath, 0o700); err != nil {
t.Fatalf("MkdirAll(cert dir) error = %v", err)
}
ca := &CertificateAuthority{
certPath: badCertPath,
keyPath: filepath.Join(dir, "ca-key.pem"),
leafCert: make(map[string]*tls.Certificate),
}
err := ca.create()
if err == nil {
t.Fatal("create() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "write ca cert") {
t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca cert")
}
})
t.Run("write ca key wraps error", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
badKeyPath := filepath.Join(dir, "key-as-dir")
if err := os.MkdirAll(badKeyPath, 0o700); err != nil {
t.Fatalf("MkdirAll(key dir) error = %v", err)
}
ca := &CertificateAuthority{
certPath: filepath.Join(dir, "ca-cert.pem"),
keyPath: badKeyPath,
leafCert: make(map[string]*tls.Certificate),
}
err := ca.create()
if err == nil {
t.Fatal("create() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "write ca key") {
t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca key")
}
})
}
func TestCertificateAuthorityLoad_ReadErrorPaths(t *testing.T) {
t.Parallel()
t.Run("read cert failure", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
ca := &CertificateAuthority{
certPath: filepath.Join(dir, "missing-cert.pem"),
keyPath: filepath.Join(dir, "missing-key.pem"),
leafCert: make(map[string]*tls.Certificate),
}
err := ca.load()
if err == nil {
t.Fatal("load() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "read ca cert") {
t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca cert")
}
})
t.Run("read key failure", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
goodCA := newTestCA(t)
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: goodCA.cert.Raw})
certPath := filepath.Join(dir, "cert.pem")
if err := os.WriteFile(certPath, certPEM, 0o600); err != nil {
t.Fatalf("WriteFile(cert) error = %v", err)
}
ca := &CertificateAuthority{
certPath: certPath,
keyPath: filepath.Join(dir, "missing-key.pem"),
leafCert: make(map[string]*tls.Certificate),
}
err := ca.load()
if err == nil {
t.Fatal("load() error = nil, want non-nil")
}
if !strings.Contains(err.Error(), "read ca key") {
t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca key")
}
})
}