github.com/XiaoMi/Gaea@v1.2.5/proxy/server/client_conn.go (about) 1 // Copyright 2019 The Gaea Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package server 16 17 import ( 18 "fmt" 19 "github.com/XiaoMi/Gaea/log" 20 "github.com/XiaoMi/Gaea/mysql" 21 "strings" 22 ) 23 24 // ClientConn session client connection 25 type ClientConn struct { 26 *mysql.Conn 27 28 salt []byte 29 30 manager *Manager 31 32 capability uint32 33 34 namespace string // TODO: remove it when refactor is done 35 36 proxy *Server 37 } 38 39 // HandshakeResponseInfo handshake response information 40 type HandshakeResponseInfo struct { 41 CollationID mysql.CollationID 42 User string 43 AuthResponse []byte 44 Salt []byte 45 Database string 46 AuthPlugin string 47 } 48 49 // NewClientConn constructor of ClientConn 50 func NewClientConn(c *mysql.Conn, manager *Manager) *ClientConn { 51 salt, _ := mysql.RandomBuf(20) 52 return &ClientConn{ 53 Conn: c, 54 salt: salt, 55 manager: manager, 56 } 57 } 58 59 func (cc *ClientConn) CompactVersion(sv string) string { 60 version := strings.Trim(sv, " ") 61 if version != "" { 62 v := strings.Split(sv, ".") 63 if len(v) < 3 { 64 return mysql.ServerVersion 65 } 66 return version 67 } else { 68 return mysql.ServerVersion 69 } 70 } 71 72 func (cc *ClientConn) writeInitialHandshakeV10() error { 73 ServerVersion := cc.CompactVersion(cc.proxy.ServerVersion) 74 length := 75 1 + // protocol version 76 mysql.LenNullString(ServerVersion) + 77 4 + // connection ID 78 8 + // first part of salt data 79 1 + // filler byte 80 2 + // capability flags (lower 2 bytes) 81 1 + // character set 82 2 + // status flag 83 2 + // capability flags (upper 2 bytes) 84 1 + // length of auth plugin data 85 10 + // reserved (0) 86 13 // auth-plugin-data 87 // mysql.LenNullString(mysql.MysqlNativePassword) // auth-plugin-name 88 if cc.proxy.AuthPlugin != "" { 89 length += mysql.LenNullString(cc.proxy.AuthPlugin) 90 } 91 92 data := cc.StartEphemeralPacket(length) 93 pos := 0 94 95 // Protocol version. 96 pos = mysql.WriteByte(data, pos, mysql.ProtocolVersion) 97 98 // Copy server version. 99 // server version data with terminate character 0x00, type: string[NUL]. 100 pos = mysql.WriteNullString(data, pos, ServerVersion) 101 102 // Add connectionID in. 103 // connection id type: 4 bytes. 104 pos = mysql.WriteUint32(data, pos, cc.GetConnectionID()) 105 106 // auth-plugin-data-part-1 type: string[8]. 107 pos += copy(data[pos:], cc.salt[:8]) 108 109 // One filler byte, always 0. 110 pos = mysql.WriteByte(data, pos, 0) 111 112 // Lower part of the capability flags, lower 2 bytes. 113 pos = mysql.WriteUint16(data, pos, uint16(DefaultCapability)) 114 115 // Character set. 116 pos = mysql.WriteByte(data, pos, byte(mysql.DefaultCollationID)) 117 118 // Status flag. 119 pos = mysql.WriteUint16(data, pos, initClientConnStatus) 120 121 // Upper part of the capability flags. 122 pos = mysql.WriteUint16(data, pos, uint16(DefaultCapability>>16)) 123 124 // Length of auth plugin data. 125 // Always 21 (8 + 13). 126 pos = mysql.WriteByte(data, pos, 21) 127 128 // Reserved 10 bytes: all 0 129 pos = mysql.WriteZeroes(data, pos, 10) 130 131 // Second part of auth plugin data. 132 pos += copy(data[pos:], cc.salt[8:]) 133 data[pos] = 0 134 pos++ 135 //authentication plugin 136 if cc.proxy.AuthPlugin != "" { 137 pos += copy(data[pos:], cc.proxy.AuthPlugin) 138 data[pos] = 0 139 pos++ 140 } 141 142 // Copy authPluginName. We always start with mysql_native_password. 143 // pos = mysql.WriteNullString(data, pos, mysql.MysqlNativePassword) 144 145 // Sanity check. 146 if pos != len(data) { 147 return fmt.Errorf("error building Handshake packet: got %v bytes expected %v", pos, len(data)) 148 } 149 150 if err := cc.WriteEphemeralPacket(); err != nil { 151 return err 152 } 153 154 return nil 155 } 156 157 func (cc *ClientConn) readHandshakeResponse() (HandshakeResponseInfo, error) { 158 info := HandshakeResponseInfo{} 159 info.Salt = cc.salt 160 161 data, err := cc.ReadEphemeralPacketDirect() 162 defer cc.RecycleReadPacket() 163 if err != nil { 164 return info, err 165 } 166 167 pos := 0 168 169 // Client flags, 4 bytes. 170 var ok bool 171 var capability uint32 172 capability, pos, ok = mysql.ReadUint32(data, pos) 173 if !ok { 174 return info, fmt.Errorf("readHandshakeResponse: can't read client flags") 175 } 176 if capability&mysql.ClientProtocol41 == 0 { 177 return info, fmt.Errorf("readHandshakeResponse: only support protocol 4.1") 178 } 179 180 cc.capability = capability 181 // Max packet size. Don't do anything with this now. 182 _, pos, ok = mysql.ReadUint32(data, pos) 183 if !ok { 184 return info, fmt.Errorf("readHandshakeResponse: can't read maxPacketSize") 185 } 186 187 // Character set 188 collationID, pos, ok := mysql.ReadByte(data, pos) 189 if !ok { 190 return info, fmt.Errorf("readHandshakeResponse: can't read characterSet") 191 } 192 info.CollationID = mysql.CollationID(collationID) 193 194 // reserved 23 zero bytes, skipped 195 pos += 23 196 197 // username 198 var user string 199 user, pos, ok = mysql.ReadNullString(data, pos) 200 if !ok { 201 return info, fmt.Errorf("readHandshakeResponse: can't read username") 202 } 203 info.User = user 204 205 // TODO auth-response can have three forms. 206 var authResponse []byte 207 var l uint64 208 l, pos, _, ok = mysql.ReadLenEncInt(data, pos) 209 if !ok { 210 return info, fmt.Errorf("readHandshakeResponse: can't read auth-response variable length") 211 } 212 213 if capability&mysql.ClientPluginAuthLenencClientData > 0 || capability&mysql.ClientSecureConnection > 0 { 214 authResponse, pos, ok = mysql.ReadBytesCopy(data, pos, int(l)) 215 } else { 216 authResponse, pos, ok = mysql.ReadNullByte(data, pos) 217 } 218 if !ok { 219 return info, fmt.Errorf("readHandshakeResponse: can't read auth-response") 220 } 221 222 info.AuthResponse = authResponse 223 224 // check if with database 225 if capability&mysql.ClientConnectWithDB > 0 { 226 var db string 227 db, pos, ok = mysql.ReadNullString(data, pos) 228 if !ok { 229 return info, fmt.Errorf("readHandshakeResponse: can't read db") 230 } 231 info.Database = db 232 } 233 if capability&mysql.ClientPluginAuth > 0 { 234 var authPlugin string 235 authPlugin, pos, ok = mysql.ReadNullString(data, pos) 236 if ok && (authPlugin != cc.proxy.AuthPlugin) { 237 info.AuthPlugin = cc.proxy.AuthPlugin 238 cc.RecycleReadPacket() 239 cc.WriteAuthSwitchRequest(info.AuthPlugin) 240 // readAuthSwitchRequestResponse 241 info.AuthResponse, err = cc.ReadEphemeralPacketDirect() 242 if err != nil { 243 return info, fmt.Errorf("readHandshakeResponse: can't read auth switch response") 244 } 245 } 246 } 247 248 // TODO auth plugin namećclient conn attrs .etc 249 return info, nil 250 } 251 252 func (cc *ClientConn) writeOK(status uint16) error { 253 err := cc.WriteOKPacket(0, 0, status, 0) 254 if err != nil { 255 log.Warn("write ok packet failed, %v", err) 256 return err 257 } 258 return nil 259 } 260 261 func (cc *ClientConn) writeOKResult(status uint16, r *mysql.Result) error { 262 if r.Resultset == nil { 263 return cc.WriteOKPacket(r.AffectedRows, r.InsertID, status, 0) 264 } 265 return cc.writeResultset(status, r.Resultset) 266 } 267 268 func (cc *ClientConn) writeEOFPacket(status uint16) error { 269 err := cc.WriteEOFPacket(status, 0) 270 if err != nil { 271 log.Warn("write eof packet failed, %v", err) 272 return err 273 } 274 return nil 275 } 276 277 func (cc *ClientConn) writeErrorPacket(err error) error { 278 e := cc.WriteErrorPacketFromError(err) 279 if e != nil { 280 log.Warn("write error packet failed, %v", err) 281 return e 282 } 283 return nil 284 } 285 286 func (cc *ClientConn) writeColumnCount(count uint64) error { 287 length := mysql.LenEncIntSize(count) 288 data := cc.StartEphemeralPacket(length) 289 cc.manager.GetStatisticManager().AddWriteFlowCount(cc.namespace, length) 290 mysql.WriteLenEncInt(data, 0, count) 291 return cc.WriteEphemeralPacket() 292 } 293 294 func (cc *ClientConn) writeRow(row []byte) error { 295 length := len(row) 296 data := cc.StartEphemeralPacket(length) 297 pos := 0 298 copy(data[pos:], row) 299 cc.manager.GetStatisticManager().AddWriteFlowCount(cc.namespace, length) 300 return cc.WriteEphemeralPacket() 301 } 302 303 // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset 304 func (cc *ClientConn) writeResultset(status uint16, r *mysql.Resultset) error { 305 var err error 306 cc.StartWriterBuffering() 307 308 // write column count 309 columnCount := uint64(len(r.Fields)) 310 err = cc.writeColumnCount(columnCount) 311 if err != nil { 312 return err 313 } 314 315 // write columns 316 err = cc.writeFieldList(status, r.Fields) 317 if err != nil { 318 return err 319 } 320 321 // write rows data 322 // resultset row, NULL is sent as 0xfb, everything else is converted into a string and is sent as Protocol::LengthEncodedString 323 for _, v := range r.RowDatas { 324 err = cc.writeRow(v) 325 if err != nil { 326 return err 327 } 328 } 329 330 err = cc.writeEOFPacket(status) 331 if err != nil { 332 return err 333 } 334 335 err = cc.Flush() 336 if err != nil { 337 return err 338 } 339 340 return nil 341 } 342 343 func (cc *ClientConn) writeFieldList(status uint16, fs []*mysql.Field) error { 344 var err error 345 for _, f := range fs { 346 err = cc.writeColumnDefinition(f) 347 if err != nil { 348 return err 349 } 350 } 351 352 err = cc.writeEOFPacket(status) 353 return err 354 } 355 356 func (cc *ClientConn) writeColumnDefinition(field *mysql.Field) error { 357 schemaLen := uint64(len(field.Schema)) 358 tableLen := uint64(len(field.Table)) 359 orgTableLen := uint64(len(field.OrgTable)) 360 nameLen := uint64(len(field.Name)) 361 orgNameLen := uint64(len(field.OrgName)) 362 length := 4 + // lenEncStringSize("def") 363 mysql.LenEncIntSize(schemaLen) + 364 len(field.Schema) + 365 mysql.LenEncIntSize(tableLen) + 366 len(field.Table) + 367 mysql.LenEncIntSize(orgTableLen) + 368 len(field.OrgTable) + 369 mysql.LenEncIntSize(nameLen) + 370 len(field.Name) + 371 mysql.LenEncIntSize(orgNameLen) + 372 len(field.OrgName) + 373 1 + // length of fixed length fields 374 2 + // character set 375 4 + // column length 376 1 + // type 377 2 + // flags 378 1 + // decimals 379 2 // filler 380 if field.DefaultValue != nil { 381 length += mysql.LenEncIntSize(uint64(len(field.DefaultValue))) + len(field.DefaultValue) 382 } 383 384 data := cc.StartEphemeralPacket(length) 385 pos := 0 386 pos = mysql.WriteLenEncString(data, pos, "def") // Always the same. 387 388 pos = mysql.WriteLenEncInt(data, pos, schemaLen) 389 copy(data[pos:], field.Schema) 390 pos += len(field.Schema) 391 392 pos = mysql.WriteLenEncInt(data, pos, tableLen) 393 copy(data[pos:], field.Table) 394 pos += len(field.Table) 395 396 pos = mysql.WriteLenEncInt(data, pos, orgTableLen) 397 copy(data[pos:], field.OrgTable) 398 pos += len(field.OrgTable) 399 400 pos = mysql.WriteLenEncInt(data, pos, nameLen) 401 copy(data[pos:], field.Name) 402 pos += len(field.Name) 403 404 pos = mysql.WriteLenEncInt(data, pos, orgNameLen) 405 copy(data[pos:], field.OrgName) 406 pos += len(field.OrgName) 407 408 pos = mysql.WriteByte(data, pos, 0x0c) 409 pos = mysql.WriteUint16(data, pos, field.Charset) 410 pos = mysql.WriteUint32(data, pos, field.ColumnLength) 411 pos = mysql.WriteByte(data, pos, byte(field.Type)) 412 pos = mysql.WriteUint16(data, pos, field.Flag) 413 pos = mysql.WriteByte(data, pos, byte(field.Decimal)) 414 pos = mysql.WriteUint16(data, pos, uint16(0x0000)) 415 416 if field.DefaultValue != nil { 417 pos = mysql.WriteLenEncInt(data, pos, field.DefaultValueLength) 418 copy(data[pos:], field.DefaultValue) 419 pos += len(field.DefaultValue) 420 } 421 if pos != len(data) { 422 return fmt.Errorf("internal error: packing of column definition used %v bytes instead of %v", pos, len(data)) 423 } 424 cc.manager.GetStatisticManager().AddWriteFlowCount(cc.namespace, len(data)) 425 426 return cc.WriteEphemeralPacket() 427 } 428 429 // writePrepareResponse write prepare response 430 func (cc *ClientConn) writePrepareResponse(status uint16, s *Stmt) error { 431 var err error 432 length := 1 + // status 433 4 + // statement-id 434 2 + // number of columns 435 2 + // number of params 436 1 + // filler 437 2 // number of warnings 438 data := cc.StartEphemeralPacket(length) 439 pos := 0 440 // status ok 441 pos = mysql.WriteByte(data, pos, 0) 442 // stmt id 443 pos = mysql.WriteUint32(data, pos, s.id) 444 // number columns 445 pos = mysql.WriteUint16(data, pos, uint16(s.columnCount)) 446 // number params 447 pos = mysql.WriteUint16(data, pos, uint16(s.paramCount)) 448 // filler [00] 449 pos = mysql.WriteByte(data, pos, 0) 450 // number of warnings 451 pos = mysql.WriteUint16(data, pos, 0) 452 if pos != length { 453 return fmt.Errorf("internal error packet row: got %v bytes but expected %v", pos, length) 454 } 455 456 err = cc.WriteEphemeralPacket() 457 if err != nil { 458 return err 459 } 460 461 if s.paramCount > 0 { 462 for i := 0; i < s.paramCount; i++ { 463 err = cc.writeColumnDefinition(p) 464 if err != nil { 465 return err 466 } 467 } 468 err = cc.writeEOFPacket(status) 469 return err 470 } 471 472 if s.columnCount > 0 { 473 for i := 0; i < s.columnCount; i++ { 474 err = cc.writeColumnDefinition(c) 475 if err != nil { 476 return err 477 } 478 } 479 err = cc.writeEOFPacket(status) 480 return err 481 } 482 483 return nil 484 } 485 486 func (cc *ClientConn) WriteAuthSwitchRequest(authMethod string) error { 487 l := 1 + len(authMethod) + 1 + len(cc.salt) + 1 488 data := cc.StartEphemeralPacket(l) 489 pos := 0 490 pos = mysql.WriteByte(data, pos, mysql.AuthSwitchHeader) 491 pos = mysql.WriteNullString(data, pos, authMethod) 492 pos = mysql.WriteBytes(data, pos, cc.salt) 493 mysql.WriteByte(data, pos, 0) 494 return cc.WriteEphemeralPacket() 495 }