github.com/XiaoMi/Gaea@v1.2.5/proxy/server/session.go (about) 1 // Copyright 2019 The Gaea Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package server 16 17 import ( 18 "fmt" 19 "net" 20 "runtime" 21 "strings" 22 "sync" 23 "sync/atomic" 24 25 "github.com/XiaoMi/Gaea/log" 26 "github.com/XiaoMi/Gaea/mysql" 27 "github.com/XiaoMi/Gaea/util" 28 ) 29 30 // DefaultCapability means default capability 31 var DefaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag | 32 mysql.ClientConnectWithDB | mysql.ClientProtocol41 | 33 mysql.ClientTransactions | mysql.ClientSecureConnection 34 35 //下面的会根据配置文件参数加进去 36 //mysql.ClientPluginAuth 37 38 var baseConnID uint32 = 10000 39 40 const initClientConnStatus = mysql.ServerStatusAutocommit 41 42 // Session means session between client and proxy 43 type Session struct { 44 sync.Mutex 45 46 c *ClientConn 47 proxy *Server 48 49 manager *Manager 50 51 namespace string 52 53 executor *SessionExecutor 54 55 closed atomic.Value 56 } 57 58 // create session between client<->proxy 59 func newSession(s *Server, co net.Conn) *Session { 60 cc := new(Session) 61 tcpConn := co.(*net.TCPConn) 62 63 //SetNoDelay controls whether the operating system should delay packet transmission 64 // in hopes of sending fewer packets (Nagle's algorithm). 65 // The default is true (no delay), 66 // meaning that data is sent as soon as possible after a Write. 67 //I set this option false. 68 tcpConn.SetNoDelay(true) 69 cc.c = NewClientConn(mysql.NewConn(tcpConn), s.manager) 70 cc.proxy = s 71 cc.manager = s.manager 72 73 cc.c.SetConnectionID(atomic.AddUint32(&baseConnID, 1)) 74 cc.c.proxy = s 75 76 cc.executor = newSessionExecutor(s.manager) 77 cc.executor.clientAddr = co.RemoteAddr().String() 78 cc.closed.Store(false) 79 return cc 80 } 81 82 func (cc *Session) getNamespace() *Namespace { 83 return cc.manager.GetNamespace(cc.namespace) 84 } 85 86 // IsAllowConnect check if allow to connect 87 func (cc *Session) IsAllowConnect() bool { 88 ns := cc.getNamespace() // maybe nil, and panic! 89 clientHost, _, err := net.SplitHostPort(cc.c.RemoteAddr().String()) 90 if err != nil { 91 log.Warn("[server] Session parse host error: %v", err) 92 } 93 clientIP := net.ParseIP(clientHost) 94 95 return ns.IsClientIPAllowed(clientIP) 96 } 97 98 // Handshake with client 99 // step1: server send plain handshake packets to client 100 // step2: client send handshake response packets to server 101 // step3: server send ok/err packets to client 102 func (cc *Session) Handshake() error { 103 // First build and send the server handshake packet. 104 if err := cc.c.writeInitialHandshakeV10(); err != nil { 105 clientHost, _, innerErr := net.SplitHostPort(cc.c.RemoteAddr().String()) 106 if innerErr != nil { 107 log.Warn("[server] Session parse host error: %v", innerErr) 108 } 109 // filter lvs detect liveness 110 hostname, _ := util.HostName(clientHost) 111 if len(hostname) > 0 && strings.Contains(hostname, "lvs") { 112 return err 113 } 114 115 log.Warn("[server] Session writeInitialHandshake error, connId: %d, ip: %s, msg: %s, error: %s", 116 cc.c.GetConnectionID(), clientHost, " send initial handshake error", err.Error()) 117 return err 118 } 119 120 info, err := cc.c.readHandshakeResponse() 121 if err != nil { 122 clientHost, _, innerErr := net.SplitHostPort(cc.c.RemoteAddr().String()) 123 if innerErr != nil { 124 log.Warn("[server] Session parse host error: %v", innerErr) 125 } 126 // filter lvs detect liveness 127 hostname, _ := util.HostName(clientHost) 128 if len(hostname) > 0 && strings.Contains(hostname, "lvs") { 129 return err 130 } 131 132 log.Warn("[server] Session readHandshakeResponse error, connId: %d, ip: %s, msg: %s, error: %s", 133 cc.c.GetConnectionID(), clientHost, "read Handshake Response error", err.Error()) 134 return err 135 } 136 137 if err := cc.handleHandshakeResponse(info); err != nil { 138 log.Warn("handleHandshakeResponse error, connId: %d, err: %v", cc.c.GetConnectionID(), err) 139 return err 140 } 141 142 if err := cc.c.writeOK(cc.executor.GetStatus()); err != nil { 143 log.Warn("[server] Session readHandshakeResponse error, connId %d, msg: %s, error: %s", 144 cc.c.GetConnectionID(), "write ok fail", err.Error()) 145 return err 146 } 147 148 return nil 149 } 150 151 func (cc *Session) handleHandshakeResponse(info HandshakeResponseInfo) error { 152 // check and set user 153 var password string 154 var succ bool 155 user := info.User 156 if !cc.manager.CheckUser(user) { 157 return mysql.NewDefaultError(mysql.ErrAccessDenied, user, cc.c.RemoteAddr().String(), "Yes") 158 } 159 cc.executor.user = user 160 161 // check password 162 if len(info.AuthPlugin) == 0 { 163 if len(info.AuthResponse) == 32 { 164 succ, password = cc.manager.CheckSha2Password(user, info.Salt, info.AuthResponse) 165 } else { 166 succ, password = cc.manager.CheckPassword(user, info.Salt, info.AuthResponse) 167 } 168 } else if info.AuthPlugin == mysql.CachingSHA2Password { 169 succ, password = cc.manager.CheckSha2Password(user, info.Salt, info.AuthResponse) 170 } else { 171 succ, password = cc.manager.CheckPassword(user, info.Salt, info.AuthResponse) 172 } 173 174 if !succ { 175 return mysql.NewDefaultError(mysql.ErrAccessDenied, user, cc.c.RemoteAddr().String(), "Yes") 176 } 177 178 // handle collation 179 collationID := info.CollationID 180 collationName, ok := mysql.Collations[mysql.CollationID(collationID)] 181 if !ok { 182 return mysql.NewError(mysql.ErrInternal, "invalid collation") 183 } 184 charset, ok := mysql.CollationNameToCharset[collationName] 185 if !ok { 186 return mysql.NewError(mysql.ErrInternal, "invalid collation") 187 } 188 cc.executor.SetCollationID(mysql.CollationID(collationID)) 189 cc.executor.SetCharset(charset) 190 191 // set database 192 cc.executor.SetDatabase(info.Database) 193 194 // set namespace 195 namespace := cc.manager.GetNamespaceByUser(user, password) 196 cc.namespace = namespace 197 cc.executor.namespace = namespace 198 cc.c.namespace = namespace // TODO: remove it when refactor is done 199 return nil 200 } 201 202 // Close close session with it's resources 203 func (cc *Session) Close() { 204 if cc.IsClosed() { 205 return 206 } 207 cc.closed.Store(true) 208 if err := cc.executor.rollback(); err != nil { 209 log.Warn("executor rollback error when Session close: %v", err) 210 } 211 cc.c.Close() 212 log.Debug("client closed, %d", cc.c.GetConnectionID()) 213 214 return 215 } 216 217 // IsClosed check if closed 218 func (cc *Session) IsClosed() bool { 219 return cc.closed.Load().(bool) 220 } 221 222 // Run start session to server client request packets 223 func (cc *Session) Run() { 224 defer func() { 225 r := recover() 226 if err, ok := r.(error); ok { 227 const size = 4096 228 buf := make([]byte, size) 229 buf = buf[:runtime.Stack(buf, false)] 230 231 log.Warn("[server] Session Run panic error, error: %s, stack: %s", err.Error(), string(buf)) 232 } 233 cc.Close() 234 cc.proxy.tw.Remove(cc) 235 cc.manager.GetStatisticManager().DescSessionCount(cc.namespace) 236 }() 237 238 cc.manager.GetStatisticManager().IncrSessionCount(cc.namespace) 239 240 for !cc.IsClosed() { 241 cc.c.SetSequence(0) 242 data, err := cc.c.ReadEphemeralPacket() 243 if err != nil { 244 cc.c.RecycleReadPacket() 245 return 246 } 247 248 cc.proxy.tw.Add(cc.proxy.sessionTimeout, cc, cc.Close) 249 cc.manager.GetStatisticManager().AddReadFlowCount(cc.namespace, len(data)) 250 251 cmd := data[0] 252 data = data[1:] 253 rs := cc.executor.ExecuteCommand(cmd, data) 254 cc.c.RecycleReadPacket() 255 256 if err = cc.writeResponse(rs); err != nil { 257 log.Warn("Session write response error, connId: %d, err: %v", cc.c.GetConnectionID(), err) 258 cc.Close() 259 return 260 } 261 262 if cmd == mysql.ComQuit { 263 cc.Close() 264 } 265 } 266 } 267 268 func (cc *Session) writeResponse(r Response) error { 269 switch r.RespType { 270 case RespEOF: 271 return cc.c.writeEOFPacket(r.Status) 272 case RespResult: 273 rs := r.Data.(*mysql.Result) 274 if rs == nil { 275 return cc.c.writeOK(r.Status) 276 } 277 return cc.c.writeOKResult(r.Status, r.Data.(*mysql.Result)) 278 case RespPrepare: 279 stmt := r.Data.(*Stmt) 280 if stmt == nil { 281 return cc.c.writeOK(r.Status) 282 } 283 return cc.c.writePrepareResponse(r.Status, stmt) 284 case RespFieldList: 285 rs := r.Data.([]*mysql.Field) 286 if rs == nil { 287 return cc.c.writeOK(r.Status) 288 } 289 return cc.c.writeFieldList(r.Status, rs) 290 case RespError: 291 rs := r.Data.(error) 292 if rs == nil { 293 return cc.c.writeOK(r.Status) 294 } 295 err := cc.c.writeErrorPacket(rs) 296 if err != nil { 297 return err 298 } 299 if rs == mysql.ErrBadConn { // 后端连接如果断开, 应该返回通知Session关闭 300 return rs 301 } 302 return nil 303 case RespOK: 304 return cc.c.writeOK(r.Status) 305 case RespNoop: 306 return nil 307 default: 308 err := fmt.Errorf("invalid response type: %T", r) 309 log.Fatal(err.Error()) 310 return cc.c.writeErrorPacket(err) 311 } 312 }