github.com/dolthub/go-mysql-server@v0.18.0/server/golden/proxy.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package golden 16 17 import ( 18 dsql "database/sql" 19 "fmt" 20 "math" 21 "reflect" 22 "strings" 23 "time" 24 25 "github.com/dolthub/vitess/go/mysql" 26 "github.com/dolthub/vitess/go/sqltypes" 27 querypb "github.com/dolthub/vitess/go/vt/proto/query" 28 "github.com/dolthub/vitess/go/vt/sqlparser" 29 mysql2 "github.com/go-sql-driver/mysql" 30 "github.com/gocraft/dbr/v2" 31 "github.com/sirupsen/logrus" 32 33 "github.com/dolthub/go-mysql-server/sql" 34 "github.com/dolthub/go-mysql-server/sql/planbuilder" 35 ) 36 37 type MySqlProxy struct { 38 ctx *sql.Context 39 connStr string 40 logger *logrus.Logger 41 conns map[uint32]proxyConn 42 } 43 44 func (h MySqlProxy) ParserOptionsForConnection(_ *mysql.Conn) (sqlparser.ParserOptions, error) { 45 return sqlparser.ParserOptions{}, nil 46 } 47 48 type proxyConn struct { 49 *dbr.Connection 50 *logrus.Entry 51 } 52 53 // NewMySqlProxyHandler creates a new MySqlProxy. 54 func NewMySqlProxyHandler(logger *logrus.Logger, connStr string) (MySqlProxy, error) { 55 // ensure parseTime=true 56 cfg, err := mysql2.ParseDSN(connStr) 57 if err != nil { 58 return MySqlProxy{}, err 59 } 60 cfg.ParseTime = true 61 connStr = cfg.FormatDSN() 62 63 conn, err := newConn(connStr, 0, logger) 64 if err != nil { 65 return MySqlProxy{}, err 66 } 67 defer func() { _ = conn.Close() }() 68 69 if err = conn.Ping(); err != nil { 70 return MySqlProxy{}, err 71 } 72 73 return MySqlProxy{ 74 ctx: sql.NewEmptyContext(), 75 connStr: connStr, 76 logger: logger, 77 conns: make(map[uint32]proxyConn), 78 }, nil 79 } 80 81 var _ mysql.Handler = MySqlProxy{} 82 83 func newConn(connStr string, connId uint32, lgr *logrus.Logger) (conn proxyConn, err error) { 84 l := logrus.NewEntry(lgr).WithField("dsn", connStr).WithField(sql.ConnectionIdLogField, connId) 85 var c *dbr.Connection 86 for d := 100.0; d < 10000.0; d *= 1.6 { 87 l.Debugf("Attempting connection to MySQL") 88 if c, err = dbr.Open("mysql", connStr, nil); err == nil { 89 if err = c.Ping(); err == nil { 90 break 91 } 92 } 93 time.Sleep(time.Duration(d) * time.Millisecond) 94 } 95 if err != nil { 96 l.Debugf("Failed to establish connection %d", connId) 97 return proxyConn{}, err 98 } 99 l.Debugf("Succesfully established connection") 100 return proxyConn{Connection: c, Entry: l}, nil 101 } 102 103 // NewConnection implements mysql.Handler. 104 func (h MySqlProxy) NewConnection(c *mysql.Conn) { 105 conn, err := newConn(h.connStr, c.ConnectionID, h.logger) 106 if err == nil { 107 h.conns[c.ConnectionID] = conn 108 } 109 } 110 111 func (h MySqlProxy) getConn(connId uint32) (conn proxyConn, err error) { 112 var ok bool 113 conn, ok = h.conns[connId] 114 if ok { 115 return conn, nil 116 } else { 117 conn, err = newConn(h.connStr, connId, h.logger) 118 if err != nil { 119 return proxyConn{}, err 120 } 121 } 122 if err = conn.Ping(); err != nil { 123 return proxyConn{}, err 124 } 125 h.conns[connId] = conn 126 return conn, nil 127 } 128 129 // ComInitDB implements mysql.Handler. 130 func (h MySqlProxy) ComInitDB(c *mysql.Conn, schemaName string) error { 131 conn, err := h.getConn(c.ConnectionID) 132 if err != nil { 133 return err 134 } 135 if schemaName != "" { 136 _, err = conn.Exec("USE " + schemaName + " ;") 137 } 138 return err 139 } 140 141 // ComPrepare implements mysql.Handler. 142 func (h MySqlProxy) ComPrepare(_ *mysql.Conn, _ string, _ *mysql.PrepareData) ([]*querypb.Field, error) { 143 return nil, fmt.Errorf("ComPrepare unsupported") 144 } 145 146 // ComStmtExecute implements mysql.Handler. 147 func (h MySqlProxy) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { 148 return fmt.Errorf("ComStmtExecute unsupported") 149 } 150 151 // ComResetConnection implements mysql.Handler. 152 func (h MySqlProxy) ComResetConnection(_ *mysql.Conn) error { 153 return nil 154 } 155 156 // ConnectionClosed implements mysql.Handler. 157 func (h MySqlProxy) ConnectionClosed(c *mysql.Conn) { 158 conn, ok := h.conns[c.ConnectionID] 159 if !ok { 160 return 161 } 162 if err := conn.Close(); err != nil { 163 lgr := logrus.WithField(sql.ConnectionIdLogField, c.ConnectionID) 164 lgr.Errorf("Error closing connection") 165 } 166 delete(h.conns, c.ConnectionID) 167 } 168 169 // ComMultiQuery implements mysql.Handler. 170 func (h MySqlProxy) ComMultiQuery( 171 c *mysql.Conn, 172 query string, 173 callback mysql.ResultSpoolFn, 174 ) (string, error) { 175 conn, err := h.getConn(c.ConnectionID) 176 if err != nil { 177 return "", err 178 } 179 conn.Entry = conn.Entry.WithField("query", query) 180 181 remainder, err := h.processQuery(c, conn, query, true, callback) 182 if err != nil { 183 conn.Errorf("Failed to process MySQL results: %s", err) 184 } 185 return remainder, err 186 } 187 188 // ComQuery implements mysql.Handler. 189 func (h MySqlProxy) ComQuery( 190 c *mysql.Conn, 191 query string, 192 callback mysql.ResultSpoolFn, 193 ) error { 194 conn, err := h.getConn(c.ConnectionID) 195 if err != nil { 196 return err 197 } 198 conn.Entry = conn.Entry.WithField("query", query) 199 200 _, err = h.processQuery(c, conn, query, false, callback) 201 if err != nil { 202 conn.Errorf("Failed to process MySQL results: %s", err) 203 } 204 return err 205 } 206 207 // ComParsedQuery implements mysql.Handler. 208 func (h MySqlProxy) ComParsedQuery( 209 c *mysql.Conn, 210 query string, 211 parsed sqlparser.Statement, 212 callback func(*sqltypes.Result, bool) error, 213 ) error { 214 return h.ComQuery(c, query, callback) 215 } 216 217 func (h MySqlProxy) processQuery( 218 c *mysql.Conn, 219 proxy proxyConn, 220 query string, 221 isMultiStatement bool, 222 callback func(*sqltypes.Result, bool) error, 223 ) (string, error) { 224 ctx := sql.NewContext(h.ctx) 225 var remainder string 226 if isMultiStatement { 227 _, ri, err := sqlparser.ParseOne(query) 228 if err != nil { 229 return "", err 230 } 231 if ri != 0 && ri < len(query) { 232 remainder = query[ri:] 233 query = query[:ri] 234 query = planbuilder.RemoveSpaceAndDelimiter(query, ';') 235 } 236 } 237 238 ctx = ctx.WithQuery(query) 239 more := remainder != "" 240 241 proxy.Debugf("Sending query to MySQL") 242 rows, err := proxy.Query(query) 243 if err != nil { 244 return "", err 245 } 246 defer func() { 247 if cerr := rows.Close(); cerr != nil { 248 err = cerr 249 } 250 }() 251 252 var processedAtLeastOneBatch bool 253 res := &sqltypes.Result{} 254 ok := true 255 for ok { 256 if res, ok, err = fetchMySqlRows(ctx, rows, 128); err != nil { 257 return "", err 258 } 259 if err := callback(res, more); err != nil { 260 return "", err 261 } 262 processedAtLeastOneBatch = true 263 } 264 265 if err := setConnStatusFlags(ctx, c); err != nil { 266 return remainder, err 267 } 268 269 switch len(res.Rows) { 270 case 0: 271 if len(res.Info) > 0 { 272 ctx.GetLogger().Tracef("returning result %s", res.Info) 273 } else { 274 ctx.GetLogger().Tracef("returning empty result") 275 } 276 case 1: 277 ctx.GetLogger().Tracef("returning result %v", res) 278 } 279 280 // processedAtLeastOneBatch means we already called resultsCB() at least 281 // once, so no need to call it if RowsAffected == 0. 282 if res != nil && (res.RowsAffected == 0 && processedAtLeastOneBatch) { 283 return remainder, nil 284 } 285 286 return remainder, nil 287 } 288 289 // WarningCount is called at the end of each query to obtain 290 // the value to be returned to the client in the EOF packet. 291 // Note that this will be called either in the context of the 292 // ComQuery resultsCB if the result does not contain any fields, 293 // or after the last ComQuery call completes. 294 func (h MySqlProxy) WarningCount(c *mysql.Conn) uint16 { 295 return 0 296 } 297 298 // See https://dev.mysql.com/doc/internals/en/status-flags.html 299 func setConnStatusFlags(ctx *sql.Context, c *mysql.Conn) error { 300 ok, err := isSessionAutocommit(ctx) 301 if err != nil { 302 return err 303 } 304 if ok { 305 c.StatusFlags |= uint16(mysql.ServerStatusAutocommit) 306 } else { 307 c.StatusFlags &= ^uint16(mysql.ServerStatusAutocommit) 308 } 309 if t := ctx.GetTransaction(); t != nil { 310 c.StatusFlags |= uint16(mysql.ServerInTransaction) 311 } else { 312 c.StatusFlags &= ^uint16(mysql.ServerInTransaction) 313 } 314 return nil 315 } 316 317 func isSessionAutocommit(ctx *sql.Context) (bool, error) { 318 autoCommitSessionVar, err := ctx.GetSessionVariable(ctx, sql.AutoCommitSessionVar) 319 if err != nil { 320 return false, err 321 } 322 return sql.ConvertToBool(ctx, autoCommitSessionVar) 323 } 324 325 func fetchMySqlRows(ctx *sql.Context, results *dsql.Rows, count int) (res *sqltypes.Result, more bool, err error) { 326 cols, err := results.ColumnTypes() 327 if err != nil { 328 return nil, false, err 329 } 330 331 types, fields, err := schemaToFields(ctx, cols) 332 if err != nil { 333 return nil, false, err 334 } 335 336 rows := make([][]sqltypes.Value, 0, count) 337 for results.Next() { 338 if len(rows) == count { 339 more = true 340 break 341 } 342 343 scanRow, err := scanResultRow(results) 344 if err != nil { 345 return nil, false, err 346 } 347 348 row := make([]sqltypes.Value, len(fields)) 349 for i := range row { 350 scanRow[i], _, err = types[i].Convert(scanRow[i]) 351 if err != nil { 352 return nil, false, err 353 } 354 row[i], err = types[i].SQL(ctx, nil, scanRow[i]) 355 if err != nil { 356 return nil, false, err 357 } 358 } 359 rows = append(rows, row) 360 } 361 362 res = &sqltypes.Result{ 363 Fields: fields, 364 RowsAffected: uint64(len(rows)), 365 Rows: rows, 366 } 367 return 368 } 369 370 var typeDefaults = map[string]string{ 371 "char": "char(255)", 372 "binary": "binary(255)", 373 "varchar": "varchar(65535)", 374 "varbinary": "varbinary(65535)", 375 } 376 377 func schemaToFields(ctx *sql.Context, cols []*dsql.ColumnType) ([]sql.Type, []*querypb.Field, error) { 378 types := make([]sql.Type, len(cols)) 379 fields := make([]*querypb.Field, len(cols)) 380 381 var err error 382 for i, col := range cols { 383 typeStr := strings.ToLower(col.DatabaseTypeName()) 384 if length, ok := col.Length(); ok { 385 // append length specifier to type 386 typeStr = fmt.Sprintf("%s(%d)", typeStr, length) 387 } else if ts, ok := typeDefaults[typeStr]; ok { 388 // if no length specifier if given, 389 // default to the maximum width 390 typeStr = ts 391 } 392 types[i], err = planbuilder.ParseColumnTypeString(typeStr) 393 if err != nil { 394 return nil, nil, err 395 } 396 397 var charset uint32 398 switch types[i].Type() { 399 case sqltypes.Binary, sqltypes.VarBinary, sqltypes.Blob: 400 charset = mysql.CharacterSetBinary 401 default: 402 charset = mysql.CharacterSetUtf8 403 } 404 405 fields[i] = &querypb.Field{ 406 Name: col.Name(), 407 Type: types[i].Type(), 408 Charset: charset, 409 ColumnLength: math.MaxUint32, 410 } 411 } 412 return types, fields, nil 413 } 414 415 func scanResultRow(results *dsql.Rows) (sql.Row, error) { 416 cols, err := results.ColumnTypes() 417 if err != nil { 418 return nil, err 419 } 420 421 scanRow := make(sql.Row, len(cols)) 422 for i := range cols { 423 scanRow[i] = reflect.New(cols[i].ScanType()).Interface() 424 } 425 426 for i, columnType := range cols { 427 scanRow[i] = reflect.New(columnType.ScanType()).Interface() 428 } 429 430 if err = results.Scan(scanRow...); err != nil { 431 return nil, err 432 } 433 for i, val := range scanRow { 434 v := reflect.ValueOf(val).Elem().Interface() 435 switch t := v.(type) { 436 case dsql.RawBytes: 437 if t == nil { 438 scanRow[i] = nil 439 } else { 440 scanRow[i] = string(t) 441 } 442 case dsql.NullBool: 443 if t.Valid { 444 scanRow[i] = t.Bool 445 } else { 446 scanRow[i] = nil 447 } 448 case dsql.NullByte: 449 if t.Valid { 450 scanRow[i] = t.Byte 451 } else { 452 scanRow[i] = nil 453 } 454 case dsql.NullFloat64: 455 if t.Valid { 456 scanRow[i] = t.Float64 457 } else { 458 scanRow[i] = nil 459 } 460 case dsql.NullInt16: 461 if t.Valid { 462 scanRow[i] = t.Int16 463 } else { 464 scanRow[i] = nil 465 } 466 case dsql.NullInt32: 467 if t.Valid { 468 scanRow[i] = t.Int32 469 } else { 470 scanRow[i] = nil 471 } 472 case dsql.NullInt64: 473 if t.Valid { 474 scanRow[i] = t.Int64 475 } else { 476 scanRow[i] = nil 477 } 478 case dsql.NullString: 479 if t.Valid { 480 scanRow[i] = t.String 481 } else { 482 scanRow[i] = nil 483 } 484 case dsql.NullTime: 485 if t.Valid { 486 scanRow[i] = t.Time 487 } else { 488 scanRow[i] = nil 489 } 490 default: 491 scanRow[i] = t 492 } 493 } 494 return scanRow, nil 495 }