github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/pingcap/tidb/server/conn.go (about) 1 // Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. 2 // 3 // This Source Code Form is subject to the terms of the Mozilla Public 4 // License, v. 2.0. If a copy of the MPL was not distributed with this file, 5 // You can obtain one at http://mozilla.org/MPL/2.0/. 6 7 // The MIT License (MIT) 8 // 9 // Copyright (c) 2014 wandoulabs 10 // Copyright (c) 2014 siddontang 11 // 12 // Permission is hereby granted, free of charge, to any person obtaining a copy of 13 // this software and associated documentation files (the "Software"), to deal in 14 // the Software without restriction, including without limitation the rights to 15 // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 16 // the Software, and to permit persons to whom the Software is furnished to do so, 17 // subject to the following conditions: 18 // 19 // The above copyright notice and this permission notice shall be included in all 20 // copies or substantial portions of the Software. 21 22 // Copyright 2015 PingCAP, Inc. 23 // 24 // Licensed under the Apache License, Version 2.0 (the "License"); 25 // you may not use this file except in compliance with the License. 26 // You may obtain a copy of the License at 27 // 28 // http://www.apache.org/licenses/LICENSE-2.0 29 // 30 // Unless required by applicable law or agreed to in writing, software 31 // distributed under the License is distributed on an "AS IS" BASIS, 32 // See the License for the specific language governing permissions and 33 // limitations under the License. 34 35 package server 36 37 import ( 38 "bytes" 39 "encoding/binary" 40 "fmt" 41 "io" 42 "net" 43 "runtime" 44 "strings" 45 "time" 46 47 "github.com/insionng/yougam/libraries/juju/errors" 48 "github.com/insionng/yougam/libraries/ngaut/log" 49 "github.com/insionng/yougam/libraries/pingcap/tidb/mysql" 50 "github.com/insionng/yougam/libraries/pingcap/tidb/terror" 51 "github.com/insionng/yougam/libraries/pingcap/tidb/util/arena" 52 "github.com/insionng/yougam/libraries/pingcap/tidb/util/hack" 53 "github.com/insionng/yougam/libraries/pingcap/tidb/util/types" 54 ) 55 56 var defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag | 57 mysql.ClientConnectWithDB | mysql.ClientProtocol41 | 58 mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows 59 60 type clientConn struct { 61 pkg *packetIO 62 conn net.Conn 63 server *Server 64 capability uint32 65 connectionID uint32 66 collation uint8 67 charset string 68 user string 69 dbname string 70 salt []byte 71 alloc arena.Allocator 72 lastCmd string 73 ctx IContext 74 } 75 76 func (cc *clientConn) String() string { 77 return fmt.Sprintf("conn: %s, status: %d, charset: %s, user: %s, lastInsertId: %d", 78 cc.conn.RemoteAddr(), cc.ctx.Status(), cc.charset, cc.user, cc.ctx.LastInsertID(), 79 ) 80 } 81 82 func (cc *clientConn) handshake() error { 83 if err := cc.writeInitialHandshake(); err != nil { 84 return errors.Trace(err) 85 } 86 if err := cc.readHandshakeResponse(); err != nil { 87 cc.writeError(err) 88 return errors.Trace(err) 89 } 90 data := cc.alloc.AllocWithLen(4, 32) 91 data = append(data, mysql.OKHeader) 92 data = append(data, 0, 0) 93 if cc.capability&mysql.ClientProtocol41 > 0 { 94 data = append(data, dumpUint16(mysql.ServerStatusAutocommit)...) 95 data = append(data, 0, 0) 96 } 97 98 err := cc.writePacket(data) 99 cc.pkg.sequence = 0 100 if err != nil { 101 return errors.Trace(err) 102 } 103 104 return errors.Trace(cc.flush()) 105 } 106 107 func (cc *clientConn) Close() error { 108 cc.server.rwlock.Lock() 109 delete(cc.server.clients, cc.connectionID) 110 cc.server.rwlock.Unlock() 111 cc.conn.Close() 112 if cc.ctx != nil { 113 return cc.ctx.Close() 114 } 115 return nil 116 } 117 118 func (cc *clientConn) writeInitialHandshake() error { 119 data := make([]byte, 4, 128) 120 121 // min version 10 122 data = append(data, 10) 123 // server version[00] 124 data = append(data, mysql.ServerVersion...) 125 data = append(data, 0) 126 // connection id 127 data = append(data, byte(cc.connectionID), byte(cc.connectionID>>8), byte(cc.connectionID>>16), byte(cc.connectionID>>24)) 128 // auth-plugin-data-part-1 129 data = append(data, cc.salt[0:8]...) 130 // filler [00] 131 data = append(data, 0) 132 // capability flag lower 2 bytes, using default capability here 133 data = append(data, byte(defaultCapability), byte(defaultCapability>>8)) 134 // charset, utf-8 default 135 data = append(data, uint8(mysql.DefaultCollationID)) 136 //status 137 data = append(data, dumpUint16(mysql.ServerStatusAutocommit)...) 138 // below 13 byte may not be used 139 // capability flag upper 2 bytes, using default capability here 140 data = append(data, byte(defaultCapability>>16), byte(defaultCapability>>24)) 141 // filler [0x15], for wireshark dump, value is 0x15 142 data = append(data, 0x15) 143 // reserved 10 [00] 144 data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) 145 // auth-plugin-data-part-2 146 data = append(data, cc.salt[8:]...) 147 // filler [00] 148 data = append(data, 0) 149 err := cc.writePacket(data) 150 if err != nil { 151 return errors.Trace(err) 152 } 153 return errors.Trace(cc.flush()) 154 } 155 156 func (cc *clientConn) readPacket() ([]byte, error) { 157 return cc.pkg.readPacket() 158 } 159 160 func (cc *clientConn) writePacket(data []byte) error { 161 return cc.pkg.writePacket(data) 162 } 163 164 func (cc *clientConn) readHandshakeResponse() error { 165 data, err := cc.readPacket() 166 if err != nil { 167 return errors.Trace(err) 168 } 169 170 pos := 0 171 // capability 172 cc.capability = binary.LittleEndian.Uint32(data[:4]) 173 pos += 4 174 // skip max packet size 175 pos += 4 176 // charset, skip, if you want to use another charset, use set names 177 cc.collation = data[pos] 178 pos++ 179 // skip reserved 23[00] 180 pos += 23 181 // user name 182 cc.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) 183 pos += len(cc.user) + 1 184 // auth length and auth 185 authLen := int(data[pos]) 186 pos++ 187 auth := data[pos : pos+authLen] 188 pos += authLen 189 if cc.capability&mysql.ClientConnectWithDB > 0 { 190 if len(data[pos:]) > 0 { 191 idx := bytes.IndexByte(data[pos:], 0) 192 cc.dbname = string(data[pos : pos+idx]) 193 } 194 } 195 // Open session and do auth 196 cc.ctx, err = cc.server.driver.OpenCtx(uint64(cc.connectionID), cc.capability, uint8(cc.collation), cc.dbname) 197 if err != nil { 198 cc.Close() 199 return errors.Trace(err) 200 } 201 if !cc.server.skipAuth() { 202 // Do Auth 203 addr := cc.conn.RemoteAddr().String() 204 host, _, err1 := net.SplitHostPort(addr) 205 if err1 != nil { 206 return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, addr, "Yes")) 207 } 208 user := fmt.Sprintf("%s@%s", cc.user, host) 209 if !cc.ctx.Auth(user, auth, cc.salt) { 210 return errors.Trace(mysql.NewErr(mysql.ErrAccessDenied, cc.user, host, "Yes")) 211 } 212 } 213 return nil 214 } 215 216 func (cc *clientConn) Run() { 217 defer func() { 218 r := recover() 219 if r != nil { 220 const size = 4096 221 buf := make([]byte, size) 222 buf = buf[:runtime.Stack(buf, false)] 223 log.Errorf("lastCmd %s, %v, %s", cc.lastCmd, r, buf) 224 } 225 cc.Close() 226 }() 227 228 for { 229 cc.alloc.Reset() 230 data, err := cc.readPacket() 231 if err != nil { 232 if terror.ErrorNotEqual(err, io.EOF) { 233 log.Error(err) 234 } 235 return 236 } 237 238 if err := cc.dispatch(data); err != nil { 239 if terror.ErrorEqual(err, io.EOF) { 240 return 241 } 242 log.Errorf("dispatch error %s, %s", errors.ErrorStack(err), cc) 243 log.Errorf("cmd: %s", string(data[1:])) 244 cc.writeError(err) 245 } 246 247 cc.pkg.sequence = 0 248 } 249 } 250 251 func (cc *clientConn) dispatch(data []byte) error { 252 cmd := data[0] 253 data = data[1:] 254 cc.lastCmd = hack.String(data) 255 256 token := cc.server.getToken() 257 258 startTs := time.Now() 259 defer func() { 260 cc.server.releaseToken(token) 261 log.Debugf("[TIME_CMD] %v %d", time.Now().Sub(startTs), cmd) 262 }() 263 264 switch cmd { 265 case mysql.ComQuit: 266 return io.EOF 267 case mysql.ComQuery: 268 return cc.handleQuery(hack.String(data)) 269 case mysql.ComPing: 270 return cc.writeOK() 271 case mysql.ComInitDB: 272 log.Debug("init db", hack.String(data)) 273 if err := cc.useDB(hack.String(data)); err != nil { 274 return errors.Trace(err) 275 } 276 return cc.writeOK() 277 case mysql.ComFieldList: 278 return cc.handleFieldList(hack.String(data)) 279 case mysql.ComStmtPrepare: 280 return cc.handleStmtPrepare(hack.String(data)) 281 case mysql.ComStmtExecute: 282 return cc.handleStmtExecute(data) 283 case mysql.ComStmtClose: 284 return cc.handleStmtClose(data) 285 case mysql.ComStmtSendLongData: 286 return cc.handleStmtSendLongData(data) 287 case mysql.ComStmtReset: 288 return cc.handleStmtReset(data) 289 default: 290 return mysql.NewErrf(mysql.ErrUnknown, "command %d not supported now", cmd) 291 } 292 } 293 294 func (cc *clientConn) useDB(db string) (err error) { 295 _, err = cc.ctx.Execute("use " + db) 296 if err != nil { 297 return errors.Trace(err) 298 } 299 cc.dbname = db 300 return 301 } 302 303 func (cc *clientConn) flush() error { 304 return cc.pkg.flush() 305 } 306 307 func (cc *clientConn) writeOK() error { 308 data := cc.alloc.AllocWithLen(4, 32) 309 data = append(data, mysql.OKHeader) 310 data = append(data, dumpLengthEncodedInt(uint64(cc.ctx.AffectedRows()))...) 311 data = append(data, dumpLengthEncodedInt(uint64(cc.ctx.LastInsertID()))...) 312 if cc.capability&mysql.ClientProtocol41 > 0 { 313 data = append(data, dumpUint16(cc.ctx.Status())...) 314 data = append(data, dumpUint16(cc.ctx.WarningCount())...) 315 } 316 317 err := cc.writePacket(data) 318 if err != nil { 319 return errors.Trace(err) 320 } 321 322 return errors.Trace(cc.flush()) 323 } 324 325 func (cc *clientConn) writeError(e error) error { 326 var ( 327 m *mysql.SQLError 328 te *terror.Error 329 ok bool 330 ) 331 originErr := errors.Cause(e) 332 if te, ok = originErr.(*terror.Error); ok { 333 m = te.ToSQLError() 334 } else { 335 m = mysql.NewErrf(mysql.ErrUnknown, e.Error()) 336 } 337 338 data := make([]byte, 4, 16+len(m.Message)) 339 data = append(data, mysql.ErrHeader) 340 data = append(data, byte(m.Code), byte(m.Code>>8)) 341 if cc.capability&mysql.ClientProtocol41 > 0 { 342 data = append(data, '#') 343 data = append(data, m.State...) 344 } 345 346 data = append(data, m.Message...) 347 348 err := cc.writePacket(data) 349 if err != nil { 350 return errors.Trace(err) 351 } 352 return errors.Trace(cc.flush()) 353 } 354 355 func (cc *clientConn) writeEOF() error { 356 data := cc.alloc.AllocWithLen(4, 9) 357 358 data = append(data, mysql.EOFHeader) 359 if cc.capability&mysql.ClientProtocol41 > 0 { 360 data = append(data, dumpUint16(cc.ctx.WarningCount())...) 361 data = append(data, dumpUint16(cc.ctx.Status())...) 362 } 363 364 err := cc.writePacket(data) 365 return errors.Trace(err) 366 } 367 368 func (cc *clientConn) handleQuery(sql string) (err error) { 369 startTs := time.Now() 370 rs, err := cc.ctx.Execute(sql) 371 if err != nil { 372 return errors.Trace(err) 373 } 374 if rs != nil { 375 err = cc.writeResultset(rs, false) 376 } else { 377 err = cc.writeOK() 378 } 379 log.Debugf("[TIME_QUERY] %v %s", time.Now().Sub(startTs), sql) 380 return errors.Trace(err) 381 } 382 383 func (cc *clientConn) handleFieldList(sql string) (err error) { 384 parts := strings.Split(sql, "\x00") 385 columns, err := cc.ctx.FieldList(parts[0]) 386 if err != nil { 387 return errors.Trace(err) 388 } 389 data := make([]byte, 4, 1024) 390 for _, v := range columns { 391 data = data[0:4] 392 data = append(data, v.Dump(cc.alloc)...) 393 if err := cc.writePacket(data); err != nil { 394 return errors.Trace(err) 395 } 396 } 397 if err := cc.writeEOF(); err != nil { 398 return errors.Trace(err) 399 } 400 return errors.Trace(cc.flush()) 401 } 402 403 func (cc *clientConn) writeResultset(rs ResultSet, binary bool) error { 404 defer rs.Close() 405 // We need to call Next before we get columns. 406 // Otherwise, we will get incorrect columns info. 407 row, err := rs.Next() 408 if err != nil { 409 return errors.Trace(err) 410 } 411 412 columns, err := rs.Columns() 413 if err != nil { 414 return errors.Trace(err) 415 } 416 columnLen := dumpLengthEncodedInt(uint64(len(columns))) 417 data := cc.alloc.AllocWithLen(4, 1024) 418 data = append(data, columnLen...) 419 if err = cc.writePacket(data); err != nil { 420 return errors.Trace(err) 421 } 422 423 for _, v := range columns { 424 data = data[0:4] 425 data = append(data, v.Dump(cc.alloc)...) 426 if err = cc.writePacket(data); err != nil { 427 return errors.Trace(err) 428 } 429 } 430 431 if err = cc.writeEOF(); err != nil { 432 return errors.Trace(err) 433 } 434 435 for { 436 if err != nil { 437 return errors.Trace(err) 438 } 439 if row == nil { 440 break 441 } 442 data = data[0:4] 443 if binary { 444 var rowData []byte 445 rowData, err = dumpRowValuesBinary(cc.alloc, columns, row) 446 if err != nil { 447 return errors.Trace(err) 448 } 449 data = append(data, rowData...) 450 } else { 451 for i, value := range row { 452 if value.Kind() == types.KindNull { 453 data = append(data, 0xfb) 454 continue 455 } 456 var valData []byte 457 valData, err = dumpTextValue(columns[i].Type, value) 458 if err != nil { 459 return errors.Trace(err) 460 } 461 data = append(data, dumpLengthEncodedString(valData, cc.alloc)...) 462 } 463 } 464 465 if err = cc.writePacket(data); err != nil { 466 return errors.Trace(err) 467 } 468 row, err = rs.Next() 469 } 470 471 err = cc.writeEOF() 472 if err != nil { 473 return errors.Trace(err) 474 } 475 476 return errors.Trace(cc.flush()) 477 }