github.com/matrixorigin/matrixone@v1.2.0/pkg/frontend/protocol.go (about) 1 // Copyright 2021 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package frontend 16 17 import ( 18 "context" 19 "fmt" 20 "math" 21 "sync" 22 "sync/atomic" 23 24 "github.com/matrixorigin/matrixone/pkg/common/moerr" 25 "github.com/matrixorigin/matrixone/pkg/vm/process" 26 27 "github.com/fagongzi/goetty/v2" 28 29 "github.com/matrixorigin/matrixone/pkg/logutil" 30 ) 31 32 // Response Categories 33 const ( 34 // OkResponse OK message 35 OkResponse = iota 36 // ErrorResponse Error message 37 ErrorResponse 38 // EoFResponse EOF message 39 EoFResponse 40 // ResultResponse result message 41 ResultResponse 42 // LocalInfileRequest local infile message 43 LocalInfileRequest 44 ) 45 46 type Request struct { 47 //the command type from the client 48 cmd CommandType 49 // sequence num 50 seq uint8 51 //the data from the client 52 data interface{} 53 } 54 55 func (req *Request) GetData() interface{} { 56 return req.data 57 } 58 59 func (req *Request) SetData(data interface{}) { 60 req.data = data 61 } 62 63 func (req *Request) GetCmd() CommandType { 64 return req.cmd 65 } 66 67 func (req *Request) SetCmd(cmd CommandType) { 68 req.cmd = cmd 69 } 70 71 type Response struct { 72 //the category of the response 73 category int 74 //the status of executing the peer request 75 status uint16 76 //the command type which generates the response 77 cmd int 78 //the data of the response 79 data interface{} 80 81 /* 82 ok response 83 */ 84 affectedRows, lastInsertId uint64 85 warnings uint16 86 } 87 88 func NewResponse(category int, affectedRows, lastInsertId uint64, warnings, status uint16, cmd int, d interface{}) *Response { 89 return &Response{ 90 category: category, 91 affectedRows: affectedRows, 92 lastInsertId: lastInsertId, 93 warnings: warnings, 94 status: status, 95 cmd: cmd, 96 data: d, 97 } 98 } 99 100 func NewGeneralErrorResponse(cmd CommandType, status uint16, err error) *Response { 101 return NewResponse(ErrorResponse, 0, 0, 0, status, int(cmd), err) 102 } 103 104 func NewGeneralOkResponse(cmd CommandType, status uint16) *Response { 105 return NewResponse(OkResponse, 0, 0, 0, status, int(cmd), nil) 106 } 107 108 func NewOkResponse(affectedRows, lastInsertId uint64, warnings, status uint16, cmd int, d interface{}) *Response { 109 return NewResponse(OkResponse, affectedRows, lastInsertId, warnings, status, cmd, d) 110 } 111 112 func (resp *Response) GetData() interface{} { 113 return resp.data 114 } 115 116 func (resp *Response) SetData(data interface{}) { 117 resp.data = data 118 } 119 120 func (resp *Response) GetStatus() uint16 { 121 return resp.status 122 } 123 124 func (resp *Response) SetStatus(status uint16) { 125 resp.status = status 126 } 127 128 func (resp *Response) GetCategory() int { 129 return resp.category 130 } 131 132 func (resp *Response) SetCategory(category int) { 133 resp.category = category 134 } 135 136 type Protocol interface { 137 IsEstablished() bool 138 139 SetEstablished() 140 141 // GetRequest gets Request from Packet 142 GetRequest(payload []byte) *Request 143 144 // SendResponse sends a response to the client for the application request 145 SendResponse(context.Context, *Response) error 146 147 // ConnectionID the identity of the client 148 ConnectionID() uint32 149 150 // Peer gets the address [Host:Port,Host:Port] of the client and the server 151 Peer() string 152 153 GetDatabaseName() string 154 155 SetDatabaseName(string) 156 157 GetUserName() string 158 159 SetUserName(string) 160 161 GetSequenceId() uint8 162 163 SetSequenceID(value uint8) 164 165 GetDebugString() string 166 167 GetTcpConnection() goetty.IOSession 168 169 GetCapability() uint32 170 171 SetCapability(uint32) 172 173 GetConnectAttrs() map[string]string 174 175 IsTlsEstablished() bool 176 177 SetTlsEstablished() 178 179 HandleHandshake(ctx context.Context, payload []byte) (bool, error) 180 181 Authenticate(ctx context.Context) error 182 183 SendPrepareResponse(ctx context.Context, stmt *PrepareStmt) error 184 185 Quit() 186 187 incDebugCount(int) 188 189 resetDebugCount() []uint64 190 191 UpdateCtx(context.Context) 192 } 193 194 type ProtocolImpl struct { 195 m sync.Mutex 196 197 io IOPackage 198 199 tcpConn goetty.IOSession 200 201 quit atomic.Bool 202 203 //random bytes 204 salt []byte 205 206 //the id of the connection 207 connectionID uint32 208 209 // whether the handshake succeeded 210 established atomic.Bool 211 212 // whether the tls handshake succeeded 213 tlsEstablished atomic.Bool 214 215 //The sequence-id is incremented with each packet and may wrap around. 216 //It starts at 0 and is reset to 0 when a new command begins in the Command Phase. 217 sequenceId atomic.Uint32 218 219 //for debug 220 debugCount [16]uint64 221 222 ctx context.Context 223 } 224 225 func (pi *ProtocolImpl) UpdateCtx(ctx context.Context) { 226 pi.ctx = ctx 227 } 228 229 func (pi *ProtocolImpl) incDebugCount(i int) { 230 if i >= 0 && i < len(pi.debugCount) { 231 atomic.AddUint64(&pi.debugCount[i], 1) 232 } 233 } 234 235 func (pi *ProtocolImpl) resetDebugCount() []uint64 { 236 ret := make([]uint64, len(pi.debugCount)) 237 for i := 0; i < len(pi.debugCount); i++ { 238 ret[i] = atomic.LoadUint64(&pi.debugCount[i]) 239 } 240 return ret 241 } 242 243 func (pi *ProtocolImpl) setQuit(b bool) bool { 244 return pi.quit.Swap(b) 245 } 246 247 func (pi *ProtocolImpl) GetSequenceId() uint8 { 248 return uint8(pi.sequenceId.Load()) 249 } 250 251 func (pi *ProtocolImpl) getDebugStringUnsafe() string { 252 if pi.tcpConn != nil { 253 return fmt.Sprintf("connectionId %d|%s", pi.connectionID, pi.tcpConn.RemoteAddress()) 254 } 255 return "" 256 } 257 258 func (pi *ProtocolImpl) GetDebugString() string { 259 pi.m.Lock() 260 defer pi.m.Unlock() 261 return pi.getDebugStringUnsafe() 262 } 263 264 func (pi *ProtocolImpl) GetSalt() []byte { 265 pi.m.Lock() 266 defer pi.m.Unlock() 267 return pi.salt 268 } 269 270 // SetSalt updates the salt value. This happens with proxy mode enabled. 271 func (pi *ProtocolImpl) SetSalt(s []byte) { 272 pi.m.Lock() 273 defer pi.m.Unlock() 274 pi.salt = s 275 } 276 277 func (pi *ProtocolImpl) IsEstablished() bool { 278 return pi.established.Load() 279 } 280 281 func (pi *ProtocolImpl) SetEstablished() { 282 logDebugf(pi.GetDebugString(), "SWITCH ESTABLISHED to true") 283 pi.established.Store(true) 284 } 285 286 func (pi *ProtocolImpl) IsTlsEstablished() bool { 287 return pi.tlsEstablished.Load() 288 } 289 290 func (pi *ProtocolImpl) SetTlsEstablished() { 291 logutil.Debugf("SWITCH TLS_ESTABLISHED to true") 292 pi.tlsEstablished.Store(true) 293 } 294 295 func (pi *ProtocolImpl) ConnectionID() uint32 { 296 return pi.connectionID 297 } 298 299 // Quit kill tcpConn still connected. 300 // before calling NewMysqlClientProtocol, tcpConn.Connected() must be true 301 // please check goetty/application.go::doStart() and goetty/application.go::NewIOSession(...) for details 302 func (pi *ProtocolImpl) Quit() { 303 //if it was quit, do nothing 304 if pi.setQuit(true) { 305 return 306 } 307 if pi.tcpConn != nil { 308 if err := pi.tcpConn.Disconnect(); err != nil { 309 return 310 } 311 } 312 //release salt 313 if pi.salt != nil { 314 pi.salt = nil 315 } 316 } 317 318 func (pi *ProtocolImpl) GetTcpConnection() goetty.IOSession { 319 return pi.tcpConn 320 } 321 322 func (pi *ProtocolImpl) Peer() string { 323 tcp := pi.GetTcpConnection() 324 if tcp == nil { 325 return "" 326 } 327 return tcp.RemoteAddress() 328 } 329 330 func (mp *MysqlProtocolImpl) GetRequest(payload []byte) *Request { 331 req := &Request{ 332 cmd: CommandType(payload[0]), 333 data: payload[1:], 334 } 335 336 return req 337 } 338 339 func (mp *MysqlProtocolImpl) SendResponse(ctx context.Context, resp *Response) error { 340 //move here to prohibit potential recursive lock 341 var attachAbort string 342 343 mp.m.Lock() 344 defer mp.m.Unlock() 345 346 switch resp.category { 347 case OkResponse: 348 s, ok := resp.data.(string) 349 if !ok { 350 return mp.sendOKPacket(resp.affectedRows, resp.lastInsertId, uint16(resp.status), resp.warnings, "") 351 } 352 return mp.sendOKPacket(resp.affectedRows, resp.lastInsertId, uint16(resp.status), resp.warnings, s) 353 case EoFResponse: 354 return mp.sendEOFPacket(0, uint16(resp.status)) 355 case ErrorResponse: 356 err := resp.data.(error) 357 if err == nil { 358 return mp.sendOKPacket(0, 0, uint16(resp.status), 0, "") 359 } 360 switch myerr := err.(type) { 361 case *moerr.Error: 362 var code uint16 363 if myerr.MySQLCode() != moerr.ER_UNKNOWN_ERROR { 364 code = myerr.MySQLCode() 365 } else { 366 code = myerr.ErrorCode() 367 } 368 errMsg := myerr.Error() 369 if attachAbort != "" { 370 errMsg = fmt.Sprintf("%s\n%s", myerr.Error(), attachAbort) 371 } 372 return mp.sendErrPacket(code, myerr.SqlState(), errMsg) 373 } 374 errMsg := "" 375 if attachAbort != "" { 376 errMsg = fmt.Sprintf("%s\n%s", err, attachAbort) 377 } else { 378 errMsg = fmt.Sprintf("%v", err) 379 } 380 return mp.sendErrPacket(moerr.ER_UNKNOWN_ERROR, DefaultMySQLState, errMsg) 381 case ResultResponse: 382 mer := resp.data.(*MysqlExecutionResult) 383 if mer == nil { 384 return mp.sendOKPacket(0, 0, uint16(resp.status), 0, "") 385 } 386 if mer.Mrs() == nil { 387 return mp.sendOKPacket(mer.AffectedRows(), mer.InsertID(), uint16(resp.status), mer.Warnings(), "") 388 } 389 return mp.sendResultSet(ctx, mer.Mrs(), resp.cmd, mer.Warnings(), uint16(resp.status)) 390 case LocalInfileRequest: 391 s, _ := resp.data.(string) 392 return mp.sendLocalInfileRequest(s) 393 default: 394 return moerr.NewInternalError(ctx, "unsupported response:%d ", resp.category) 395 } 396 } 397 398 func (mp *MysqlProtocolImpl) DisableAutoFlush() { 399 mp.disableAutoFlush = true 400 } 401 402 func (mp *MysqlProtocolImpl) EnableAutoFlush() { 403 mp.disableAutoFlush = false 404 } 405 406 func (mp *MysqlProtocolImpl) Flush() error { 407 return nil 408 } 409 410 var _ MysqlProtocol = &FakeProtocol{} 411 412 const ( 413 fakeConnectionID uint32 = math.MaxUint32 414 ) 415 416 // FakeProtocol works for the background transaction that does not use the network protocol. 417 type FakeProtocol struct { 418 username string 419 database string 420 ioses goetty.IOSession 421 } 422 423 func (fp *FakeProtocol) UpdateCtx(ctx context.Context) { 424 425 } 426 427 func (fp *FakeProtocol) GetCapability() uint32 { 428 return DefaultCapability 429 } 430 431 func (fp *FakeProtocol) SetCapability(uint32) { 432 433 } 434 435 func (fp *FakeProtocol) IsTlsEstablished() bool { 436 return true 437 } 438 439 func (fp *FakeProtocol) SetTlsEstablished() { 440 441 } 442 443 func (fp *FakeProtocol) HandleHandshake(ctx context.Context, payload []byte) (bool, error) { 444 return false, nil 445 } 446 447 func (fp *FakeProtocol) Authenticate(ctx context.Context) error { 448 return nil 449 } 450 451 func (fp *FakeProtocol) GetTcpConnection() goetty.IOSession { 452 return fp.ioses 453 } 454 455 func (fp *FakeProtocol) GetDebugString() string { 456 return "fake protocol" 457 } 458 459 func (fp *FakeProtocol) GetSequenceId() uint8 { 460 return 0 461 } 462 463 func (fp *FakeProtocol) SetSequenceID(value uint8) { 464 } 465 466 func (fp *FakeProtocol) GetConnectAttrs() map[string]string { 467 return nil 468 } 469 470 func (fp *FakeProtocol) SendPrepareResponse(ctx context.Context, stmt *PrepareStmt) error { 471 return nil 472 } 473 474 func (fp *FakeProtocol) ParseSendLongData(ctx context.Context, proc *process.Process, stmt *PrepareStmt, data []byte, pos int) error { 475 return nil 476 } 477 478 func (fp *FakeProtocol) ParseExecuteData(ctx context.Context, proc *process.Process, stmt *PrepareStmt, data []byte, pos int) error { 479 return nil 480 } 481 482 func (fp *FakeProtocol) SendResultSetTextBatchRow(mrs *MysqlResultSet, cnt uint64) error { 483 return nil 484 } 485 486 func (fp *FakeProtocol) SendResultSetTextBatchRowSpeedup(mrs *MysqlResultSet, cnt uint64) error { 487 return nil 488 } 489 490 func (fp *FakeProtocol) SendColumnDefinitionPacket(ctx context.Context, column Column, cmd int) error { 491 return nil 492 } 493 494 func (fp *FakeProtocol) SendColumnCountPacket(count uint64) error { 495 return nil 496 } 497 498 func (fp *FakeProtocol) SendEOFPacketIf(warnings uint16, status uint16) error { 499 return nil 500 } 501 502 func (fp *FakeProtocol) sendOKPacket(affectedRows uint64, lastInsertId uint64, status uint16, warnings uint16, message string) error { 503 return nil 504 } 505 506 func (fp *FakeProtocol) sendEOFOrOkPacket(warnings uint16, status uint16) error { 507 return nil 508 } 509 510 func (fp *FakeProtocol) ResetStatistics() {} 511 512 func (fp *FakeProtocol) GetStats() string { 513 return "" 514 } 515 516 func (fp *FakeProtocol) CalculateOutTrafficBytes(reset bool) (int64, int64) { return 0, 0 } 517 518 func (fp *FakeProtocol) IsEstablished() bool { 519 return true 520 } 521 522 func (fp *FakeProtocol) SetEstablished() {} 523 524 func (fp *FakeProtocol) GetRequest(payload []byte) *Request { 525 return nil 526 } 527 528 func (fp *FakeProtocol) SendResponse(ctx context.Context, resp *Response) error { 529 return nil 530 } 531 532 func (fp *FakeProtocol) ConnectionID() uint32 { 533 return fakeConnectionID 534 } 535 536 func (fp *FakeProtocol) Peer() string { 537 return "0.0.0.0:0" 538 } 539 540 func (fp *FakeProtocol) GetDatabaseName() string { 541 return fp.database 542 } 543 544 func (fp *FakeProtocol) SetDatabaseName(s string) { 545 fp.database = s 546 } 547 548 func (fp *FakeProtocol) GetUserName() string { 549 return fp.username 550 } 551 552 func (fp *FakeProtocol) SetUserName(s string) { 553 fp.username = s 554 } 555 556 func (fp *FakeProtocol) Quit() {} 557 558 func (fp *FakeProtocol) sendLocalInfileRequest(filename string) error { 559 return nil 560 } 561 562 func (fp *FakeProtocol) incDebugCount(int) {} 563 564 func (fp *FakeProtocol) resetDebugCount() []uint64 { 565 return nil 566 } 567 568 func (fp *FakeProtocol) DisableAutoFlush() { 569 } 570 571 func (fp *FakeProtocol) EnableAutoFlush() { 572 } 573 574 func (fp *FakeProtocol) Flush() error { 575 return nil 576 }