github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/allegrosql/server/conn.go (about) 1 // Copyright 2020 The Go-MyALLEGROSQL-Driver Authors. All rights reserved. 2 // 3 // This Source Code Form is subject to the terms of the Mozilla Public 4 // License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 // You can obtain one at http://mozilla.org/MPL/2.0/. 6 7 // The MIT License (MIT) 8 // 9 // Copyright (c) 2020 wandoulabs 10 // Copyright (c) 2020 siddontang 11 // 12 // Permission is hereby granted, free of charge, to any person obtaining a copy of 13 // this software and associated documentation files (the "Software"), to deal in 14 // the Software without restriction, including without limitation the rights to 15 // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 16 // the Software, and to permit persons to whom the Software is furnished to do so, 17 // subject to the following conditions: 18 // 19 // The above copyright notice and this permission notice shall be included in all 20 // copies or substantial portions of the Software. 21 22 // Copyright 2020 WHTCORPS INC, Inc. 23 // 24 // Licensed under the Apache License, Version 2.0 (the "License"); 25 // you may not use this file except in compliance with the License. 26 // You may obtain a copy of the License at 27 // 28 // http://www.apache.org/licenses/LICENSE-2.0 29 // 30 // Unless required by applicable law or agreed to in writing, software 31 // distributed under the License is distributed on an "AS IS" BASIS, 32 // See the License for the specific language governing permissions and 33 // limitations under the License. 34 35 package server 36 37 import ( 38 "bytes" 39 "context" 40 "crypto/tls" 41 "encoding/binary" 42 "fmt" 43 "io" 44 "net" 45 "runtime" 46 "runtime/pprof" 47 "runtime/trace" 48 "strconv" 49 "strings" 50 "sync" 51 "sync/atomic" 52 "time" 53 54 "github.com/opentracing/opentracing-go" 55 "github.com/prometheus/client_golang/prometheus" 56 "github.com/whtcorpsinc/BerolinaSQL" 57 "github.com/whtcorpsinc/BerolinaSQL/allegrosql" 58 "github.com/whtcorpsinc/BerolinaSQL/ast" 59 "github.com/whtcorpsinc/BerolinaSQL/auth" 60 "github.com/whtcorpsinc/BerolinaSQL/terror" 61 "github.com/whtcorpsinc/errors" 62 "github.com/whtcorpsinc/failpoint" 63 "github.com/whtcorpsinc/milevadb/blockcodec" 64 causetembedded "github.com/whtcorpsinc/milevadb/causet/embedded" 65 "github.com/whtcorpsinc/milevadb/config" 66 "github.com/whtcorpsinc/milevadb/ekv" 67 "github.com/whtcorpsinc/milevadb/interlock" 68 "github.com/whtcorpsinc/milevadb/metrics" 69 "github.com/whtcorpsinc/milevadb/petri" 70 "github.com/whtcorpsinc/milevadb/plugin" 71 "github.com/whtcorpsinc/milevadb/schemareplicant" 72 "github.com/whtcorpsinc/milevadb/soliton/chunk" 73 "github.com/whtcorpsinc/milevadb/soliton/execdetails" 74 "github.com/whtcorpsinc/milevadb/soliton/logutil" 75 "github.com/whtcorpsinc/milevadb/soliton/memcam" 76 "github.com/whtcorpsinc/milevadb/soliton/memory" 77 "github.com/whtcorpsinc/milevadb/soliton/replog" 78 "github.com/whtcorpsinc/milevadb/soliton/sqlexec" 79 "github.com/whtcorpsinc/milevadb/stochastikctx" 80 "github.com/whtcorpsinc/milevadb/stochastikctx/stmtctx" 81 "github.com/whtcorpsinc/milevadb/stochastikctx/variable" 82 "go.uber.org/zap" 83 ) 84 85 const ( 86 connStatusDispatching int32 = iota 87 connStatusReading 88 connStatusShutdown // Closed by server. 89 connStatusWaitShutdown // Notified by server to close. 90 ) 91 92 var ( 93 queryTotalCountOk = [...]prometheus.Counter{ 94 allegrosql.ComSleep: metrics.QueryTotalCounter.WithLabelValues("Sleep", "OK"), 95 allegrosql.ComQuit: metrics.QueryTotalCounter.WithLabelValues("Quit", "OK"), 96 allegrosql.ComInitDB: metrics.QueryTotalCounter.WithLabelValues("InitDB", "OK"), 97 allegrosql.ComQuery: metrics.QueryTotalCounter.WithLabelValues("Query", "OK"), 98 allegrosql.ComPing: metrics.QueryTotalCounter.WithLabelValues("Ping", "OK"), 99 allegrosql.ComFieldList: metrics.QueryTotalCounter.WithLabelValues("FieldList", "OK"), 100 allegrosql.ComStmtPrepare: metrics.QueryTotalCounter.WithLabelValues("StmtPrepare", "OK"), 101 allegrosql.ComStmtInterDircute: metrics.QueryTotalCounter.WithLabelValues("StmtInterDircute", "OK"), 102 allegrosql.ComStmtFetch: metrics.QueryTotalCounter.WithLabelValues("StmtFetch", "OK"), 103 allegrosql.ComStmtClose: metrics.QueryTotalCounter.WithLabelValues("StmtClose", "OK"), 104 allegrosql.ComStmtSendLongData: metrics.QueryTotalCounter.WithLabelValues("StmtSendLongData", "OK"), 105 allegrosql.ComStmtReset: metrics.QueryTotalCounter.WithLabelValues("StmtReset", "OK"), 106 allegrosql.ComSetOption: metrics.QueryTotalCounter.WithLabelValues("SetOption", "OK"), 107 } 108 queryTotalCountErr = [...]prometheus.Counter{ 109 allegrosql.ComSleep: metrics.QueryTotalCounter.WithLabelValues("Sleep", "Error"), 110 allegrosql.ComQuit: metrics.QueryTotalCounter.WithLabelValues("Quit", "Error"), 111 allegrosql.ComInitDB: metrics.QueryTotalCounter.WithLabelValues("InitDB", "Error"), 112 allegrosql.ComQuery: metrics.QueryTotalCounter.WithLabelValues("Query", "Error"), 113 allegrosql.ComPing: metrics.QueryTotalCounter.WithLabelValues("Ping", "Error"), 114 allegrosql.ComFieldList: metrics.QueryTotalCounter.WithLabelValues("FieldList", "Error"), 115 allegrosql.ComStmtPrepare: metrics.QueryTotalCounter.WithLabelValues("StmtPrepare", "Error"), 116 allegrosql.ComStmtInterDircute: metrics.QueryTotalCounter.WithLabelValues("StmtInterDircute", "Error"), 117 allegrosql.ComStmtFetch: metrics.QueryTotalCounter.WithLabelValues("StmtFetch", "Error"), 118 allegrosql.ComStmtClose: metrics.QueryTotalCounter.WithLabelValues("StmtClose", "Error"), 119 allegrosql.ComStmtSendLongData: metrics.QueryTotalCounter.WithLabelValues("StmtSendLongData", "Error"), 120 allegrosql.ComStmtReset: metrics.QueryTotalCounter.WithLabelValues("StmtReset", "Error"), 121 allegrosql.ComSetOption: metrics.QueryTotalCounter.WithLabelValues("SetOption", "Error"), 122 } 123 124 queryDurationHistogramUse = metrics.QueryDurationHistogram.WithLabelValues("Use") 125 queryDurationHistogramShow = metrics.QueryDurationHistogram.WithLabelValues("Show") 126 queryDurationHistogramBegin = metrics.QueryDurationHistogram.WithLabelValues("Begin") 127 queryDurationHistogramCommit = metrics.QueryDurationHistogram.WithLabelValues("Commit") 128 queryDurationHistogramRollback = metrics.QueryDurationHistogram.WithLabelValues("Rollback") 129 queryDurationHistogramInsert = metrics.QueryDurationHistogram.WithLabelValues("Insert") 130 queryDurationHistogramReplace = metrics.QueryDurationHistogram.WithLabelValues("Replace") 131 queryDurationHistogramDelete = metrics.QueryDurationHistogram.WithLabelValues("Delete") 132 queryDurationHistogramUFIDelate = metrics.QueryDurationHistogram.WithLabelValues("UFIDelate") 133 queryDurationHistogramSelect = metrics.QueryDurationHistogram.WithLabelValues("Select") 134 queryDurationHistogramInterDircute = metrics.QueryDurationHistogram.WithLabelValues("InterDircute") 135 queryDurationHistogramSet = metrics.QueryDurationHistogram.WithLabelValues("Set") 136 queryDurationHistogramGeneral = metrics.QueryDurationHistogram.WithLabelValues(metrics.LblGeneral) 137 138 disconnectNormal = metrics.DisconnectionCounter.WithLabelValues(metrics.LblOK) 139 disconnectByClientWithError = metrics.DisconnectionCounter.WithLabelValues(metrics.LblError) 140 disconnectErrorUndetermined = metrics.DisconnectionCounter.WithLabelValues("undetermined") 141 ) 142 143 // newClientConn creates a *clientConn object. 144 func newClientConn(s *Server) *clientConn { 145 return &clientConn{ 146 server: s, 147 connectionID: atomic.AddUint32(&baseConnID, 1), 148 defCauslation: allegrosql.DefaultDefCauslationID, 149 alloc: memcam.NewSlabPredictor(32 * 1024), 150 status: connStatusDispatching, 151 } 152 } 153 154 // clientConn represents a connection between server and client, it maintains connection specific state, 155 // handles client query. 156 type clientConn struct { 157 pkt *packetIO // a helper to read and write data in packet format. 158 bufReadConn *bufferedReadConn // a buffered-read net.Conn or buffered-read tls.Conn. 159 tlsConn *tls.Conn // TLS connection, nil if not TLS. 160 server *Server // a reference of server instance. 161 capability uint32 // client capability affects the way server handles client request. 162 connectionID uint32 // atomically allocated by a global variable, unique in process scope. 163 user string // user of the client. 164 dbname string // default database name. 165 salt []byte // random bytes used for authentication. 166 alloc memcam.SlabPredictor // an memory allocator for reducing memory allocation. 167 lastPacket []byte // latest allegrosql query string, currently used for logging error. 168 ctx *MilevaDBContext // an interface to execute allegrosql memexs. 169 attrs map[string]string // attributes parsed from client handshake response, not used for now. 170 peerHost string // peer host 171 peerPort string // peer port 172 status int32 // dispatching/reading/shutdown/waitshutdown 173 lastCode uint16 // last error code 174 defCauslation uint8 // defCauslation used by client, may be different from the defCauslation used by database. 175 } 176 177 func (cc *clientConn) String() string { 178 defCauslationStr := allegrosql.DefCauslations[cc.defCauslation] 179 return fmt.Sprintf("id:%d, addr:%s status:%b, defCauslation:%s, user:%s", 180 cc.connectionID, cc.bufReadConn.RemoteAddr(), cc.ctx.Status(), defCauslationStr, cc.user, 181 ) 182 } 183 184 // authSwitchRequest is used when the client asked to speak something 185 // other than mysql_native_password. The server is allowed to ask 186 // the client to switch, so lets ask for mysql_native_password 187 // https://dev.allegrosql.com/doc/internals/en/connection-phase-packets.html#packet-ProtodefCaus::AuthSwitchRequest 188 func (cc *clientConn) authSwitchRequest(ctx context.Context) ([]byte, error) { 189 enclen := 1 + len("mysql_native_password") + 1 + len(cc.salt) + 1 190 data := cc.alloc.AllocWithLen(4, enclen) 191 data = append(data, 0xfe) // switch request 192 data = append(data, []byte("mysql_native_password")...) 193 data = append(data, byte(0x00)) // requires null 194 data = append(data, cc.salt...) 195 data = append(data, 0) 196 err := cc.writePacket(data) 197 if err != nil { 198 logutil.Logger(ctx).Debug("write response to client failed", zap.Error(err)) 199 return nil, err 200 } 201 err = cc.flush(ctx) 202 if err != nil { 203 logutil.Logger(ctx).Debug("flush response to client failed", zap.Error(err)) 204 return nil, err 205 } 206 resp, err := cc.readPacket() 207 if err != nil { 208 err = errors.SuspendStack(err) 209 if errors.Cause(err) == io.EOF { 210 logutil.Logger(ctx).Warn("authSwitchRequest response fail due to connection has be closed by client-side") 211 } else { 212 logutil.Logger(ctx).Warn("authSwitchRequest response fail", zap.Error(err)) 213 } 214 return nil, err 215 } 216 return resp, nil 217 } 218 219 // handshake works like TCP handshake, but in a higher level, it first writes initial packet to client, 220 // during handshake, client and server negotiate compatible features and do authentication. 221 // After handshake, client can send allegrosql query to server. 222 func (cc *clientConn) handshake(ctx context.Context) error { 223 if err := cc.writeInitialHandshake(ctx); err != nil { 224 if errors.Cause(err) == io.EOF { 225 logutil.Logger(ctx).Debug("Could not send handshake due to connection has be closed by client-side") 226 } else { 227 logutil.Logger(ctx).Debug("Write init handshake to client fail", zap.Error(errors.SuspendStack(err))) 228 } 229 return err 230 } 231 if err := cc.readOptionalSSLRequestAndHandshakeResponse(ctx); err != nil { 232 err1 := cc.writeError(ctx, err) 233 if err1 != nil { 234 logutil.Logger(ctx).Debug("writeError failed", zap.Error(err1)) 235 } 236 return err 237 } 238 data := cc.alloc.AllocWithLen(4, 32) 239 data = append(data, allegrosql.OKHeader) 240 data = append(data, 0, 0) 241 if cc.capability&allegrosql.ClientProtodefCaus41 > 0 { 242 data = dumpUint16(data, allegrosql.ServerStatusAutocommit) 243 data = append(data, 0, 0) 244 } 245 246 err := cc.writePacket(data) 247 cc.pkt.sequence = 0 248 if err != nil { 249 err = errors.SuspendStack(err) 250 logutil.Logger(ctx).Debug("write response to client failed", zap.Error(err)) 251 return err 252 } 253 254 err = cc.flush(ctx) 255 if err != nil { 256 err = errors.SuspendStack(err) 257 logutil.Logger(ctx).Debug("flush response to client failed", zap.Error(err)) 258 return err 259 } 260 return err 261 } 262 263 func (cc *clientConn) Close() error { 264 cc.server.rwlock.Lock() 265 delete(cc.server.clients, cc.connectionID) 266 connections := len(cc.server.clients) 267 cc.server.rwlock.Unlock() 268 return closeConn(cc, connections) 269 } 270 271 func closeConn(cc *clientConn, connections int) error { 272 metrics.ConnGauge.Set(float64(connections)) 273 err := cc.bufReadConn.Close() 274 terror.Log(err) 275 if cc.ctx != nil { 276 return cc.ctx.Close() 277 } 278 return nil 279 } 280 281 func (cc *clientConn) closeWithoutLock() error { 282 delete(cc.server.clients, cc.connectionID) 283 return closeConn(cc, len(cc.server.clients)) 284 } 285 286 // writeInitialHandshake sends server version, connection ID, server capability, defCauslation, server status 287 // and auth salt to the client. 288 func (cc *clientConn) writeInitialHandshake(ctx context.Context) error { 289 data := make([]byte, 4, 128) 290 291 // min version 10 292 data = append(data, 10) 293 // server version[00] 294 data = append(data, allegrosql.ServerVersion...) 295 data = append(data, 0) 296 // connection id 297 data = append(data, byte(cc.connectionID), byte(cc.connectionID>>8), byte(cc.connectionID>>16), byte(cc.connectionID>>24)) 298 // auth-plugin-data-part-1 299 data = append(data, cc.salt[0:8]...) 300 // filler [00] 301 data = append(data, 0) 302 // capability flag lower 2 bytes, using default capability here 303 data = append(data, byte(cc.server.capability), byte(cc.server.capability>>8)) 304 // charset 305 if cc.defCauslation == 0 { 306 cc.defCauslation = uint8(allegrosql.DefaultDefCauslationID) 307 } 308 data = append(data, cc.defCauslation) 309 // status 310 data = dumpUint16(data, allegrosql.ServerStatusAutocommit) 311 // below 13 byte may not be used 312 // capability flag upper 2 bytes, using default capability here 313 data = append(data, byte(cc.server.capability>>16), byte(cc.server.capability>>24)) 314 // length of auth-plugin-data 315 data = append(data, byte(len(cc.salt)+1)) 316 // reserved 10 [00] 317 data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) 318 // auth-plugin-data-part-2 319 data = append(data, cc.salt[8:]...) 320 data = append(data, 0) 321 // auth-plugin name 322 data = append(data, []byte("mysql_native_password")...) 323 data = append(data, 0) 324 err := cc.writePacket(data) 325 if err != nil { 326 return err 327 } 328 return cc.flush(ctx) 329 } 330 331 func (cc *clientConn) readPacket() ([]byte, error) { 332 return cc.pkt.readPacket() 333 } 334 335 func (cc *clientConn) writePacket(data []byte) error { 336 failpoint.Inject("FakeClientConn", func() { 337 if cc.pkt == nil { 338 failpoint.Return(nil) 339 } 340 }) 341 return cc.pkt.writePacket(data) 342 } 343 344 // getStochastikVarsWaitTimeout get stochastik variable wait_timeout 345 func (cc *clientConn) getStochastikVarsWaitTimeout(ctx context.Context) uint64 { 346 valStr, exists := cc.ctx.GetStochastikVars().GetSystemVar(variable.WaitTimeout) 347 if !exists { 348 return variable.DefWaitTimeout 349 } 350 waitTimeout, err := strconv.ParseUint(valStr, 10, 64) 351 if err != nil { 352 logutil.Logger(ctx).Warn("get sysval wait_timeout failed, use default value", zap.Error(err)) 353 // if get waitTimeout error, use default value 354 return variable.DefWaitTimeout 355 } 356 return waitTimeout 357 } 358 359 type handshakeResponse41 struct { 360 Capability uint32 361 DefCauslation uint8 362 User string 363 DBName string 364 Auth []byte 365 AuthPlugin string 366 Attrs map[string]string 367 } 368 369 // parseOldHandshakeResponseHeader parses the old version handshake header HandshakeResponse320 370 func parseOldHandshakeResponseHeader(ctx context.Context, packet *handshakeResponse41, data []byte) (parsedBytes int, err error) { 371 // Ensure there are enough data to read: 372 // https://dev.allegrosql.com/doc/internals/en/connection-phase-packets.html#packet-ProtodefCaus::HandshakeResponse320 373 logutil.Logger(ctx).Debug("try to parse hanshake response as ProtodefCaus::HandshakeResponse320", zap.ByteString("packetData", data)) 374 if len(data) < 2+3 { 375 logutil.Logger(ctx).Error("got malformed handshake response", zap.ByteString("packetData", data)) 376 return 0, allegrosql.ErrMalformPacket 377 } 378 offset := 0 379 // capability 380 capability := binary.LittleEndian.Uint16(data[:2]) 381 packet.Capability = uint32(capability) 382 383 // be compatible with ProtodefCaus::HandshakeResponse41 384 packet.Capability = packet.Capability | allegrosql.ClientProtodefCaus41 385 386 offset += 2 387 // skip max packet size 388 offset += 3 389 // usa default CharsetID 390 packet.DefCauslation = allegrosql.DefCauslationNames["utf8mb4_general_ci"] 391 392 return offset, nil 393 } 394 395 // parseOldHandshakeResponseBody parse the HandshakeResponse for ProtodefCaus::HandshakeResponse320 (except the common header part). 396 func parseOldHandshakeResponseBody(ctx context.Context, packet *handshakeResponse41, data []byte, offset int) (err error) { 397 defer func() { 398 // Check malformat packet cause out of range is disgusting, but don't panic! 399 if r := recover(); r != nil { 400 logutil.Logger(ctx).Error("handshake panic", zap.ByteString("packetData", data), zap.Stack("stack")) 401 err = allegrosql.ErrMalformPacket 402 } 403 }() 404 // user name 405 packet.User = string(data[offset : offset+bytes.IndexByte(data[offset:], 0)]) 406 offset += len(packet.User) + 1 407 408 if packet.Capability&allegrosql.ClientConnectWithDB > 0 { 409 if len(data[offset:]) > 0 { 410 idx := bytes.IndexByte(data[offset:], 0) 411 packet.DBName = string(data[offset : offset+idx]) 412 offset = offset + idx + 1 413 } 414 if len(data[offset:]) > 0 { 415 packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)] 416 } 417 } else { 418 packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)] 419 } 420 421 return nil 422 } 423 424 // parseHandshakeResponseHeader parses the common header of SSLRequest and HandshakeResponse41. 425 func parseHandshakeResponseHeader(ctx context.Context, packet *handshakeResponse41, data []byte) (parsedBytes int, err error) { 426 // Ensure there are enough data to read: 427 // http://dev.allegrosql.com/doc/internals/en/connection-phase-packets.html#packet-ProtodefCaus::SSLRequest 428 if len(data) < 4+4+1+23 { 429 logutil.Logger(ctx).Error("got malformed handshake response", zap.ByteString("packetData", data)) 430 return 0, allegrosql.ErrMalformPacket 431 } 432 433 offset := 0 434 // capability 435 capability := binary.LittleEndian.Uint32(data[:4]) 436 packet.Capability = capability 437 offset += 4 438 // skip max packet size 439 offset += 4 440 // charset, skip, if you want to use another charset, use set names 441 packet.DefCauslation = data[offset] 442 offset++ 443 // skip reserved 23[00] 444 offset += 23 445 446 return offset, nil 447 } 448 449 // parseHandshakeResponseBody parse the HandshakeResponse (except the common header part). 450 func parseHandshakeResponseBody(ctx context.Context, packet *handshakeResponse41, data []byte, offset int) (err error) { 451 defer func() { 452 // Check malformat packet cause out of range is disgusting, but don't panic! 453 if r := recover(); r != nil { 454 logutil.Logger(ctx).Error("handshake panic", zap.ByteString("packetData", data)) 455 err = allegrosql.ErrMalformPacket 456 } 457 }() 458 // user name 459 packet.User = string(data[offset : offset+bytes.IndexByte(data[offset:], 0)]) 460 offset += len(packet.User) + 1 461 462 if packet.Capability&allegrosql.ClientPluginAuthLenencClientData > 0 { 463 // MyALLEGROSQL client sets the wrong capability, it will set this bit even server doesn't 464 // support ClientPluginAuthLenencClientData. 465 // https://github.com/allegrosql/allegrosql-server/blob/5.7/allegrosql-common/client.c#L3478 466 num, null, off := parseLengthEncodedInt(data[offset:]) 467 offset += off 468 if !null { 469 packet.Auth = data[offset : offset+int(num)] 470 offset += int(num) 471 } 472 } else if packet.Capability&allegrosql.ClientSecureConnection > 0 { 473 // auth length and auth 474 authLen := int(data[offset]) 475 offset++ 476 packet.Auth = data[offset : offset+authLen] 477 offset += authLen 478 } else { 479 packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)] 480 offset += len(packet.Auth) + 1 481 } 482 483 if packet.Capability&allegrosql.ClientConnectWithDB > 0 { 484 if len(data[offset:]) > 0 { 485 idx := bytes.IndexByte(data[offset:], 0) 486 packet.DBName = string(data[offset : offset+idx]) 487 offset += idx + 1 488 } 489 } 490 491 if packet.Capability&allegrosql.ClientPluginAuth > 0 { 492 idx := bytes.IndexByte(data[offset:], 0) 493 s := offset 494 f := offset + idx 495 if s < f { // handle unexpected bad packets 496 packet.AuthPlugin = string(data[s:f]) 497 } 498 offset += idx + 1 499 } 500 501 if packet.Capability&allegrosql.ClientConnectAtts > 0 { 502 if len(data[offset:]) == 0 { 503 // Defend some ill-formated packet, connection attribute is not important and can be ignored. 504 return nil 505 } 506 if num, null, off := parseLengthEncodedInt(data[offset:]); !null { 507 offset += off 508 event := data[offset : offset+int(num)] 509 attrs, err := parseAttrs(event) 510 if err != nil { 511 logutil.Logger(ctx).Warn("parse attrs failed", zap.Error(err)) 512 return nil 513 } 514 packet.Attrs = attrs 515 } 516 } 517 518 return nil 519 } 520 521 func parseAttrs(data []byte) (map[string]string, error) { 522 attrs := make(map[string]string) 523 pos := 0 524 for pos < len(data) { 525 key, _, off, err := parseLengthEncodedBytes(data[pos:]) 526 if err != nil { 527 return attrs, err 528 } 529 pos += off 530 value, _, off, err := parseLengthEncodedBytes(data[pos:]) 531 if err != nil { 532 return attrs, err 533 } 534 pos += off 535 536 attrs[string(key)] = string(value) 537 } 538 return attrs, nil 539 } 540 541 func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Context) error { 542 // Read a packet. It may be a SSLRequest or HandshakeResponse. 543 data, err := cc.readPacket() 544 if err != nil { 545 err = errors.SuspendStack(err) 546 if errors.Cause(err) == io.EOF { 547 logutil.Logger(ctx).Debug("wait handshake response fail due to connection has be closed by client-side") 548 } else { 549 logutil.Logger(ctx).Debug("wait handshake response fail", zap.Error(err)) 550 } 551 return err 552 } 553 554 isOldVersion := false 555 556 var resp handshakeResponse41 557 var pos int 558 559 if len(data) < 2 { 560 logutil.Logger(ctx).Error("got malformed handshake response", zap.ByteString("packetData", data)) 561 return allegrosql.ErrMalformPacket 562 } 563 564 capability := uint32(binary.LittleEndian.Uint16(data[:2])) 565 if capability&allegrosql.ClientProtodefCaus41 > 0 { 566 pos, err = parseHandshakeResponseHeader(ctx, &resp, data) 567 } else { 568 pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data) 569 isOldVersion = true 570 } 571 572 if err != nil { 573 terror.Log(err) 574 return err 575 } 576 577 if resp.Capability&allegrosql.ClientSSL > 0 { 578 tlsConfig := (*tls.Config)(atomic.LoadPointer(&cc.server.tlsConfig)) 579 if tlsConfig != nil { 580 // The packet is a SSLRequest, let's switch to TLS. 581 if err = cc.upgradeToTLS(tlsConfig); err != nil { 582 return err 583 } 584 // Read the following HandshakeResponse packet. 585 data, err = cc.readPacket() 586 if err != nil { 587 logutil.Logger(ctx).Warn("read handshake response failure after upgrade to TLS", zap.Error(err)) 588 return err 589 } 590 if isOldVersion { 591 pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data) 592 } else { 593 pos, err = parseHandshakeResponseHeader(ctx, &resp, data) 594 } 595 if err != nil { 596 terror.Log(err) 597 return err 598 } 599 } 600 } else if config.GetGlobalConfig().Security.RequireSecureTransport { 601 err := errSecureTransportRequired.FastGenByArgs() 602 terror.Log(err) 603 return err 604 } 605 606 // Read the remaining part of the packet. 607 if isOldVersion { 608 err = parseOldHandshakeResponseBody(ctx, &resp, data, pos) 609 } else { 610 err = parseHandshakeResponseBody(ctx, &resp, data, pos) 611 } 612 if err != nil { 613 terror.Log(err) 614 return err 615 } 616 617 // switching from other methods should work, but not tested 618 if resp.AuthPlugin == "caching_sha2_password" { 619 resp.Auth, err = cc.authSwitchRequest(ctx) 620 if err != nil { 621 logutil.Logger(ctx).Warn("attempt to send auth switch request packet failed", zap.Error(err)) 622 return err 623 } 624 } 625 cc.capability = resp.Capability & cc.server.capability 626 cc.user = resp.User 627 cc.dbname = resp.DBName 628 cc.defCauslation = resp.DefCauslation 629 cc.attrs = resp.Attrs 630 631 err = cc.openStochastikAndDoAuth(resp.Auth) 632 if err != nil { 633 logutil.Logger(ctx).Warn("open new stochastik failure", zap.Error(err)) 634 } 635 return err 636 } 637 638 func (cc *clientConn) StochastikStatusToString() string { 639 status := cc.ctx.Status() 640 inTxn, autoCommit := 0, 0 641 if status&allegrosql.ServerStatusInTrans > 0 { 642 inTxn = 1 643 } 644 if status&allegrosql.ServerStatusAutocommit > 0 { 645 autoCommit = 1 646 } 647 return fmt.Sprintf("inTxn:%d, autocommit:%d", 648 inTxn, autoCommit, 649 ) 650 } 651 652 func (cc *clientConn) openStochastikAndDoAuth(authData []byte) error { 653 var tlsStatePtr *tls.ConnectionState 654 if cc.tlsConn != nil { 655 tlsState := cc.tlsConn.ConnectionState() 656 tlsStatePtr = &tlsState 657 } 658 var err error 659 cc.ctx, err = cc.server.driver.OpenCtx(uint64(cc.connectionID), cc.capability, cc.defCauslation, cc.dbname, tlsStatePtr) 660 if err != nil { 661 return err 662 } 663 664 if err = cc.server.checkConnectionCount(); err != nil { 665 return err 666 } 667 hasPassword := "YES" 668 if len(authData) == 0 { 669 hasPassword = "NO" 670 } 671 host, err := cc.PeerHost(hasPassword) 672 if err != nil { 673 return err 674 } 675 if !cc.ctx.Auth(&auth.UserIdentity{Username: cc.user, Hostname: host}, authData, cc.salt) { 676 return errAccessDenied.FastGenByArgs(cc.user, host, hasPassword) 677 } 678 if cc.dbname != "" { 679 err = cc.useDB(context.Background(), cc.dbname) 680 if err != nil { 681 return err 682 } 683 } 684 cc.ctx.SetStochastikManager(cc.server) 685 return nil 686 } 687 688 func (cc *clientConn) PeerHost(hasPassword string) (host string, err error) { 689 if len(cc.peerHost) > 0 { 690 return cc.peerHost, nil 691 } 692 host = variable.DefHostname 693 if cc.server.isUnixSocket() { 694 cc.peerHost = host 695 return 696 } 697 addr := cc.bufReadConn.RemoteAddr().String() 698 var port string 699 host, port, err = net.SplitHostPort(addr) 700 if err != nil { 701 err = errAccessDenied.GenWithStackByArgs(cc.user, addr, hasPassword) 702 return 703 } 704 cc.peerHost = host 705 cc.peerPort = port 706 return 707 } 708 709 // Run reads client query and writes query result to client in for loop, if there is a panic during query handling, 710 // it will be recovered and log the panic error. 711 // This function returns and the connection is closed if there is an IO error or there is a panic. 712 func (cc *clientConn) Run(ctx context.Context) { 713 const size = 4096 714 defer func() { 715 r := recover() 716 if r != nil { 717 buf := make([]byte, size) 718 stackSize := runtime.Stack(buf, false) 719 buf = buf[:stackSize] 720 logutil.Logger(ctx).Error("connection running loop panic", 721 zap.Stringer("lastALLEGROSQL", getLastStmtInConn{cc}), 722 zap.String("err", fmt.Sprintf("%v", r)), 723 zap.String("stack", string(buf)), 724 ) 725 err := cc.writeError(ctx, errors.New(fmt.Sprintf("%v", r))) 726 terror.Log(err) 727 metrics.PanicCounter.WithLabelValues(metrics.LabelStochastik).Inc() 728 } 729 if atomic.LoadInt32(&cc.status) != connStatusShutdown { 730 err := cc.Close() 731 terror.Log(err) 732 } 733 }() 734 // Usually, client connection status changes between [dispatching] <=> [reading]. 735 // When some event happens, server may notify this client connection by setting 736 // the status to special values, for example: kill or graceful shutdown. 737 // The client connection would detect the events when it fails to change status 738 // by CAS operation, it would then take some actions accordingly. 739 for { 740 if !atomic.CompareAndSwapInt32(&cc.status, connStatusDispatching, connStatusReading) { 741 return 742 } 743 744 cc.alloc.Reset() 745 // close connection when idle time is more than wait_timeout 746 waitTimeout := cc.getStochastikVarsWaitTimeout(ctx) 747 cc.pkt.setReadTimeout(time.Duration(waitTimeout) * time.Second) 748 start := time.Now() 749 data, err := cc.readPacket() 750 if err != nil { 751 if terror.ErrorNotEqual(err, io.EOF) { 752 if netErr, isNetErr := errors.Cause(err).(net.Error); isNetErr && netErr.Timeout() { 753 idleTime := time.Since(start) 754 logutil.Logger(ctx).Info("read packet timeout, close this connection", 755 zap.Duration("idle", idleTime), 756 zap.Uint64("waitTimeout", waitTimeout), 757 zap.Error(err), 758 ) 759 } else { 760 errStack := errors.ErrorStack(err) 761 if !strings.Contains(errStack, "use of closed network connection") { 762 logutil.Logger(ctx).Warn("read packet failed, close this connection", 763 zap.Error(errors.SuspendStack(err))) 764 } 765 } 766 } 767 disconnectByClientWithError.Inc() 768 return 769 } 770 771 if !atomic.CompareAndSwapInt32(&cc.status, connStatusReading, connStatusDispatching) { 772 return 773 } 774 775 startTime := time.Now() 776 if err = cc.dispatch(ctx, data); err != nil { 777 if terror.ErrorEqual(err, io.EOF) { 778 cc.addMetrics(data[0], startTime, nil) 779 disconnectNormal.Inc() 780 return 781 } else if terror.ErrResultUndetermined.Equal(err) { 782 logutil.Logger(ctx).Error("result undetermined, close this connection", zap.Error(err)) 783 disconnectErrorUndetermined.Inc() 784 return 785 } else if terror.ErrCritical.Equal(err) { 786 metrics.CriticalErrorCounter.Add(1) 787 logutil.Logger(ctx).Fatal("critical error, stop the server", zap.Error(err)) 788 } 789 var txnMode string 790 if cc.ctx != nil { 791 txnMode = cc.ctx.GetStochastikVars().GetReadableTxnMode() 792 } 793 logutil.Logger(ctx).Info("command dispatched failed", 794 zap.String("connInfo", cc.String()), 795 zap.String("command", allegrosql.Command2Str[data[0]]), 796 zap.String("status", cc.StochastikStatusToString()), 797 zap.Stringer("allegrosql", getLastStmtInConn{cc}), 798 zap.String("txn_mode", txnMode), 799 zap.String("err", errStrForLog(err)), 800 ) 801 err1 := cc.writeError(ctx, err) 802 terror.Log(err1) 803 } 804 cc.addMetrics(data[0], startTime, err) 805 cc.pkt.sequence = 0 806 } 807 } 808 809 // ShutdownOrNotify will Shutdown this client connection, or do its best to notify. 810 func (cc *clientConn) ShutdownOrNotify() bool { 811 if (cc.ctx.Status() & allegrosql.ServerStatusInTrans) > 0 { 812 return false 813 } 814 // If the client connection status is reading, it's safe to shutdown it. 815 if atomic.CompareAndSwapInt32(&cc.status, connStatusReading, connStatusShutdown) { 816 return true 817 } 818 // If the client connection status is dispatching, we can't shutdown it immediately, 819 // so set the status to WaitShutdown as a notification, the client will detect it 820 // and then exit. 821 atomic.StoreInt32(&cc.status, connStatusWaitShutdown) 822 return false 823 } 824 825 func queryStrForLog(query string) string { 826 const size = 4096 827 if len(query) > size { 828 return query[:size] + fmt.Sprintf("(len: %d)", len(query)) 829 } 830 return query 831 } 832 833 func errStrForLog(err error) string { 834 if ekv.ErrKeyExists.Equal(err) || BerolinaSQL.ErrParse.Equal(err) || schemareplicant.ErrBlockNotExists.Equal(err) { 835 // Do not log stack for duplicated entry error. 836 return err.Error() 837 } 838 return errors.ErrorStack(err) 839 } 840 841 func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) { 842 if cmd == allegrosql.ComQuery && cc.ctx.Value(stochastikctx.LastInterDircuteDBS) != nil { 843 // Don't take DBS execute time into account. 844 // It's already recorded by other metrics in dbs package. 845 return 846 } 847 848 var counter prometheus.Counter 849 if err != nil && int(cmd) < len(queryTotalCountErr) { 850 counter = queryTotalCountErr[cmd] 851 } else if err == nil && int(cmd) < len(queryTotalCountOk) { 852 counter = queryTotalCountOk[cmd] 853 } 854 if counter != nil { 855 counter.Inc() 856 } else { 857 label := strconv.Itoa(int(cmd)) 858 if err != nil { 859 metrics.QueryTotalCounter.WithLabelValues(label, "ERROR").Inc() 860 } else { 861 metrics.QueryTotalCounter.WithLabelValues(label, "OK").Inc() 862 } 863 } 864 865 stmtType := cc.ctx.GetStochastikVars().StmtCtx.StmtType 866 sqlType := metrics.LblGeneral 867 if stmtType != "" { 868 sqlType = stmtType 869 } 870 871 switch sqlType { 872 case "Use": 873 queryDurationHistogramUse.Observe(time.Since(startTime).Seconds()) 874 case "Show": 875 queryDurationHistogramShow.Observe(time.Since(startTime).Seconds()) 876 case "Begin": 877 queryDurationHistogramBegin.Observe(time.Since(startTime).Seconds()) 878 case "Commit": 879 queryDurationHistogramCommit.Observe(time.Since(startTime).Seconds()) 880 case "Rollback": 881 queryDurationHistogramRollback.Observe(time.Since(startTime).Seconds()) 882 case "Insert": 883 queryDurationHistogramInsert.Observe(time.Since(startTime).Seconds()) 884 case "Replace": 885 queryDurationHistogramReplace.Observe(time.Since(startTime).Seconds()) 886 case "Delete": 887 queryDurationHistogramDelete.Observe(time.Since(startTime).Seconds()) 888 case "UFIDelate": 889 queryDurationHistogramUFIDelate.Observe(time.Since(startTime).Seconds()) 890 case "Select": 891 queryDurationHistogramSelect.Observe(time.Since(startTime).Seconds()) 892 case "InterDircute": 893 queryDurationHistogramInterDircute.Observe(time.Since(startTime).Seconds()) 894 case "Set": 895 queryDurationHistogramSet.Observe(time.Since(startTime).Seconds()) 896 case metrics.LblGeneral: 897 queryDurationHistogramGeneral.Observe(time.Since(startTime).Seconds()) 898 default: 899 metrics.QueryDurationHistogram.WithLabelValues(sqlType).Observe(time.Since(startTime).Seconds()) 900 } 901 } 902 903 // dispatch handles client request based on command which is the first byte of the data. 904 // It also gets a token from server which is used to limit the concurrently handling clients. 905 // The most frequently used command is ComQuery. 906 func (cc *clientConn) dispatch(ctx context.Context, data []byte) error { 907 defer func() { 908 // reset killed for each request 909 atomic.StoreUint32(&cc.ctx.GetStochastikVars().Killed, 0) 910 }() 911 span := opentracing.StartSpan("server.dispatch") 912 913 t := time.Now() 914 cc.lastPacket = data 915 cmd := data[0] 916 data = data[1:] 917 if variable.EnablePProfALLEGROSQLCPU.Load() { 918 label := getLastStmtInConn{cc}.PProfLabel() 919 if len(label) > 0 { 920 defer pprof.SetGoroutineLabels(ctx) 921 ctx = pprof.WithLabels(ctx, pprof.Labels("allegrosql", label)) 922 pprof.SetGoroutineLabels(ctx) 923 } 924 } 925 if trace.IsEnabled() { 926 lc := getLastStmtInConn{cc} 927 sqlType := lc.PProfLabel() 928 if len(sqlType) > 0 { 929 var task *trace.Task 930 ctx, task = trace.NewTask(ctx, sqlType) 931 trace.Log(ctx, "allegrosql", lc.String()) 932 defer task.End() 933 } 934 } 935 token := cc.server.getToken() 936 defer func() { 937 // if handleChangeUser failed, cc.ctx may be nil 938 if cc.ctx != nil { 939 cc.ctx.SetProcessInfo("", t, allegrosql.ComSleep, 0) 940 } 941 942 cc.server.releaseToken(token) 943 span.Finish() 944 }() 945 946 vars := cc.ctx.GetStochastikVars() 947 // reset killed for each request 948 atomic.StoreUint32(&vars.Killed, 0) 949 if cmd < allegrosql.ComEnd { 950 cc.ctx.SetCommandValue(cmd) 951 } 952 953 dataStr := string(replog.String(data)) 954 switch cmd { 955 case allegrosql.ComPing, allegrosql.ComStmtClose, allegrosql.ComStmtSendLongData, allegrosql.ComStmtReset, 956 allegrosql.ComSetOption, allegrosql.ComChangeUser: 957 cc.ctx.SetProcessInfo("", t, cmd, 0) 958 case allegrosql.ComInitDB: 959 cc.ctx.SetProcessInfo("use "+dataStr, t, cmd, 0) 960 } 961 962 switch cmd { 963 case allegrosql.ComSleep: 964 // TODO: According to allegrosql document, this command is supposed to be used only internally. 965 // So it's just a temp fix, not sure if it's done right. 966 // Investigate this command and write test case later. 967 return nil 968 case allegrosql.ComQuit: 969 return io.EOF 970 case allegrosql.ComInitDB: 971 if err := cc.useDB(ctx, dataStr); err != nil { 972 return err 973 } 974 return cc.writeOK(ctx) 975 case allegrosql.ComQuery: // Most frequently used command. 976 // For issue 1989 977 // Input payload may end with byte '\0', we didn't find related allegrosql document about it, but allegrosql 978 // implementation accept that case. So trim the last '\0' here as if the payload an EOF string. 979 // See http://dev.allegrosql.com/doc/internals/en/com-query.html 980 if len(data) > 0 && data[len(data)-1] == 0 { 981 data = data[:len(data)-1] 982 dataStr = string(replog.String(data)) 983 } 984 return cc.handleQuery(ctx, dataStr) 985 case allegrosql.ComFieldList: 986 return cc.handleFieldList(ctx, dataStr) 987 // ComCreateDB, ComDroFIDelB 988 case allegrosql.ComRefresh: 989 return cc.handleRefresh(ctx, data[0]) 990 case allegrosql.ComShutdown: // redirect to ALLEGROALLEGROSQL 991 if err := cc.handleQuery(ctx, "SHUTDOWN"); err != nil { 992 return err 993 } 994 return cc.writeOK(ctx) 995 // ComStatistics, ComProcessInfo, ComConnect, ComProcessKill, ComDebug 996 case allegrosql.ComPing: 997 return cc.writeOK(ctx) 998 // ComTime, ComDelayedInsert 999 case allegrosql.ComChangeUser: 1000 return cc.handleChangeUser(ctx, data) 1001 // ComBinlogDump, ComBlockDump, ComConnectOut, ComRegisterSlave 1002 case allegrosql.ComStmtPrepare: 1003 return cc.handleStmtPrepare(ctx, dataStr) 1004 case allegrosql.ComStmtInterDircute: 1005 return cc.handleStmtInterDircute(ctx, data) 1006 case allegrosql.ComStmtSendLongData: 1007 return cc.handleStmtSendLongData(data) 1008 case allegrosql.ComStmtClose: 1009 return cc.handleStmtClose(data) 1010 case allegrosql.ComStmtReset: 1011 return cc.handleStmtReset(ctx, data) 1012 case allegrosql.ComSetOption: 1013 return cc.handleSetOption(ctx, data) 1014 case allegrosql.ComStmtFetch: 1015 return cc.handleStmtFetch(ctx, data) 1016 // ComDaemon, ComBinlogDumpGtid 1017 case allegrosql.ComResetConnection: 1018 return cc.handleResetConnection(ctx) 1019 // ComEnd 1020 default: 1021 return allegrosql.NewErrf(allegrosql.ErrUnknown, "command %d not supported now", cmd) 1022 } 1023 } 1024 1025 func (cc *clientConn) useDB(ctx context.Context, EDB string) (err error) { 1026 // if input is "use `SELECT`", allegrosql client just send "SELECT" 1027 // so we add `` around EDB. 1028 stmts, err := cc.ctx.Parse(ctx, "use `"+EDB+"`") 1029 if err != nil { 1030 return err 1031 } 1032 _, err = cc.ctx.InterDircuteStmt(ctx, stmts[0]) 1033 if err != nil { 1034 return err 1035 } 1036 cc.dbname = EDB 1037 return 1038 } 1039 1040 func (cc *clientConn) flush(ctx context.Context) error { 1041 defer trace.StartRegion(ctx, "FlushClientConn").End() 1042 failpoint.Inject("FakeClientConn", func() { 1043 if cc.pkt == nil { 1044 failpoint.Return(nil) 1045 } 1046 }) 1047 return cc.pkt.flush() 1048 } 1049 1050 func (cc *clientConn) writeOK(ctx context.Context) error { 1051 msg := cc.ctx.LastMessage() 1052 return cc.writeOkWith(ctx, msg, cc.ctx.AffectedRows(), cc.ctx.LastInsertID(), cc.ctx.Status(), cc.ctx.WarningCount()) 1053 } 1054 1055 func (cc *clientConn) writeOkWith(ctx context.Context, msg string, affectedRows, lastInsertID uint64, status, warnCnt uint16) error { 1056 enclen := 0 1057 if len(msg) > 0 { 1058 enclen = lengthEncodedIntSize(uint64(len(msg))) + len(msg) 1059 } 1060 1061 data := cc.alloc.AllocWithLen(4, 32+enclen) 1062 data = append(data, allegrosql.OKHeader) 1063 data = dumpLengthEncodedInt(data, affectedRows) 1064 data = dumpLengthEncodedInt(data, lastInsertID) 1065 if cc.capability&allegrosql.ClientProtodefCaus41 > 0 { 1066 data = dumpUint16(data, status) 1067 data = dumpUint16(data, warnCnt) 1068 } 1069 if enclen > 0 { 1070 // although MyALLEGROSQL manual says the info message is string<EOF>(https://dev.allegrosql.com/doc/internals/en/packet-OK_Packet.html), 1071 // it is actually string<lenenc> 1072 data = dumpLengthEncodedString(data, []byte(msg)) 1073 } 1074 1075 err := cc.writePacket(data) 1076 if err != nil { 1077 return err 1078 } 1079 1080 return cc.flush(ctx) 1081 } 1082 1083 func (cc *clientConn) writeError(ctx context.Context, e error) error { 1084 var ( 1085 m *allegrosql.ALLEGROSQLError 1086 te *terror.Error 1087 ok bool 1088 ) 1089 originErr := errors.Cause(e) 1090 if te, ok = originErr.(*terror.Error); ok { 1091 m = terror.ToALLEGROSQLError(te) 1092 } else { 1093 e := errors.Cause(originErr) 1094 switch y := e.(type) { 1095 case *terror.Error: 1096 m = terror.ToALLEGROSQLError(y) 1097 default: 1098 m = allegrosql.NewErrf(allegrosql.ErrUnknown, "%s", e.Error()) 1099 } 1100 } 1101 1102 cc.lastCode = m.Code 1103 data := cc.alloc.AllocWithLen(4, 16+len(m.Message)) 1104 data = append(data, allegrosql.ErrHeader) 1105 data = append(data, byte(m.Code), byte(m.Code>>8)) 1106 if cc.capability&allegrosql.ClientProtodefCaus41 > 0 { 1107 data = append(data, '#') 1108 data = append(data, m.State...) 1109 } 1110 1111 data = append(data, m.Message...) 1112 1113 err := cc.writePacket(data) 1114 if err != nil { 1115 return err 1116 } 1117 return cc.flush(ctx) 1118 } 1119 1120 // writeEOF writes an EOF packet. 1121 // Note this function won't flush the stream because maybe there are more 1122 // packets following it. 1123 // serverStatus, a flag bit represents server information 1124 // in the packet. 1125 func (cc *clientConn) writeEOF(serverStatus uint16) error { 1126 data := cc.alloc.AllocWithLen(4, 9) 1127 1128 data = append(data, allegrosql.EOFHeader) 1129 if cc.capability&allegrosql.ClientProtodefCaus41 > 0 { 1130 data = dumpUint16(data, cc.ctx.WarningCount()) 1131 status := cc.ctx.Status() 1132 status |= serverStatus 1133 data = dumpUint16(data, status) 1134 } 1135 1136 err := cc.writePacket(data) 1137 return err 1138 } 1139 1140 func (cc *clientConn) writeReq(ctx context.Context, filePath string) error { 1141 data := cc.alloc.AllocWithLen(4, 5+len(filePath)) 1142 data = append(data, allegrosql.LocalInFileHeader) 1143 data = append(data, filePath...) 1144 1145 err := cc.writePacket(data) 1146 if err != nil { 1147 return err 1148 } 1149 1150 return cc.flush(ctx) 1151 } 1152 1153 func insertDataWithCommit(ctx context.Context, prevData, 1154 curData []byte, loadDataInfo *interlock.LoadDataInfo) ([]byte, error) { 1155 var err error 1156 var reachLimit bool 1157 for { 1158 prevData, reachLimit, err = loadDataInfo.InsertData(ctx, prevData, curData) 1159 if err != nil { 1160 return nil, err 1161 } 1162 if !reachLimit { 1163 break 1164 } 1165 // push into commit task queue 1166 err = loadDataInfo.EnqOneTask(ctx) 1167 if err != nil { 1168 return prevData, err 1169 } 1170 curData = prevData 1171 prevData = nil 1172 } 1173 return prevData, nil 1174 } 1175 1176 // processStream process input stream from network 1177 func processStream(ctx context.Context, cc *clientConn, loadDataInfo *interlock.LoadDataInfo, wg *sync.WaitGroup) { 1178 var err error 1179 var shouldBreak bool 1180 var prevData, curData []byte 1181 defer func() { 1182 r := recover() 1183 if r != nil { 1184 logutil.Logger(ctx).Error("process routine panicked", 1185 zap.Reflect("r", r), 1186 zap.Stack("stack")) 1187 } 1188 if err != nil || r != nil { 1189 loadDataInfo.ForceQuit() 1190 } else { 1191 loadDataInfo.CloseTaskQueue() 1192 } 1193 wg.Done() 1194 }() 1195 for { 1196 curData, err = cc.readPacket() 1197 if err != nil { 1198 if terror.ErrorNotEqual(err, io.EOF) { 1199 logutil.Logger(ctx).Error("read packet failed", zap.Error(err)) 1200 break 1201 } 1202 } 1203 if len(curData) == 0 { 1204 loadDataInfo.Drained = true 1205 shouldBreak = true 1206 if len(prevData) == 0 { 1207 break 1208 } 1209 } 1210 select { 1211 case <-loadDataInfo.QuitCh: 1212 err = errors.New("processStream forced to quit") 1213 default: 1214 } 1215 if err != nil { 1216 break 1217 } 1218 // prepare batch and enqueue task 1219 prevData, err = insertDataWithCommit(ctx, prevData, curData, loadDataInfo) 1220 if err != nil { 1221 break 1222 } 1223 if shouldBreak { 1224 break 1225 } 1226 } 1227 if err != nil { 1228 logutil.Logger(ctx).Error("load data process stream error", zap.Error(err)) 1229 } else { 1230 err = loadDataInfo.EnqOneTask(ctx) 1231 if err != nil { 1232 logutil.Logger(ctx).Error("load data process stream error", zap.Error(err)) 1233 } 1234 } 1235 } 1236 1237 // handleLoadData does the additional work after processing the 'load data' query. 1238 // It sends client a file path, then reads the file content from client, inserts data into database. 1239 func (cc *clientConn) handleLoadData(ctx context.Context, loadDataInfo *interlock.LoadDataInfo) error { 1240 // If the server handles the load data request, the client has to set the ClientLocalFiles capability. 1241 if cc.capability&allegrosql.ClientLocalFiles == 0 { 1242 return errNotAllowedCommand 1243 } 1244 if loadDataInfo == nil { 1245 return errors.New("load data info is empty") 1246 } 1247 1248 err := cc.writeReq(ctx, loadDataInfo.Path) 1249 if err != nil { 1250 return err 1251 } 1252 1253 loadDataInfo.InitQueues() 1254 loadDataInfo.SetMaxRowsInBatch(uint64(loadDataInfo.Ctx.GetStochastikVars().DMLBatchSize)) 1255 loadDataInfo.StartStopWatcher() 1256 // let stop watcher goroutine quit 1257 defer loadDataInfo.ForceQuit() 1258 err = loadDataInfo.Ctx.NewTxn(ctx) 1259 if err != nil { 1260 return err 1261 } 1262 // processStream process input data, enqueue commit task 1263 wg := new(sync.WaitGroup) 1264 wg.Add(1) 1265 go processStream(ctx, cc, loadDataInfo, wg) 1266 err = loadDataInfo.CommitWork(ctx) 1267 wg.Wait() 1268 if err != nil { 1269 if !loadDataInfo.Drained { 1270 logutil.Logger(ctx).Info("not drained yet, try reading left data from client connection") 1271 } 1272 // drain the data from client conn soliton empty packet received, otherwise the connection will be reset 1273 for !loadDataInfo.Drained { 1274 // check kill flag again, let the draining loop could quit if empty packet could not be received 1275 if atomic.CompareAndSwapUint32(&loadDataInfo.Ctx.GetStochastikVars().Killed, 1, 0) { 1276 logutil.Logger(ctx).Warn("receiving kill, stop draining data, connection may be reset") 1277 return interlock.ErrQueryInterrupted 1278 } 1279 curData, err1 := cc.readPacket() 1280 if err1 != nil { 1281 logutil.Logger(ctx).Error("drain reading left data encounter errors", zap.Error(err1)) 1282 break 1283 } 1284 if len(curData) == 0 { 1285 loadDataInfo.Drained = true 1286 logutil.Logger(ctx).Info("draining finished for error", zap.Error(err)) 1287 break 1288 } 1289 } 1290 } 1291 loadDataInfo.SetMessage() 1292 return err 1293 } 1294 1295 // getDataFromPath gets file contents from file path. 1296 func (cc *clientConn) getDataFromPath(ctx context.Context, path string) ([]byte, error) { 1297 err := cc.writeReq(ctx, path) 1298 if err != nil { 1299 return nil, err 1300 } 1301 var prevData, curData []byte 1302 for { 1303 curData, err = cc.readPacket() 1304 if err != nil && terror.ErrorNotEqual(err, io.EOF) { 1305 return nil, err 1306 } 1307 if len(curData) == 0 { 1308 break 1309 } 1310 prevData = append(prevData, curData...) 1311 } 1312 return prevData, nil 1313 } 1314 1315 // handleLoadStats does the additional work after processing the 'load stats' query. 1316 // It sends client a file path, then reads the file content from client, loads it into the storage. 1317 func (cc *clientConn) handleLoadStats(ctx context.Context, loadStatsInfo *interlock.LoadStatsInfo) error { 1318 // If the server handles the load data request, the client has to set the ClientLocalFiles capability. 1319 if cc.capability&allegrosql.ClientLocalFiles == 0 { 1320 return errNotAllowedCommand 1321 } 1322 if loadStatsInfo == nil { 1323 return errors.New("load stats: info is empty") 1324 } 1325 data, err := cc.getDataFromPath(ctx, loadStatsInfo.Path) 1326 if err != nil { 1327 return err 1328 } 1329 if len(data) == 0 { 1330 return nil 1331 } 1332 return loadStatsInfo.UFIDelate(data) 1333 } 1334 1335 // handleIndexAdvise does the index advise work and returns the advise result for index. 1336 func (cc *clientConn) handleIndexAdvise(ctx context.Context, indexAdviseInfo *interlock.IndexAdviseInfo) error { 1337 if cc.capability&allegrosql.ClientLocalFiles == 0 { 1338 return errNotAllowedCommand 1339 } 1340 if indexAdviseInfo == nil { 1341 return errors.New("Index Advise: info is empty") 1342 } 1343 1344 data, err := cc.getDataFromPath(ctx, indexAdviseInfo.Path) 1345 if err != nil { 1346 return err 1347 } 1348 if len(data) == 0 { 1349 return errors.New("Index Advise: infile is empty") 1350 } 1351 1352 if err := indexAdviseInfo.GetIndexAdvice(ctx, data); err != nil { 1353 return err 1354 } 1355 1356 // TODO: Write the rss []ResultSet. It will be done in another PR. 1357 return nil 1358 } 1359 1360 // handleQuery executes the allegrosql query string and writes result set or result ok to the client. 1361 // As the execution time of this function represents the performance of MilevaDB, we do time log and metrics here. 1362 // There is a special query `load data` that does not return result, which is handled differently. 1363 // Query `load stats` does not return result either. 1364 func (cc *clientConn) handleQuery(ctx context.Context, allegrosql string) (err error) { 1365 defer trace.StartRegion(ctx, "handleQuery").End() 1366 sc := cc.ctx.GetStochastikVars().StmtCtx 1367 prevWarns := sc.GetWarnings() 1368 stmts, err := cc.ctx.Parse(ctx, allegrosql) 1369 if err != nil { 1370 metrics.InterDircuteErrorCounter.WithLabelValues(metrics.InterDircuteErrorToLabel(err)).Inc() 1371 return err 1372 } 1373 1374 if len(stmts) == 0 { 1375 return cc.writeOK(ctx) 1376 } 1377 1378 warns := sc.GetWarnings() 1379 BerolinaSQLWarns := warns[len(prevWarns):] 1380 1381 var pointCausets []causetembedded.Causet 1382 if len(stmts) > 1 { 1383 // Only pre-build point plans for multi-memex query 1384 pointCausets, err = cc.prefetchPointCausetKeys(ctx, stmts) 1385 if err != nil { 1386 return err 1387 } 1388 } 1389 if len(pointCausets) > 0 { 1390 defer cc.ctx.ClearValue(causetembedded.PointCausetKey) 1391 } 1392 for i, stmt := range stmts { 1393 if len(pointCausets) > 0 { 1394 // Save the point plan in Stochastik so we don't need to build the point plan again. 1395 cc.ctx.SetValue(causetembedded.PointCausetKey, causetembedded.PointCausetVal{Causet: pointCausets[i]}) 1396 } 1397 err = cc.handleStmt(ctx, stmt, BerolinaSQLWarns, i == len(stmts)-1) 1398 if err != nil { 1399 break 1400 } 1401 } 1402 if err != nil { 1403 metrics.InterDircuteErrorCounter.WithLabelValues(metrics.InterDircuteErrorToLabel(err)).Inc() 1404 } 1405 return err 1406 } 1407 1408 // prefetchPointCausetKeys extracts the point keys in multi-memex query, 1409 // use BatchGet to get the keys, so the values will be cached in the snapshot cache, save RPC call cost. 1410 // For pessimistic transaction, the keys will be batch locked. 1411 func (cc *clientConn) prefetchPointCausetKeys(ctx context.Context, stmts []ast.StmtNode) ([]causetembedded.Causet, error) { 1412 txn, err := cc.ctx.Txn(false) 1413 if err != nil { 1414 return nil, err 1415 } 1416 if !txn.Valid() { 1417 // Only prefetch in-transaction query for simplicity. 1418 // Later we can support out-transaction multi-memex query. 1419 return nil, nil 1420 } 1421 vars := cc.ctx.GetStochastikVars() 1422 if vars.TxnCtx.IsPessimistic { 1423 if vars.IsReadConsistencyTxn() { 1424 // TODO: to support READ-COMMITTED, we need to avoid getting new TS for each memex in the query. 1425 return nil, nil 1426 } 1427 if vars.TxnCtx.GetForUFIDelateTS() != vars.TxnCtx.StartTS { 1428 // Do not handle the case that ForUFIDelateTS is changed for simplicity. 1429 return nil, nil 1430 } 1431 } 1432 pointCausets := make([]causetembedded.Causet, len(stmts)) 1433 var idxKeys []ekv.Key 1434 var rowKeys []ekv.Key 1435 is := petri.GetPetri(cc.ctx).SchemaReplicant() 1436 sc := vars.StmtCtx 1437 for i, stmt := range stmts { 1438 // TODO: the preprocess is run twice, we should find some way to avoid do it again. 1439 if err = causetembedded.Preprocess(cc.ctx, stmt, is); err != nil { 1440 return nil, err 1441 } 1442 p := causetembedded.TryFastCauset(cc.ctx.Stochastik, stmt) 1443 pointCausets[i] = p 1444 if p == nil { 1445 continue 1446 } 1447 // Only support UFIDelate for now. 1448 // TODO: support other point plans. 1449 switch x := p.(type) { 1450 case *causetembedded.UFIDelate: 1451 uFIDelateStmt := stmt.(*ast.UFIDelateStmt) 1452 if pp, ok := x.SelectCauset.(*causetembedded.PointGetCauset); ok { 1453 if pp.PartitionInfo != nil { 1454 continue 1455 } 1456 if pp.IndexInfo != nil { 1457 interlock.ResetUFIDelateStmtCtx(sc, uFIDelateStmt, vars) 1458 idxKey, err1 := interlock.EncodeUniqueIndexKey(cc.ctx, pp.TblInfo, pp.IndexInfo, pp.IndexValues, pp.TblInfo.ID) 1459 if err1 != nil { 1460 return nil, err1 1461 } 1462 idxKeys = append(idxKeys, idxKey) 1463 } else { 1464 rowKeys = append(rowKeys, blockcodec.EncodeRowKeyWithHandle(pp.TblInfo.ID, pp.Handle)) 1465 } 1466 } 1467 } 1468 } 1469 if len(idxKeys) == 0 && len(rowKeys) == 0 { 1470 return pointCausets, nil 1471 } 1472 snapshot := txn.GetSnapshot() 1473 idxVals, err1 := snapshot.BatchGet(ctx, idxKeys) 1474 if err1 != nil { 1475 return nil, err1 1476 } 1477 for idxKey, idxVal := range idxVals { 1478 h, err2 := blockcodec.DecodeHandleInUniqueIndexValue(idxVal, false) 1479 if err2 != nil { 1480 return nil, err2 1481 } 1482 tblID := blockcodec.DecodeBlockID(replog.Slice(idxKey)) 1483 rowKeys = append(rowKeys, blockcodec.EncodeRowKeyWithHandle(tblID, h)) 1484 } 1485 if vars.TxnCtx.IsPessimistic { 1486 allKeys := append(rowKeys, idxKeys...) 1487 err = interlock.LockKeys(ctx, cc.ctx, vars.LockWaitTimeout, allKeys...) 1488 if err != nil { 1489 // suppress the dagger error, we are not going to handle it here for simplicity. 1490 err = nil 1491 logutil.BgLogger().Warn("dagger keys error on prefetch", zap.Error(err)) 1492 } 1493 } else { 1494 _, err = snapshot.BatchGet(ctx, rowKeys) 1495 if err != nil { 1496 return nil, err 1497 } 1498 } 1499 return pointCausets, nil 1500 } 1501 1502 func (cc *clientConn) handleStmt(ctx context.Context, stmt ast.StmtNode, warns []stmtctx.ALLEGROSQLWarn, lastStmt bool) error { 1503 ctx = context.WithValue(ctx, execdetails.StmtInterDircDetailKey, &execdetails.StmtInterDircDetails{}) 1504 reg := trace.StartRegion(ctx, "InterDircuteStmt") 1505 rs, err := cc.ctx.InterDircuteStmt(ctx, stmt) 1506 reg.End() 1507 // The stochastik tracker detachment from global tracker is solved in the `rs.Close` in most cases. 1508 // If the rs is nil, the detachment will be done in the `handleNoDelay`. 1509 if rs != nil { 1510 defer terror.Call(rs.Close) 1511 } 1512 if err != nil { 1513 return err 1514 } 1515 1516 if lastStmt { 1517 cc.ctx.GetStochastikVars().StmtCtx.AppendWarnings(warns) 1518 } 1519 1520 status := cc.ctx.Status() 1521 if !lastStmt { 1522 status |= allegrosql.ServerMoreResultsExists 1523 } 1524 1525 if rs != nil { 1526 connStatus := atomic.LoadInt32(&cc.status) 1527 if connStatus == connStatusShutdown || connStatus == connStatusWaitShutdown { 1528 return interlock.ErrQueryInterrupted 1529 } 1530 1531 err = cc.writeResultset(ctx, rs, false, status, 0) 1532 if err != nil { 1533 return err 1534 } 1535 } else { 1536 err = cc.handleQuerySpecial(ctx, status) 1537 if err != nil { 1538 return err 1539 } 1540 } 1541 return nil 1542 } 1543 1544 func (cc *clientConn) handleQuerySpecial(ctx context.Context, status uint16) error { 1545 loadDataInfo := cc.ctx.Value(interlock.LoadDataVarKey) 1546 if loadDataInfo != nil { 1547 defer cc.ctx.SetValue(interlock.LoadDataVarKey, nil) 1548 if err := cc.handleLoadData(ctx, loadDataInfo.(*interlock.LoadDataInfo)); err != nil { 1549 return err 1550 } 1551 } 1552 1553 loadStats := cc.ctx.Value(interlock.LoadStatsVarKey) 1554 if loadStats != nil { 1555 defer cc.ctx.SetValue(interlock.LoadStatsVarKey, nil) 1556 if err := cc.handleLoadStats(ctx, loadStats.(*interlock.LoadStatsInfo)); err != nil { 1557 return err 1558 } 1559 } 1560 1561 indexAdvise := cc.ctx.Value(interlock.IndexAdviseVarKey) 1562 if indexAdvise != nil { 1563 defer cc.ctx.SetValue(interlock.IndexAdviseVarKey, nil) 1564 if err := cc.handleIndexAdvise(ctx, indexAdvise.(*interlock.IndexAdviseInfo)); err != nil { 1565 return err 1566 } 1567 } 1568 return cc.writeOkWith(ctx, cc.ctx.LastMessage(), cc.ctx.AffectedRows(), cc.ctx.LastInsertID(), status, cc.ctx.WarningCount()) 1569 } 1570 1571 // handleFieldList returns the field list for a causet. 1572 // The allegrosql string is composed of a causet name and a terminating character \x00. 1573 func (cc *clientConn) handleFieldList(ctx context.Context, allegrosql string) (err error) { 1574 parts := strings.Split(allegrosql, "\x00") 1575 defCausumns, err := cc.ctx.FieldList(parts[0]) 1576 if err != nil { 1577 return err 1578 } 1579 data := cc.alloc.AllocWithLen(4, 1024) 1580 for _, defCausumn := range defCausumns { 1581 // Current we doesn't output defaultValue but reserve defaultValue length byte to make mariadb client happy. 1582 // https://dev.allegrosql.com/doc/internals/en/com-query-response.html#defCausumn-definition 1583 // TODO: fill the right DefaultValues. 1584 defCausumn.DefaultValueLength = 0 1585 defCausumn.DefaultValue = []byte{} 1586 1587 data = data[0:4] 1588 data = defCausumn.Dump(data) 1589 if err := cc.writePacket(data); err != nil { 1590 return err 1591 } 1592 } 1593 if err := cc.writeEOF(0); err != nil { 1594 return err 1595 } 1596 return cc.flush(ctx) 1597 } 1598 1599 // writeResultset writes data into a resultset and uses rs.Next to get event data back. 1600 // If binary is true, the data would be encoded in BINARY format. 1601 // serverStatus, a flag bit represents server information. 1602 // fetchSize, the desired number of rows to be fetched each time when client uses cursor. 1603 func (cc *clientConn) writeResultset(ctx context.Context, rs ResultSet, binary bool, serverStatus uint16, fetchSize int) (runErr error) { 1604 defer func() { 1605 // close ResultSet when cursor doesn't exist 1606 r := recover() 1607 if r == nil { 1608 return 1609 } 1610 if str, ok := r.(string); !ok || !strings.HasPrefix(str, memory.PanicMemoryExceed) { 1611 panic(r) 1612 } 1613 // TODO(jianzhang.zj: add metrics here) 1614 runErr = errors.Errorf("%v", r) 1615 buf := make([]byte, 4096) 1616 stackSize := runtime.Stack(buf, false) 1617 buf = buf[:stackSize] 1618 logutil.Logger(ctx).Error("write query result panic", zap.Stringer("lastALLEGROSQL", getLastStmtInConn{cc}), zap.String("stack", string(buf))) 1619 }() 1620 var err error 1621 if allegrosql.HasCursorExistsFlag(serverStatus) { 1622 err = cc.writeChunksWithFetchSize(ctx, rs, serverStatus, fetchSize) 1623 } else { 1624 err = cc.writeChunks(ctx, rs, binary, serverStatus) 1625 } 1626 if err != nil { 1627 return err 1628 } 1629 1630 return cc.flush(ctx) 1631 } 1632 1633 func (cc *clientConn) writeDeferredCausetInfo(defCausumns []*DeferredCausetInfo, serverStatus uint16) error { 1634 data := cc.alloc.AllocWithLen(4, 1024) 1635 data = dumpLengthEncodedInt(data, uint64(len(defCausumns))) 1636 if err := cc.writePacket(data); err != nil { 1637 return err 1638 } 1639 for _, v := range defCausumns { 1640 data = data[0:4] 1641 data = v.Dump(data) 1642 if err := cc.writePacket(data); err != nil { 1643 return err 1644 } 1645 } 1646 return cc.writeEOF(serverStatus) 1647 } 1648 1649 // writeChunks writes data from a Chunk, which filled data by a ResultSet, into a connection. 1650 // binary specifies the way to dump data. It throws any error while dumping data. 1651 // serverStatus, a flag bit represents server information 1652 func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool, serverStatus uint16) error { 1653 data := cc.alloc.AllocWithLen(4, 1024) 1654 req := rs.NewChunk() 1655 gotDeferredCausetInfo := false 1656 var stmtDetail *execdetails.StmtInterDircDetails 1657 stmtDetailRaw := ctx.Value(execdetails.StmtInterDircDetailKey) 1658 if stmtDetailRaw != nil { 1659 stmtDetail = stmtDetailRaw.(*execdetails.StmtInterDircDetails) 1660 } 1661 1662 for { 1663 // Here server.milevadbResultSet implements Next method. 1664 err := rs.Next(ctx, req) 1665 if err != nil { 1666 return err 1667 } 1668 if !gotDeferredCausetInfo { 1669 // We need to call Next before we get defCausumns. 1670 // Otherwise, we will get incorrect defCausumns info. 1671 defCausumns := rs.DeferredCausets() 1672 err = cc.writeDeferredCausetInfo(defCausumns, serverStatus) 1673 if err != nil { 1674 return err 1675 } 1676 gotDeferredCausetInfo = true 1677 } 1678 rowCount := req.NumRows() 1679 if rowCount == 0 { 1680 break 1681 } 1682 reg := trace.StartRegion(ctx, "WriteClientConn") 1683 start := time.Now() 1684 for i := 0; i < rowCount; i++ { 1685 data = data[0:4] 1686 if binary { 1687 data, err = dumpBinaryRow(data, rs.DeferredCausets(), req.GetRow(i)) 1688 } else { 1689 data, err = dumpTextRow(data, rs.DeferredCausets(), req.GetRow(i)) 1690 } 1691 if err != nil { 1692 reg.End() 1693 return err 1694 } 1695 if err = cc.writePacket(data); err != nil { 1696 reg.End() 1697 return err 1698 } 1699 } 1700 reg.End() 1701 if stmtDetail != nil { 1702 stmtDetail.WriteALLEGROSQLResFIDeluration += time.Since(start) 1703 } 1704 } 1705 return cc.writeEOF(serverStatus) 1706 } 1707 1708 // writeChunksWithFetchSize writes data from a Chunk, which filled data by a ResultSet, into a connection. 1709 // binary specifies the way to dump data. It throws any error while dumping data. 1710 // serverStatus, a flag bit represents server information. 1711 // fetchSize, the desired number of rows to be fetched each time when client uses cursor. 1712 func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet, serverStatus uint16, fetchSize int) error { 1713 fetchedRows := rs.GetFetchedRows() 1714 1715 // if fetchedRows is not enough, getting data from recordSet. 1716 req := rs.NewChunk() 1717 for len(fetchedRows) < fetchSize { 1718 // Here server.milevadbResultSet implements Next method. 1719 err := rs.Next(ctx, req) 1720 if err != nil { 1721 return err 1722 } 1723 rowCount := req.NumRows() 1724 if rowCount == 0 { 1725 break 1726 } 1727 // filling fetchedRows with chunk 1728 for i := 0; i < rowCount; i++ { 1729 fetchedRows = append(fetchedRows, req.GetRow(i)) 1730 } 1731 req = chunk.Renew(req, cc.ctx.GetStochastikVars().MaxChunkSize) 1732 } 1733 1734 // tell the client COM_STMT_FETCH has finished by setting proper serverStatus, 1735 // and close ResultSet. 1736 if len(fetchedRows) == 0 { 1737 serverStatus &^= allegrosql.ServerStatusCursorExists 1738 serverStatus |= allegrosql.ServerStatusLastRowSend 1739 terror.Call(rs.Close) 1740 return cc.writeEOF(serverStatus) 1741 } 1742 1743 // construct the rows sent to the client according to fetchSize. 1744 var curRows []chunk.Row 1745 if fetchSize < len(fetchedRows) { 1746 curRows = fetchedRows[:fetchSize] 1747 fetchedRows = fetchedRows[fetchSize:] 1748 } else { 1749 curRows = fetchedRows[:] 1750 fetchedRows = fetchedRows[:0] 1751 } 1752 rs.StoreFetchedRows(fetchedRows) 1753 1754 data := cc.alloc.AllocWithLen(4, 1024) 1755 var stmtDetail *execdetails.StmtInterDircDetails 1756 stmtDetailRaw := ctx.Value(execdetails.StmtInterDircDetailKey) 1757 if stmtDetailRaw != nil { 1758 stmtDetail = stmtDetailRaw.(*execdetails.StmtInterDircDetails) 1759 } 1760 start := time.Now() 1761 var err error 1762 for _, event := range curRows { 1763 data = data[0:4] 1764 data, err = dumpBinaryRow(data, rs.DeferredCausets(), event) 1765 if err != nil { 1766 return err 1767 } 1768 if err = cc.writePacket(data); err != nil { 1769 return err 1770 } 1771 } 1772 if stmtDetail != nil { 1773 stmtDetail.WriteALLEGROSQLResFIDeluration += time.Since(start) 1774 } 1775 if cl, ok := rs.(fetchNotifier); ok { 1776 cl.OnFetchReturned() 1777 } 1778 return cc.writeEOF(serverStatus) 1779 } 1780 1781 func (cc *clientConn) writeMultiResultset(ctx context.Context, rss []ResultSet, binary bool) error { 1782 for i, rs := range rss { 1783 lastRs := i == len(rss)-1 1784 if r, ok := rs.(*milevadbResultSet).recordSet.(sqlexec.MultiQueryNoDelayResult); ok { 1785 status := r.Status() 1786 if !lastRs { 1787 status |= allegrosql.ServerMoreResultsExists 1788 } 1789 if err := cc.writeOkWith(ctx, r.LastMessage(), r.AffectedRows(), r.LastInsertID(), status, r.WarnCount()); err != nil { 1790 return err 1791 } 1792 continue 1793 } 1794 status := uint16(0) 1795 if !lastRs { 1796 status |= allegrosql.ServerMoreResultsExists 1797 } 1798 if err := cc.writeResultset(ctx, rs, binary, status, 0); err != nil { 1799 return err 1800 } 1801 } 1802 return nil 1803 } 1804 1805 func (cc *clientConn) setConn(conn net.Conn) { 1806 cc.bufReadConn = newBufferedReadConn(conn) 1807 if cc.pkt == nil { 1808 cc.pkt = newPacketIO(cc.bufReadConn) 1809 } else { 1810 // Preserve current sequence number. 1811 cc.pkt.setBufferedReadConn(cc.bufReadConn) 1812 } 1813 } 1814 1815 func (cc *clientConn) upgradeToTLS(tlsConfig *tls.Config) error { 1816 // Important: read from buffered reader instead of the original net.Conn because it may contain data we need. 1817 tlsConn := tls.Server(cc.bufReadConn, tlsConfig) 1818 if err := tlsConn.Handshake(); err != nil { 1819 return err 1820 } 1821 cc.setConn(tlsConn) 1822 cc.tlsConn = tlsConn 1823 return nil 1824 } 1825 1826 func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error { 1827 user, data := parseNullTermString(data) 1828 cc.user = string(replog.String(user)) 1829 if len(data) < 1 { 1830 return allegrosql.ErrMalformPacket 1831 } 1832 passLen := int(data[0]) 1833 data = data[1:] 1834 if passLen > len(data) { 1835 return allegrosql.ErrMalformPacket 1836 } 1837 pass := data[:passLen] 1838 data = data[passLen:] 1839 dbName, _ := parseNullTermString(data) 1840 cc.dbname = string(replog.String(dbName)) 1841 1842 err := cc.ctx.Close() 1843 if err != nil { 1844 logutil.Logger(ctx).Debug("close old context failed", zap.Error(err)) 1845 } 1846 err = cc.openStochastikAndDoAuth(pass) 1847 if err != nil { 1848 return err 1849 } 1850 return cc.handleCommonConnectionReset(ctx) 1851 } 1852 1853 func (cc *clientConn) handleResetConnection(ctx context.Context) error { 1854 user := cc.ctx.GetStochastikVars().User 1855 err := cc.ctx.Close() 1856 if err != nil { 1857 logutil.Logger(ctx).Debug("close old context failed", zap.Error(err)) 1858 } 1859 var tlsStatePtr *tls.ConnectionState 1860 if cc.tlsConn != nil { 1861 tlsState := cc.tlsConn.ConnectionState() 1862 tlsStatePtr = &tlsState 1863 } 1864 cc.ctx, err = cc.server.driver.OpenCtx(uint64(cc.connectionID), cc.capability, cc.defCauslation, cc.dbname, tlsStatePtr) 1865 if err != nil { 1866 return err 1867 } 1868 if !cc.ctx.AuthWithoutVerification(user) { 1869 return errors.New("Could not reset connection") 1870 } 1871 if cc.dbname != "" { // Restore the current EDB 1872 err = cc.useDB(context.Background(), cc.dbname) 1873 if err != nil { 1874 return err 1875 } 1876 } 1877 cc.ctx.SetStochastikManager(cc.server) 1878 1879 return cc.handleCommonConnectionReset(ctx) 1880 } 1881 1882 func (cc *clientConn) handleCommonConnectionReset(ctx context.Context) error { 1883 if plugin.IsEnable(plugin.Audit) { 1884 cc.ctx.GetStochastikVars().ConnectionInfo = cc.connectInfo() 1885 } 1886 1887 err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { 1888 authPlugin := plugin.DeclareAuditManifest(p.Manifest) 1889 if authPlugin.OnConnectionEvent != nil { 1890 connInfo := cc.ctx.GetStochastikVars().ConnectionInfo 1891 err := authPlugin.OnConnectionEvent(context.Background(), plugin.ChangeUser, connInfo) 1892 if err != nil { 1893 return err 1894 } 1895 } 1896 return nil 1897 }) 1898 if err != nil { 1899 return err 1900 } 1901 return cc.writeOK(ctx) 1902 } 1903 1904 // safe to noop except 0x01 "FLUSH PRIVILEGES" 1905 func (cc *clientConn) handleRefresh(ctx context.Context, subCommand byte) error { 1906 if subCommand == 0x01 { 1907 if err := cc.handleQuery(ctx, "FLUSH PRIVILEGES"); err != nil { 1908 return err 1909 } 1910 } 1911 return cc.writeOK(ctx) 1912 } 1913 1914 var _ fmt.Stringer = getLastStmtInConn{} 1915 1916 type getLastStmtInConn struct { 1917 *clientConn 1918 } 1919 1920 func (cc getLastStmtInConn) String() string { 1921 if len(cc.lastPacket) == 0 { 1922 return "" 1923 } 1924 cmd, data := cc.lastPacket[0], cc.lastPacket[1:] 1925 switch cmd { 1926 case allegrosql.ComInitDB: 1927 return "Use " + string(data) 1928 case allegrosql.ComFieldList: 1929 return "ListFields " + string(data) 1930 case allegrosql.ComQuery, allegrosql.ComStmtPrepare: 1931 allegrosql := string(replog.String(data)) 1932 if config.RedactLogEnabled() { 1933 allegrosql, _ = BerolinaSQL.NormalizeDigest(allegrosql) 1934 } 1935 return queryStrForLog(allegrosql) 1936 case allegrosql.ComStmtInterDircute, allegrosql.ComStmtFetch: 1937 stmtID := binary.LittleEndian.Uint32(data[0:4]) 1938 return queryStrForLog(cc.preparedStmt2String(stmtID)) 1939 case allegrosql.ComStmtClose, allegrosql.ComStmtReset: 1940 stmtID := binary.LittleEndian.Uint32(data[0:4]) 1941 return allegrosql.Command2Str[cmd] + " " + strconv.Itoa(int(stmtID)) 1942 default: 1943 if cmdStr, ok := allegrosql.Command2Str[cmd]; ok { 1944 return cmdStr 1945 } 1946 return string(replog.String(data)) 1947 } 1948 } 1949 1950 // PProfLabel return allegrosql label used to tag pprof. 1951 func (cc getLastStmtInConn) PProfLabel() string { 1952 if len(cc.lastPacket) == 0 { 1953 return "" 1954 } 1955 cmd, data := cc.lastPacket[0], cc.lastPacket[1:] 1956 switch cmd { 1957 case allegrosql.ComInitDB: 1958 return "UseDB" 1959 case allegrosql.ComFieldList: 1960 return "ListFields" 1961 case allegrosql.ComStmtClose: 1962 return "CloseStmt" 1963 case allegrosql.ComStmtReset: 1964 return "ResetStmt" 1965 case allegrosql.ComQuery, allegrosql.ComStmtPrepare: 1966 return BerolinaSQL.Normalize(queryStrForLog(string(replog.String(data)))) 1967 case allegrosql.ComStmtInterDircute, allegrosql.ComStmtFetch: 1968 stmtID := binary.LittleEndian.Uint32(data[0:4]) 1969 return queryStrForLog(cc.preparedStmt2StringNoArgs(stmtID)) 1970 default: 1971 return "" 1972 } 1973 }