github.com/XiaoMi/Gaea@v1.2.5/proxy/server/session.go (about)

     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package server
    16  
    17  import (
    18  	"fmt"
    19  	"net"
    20  	"runtime"
    21  	"strings"
    22  	"sync"
    23  	"sync/atomic"
    24  
    25  	"github.com/XiaoMi/Gaea/log"
    26  	"github.com/XiaoMi/Gaea/mysql"
    27  	"github.com/XiaoMi/Gaea/util"
    28  )
    29  
    30  // DefaultCapability means default capability
    31  var DefaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag |
    32  	mysql.ClientConnectWithDB | mysql.ClientProtocol41 |
    33  	mysql.ClientTransactions | mysql.ClientSecureConnection
    34  
    35  //下面的会根据配置文件参数加进去
    36  //mysql.ClientPluginAuth
    37  
    38  var baseConnID uint32 = 10000
    39  
    40  const initClientConnStatus = mysql.ServerStatusAutocommit
    41  
    42  // Session means session between client and proxy
    43  type Session struct {
    44  	sync.Mutex
    45  
    46  	c     *ClientConn
    47  	proxy *Server
    48  
    49  	manager *Manager
    50  
    51  	namespace string
    52  
    53  	executor *SessionExecutor
    54  
    55  	closed atomic.Value
    56  }
    57  
    58  // create session between client<->proxy
    59  func newSession(s *Server, co net.Conn) *Session {
    60  	cc := new(Session)
    61  	tcpConn := co.(*net.TCPConn)
    62  
    63  	//SetNoDelay controls whether the operating system should delay packet transmission
    64  	// in hopes of sending fewer packets (Nagle's algorithm).
    65  	// The default is true (no delay),
    66  	// meaning that data is sent as soon as possible after a Write.
    67  	//I set this option false.
    68  	tcpConn.SetNoDelay(true)
    69  	cc.c = NewClientConn(mysql.NewConn(tcpConn), s.manager)
    70  	cc.proxy = s
    71  	cc.manager = s.manager
    72  
    73  	cc.c.SetConnectionID(atomic.AddUint32(&baseConnID, 1))
    74  	cc.c.proxy = s
    75  
    76  	cc.executor = newSessionExecutor(s.manager)
    77  	cc.executor.clientAddr = co.RemoteAddr().String()
    78  	cc.closed.Store(false)
    79  	return cc
    80  }
    81  
    82  func (cc *Session) getNamespace() *Namespace {
    83  	return cc.manager.GetNamespace(cc.namespace)
    84  }
    85  
    86  // IsAllowConnect check if allow to connect
    87  func (cc *Session) IsAllowConnect() bool {
    88  	ns := cc.getNamespace() // maybe nil, and panic!
    89  	clientHost, _, err := net.SplitHostPort(cc.c.RemoteAddr().String())
    90  	if err != nil {
    91  		log.Warn("[server] Session parse host error: %v", err)
    92  	}
    93  	clientIP := net.ParseIP(clientHost)
    94  
    95  	return ns.IsClientIPAllowed(clientIP)
    96  }
    97  
    98  // Handshake with client
    99  // step1: server send plain handshake packets to client
   100  // step2: client send handshake response packets to server
   101  // step3: server send ok/err packets to client
   102  func (cc *Session) Handshake() error {
   103  	// First build and send the server handshake packet.
   104  	if err := cc.c.writeInitialHandshakeV10(); err != nil {
   105  		clientHost, _, innerErr := net.SplitHostPort(cc.c.RemoteAddr().String())
   106  		if innerErr != nil {
   107  			log.Warn("[server] Session parse host error: %v", innerErr)
   108  		}
   109  		// filter lvs detect liveness
   110  		hostname, _ := util.HostName(clientHost)
   111  		if len(hostname) > 0 && strings.Contains(hostname, "lvs") {
   112  			return err
   113  		}
   114  
   115  		log.Warn("[server] Session writeInitialHandshake error, connId: %d, ip: %s, msg: %s, error: %s",
   116  			cc.c.GetConnectionID(), clientHost, " send initial handshake error", err.Error())
   117  		return err
   118  	}
   119  
   120  	info, err := cc.c.readHandshakeResponse()
   121  	if err != nil {
   122  		clientHost, _, innerErr := net.SplitHostPort(cc.c.RemoteAddr().String())
   123  		if innerErr != nil {
   124  			log.Warn("[server] Session parse host error: %v", innerErr)
   125  		}
   126  		// filter lvs detect liveness
   127  		hostname, _ := util.HostName(clientHost)
   128  		if len(hostname) > 0 && strings.Contains(hostname, "lvs") {
   129  			return err
   130  		}
   131  
   132  		log.Warn("[server] Session readHandshakeResponse error, connId: %d, ip: %s, msg: %s, error: %s",
   133  			cc.c.GetConnectionID(), clientHost, "read Handshake Response error", err.Error())
   134  		return err
   135  	}
   136  
   137  	if err := cc.handleHandshakeResponse(info); err != nil {
   138  		log.Warn("handleHandshakeResponse error, connId: %d, err: %v", cc.c.GetConnectionID(), err)
   139  		return err
   140  	}
   141  
   142  	if err := cc.c.writeOK(cc.executor.GetStatus()); err != nil {
   143  		log.Warn("[server] Session readHandshakeResponse error, connId %d, msg: %s, error: %s",
   144  			cc.c.GetConnectionID(), "write ok fail", err.Error())
   145  		return err
   146  	}
   147  
   148  	return nil
   149  }
   150  
   151  func (cc *Session) handleHandshakeResponse(info HandshakeResponseInfo) error {
   152  	// check and set user
   153  	var password string
   154  	var succ bool
   155  	user := info.User
   156  	if !cc.manager.CheckUser(user) {
   157  		return mysql.NewDefaultError(mysql.ErrAccessDenied, user, cc.c.RemoteAddr().String(), "Yes")
   158  	}
   159  	cc.executor.user = user
   160  
   161  	// check password
   162  	if len(info.AuthPlugin) == 0 {
   163  		if len(info.AuthResponse) == 32 {
   164  			succ, password = cc.manager.CheckSha2Password(user, info.Salt, info.AuthResponse)
   165  		} else {
   166  			succ, password = cc.manager.CheckPassword(user, info.Salt, info.AuthResponse)
   167  		}
   168  	} else if info.AuthPlugin == mysql.CachingSHA2Password {
   169  		succ, password = cc.manager.CheckSha2Password(user, info.Salt, info.AuthResponse)
   170  	} else {
   171  		succ, password = cc.manager.CheckPassword(user, info.Salt, info.AuthResponse)
   172  	}
   173  
   174  	if !succ {
   175  		return mysql.NewDefaultError(mysql.ErrAccessDenied, user, cc.c.RemoteAddr().String(), "Yes")
   176  	}
   177  
   178  	// handle collation
   179  	collationID := info.CollationID
   180  	collationName, ok := mysql.Collations[mysql.CollationID(collationID)]
   181  	if !ok {
   182  		return mysql.NewError(mysql.ErrInternal, "invalid collation")
   183  	}
   184  	charset, ok := mysql.CollationNameToCharset[collationName]
   185  	if !ok {
   186  		return mysql.NewError(mysql.ErrInternal, "invalid collation")
   187  	}
   188  	cc.executor.SetCollationID(mysql.CollationID(collationID))
   189  	cc.executor.SetCharset(charset)
   190  
   191  	// set database
   192  	cc.executor.SetDatabase(info.Database)
   193  
   194  	// set namespace
   195  	namespace := cc.manager.GetNamespaceByUser(user, password)
   196  	cc.namespace = namespace
   197  	cc.executor.namespace = namespace
   198  	cc.c.namespace = namespace // TODO: remove it when refactor is done
   199  	return nil
   200  }
   201  
   202  // Close close session with it's resources
   203  func (cc *Session) Close() {
   204  	if cc.IsClosed() {
   205  		return
   206  	}
   207  	cc.closed.Store(true)
   208  	if err := cc.executor.rollback(); err != nil {
   209  		log.Warn("executor rollback error when Session close: %v", err)
   210  	}
   211  	cc.c.Close()
   212  	log.Debug("client closed, %d", cc.c.GetConnectionID())
   213  
   214  	return
   215  }
   216  
   217  // IsClosed check if closed
   218  func (cc *Session) IsClosed() bool {
   219  	return cc.closed.Load().(bool)
   220  }
   221  
   222  // Run start session to server client request packets
   223  func (cc *Session) Run() {
   224  	defer func() {
   225  		r := recover()
   226  		if err, ok := r.(error); ok {
   227  			const size = 4096
   228  			buf := make([]byte, size)
   229  			buf = buf[:runtime.Stack(buf, false)]
   230  
   231  			log.Warn("[server] Session Run panic error, error: %s, stack: %s", err.Error(), string(buf))
   232  		}
   233  		cc.Close()
   234  		cc.proxy.tw.Remove(cc)
   235  		cc.manager.GetStatisticManager().DescSessionCount(cc.namespace)
   236  	}()
   237  
   238  	cc.manager.GetStatisticManager().IncrSessionCount(cc.namespace)
   239  
   240  	for !cc.IsClosed() {
   241  		cc.c.SetSequence(0)
   242  		data, err := cc.c.ReadEphemeralPacket()
   243  		if err != nil {
   244  			cc.c.RecycleReadPacket()
   245  			return
   246  		}
   247  
   248  		cc.proxy.tw.Add(cc.proxy.sessionTimeout, cc, cc.Close)
   249  		cc.manager.GetStatisticManager().AddReadFlowCount(cc.namespace, len(data))
   250  
   251  		cmd := data[0]
   252  		data = data[1:]
   253  		rs := cc.executor.ExecuteCommand(cmd, data)
   254  		cc.c.RecycleReadPacket()
   255  
   256  		if err = cc.writeResponse(rs); err != nil {
   257  			log.Warn("Session write response error, connId: %d, err: %v", cc.c.GetConnectionID(), err)
   258  			cc.Close()
   259  			return
   260  		}
   261  
   262  		if cmd == mysql.ComQuit {
   263  			cc.Close()
   264  		}
   265  	}
   266  }
   267  
   268  func (cc *Session) writeResponse(r Response) error {
   269  	switch r.RespType {
   270  	case RespEOF:
   271  		return cc.c.writeEOFPacket(r.Status)
   272  	case RespResult:
   273  		rs := r.Data.(*mysql.Result)
   274  		if rs == nil {
   275  			return cc.c.writeOK(r.Status)
   276  		}
   277  		return cc.c.writeOKResult(r.Status, r.Data.(*mysql.Result))
   278  	case RespPrepare:
   279  		stmt := r.Data.(*Stmt)
   280  		if stmt == nil {
   281  			return cc.c.writeOK(r.Status)
   282  		}
   283  		return cc.c.writePrepareResponse(r.Status, stmt)
   284  	case RespFieldList:
   285  		rs := r.Data.([]*mysql.Field)
   286  		if rs == nil {
   287  			return cc.c.writeOK(r.Status)
   288  		}
   289  		return cc.c.writeFieldList(r.Status, rs)
   290  	case RespError:
   291  		rs := r.Data.(error)
   292  		if rs == nil {
   293  			return cc.c.writeOK(r.Status)
   294  		}
   295  		err := cc.c.writeErrorPacket(rs)
   296  		if err != nil {
   297  			return err
   298  		}
   299  		if rs == mysql.ErrBadConn { // 后端连接如果断开, 应该返回通知Session关闭
   300  			return rs
   301  		}
   302  		return nil
   303  	case RespOK:
   304  		return cc.c.writeOK(r.Status)
   305  	case RespNoop:
   306  		return nil
   307  	default:
   308  		err := fmt.Errorf("invalid response type: %T", r)
   309  		log.Fatal(err.Error())
   310  		return cc.c.writeErrorPacket(err)
   311  	}
   312  }