However, the enum tree is not supported since they do not support them. But other than that, mysql and maria DB seem to both be supported.
654 lines
17 KiB
Go
654 lines
17 KiB
Go
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
|
//
|
|
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package mysql
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"math/big"
|
|
"net"
|
|
"net/url"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?")
|
|
errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)")
|
|
errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name")
|
|
errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
|
|
)
|
|
|
|
// Config is a configuration parsed from a DSN string.
|
|
// If a new Config is created instead of being parsed from a DSN string,
|
|
// the NewConfig function should be used, which sets default values.
|
|
type Config struct {
|
|
// non boolean fields
|
|
|
|
User string // Username
|
|
Passwd string // Password (requires User)
|
|
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
|
|
Addr string // Address (default: "127.0.0.1:3306" for "tcp" and "/tmp/mysql.sock" for "unix")
|
|
DBName string // Database name
|
|
Params map[string]string // Connection parameters
|
|
ConnectionAttributes string // Connection Attributes, comma-delimited string of user-defined "key:value" pairs
|
|
Collation string // Connection collation
|
|
Loc *time.Location // Location for time.Time values
|
|
MaxAllowedPacket int // Max packet size allowed
|
|
ServerPubKey string // Server public key name
|
|
TLSConfig string // TLS configuration name
|
|
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
|
|
Timeout time.Duration // Dial timeout
|
|
ReadTimeout time.Duration // I/O read timeout
|
|
WriteTimeout time.Duration // I/O write timeout
|
|
Logger Logger // Logger
|
|
|
|
// boolean fields
|
|
|
|
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
|
|
AllowCleartextPasswords bool // Allows the cleartext client side plugin
|
|
AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS
|
|
AllowNativePasswords bool // Allows the native password authentication method
|
|
AllowOldPasswords bool // Allows the old insecure password method
|
|
CheckConnLiveness bool // Check connections for liveness before using them
|
|
ClientFoundRows bool // Return number of matching rows instead of rows changed
|
|
ColumnsWithAlias bool // Prepend table alias to column names
|
|
InterpolateParams bool // Interpolate placeholders into query string
|
|
MultiStatements bool // Allow multiple statements in one query
|
|
ParseTime bool // Parse time values to time.Time
|
|
RejectReadOnly bool // Reject read-only connections
|
|
|
|
// unexported fields. new options should be come here
|
|
|
|
beforeConnect func(context.Context, *Config) error // Invoked before a connection is established
|
|
pubKey *rsa.PublicKey // Server public key
|
|
timeTruncate time.Duration // Truncate time.Time values to the specified duration
|
|
}
|
|
|
|
// Functional Options Pattern
|
|
// https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis
|
|
type Option func(*Config) error
|
|
|
|
// NewConfig creates a new Config and sets default values.
|
|
func NewConfig() *Config {
|
|
cfg := &Config{
|
|
Loc: time.UTC,
|
|
MaxAllowedPacket: defaultMaxAllowedPacket,
|
|
Logger: defaultLogger,
|
|
AllowNativePasswords: true,
|
|
CheckConnLiveness: true,
|
|
}
|
|
|
|
return cfg
|
|
}
|
|
|
|
// Apply applies the given options to the Config object.
|
|
func (c *Config) Apply(opts ...Option) error {
|
|
for _, opt := range opts {
|
|
err := opt(c)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// TimeTruncate sets the time duration to truncate time.Time values in
|
|
// query parameters.
|
|
func TimeTruncate(d time.Duration) Option {
|
|
return func(cfg *Config) error {
|
|
cfg.timeTruncate = d
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// BeforeConnect sets the function to be invoked before a connection is established.
|
|
func BeforeConnect(fn func(context.Context, *Config) error) Option {
|
|
return func(cfg *Config) error {
|
|
cfg.beforeConnect = fn
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (cfg *Config) Clone() *Config {
|
|
cp := *cfg
|
|
if cp.TLS != nil {
|
|
cp.TLS = cfg.TLS.Clone()
|
|
}
|
|
if len(cp.Params) > 0 {
|
|
cp.Params = make(map[string]string, len(cfg.Params))
|
|
for k, v := range cfg.Params {
|
|
cp.Params[k] = v
|
|
}
|
|
}
|
|
if cfg.pubKey != nil {
|
|
cp.pubKey = &rsa.PublicKey{
|
|
N: new(big.Int).Set(cfg.pubKey.N),
|
|
E: cfg.pubKey.E,
|
|
}
|
|
}
|
|
return &cp
|
|
}
|
|
|
|
func (cfg *Config) normalize() error {
|
|
if cfg.InterpolateParams && cfg.Collation != "" && unsafeCollations[cfg.Collation] {
|
|
return errInvalidDSNUnsafeCollation
|
|
}
|
|
|
|
// Set default network if empty
|
|
if cfg.Net == "" {
|
|
cfg.Net = "tcp"
|
|
}
|
|
|
|
// Set default address if empty
|
|
if cfg.Addr == "" {
|
|
switch cfg.Net {
|
|
case "tcp":
|
|
cfg.Addr = "127.0.0.1:3306"
|
|
case "unix":
|
|
cfg.Addr = "/tmp/mysql.sock"
|
|
default:
|
|
return errors.New("default addr for network '" + cfg.Net + "' unknown")
|
|
}
|
|
} else if cfg.Net == "tcp" {
|
|
cfg.Addr = ensureHavePort(cfg.Addr)
|
|
}
|
|
|
|
if cfg.TLS == nil {
|
|
switch cfg.TLSConfig {
|
|
case "false", "":
|
|
// don't set anything
|
|
case "true":
|
|
cfg.TLS = &tls.Config{}
|
|
case "skip-verify":
|
|
cfg.TLS = &tls.Config{InsecureSkipVerify: true}
|
|
case "preferred":
|
|
cfg.TLS = &tls.Config{InsecureSkipVerify: true}
|
|
cfg.AllowFallbackToPlaintext = true
|
|
default:
|
|
cfg.TLS = getTLSConfigClone(cfg.TLSConfig)
|
|
if cfg.TLS == nil {
|
|
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
|
|
}
|
|
}
|
|
}
|
|
|
|
if cfg.TLS != nil && cfg.TLS.ServerName == "" && !cfg.TLS.InsecureSkipVerify {
|
|
host, _, err := net.SplitHostPort(cfg.Addr)
|
|
if err == nil {
|
|
cfg.TLS.ServerName = host
|
|
}
|
|
}
|
|
|
|
if cfg.ServerPubKey != "" {
|
|
cfg.pubKey = getServerPubKey(cfg.ServerPubKey)
|
|
if cfg.pubKey == nil {
|
|
return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey)
|
|
}
|
|
}
|
|
|
|
if cfg.Logger == nil {
|
|
cfg.Logger = defaultLogger
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) {
|
|
buf.Grow(1 + len(name) + 1 + len(value))
|
|
if !*hasParam {
|
|
*hasParam = true
|
|
buf.WriteByte('?')
|
|
} else {
|
|
buf.WriteByte('&')
|
|
}
|
|
buf.WriteString(name)
|
|
buf.WriteByte('=')
|
|
buf.WriteString(value)
|
|
}
|
|
|
|
// FormatDSN formats the given Config into a DSN string which can be passed to
|
|
// the driver.
|
|
//
|
|
// Note: use [NewConnector] and [database/sql.OpenDB] to open a connection from a [*Config].
|
|
func (cfg *Config) FormatDSN() string {
|
|
var buf bytes.Buffer
|
|
|
|
// [username[:password]@]
|
|
if len(cfg.User) > 0 {
|
|
buf.WriteString(cfg.User)
|
|
if len(cfg.Passwd) > 0 {
|
|
buf.WriteByte(':')
|
|
buf.WriteString(cfg.Passwd)
|
|
}
|
|
buf.WriteByte('@')
|
|
}
|
|
|
|
// [protocol[(address)]]
|
|
if len(cfg.Net) > 0 {
|
|
buf.WriteString(cfg.Net)
|
|
if len(cfg.Addr) > 0 {
|
|
buf.WriteByte('(')
|
|
buf.WriteString(cfg.Addr)
|
|
buf.WriteByte(')')
|
|
}
|
|
}
|
|
|
|
// /dbname
|
|
buf.WriteByte('/')
|
|
buf.WriteString(url.PathEscape(cfg.DBName))
|
|
|
|
// [?param1=value1&...¶mN=valueN]
|
|
hasParam := false
|
|
|
|
if cfg.AllowAllFiles {
|
|
hasParam = true
|
|
buf.WriteString("?allowAllFiles=true")
|
|
}
|
|
|
|
if cfg.AllowCleartextPasswords {
|
|
writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true")
|
|
}
|
|
|
|
if cfg.AllowFallbackToPlaintext {
|
|
writeDSNParam(&buf, &hasParam, "allowFallbackToPlaintext", "true")
|
|
}
|
|
|
|
if !cfg.AllowNativePasswords {
|
|
writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false")
|
|
}
|
|
|
|
if cfg.AllowOldPasswords {
|
|
writeDSNParam(&buf, &hasParam, "allowOldPasswords", "true")
|
|
}
|
|
|
|
if !cfg.CheckConnLiveness {
|
|
writeDSNParam(&buf, &hasParam, "checkConnLiveness", "false")
|
|
}
|
|
|
|
if cfg.ClientFoundRows {
|
|
writeDSNParam(&buf, &hasParam, "clientFoundRows", "true")
|
|
}
|
|
|
|
if col := cfg.Collation; col != "" {
|
|
writeDSNParam(&buf, &hasParam, "collation", col)
|
|
}
|
|
|
|
if cfg.ColumnsWithAlias {
|
|
writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true")
|
|
}
|
|
|
|
if cfg.InterpolateParams {
|
|
writeDSNParam(&buf, &hasParam, "interpolateParams", "true")
|
|
}
|
|
|
|
if cfg.Loc != time.UTC && cfg.Loc != nil {
|
|
writeDSNParam(&buf, &hasParam, "loc", url.QueryEscape(cfg.Loc.String()))
|
|
}
|
|
|
|
if cfg.MultiStatements {
|
|
writeDSNParam(&buf, &hasParam, "multiStatements", "true")
|
|
}
|
|
|
|
if cfg.ParseTime {
|
|
writeDSNParam(&buf, &hasParam, "parseTime", "true")
|
|
}
|
|
|
|
if cfg.timeTruncate > 0 {
|
|
writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.timeTruncate.String())
|
|
}
|
|
|
|
if cfg.ReadTimeout > 0 {
|
|
writeDSNParam(&buf, &hasParam, "readTimeout", cfg.ReadTimeout.String())
|
|
}
|
|
|
|
if cfg.RejectReadOnly {
|
|
writeDSNParam(&buf, &hasParam, "rejectReadOnly", "true")
|
|
}
|
|
|
|
if len(cfg.ServerPubKey) > 0 {
|
|
writeDSNParam(&buf, &hasParam, "serverPubKey", url.QueryEscape(cfg.ServerPubKey))
|
|
}
|
|
|
|
if cfg.Timeout > 0 {
|
|
writeDSNParam(&buf, &hasParam, "timeout", cfg.Timeout.String())
|
|
}
|
|
|
|
if len(cfg.TLSConfig) > 0 {
|
|
writeDSNParam(&buf, &hasParam, "tls", url.QueryEscape(cfg.TLSConfig))
|
|
}
|
|
|
|
if cfg.WriteTimeout > 0 {
|
|
writeDSNParam(&buf, &hasParam, "writeTimeout", cfg.WriteTimeout.String())
|
|
}
|
|
|
|
if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {
|
|
writeDSNParam(&buf, &hasParam, "maxAllowedPacket", strconv.Itoa(cfg.MaxAllowedPacket))
|
|
}
|
|
|
|
// other params
|
|
if cfg.Params != nil {
|
|
var params []string
|
|
for param := range cfg.Params {
|
|
params = append(params, param)
|
|
}
|
|
sort.Strings(params)
|
|
for _, param := range params {
|
|
writeDSNParam(&buf, &hasParam, param, url.QueryEscape(cfg.Params[param]))
|
|
}
|
|
}
|
|
|
|
return buf.String()
|
|
}
|
|
|
|
// ParseDSN parses the DSN string to a Config
|
|
func ParseDSN(dsn string) (cfg *Config, err error) {
|
|
// New config with some default values
|
|
cfg = NewConfig()
|
|
|
|
// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
|
|
// Find the last '/' (since the password or the net addr might contain a '/')
|
|
foundSlash := false
|
|
for i := len(dsn) - 1; i >= 0; i-- {
|
|
if dsn[i] == '/' {
|
|
foundSlash = true
|
|
var j, k int
|
|
|
|
// left part is empty if i <= 0
|
|
if i > 0 {
|
|
// [username[:password]@][protocol[(address)]]
|
|
// Find the last '@' in dsn[:i]
|
|
for j = i; j >= 0; j-- {
|
|
if dsn[j] == '@' {
|
|
// username[:password]
|
|
// Find the first ':' in dsn[:j]
|
|
for k = 0; k < j; k++ {
|
|
if dsn[k] == ':' {
|
|
cfg.Passwd = dsn[k+1 : j]
|
|
break
|
|
}
|
|
}
|
|
cfg.User = dsn[:k]
|
|
|
|
break
|
|
}
|
|
}
|
|
|
|
// [protocol[(address)]]
|
|
// Find the first '(' in dsn[j+1:i]
|
|
for k = j + 1; k < i; k++ {
|
|
if dsn[k] == '(' {
|
|
// dsn[i-1] must be == ')' if an address is specified
|
|
if dsn[i-1] != ')' {
|
|
if strings.ContainsRune(dsn[k+1:i], ')') {
|
|
return nil, errInvalidDSNUnescaped
|
|
}
|
|
return nil, errInvalidDSNAddr
|
|
}
|
|
cfg.Addr = dsn[k+1 : i-1]
|
|
break
|
|
}
|
|
}
|
|
cfg.Net = dsn[j+1 : k]
|
|
}
|
|
|
|
// dbname[?param1=value1&...¶mN=valueN]
|
|
// Find the first '?' in dsn[i+1:]
|
|
for j = i + 1; j < len(dsn); j++ {
|
|
if dsn[j] == '?' {
|
|
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
|
|
return
|
|
}
|
|
break
|
|
}
|
|
}
|
|
|
|
dbname := dsn[i+1 : j]
|
|
if cfg.DBName, err = url.PathUnescape(dbname); err != nil {
|
|
return nil, fmt.Errorf("invalid dbname %q: %w", dbname, err)
|
|
}
|
|
|
|
break
|
|
}
|
|
}
|
|
|
|
if !foundSlash && len(dsn) > 0 {
|
|
return nil, errInvalidDSNNoSlash
|
|
}
|
|
|
|
if err = cfg.normalize(); err != nil {
|
|
return nil, err
|
|
}
|
|
return
|
|
}
|
|
|
|
// parseDSNParams parses the DSN "query string"
|
|
// Values must be url.QueryEscape'ed
|
|
func parseDSNParams(cfg *Config, params string) (err error) {
|
|
for _, v := range strings.Split(params, "&") {
|
|
key, value, found := strings.Cut(v, "=")
|
|
if !found {
|
|
continue
|
|
}
|
|
|
|
// cfg params
|
|
switch key {
|
|
// Disable INFILE allowlist / enable all files
|
|
case "allowAllFiles":
|
|
var isBool bool
|
|
cfg.AllowAllFiles, isBool = readBool(value)
|
|
if !isBool {
|
|
return errors.New("invalid bool value: " + value)
|
|
}
|
|
|
|
// Use cleartext authentication mode (MySQL 5.5.10+)
|
|
case "allowCleartextPasswords":
|
|
var isBool bool
|
|
cfg.AllowCleartextPasswords, isBool = readBool(value)
|
|
if !isBool {
|
|
return errors.New("invalid bool value: " + value)
|
|
}
|
|
|
|
// Allow fallback to unencrypted connection if server does not support TLS
|
|
case "allowFallbackToPlaintext":
|
|
var isBool bool
|
|
cfg.AllowFallbackToPlaintext, isBool = readBool(value)
|
|
if !isBool {
|
|
return errors.New("invalid bool value: " + value)
|
|
}
|
|
|
|
// Use native password authentication
|
|
case "allowNativePasswords":
|
|
var isBool bool
|
|
cfg.AllowNativePasswords, isBool = readBool(value)
|
|
if !isBool {
|
|
return errors.New("invalid bool value: " + value)
|
|
}
|
|
|
|
// Use old authentication mode (pre MySQL 4.1)
|
|
case "allowOldPasswords":
|
|
var isBool bool
|
|
cfg.AllowOldPasswords, isBool = readBool(value)
|
|
if !isBool {
|
|
return errors.New("invalid bool value: " + value)
|
|
}
|
|
|
|
// Check connections for Liveness before using them
|
|
case "checkConnLiveness":
|
|
var isBool bool
|
|
cfg.CheckConnLiveness, isBool = readBool(value)
|
|
if !isBool {
|
|
return errors.New("invalid bool value: " + value)
|
|
}
|
|
|
|
// Switch "rowsAffected" mode
|
|
case "clientFoundRows":
|
|
var isBool bool
|
|
cfg.ClientFoundRows, isBool = readBool(value)
|
|
if !isBool {
|
|
return errors.New("invalid bool value: " + value)
|
|
}
|
|
|
|
// Collation
|
|
case "collation":
|
|
cfg.Collation = value
|
|
|
|
case "columnsWithAlias":
|
|
var isBool bool
|
|
cfg.ColumnsWithAlias, isBool = readBool(value)
|
|
if !isBool {
|
|
return errors.New("invalid bool value: " + value)
|
|
}
|
|
|
|
// Compression
|
|
case "compress":
|
|
return errors.New("compression not implemented yet")
|
|
|
|
// Enable client side placeholder substitution
|
|
case "interpolateParams":
|
|
var isBool bool
|
|
cfg.InterpolateParams, isBool = readBool(value)
|
|
if !isBool {
|
|
return errors.New("invalid bool value: " + value)
|
|
}
|
|
|
|
// Time Location
|
|
case "loc":
|
|
if value, err = url.QueryUnescape(value); err != nil {
|
|
return
|
|
}
|
|
cfg.Loc, err = time.LoadLocation(value)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// multiple statements in one query
|
|
case "multiStatements":
|
|
var isBool bool
|
|
cfg.MultiStatements, isBool = readBool(value)
|
|
if !isBool {
|
|
return errors.New("invalid bool value: " + value)
|
|
}
|
|
|
|
// time.Time parsing
|
|
case "parseTime":
|
|
var isBool bool
|
|
cfg.ParseTime, isBool = readBool(value)
|
|
if !isBool {
|
|
return errors.New("invalid bool value: " + value)
|
|
}
|
|
|
|
// time.Time truncation
|
|
case "timeTruncate":
|
|
cfg.timeTruncate, err = time.ParseDuration(value)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid timeTruncate value: %v, error: %w", value, err)
|
|
}
|
|
|
|
// I/O read Timeout
|
|
case "readTimeout":
|
|
cfg.ReadTimeout, err = time.ParseDuration(value)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Reject read-only connections
|
|
case "rejectReadOnly":
|
|
var isBool bool
|
|
cfg.RejectReadOnly, isBool = readBool(value)
|
|
if !isBool {
|
|
return errors.New("invalid bool value: " + value)
|
|
}
|
|
|
|
// Server public key
|
|
case "serverPubKey":
|
|
name, err := url.QueryUnescape(value)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid value for server pub key name: %v", err)
|
|
}
|
|
cfg.ServerPubKey = name
|
|
|
|
// Strict mode
|
|
case "strict":
|
|
panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
|
|
|
|
// Dial Timeout
|
|
case "timeout":
|
|
cfg.Timeout, err = time.ParseDuration(value)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// TLS-Encryption
|
|
case "tls":
|
|
boolValue, isBool := readBool(value)
|
|
if isBool {
|
|
if boolValue {
|
|
cfg.TLSConfig = "true"
|
|
} else {
|
|
cfg.TLSConfig = "false"
|
|
}
|
|
} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
|
|
cfg.TLSConfig = vl
|
|
} else {
|
|
name, err := url.QueryUnescape(value)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid value for TLS config name: %v", err)
|
|
}
|
|
cfg.TLSConfig = name
|
|
}
|
|
|
|
// I/O write Timeout
|
|
case "writeTimeout":
|
|
cfg.WriteTimeout, err = time.ParseDuration(value)
|
|
if err != nil {
|
|
return
|
|
}
|
|
case "maxAllowedPacket":
|
|
cfg.MaxAllowedPacket, err = strconv.Atoi(value)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Connection attributes
|
|
case "connectionAttributes":
|
|
connectionAttributes, err := url.QueryUnescape(value)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid connectionAttributes value: %v", err)
|
|
}
|
|
cfg.ConnectionAttributes = connectionAttributes
|
|
|
|
default:
|
|
// lazy init
|
|
if cfg.Params == nil {
|
|
cfg.Params = make(map[string]string)
|
|
}
|
|
|
|
if cfg.Params[key], err = url.QueryUnescape(value); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func ensureHavePort(addr string) string {
|
|
if _, _, err := net.SplitHostPort(addr); err != nil {
|
|
return net.JoinHostPort(addr, "3306")
|
|
}
|
|
return addr
|
|
}
|