243 lines
6.5 KiB
Go
243 lines
6.5 KiB
Go
package proxy
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"encoding/pem"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestLoadOrCreateCertificateAuthority_RecreatesWhenKeyMissing(t *testing.T) {
|
|
configRoot := t.TempDir()
|
|
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
|
|
baseDir := filepath.Join(configRoot, caDirName)
|
|
if err := os.MkdirAll(baseDir, 0o700); err != nil {
|
|
t.Fatalf("MkdirAll() error = %v", err)
|
|
}
|
|
|
|
certPath := filepath.Join(baseDir, caCertName)
|
|
if err := os.WriteFile(certPath, []byte("stale-cert"), 0o600); err != nil {
|
|
t.Fatalf("WriteFile(cert) error = %v", err)
|
|
}
|
|
|
|
ca, err := loadOrCreateCertificateAuthority()
|
|
if err != nil {
|
|
t.Fatalf("loadOrCreateCertificateAuthority() error = %v", err)
|
|
}
|
|
if !ca.WasCreated() {
|
|
t.Fatal("WasCreated = false, want true when key is missing")
|
|
}
|
|
if _, err := os.Stat(filepath.Join(baseDir, caKeyName)); err != nil {
|
|
t.Fatalf("expected key file to be created, stat error = %v", err)
|
|
}
|
|
}
|
|
|
|
func TestLoadOrCreateCertificateAuthority_LoadErrorOnCorruptFiles(t *testing.T) {
|
|
configRoot := t.TempDir()
|
|
t.Setenv("XDG_CONFIG_HOME", configRoot)
|
|
|
|
baseDir := filepath.Join(configRoot, caDirName)
|
|
if err := os.MkdirAll(baseDir, 0o700); err != nil {
|
|
t.Fatalf("MkdirAll() error = %v", err)
|
|
}
|
|
|
|
certPath := filepath.Join(baseDir, caCertName)
|
|
keyPath := filepath.Join(baseDir, caKeyName)
|
|
if err := os.WriteFile(certPath, []byte("not-a-pem"), 0o600); err != nil {
|
|
t.Fatalf("WriteFile(cert) error = %v", err)
|
|
}
|
|
if err := os.WriteFile(keyPath, []byte("not-a-pem"), 0o600); err != nil {
|
|
t.Fatalf("WriteFile(key) error = %v", err)
|
|
}
|
|
|
|
_, err := loadOrCreateCertificateAuthority()
|
|
if err == nil {
|
|
t.Fatal("loadOrCreateCertificateAuthority() error = nil, want non-nil")
|
|
}
|
|
}
|
|
|
|
func TestCertificateAuthorityLoad_ErrorPaths(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
certBytes []byte
|
|
keyBytes []byte
|
|
wantPart string
|
|
}{
|
|
{
|
|
name: "invalid cert pem",
|
|
certBytes: []byte("bad-cert"),
|
|
keyBytes: []byte("bad-key"),
|
|
wantPart: "decode ca cert pem",
|
|
},
|
|
{
|
|
name: "parse cert fails",
|
|
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: []byte("bogus")}),
|
|
keyBytes: []byte("bad-key"),
|
|
wantPart: "parse ca cert",
|
|
},
|
|
{
|
|
name: "invalid key pem",
|
|
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}),
|
|
keyBytes: []byte("bad-key"),
|
|
wantPart: "decode ca key pem",
|
|
},
|
|
{
|
|
name: "parse key fails",
|
|
certBytes: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: newTestCA(t).cert.Raw}),
|
|
keyBytes: pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: []byte("bogus")}),
|
|
wantPart: "parse ca key",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
ca := &CertificateAuthority{
|
|
certPath: filepath.Join(dir, caCertName),
|
|
keyPath: filepath.Join(dir, caKeyName),
|
|
leafCert: make(map[string]*tls.Certificate),
|
|
}
|
|
|
|
if err := os.WriteFile(ca.certPath, tt.certBytes, 0o600); err != nil {
|
|
t.Fatalf("write cert file error = %v", err)
|
|
}
|
|
if err := os.WriteFile(ca.keyPath, tt.keyBytes, 0o600); err != nil {
|
|
t.Fatalf("write key file error = %v", err)
|
|
}
|
|
|
|
err := ca.load()
|
|
if err == nil {
|
|
t.Fatal("load() error = nil, want non-nil")
|
|
}
|
|
if !strings.Contains(err.Error(), tt.wantPart) {
|
|
t.Fatalf("load() error = %q, want contains %q", err.Error(), tt.wantPart)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCertificateAuthorityCreate_ErrorWhenWritePathInvalid(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ca := &CertificateAuthority{
|
|
certPath: filepath.Join("/nope", "missing", "ca-cert.pem"),
|
|
keyPath: filepath.Join("/nope", "missing", "ca-key.pem"),
|
|
leafCert: make(map[string]*tls.Certificate),
|
|
}
|
|
|
|
err := ca.create()
|
|
if err == nil {
|
|
t.Fatal("create() error = nil, want non-nil")
|
|
}
|
|
}
|
|
|
|
func TestCertificateAuthorityCreate_WriteErrorPaths(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("write ca cert wraps error", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
badCertPath := filepath.Join(dir, "cert-as-dir")
|
|
if err := os.MkdirAll(badCertPath, 0o700); err != nil {
|
|
t.Fatalf("MkdirAll(cert dir) error = %v", err)
|
|
}
|
|
|
|
ca := &CertificateAuthority{
|
|
certPath: badCertPath,
|
|
keyPath: filepath.Join(dir, "ca-key.pem"),
|
|
leafCert: make(map[string]*tls.Certificate),
|
|
}
|
|
|
|
err := ca.create()
|
|
if err == nil {
|
|
t.Fatal("create() error = nil, want non-nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "write ca cert") {
|
|
t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca cert")
|
|
}
|
|
})
|
|
|
|
t.Run("write ca key wraps error", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
badKeyPath := filepath.Join(dir, "key-as-dir")
|
|
if err := os.MkdirAll(badKeyPath, 0o700); err != nil {
|
|
t.Fatalf("MkdirAll(key dir) error = %v", err)
|
|
}
|
|
|
|
ca := &CertificateAuthority{
|
|
certPath: filepath.Join(dir, "ca-cert.pem"),
|
|
keyPath: badKeyPath,
|
|
leafCert: make(map[string]*tls.Certificate),
|
|
}
|
|
|
|
err := ca.create()
|
|
if err == nil {
|
|
t.Fatal("create() error = nil, want non-nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "write ca key") {
|
|
t.Fatalf("create() error = %q, want contains %q", err.Error(), "write ca key")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestCertificateAuthorityLoad_ReadErrorPaths(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("read cert failure", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
ca := &CertificateAuthority{
|
|
certPath: filepath.Join(dir, "missing-cert.pem"),
|
|
keyPath: filepath.Join(dir, "missing-key.pem"),
|
|
leafCert: make(map[string]*tls.Certificate),
|
|
}
|
|
|
|
err := ca.load()
|
|
if err == nil {
|
|
t.Fatal("load() error = nil, want non-nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "read ca cert") {
|
|
t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca cert")
|
|
}
|
|
})
|
|
|
|
t.Run("read key failure", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
goodCA := newTestCA(t)
|
|
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: goodCA.cert.Raw})
|
|
|
|
certPath := filepath.Join(dir, "cert.pem")
|
|
if err := os.WriteFile(certPath, certPEM, 0o600); err != nil {
|
|
t.Fatalf("WriteFile(cert) error = %v", err)
|
|
}
|
|
|
|
ca := &CertificateAuthority{
|
|
certPath: certPath,
|
|
keyPath: filepath.Join(dir, "missing-key.pem"),
|
|
leafCert: make(map[string]*tls.Certificate),
|
|
}
|
|
|
|
err := ca.load()
|
|
if err == nil {
|
|
t.Fatal("load() error = nil, want non-nil")
|
|
}
|
|
if !strings.Contains(err.Error(), "read ca key") {
|
|
t.Fatalf("load() error = %q, want contains %q", err.Error(), "read ca key")
|
|
}
|
|
})
|
|
}
|