However, the enum tree is not supported since they do not support them. But other than that, mysql and maria DB seem to both be supported.
689 lines
15 KiB
Go
689 lines
15 KiB
Go
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
|
//
|
|
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
|
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"encoding/json"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type mysqlConn struct {
|
|
buf buffer
|
|
netConn net.Conn
|
|
rawConn net.Conn // underlying connection when netConn is TLS connection.
|
|
result mysqlResult // managed by clearResult() and handleOkPacket().
|
|
cfg *Config
|
|
connector *connector
|
|
maxAllowedPacket int
|
|
maxWriteSize int
|
|
writeTimeout time.Duration
|
|
flags clientFlag
|
|
status statusFlag
|
|
sequence uint8
|
|
parseTime bool
|
|
|
|
// for context support (Go 1.8+)
|
|
watching bool
|
|
watcher chan<- context.Context
|
|
closech chan struct{}
|
|
finished chan<- struct{}
|
|
canceled atomicError // set non-nil if conn is canceled
|
|
closed atomicBool // set when conn is closed, before closech is closed
|
|
}
|
|
|
|
// Helper function to call per-connection logger.
|
|
func (mc *mysqlConn) log(v ...any) {
|
|
mc.cfg.Logger.Print(v...)
|
|
}
|
|
|
|
// Handles parameters set in DSN after the connection is established
|
|
func (mc *mysqlConn) handleParams() (err error) {
|
|
var cmdSet strings.Builder
|
|
|
|
for param, val := range mc.cfg.Params {
|
|
switch param {
|
|
// Charset: character_set_connection, character_set_client, character_set_results
|
|
case "charset":
|
|
charsets := strings.Split(val, ",")
|
|
for _, cs := range charsets {
|
|
// ignore errors here - a charset may not exist
|
|
if mc.cfg.Collation != "" {
|
|
err = mc.exec("SET NAMES " + cs + " COLLATE " + mc.cfg.Collation)
|
|
} else {
|
|
err = mc.exec("SET NAMES " + cs)
|
|
}
|
|
if err == nil {
|
|
break
|
|
}
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
// Other system vars accumulated in a single SET command
|
|
default:
|
|
if cmdSet.Len() == 0 {
|
|
// Heuristic: 29 chars for each other key=value to reduce reallocations
|
|
cmdSet.Grow(4 + len(param) + 3 + len(val) + 30*(len(mc.cfg.Params)-1))
|
|
cmdSet.WriteString("SET ")
|
|
} else {
|
|
cmdSet.WriteString(", ")
|
|
}
|
|
cmdSet.WriteString(param)
|
|
cmdSet.WriteString(" = ")
|
|
cmdSet.WriteString(val)
|
|
}
|
|
}
|
|
|
|
if cmdSet.Len() > 0 {
|
|
err = mc.exec(cmdSet.String())
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (mc *mysqlConn) markBadConn(err error) error {
|
|
if mc == nil {
|
|
return err
|
|
}
|
|
if err != errBadConnNoWrite {
|
|
return err
|
|
}
|
|
return driver.ErrBadConn
|
|
}
|
|
|
|
func (mc *mysqlConn) Begin() (driver.Tx, error) {
|
|
return mc.begin(false)
|
|
}
|
|
|
|
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
|
|
if mc.closed.Load() {
|
|
mc.log(ErrInvalidConn)
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
var q string
|
|
if readOnly {
|
|
q = "START TRANSACTION READ ONLY"
|
|
} else {
|
|
q = "START TRANSACTION"
|
|
}
|
|
err := mc.exec(q)
|
|
if err == nil {
|
|
return &mysqlTx{mc}, err
|
|
}
|
|
return nil, mc.markBadConn(err)
|
|
}
|
|
|
|
func (mc *mysqlConn) Close() (err error) {
|
|
// Makes Close idempotent
|
|
if !mc.closed.Load() {
|
|
err = mc.writeCommandPacket(comQuit)
|
|
}
|
|
|
|
mc.cleanup()
|
|
mc.clearResult()
|
|
return
|
|
}
|
|
|
|
// Closes the network connection and unsets internal variables. Do not call this
|
|
// function after successfully authentication, call Close instead. This function
|
|
// is called before auth or on auth failure because MySQL will have already
|
|
// closed the network connection.
|
|
func (mc *mysqlConn) cleanup() {
|
|
if mc.closed.Swap(true) {
|
|
return
|
|
}
|
|
|
|
// Makes cleanup idempotent
|
|
close(mc.closech)
|
|
conn := mc.rawConn
|
|
if conn == nil {
|
|
return
|
|
}
|
|
if err := conn.Close(); err != nil {
|
|
mc.log(err)
|
|
}
|
|
// This function can be called from multiple goroutines.
|
|
// So we can not mc.clearResult() here.
|
|
// Caller should do it if they are in safe goroutine.
|
|
}
|
|
|
|
func (mc *mysqlConn) error() error {
|
|
if mc.closed.Load() {
|
|
if err := mc.canceled.Value(); err != nil {
|
|
return err
|
|
}
|
|
return ErrInvalidConn
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
|
|
if mc.closed.Load() {
|
|
mc.log(ErrInvalidConn)
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
// Send command
|
|
err := mc.writeCommandPacketStr(comStmtPrepare, query)
|
|
if err != nil {
|
|
// STMT_PREPARE is safe to retry. So we can return ErrBadConn here.
|
|
mc.log(err)
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
|
|
stmt := &mysqlStmt{
|
|
mc: mc,
|
|
}
|
|
|
|
// Read Result
|
|
columnCount, err := stmt.readPrepareResultPacket()
|
|
if err == nil {
|
|
if stmt.paramCount > 0 {
|
|
if err = mc.readUntilEOF(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if columnCount > 0 {
|
|
err = mc.readUntilEOF()
|
|
}
|
|
}
|
|
|
|
return stmt, err
|
|
}
|
|
|
|
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
|
|
// Number of ? should be same to len(args)
|
|
if strings.Count(query, "?") != len(args) {
|
|
return "", driver.ErrSkip
|
|
}
|
|
|
|
buf, err := mc.buf.takeCompleteBuffer()
|
|
if err != nil {
|
|
// can not take the buffer. Something must be wrong with the connection
|
|
mc.log(err)
|
|
return "", ErrInvalidConn
|
|
}
|
|
buf = buf[:0]
|
|
argPos := 0
|
|
|
|
for i := 0; i < len(query); i++ {
|
|
q := strings.IndexByte(query[i:], '?')
|
|
if q == -1 {
|
|
buf = append(buf, query[i:]...)
|
|
break
|
|
}
|
|
buf = append(buf, query[i:i+q]...)
|
|
i += q
|
|
|
|
arg := args[argPos]
|
|
argPos++
|
|
|
|
if arg == nil {
|
|
buf = append(buf, "NULL"...)
|
|
continue
|
|
}
|
|
|
|
switch v := arg.(type) {
|
|
case int64:
|
|
buf = strconv.AppendInt(buf, v, 10)
|
|
case uint64:
|
|
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
|
|
buf = strconv.AppendUint(buf, v, 10)
|
|
case float64:
|
|
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
|
|
case bool:
|
|
if v {
|
|
buf = append(buf, '1')
|
|
} else {
|
|
buf = append(buf, '0')
|
|
}
|
|
case time.Time:
|
|
if v.IsZero() {
|
|
buf = append(buf, "'0000-00-00'"...)
|
|
} else {
|
|
buf = append(buf, '\'')
|
|
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
buf = append(buf, '\'')
|
|
}
|
|
case json.RawMessage:
|
|
buf = append(buf, '\'')
|
|
if mc.status&statusNoBackslashEscapes == 0 {
|
|
buf = escapeBytesBackslash(buf, v)
|
|
} else {
|
|
buf = escapeBytesQuotes(buf, v)
|
|
}
|
|
buf = append(buf, '\'')
|
|
case []byte:
|
|
if v == nil {
|
|
buf = append(buf, "NULL"...)
|
|
} else {
|
|
buf = append(buf, "_binary'"...)
|
|
if mc.status&statusNoBackslashEscapes == 0 {
|
|
buf = escapeBytesBackslash(buf, v)
|
|
} else {
|
|
buf = escapeBytesQuotes(buf, v)
|
|
}
|
|
buf = append(buf, '\'')
|
|
}
|
|
case string:
|
|
buf = append(buf, '\'')
|
|
if mc.status&statusNoBackslashEscapes == 0 {
|
|
buf = escapeStringBackslash(buf, v)
|
|
} else {
|
|
buf = escapeStringQuotes(buf, v)
|
|
}
|
|
buf = append(buf, '\'')
|
|
default:
|
|
return "", driver.ErrSkip
|
|
}
|
|
|
|
if len(buf)+4 > mc.maxAllowedPacket {
|
|
return "", driver.ErrSkip
|
|
}
|
|
}
|
|
if argPos != len(args) {
|
|
return "", driver.ErrSkip
|
|
}
|
|
return string(buf), nil
|
|
}
|
|
|
|
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
|
if mc.closed.Load() {
|
|
mc.log(ErrInvalidConn)
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
if len(args) != 0 {
|
|
if !mc.cfg.InterpolateParams {
|
|
return nil, driver.ErrSkip
|
|
}
|
|
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
|
|
prepared, err := mc.interpolateParams(query, args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
query = prepared
|
|
}
|
|
|
|
err := mc.exec(query)
|
|
if err == nil {
|
|
copied := mc.result
|
|
return &copied, err
|
|
}
|
|
return nil, mc.markBadConn(err)
|
|
}
|
|
|
|
// Internal function to execute commands
|
|
func (mc *mysqlConn) exec(query string) error {
|
|
handleOk := mc.clearResult()
|
|
// Send command
|
|
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
|
|
return mc.markBadConn(err)
|
|
}
|
|
|
|
// Read Result
|
|
resLen, err := handleOk.readResultSetHeaderPacket()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if resLen > 0 {
|
|
// columns
|
|
if err := mc.readUntilEOF(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// rows
|
|
if err := mc.readUntilEOF(); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return handleOk.discardResults()
|
|
}
|
|
|
|
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
|
return mc.query(query, args)
|
|
}
|
|
|
|
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
|
|
handleOk := mc.clearResult()
|
|
|
|
if mc.closed.Load() {
|
|
mc.log(ErrInvalidConn)
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
if len(args) != 0 {
|
|
if !mc.cfg.InterpolateParams {
|
|
return nil, driver.ErrSkip
|
|
}
|
|
// try client-side prepare to reduce roundtrip
|
|
prepared, err := mc.interpolateParams(query, args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
query = prepared
|
|
}
|
|
// Send command
|
|
err := mc.writeCommandPacketStr(comQuery, query)
|
|
if err == nil {
|
|
// Read Result
|
|
var resLen int
|
|
resLen, err = handleOk.readResultSetHeaderPacket()
|
|
if err == nil {
|
|
rows := new(textRows)
|
|
rows.mc = mc
|
|
|
|
if resLen == 0 {
|
|
rows.rs.done = true
|
|
|
|
switch err := rows.NextResultSet(); err {
|
|
case nil, io.EOF:
|
|
return rows, nil
|
|
default:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
// Columns
|
|
rows.rs.columns, err = mc.readColumns(resLen)
|
|
return rows, err
|
|
}
|
|
}
|
|
return nil, mc.markBadConn(err)
|
|
}
|
|
|
|
// Gets the value of the given MySQL System Variable
|
|
// The returned byte slice is only valid until the next read
|
|
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
|
|
// Send command
|
|
handleOk := mc.clearResult()
|
|
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Read Result
|
|
resLen, err := handleOk.readResultSetHeaderPacket()
|
|
if err == nil {
|
|
rows := new(textRows)
|
|
rows.mc = mc
|
|
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
|
|
|
|
if resLen > 0 {
|
|
// Columns
|
|
if err := mc.readUntilEOF(); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
dest := make([]driver.Value, resLen)
|
|
if err = rows.readRow(dest); err == nil {
|
|
return dest[0].([]byte), mc.readUntilEOF()
|
|
}
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// finish is called when the query has canceled.
|
|
func (mc *mysqlConn) cancel(err error) {
|
|
mc.canceled.Set(err)
|
|
mc.cleanup()
|
|
}
|
|
|
|
// finish is called when the query has succeeded.
|
|
func (mc *mysqlConn) finish() {
|
|
if !mc.watching || mc.finished == nil {
|
|
return
|
|
}
|
|
select {
|
|
case mc.finished <- struct{}{}:
|
|
mc.watching = false
|
|
case <-mc.closech:
|
|
}
|
|
}
|
|
|
|
// Ping implements driver.Pinger interface
|
|
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
|
|
if mc.closed.Load() {
|
|
mc.log(ErrInvalidConn)
|
|
return driver.ErrBadConn
|
|
}
|
|
|
|
if err = mc.watchCancel(ctx); err != nil {
|
|
return
|
|
}
|
|
defer mc.finish()
|
|
|
|
handleOk := mc.clearResult()
|
|
if err = mc.writeCommandPacket(comPing); err != nil {
|
|
return mc.markBadConn(err)
|
|
}
|
|
|
|
return handleOk.readResultOK()
|
|
}
|
|
|
|
// BeginTx implements driver.ConnBeginTx interface
|
|
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
|
if mc.closed.Load() {
|
|
return nil, driver.ErrBadConn
|
|
}
|
|
|
|
if err := mc.watchCancel(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
defer mc.finish()
|
|
|
|
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
|
|
level, err := mapIsolationLevel(opts.Isolation)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return mc.begin(opts.ReadOnly)
|
|
}
|
|
|
|
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
|
dargs, err := namedValueToValue(args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := mc.watchCancel(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rows, err := mc.query(query, dargs)
|
|
if err != nil {
|
|
mc.finish()
|
|
return nil, err
|
|
}
|
|
rows.finish = mc.finish
|
|
return rows, err
|
|
}
|
|
|
|
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
|
dargs, err := namedValueToValue(args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := mc.watchCancel(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
defer mc.finish()
|
|
|
|
return mc.Exec(query, dargs)
|
|
}
|
|
|
|
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
|
if err := mc.watchCancel(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
stmt, err := mc.Prepare(query)
|
|
mc.finish()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
select {
|
|
default:
|
|
case <-ctx.Done():
|
|
stmt.Close()
|
|
return nil, ctx.Err()
|
|
}
|
|
return stmt, nil
|
|
}
|
|
|
|
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
|
dargs, err := namedValueToValue(args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := stmt.mc.watchCancel(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rows, err := stmt.query(dargs)
|
|
if err != nil {
|
|
stmt.mc.finish()
|
|
return nil, err
|
|
}
|
|
rows.finish = stmt.mc.finish
|
|
return rows, err
|
|
}
|
|
|
|
func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
|
dargs, err := namedValueToValue(args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := stmt.mc.watchCancel(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
defer stmt.mc.finish()
|
|
|
|
return stmt.Exec(dargs)
|
|
}
|
|
|
|
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
|
|
if mc.watching {
|
|
// Reach here if canceled,
|
|
// so the connection is already invalid
|
|
mc.cleanup()
|
|
return nil
|
|
}
|
|
// When ctx is already cancelled, don't watch it.
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
// When ctx is not cancellable, don't watch it.
|
|
if ctx.Done() == nil {
|
|
return nil
|
|
}
|
|
// When watcher is not alive, can't watch it.
|
|
if mc.watcher == nil {
|
|
return nil
|
|
}
|
|
|
|
mc.watching = true
|
|
mc.watcher <- ctx
|
|
return nil
|
|
}
|
|
|
|
func (mc *mysqlConn) startWatcher() {
|
|
watcher := make(chan context.Context, 1)
|
|
mc.watcher = watcher
|
|
finished := make(chan struct{})
|
|
mc.finished = finished
|
|
go func() {
|
|
for {
|
|
var ctx context.Context
|
|
select {
|
|
case ctx = <-watcher:
|
|
case <-mc.closech:
|
|
return
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
mc.cancel(ctx.Err())
|
|
case <-finished:
|
|
case <-mc.closech:
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
|
|
nv.Value, err = converter{}.ConvertValue(nv.Value)
|
|
return
|
|
}
|
|
|
|
// ResetSession implements driver.SessionResetter.
|
|
// (From Go 1.10)
|
|
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
|
|
if mc.closed.Load() {
|
|
return driver.ErrBadConn
|
|
}
|
|
|
|
// Perform a stale connection check. We only perform this check for
|
|
// the first query on a connection that has been checked out of the
|
|
// connection pool: a fresh connection from the pool is more likely
|
|
// to be stale, and it has not performed any previous writes that
|
|
// could cause data corruption, so it's safe to return ErrBadConn
|
|
// if the check fails.
|
|
if mc.cfg.CheckConnLiveness {
|
|
conn := mc.netConn
|
|
if mc.rawConn != nil {
|
|
conn = mc.rawConn
|
|
}
|
|
var err error
|
|
if mc.cfg.ReadTimeout != 0 {
|
|
err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout))
|
|
}
|
|
if err == nil {
|
|
err = connCheck(conn)
|
|
}
|
|
if err != nil {
|
|
mc.log("closing bad idle connection: ", err)
|
|
return driver.ErrBadConn
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// IsValid implements driver.Validator interface
|
|
// (From Go 1.15)
|
|
func (mc *mysqlConn) IsValid() bool {
|
|
return !mc.closed.Load()
|
|
}
|