github.com/matrixorigin/matrixone@v1.2.0/pkg/frontend/protocol.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package frontend
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"math"
    21  	"sync"
    22  	"sync/atomic"
    23  
    24  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    25  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    26  
    27  	"github.com/fagongzi/goetty/v2"
    28  
    29  	"github.com/matrixorigin/matrixone/pkg/logutil"
    30  )
    31  
    32  // Response Categories
    33  const (
    34  	// OkResponse OK message
    35  	OkResponse = iota
    36  	// ErrorResponse Error message
    37  	ErrorResponse
    38  	// EoFResponse EOF message
    39  	EoFResponse
    40  	// ResultResponse result message
    41  	ResultResponse
    42  	// LocalInfileRequest local infile message
    43  	LocalInfileRequest
    44  )
    45  
    46  type Request struct {
    47  	//the command type from the client
    48  	cmd CommandType
    49  	// sequence num
    50  	seq uint8
    51  	//the data from the client
    52  	data interface{}
    53  }
    54  
    55  func (req *Request) GetData() interface{} {
    56  	return req.data
    57  }
    58  
    59  func (req *Request) SetData(data interface{}) {
    60  	req.data = data
    61  }
    62  
    63  func (req *Request) GetCmd() CommandType {
    64  	return req.cmd
    65  }
    66  
    67  func (req *Request) SetCmd(cmd CommandType) {
    68  	req.cmd = cmd
    69  }
    70  
    71  type Response struct {
    72  	//the category of the response
    73  	category int
    74  	//the status of executing the peer request
    75  	status uint16
    76  	//the command type which generates the response
    77  	cmd int
    78  	//the data of the response
    79  	data interface{}
    80  
    81  	/*
    82  		ok response
    83  	*/
    84  	affectedRows, lastInsertId uint64
    85  	warnings                   uint16
    86  }
    87  
    88  func NewResponse(category int, affectedRows, lastInsertId uint64, warnings, status uint16, cmd int, d interface{}) *Response {
    89  	return &Response{
    90  		category:     category,
    91  		affectedRows: affectedRows,
    92  		lastInsertId: lastInsertId,
    93  		warnings:     warnings,
    94  		status:       status,
    95  		cmd:          cmd,
    96  		data:         d,
    97  	}
    98  }
    99  
   100  func NewGeneralErrorResponse(cmd CommandType, status uint16, err error) *Response {
   101  	return NewResponse(ErrorResponse, 0, 0, 0, status, int(cmd), err)
   102  }
   103  
   104  func NewGeneralOkResponse(cmd CommandType, status uint16) *Response {
   105  	return NewResponse(OkResponse, 0, 0, 0, status, int(cmd), nil)
   106  }
   107  
   108  func NewOkResponse(affectedRows, lastInsertId uint64, warnings, status uint16, cmd int, d interface{}) *Response {
   109  	return NewResponse(OkResponse, affectedRows, lastInsertId, warnings, status, cmd, d)
   110  }
   111  
   112  func (resp *Response) GetData() interface{} {
   113  	return resp.data
   114  }
   115  
   116  func (resp *Response) SetData(data interface{}) {
   117  	resp.data = data
   118  }
   119  
   120  func (resp *Response) GetStatus() uint16 {
   121  	return resp.status
   122  }
   123  
   124  func (resp *Response) SetStatus(status uint16) {
   125  	resp.status = status
   126  }
   127  
   128  func (resp *Response) GetCategory() int {
   129  	return resp.category
   130  }
   131  
   132  func (resp *Response) SetCategory(category int) {
   133  	resp.category = category
   134  }
   135  
   136  type Protocol interface {
   137  	IsEstablished() bool
   138  
   139  	SetEstablished()
   140  
   141  	// GetRequest gets Request from Packet
   142  	GetRequest(payload []byte) *Request
   143  
   144  	// SendResponse sends a response to the client for the application request
   145  	SendResponse(context.Context, *Response) error
   146  
   147  	// ConnectionID the identity of the client
   148  	ConnectionID() uint32
   149  
   150  	// Peer gets the address [Host:Port,Host:Port] of the client and the server
   151  	Peer() string
   152  
   153  	GetDatabaseName() string
   154  
   155  	SetDatabaseName(string)
   156  
   157  	GetUserName() string
   158  
   159  	SetUserName(string)
   160  
   161  	GetSequenceId() uint8
   162  
   163  	SetSequenceID(value uint8)
   164  
   165  	GetDebugString() string
   166  
   167  	GetTcpConnection() goetty.IOSession
   168  
   169  	GetCapability() uint32
   170  
   171  	SetCapability(uint32)
   172  
   173  	GetConnectAttrs() map[string]string
   174  
   175  	IsTlsEstablished() bool
   176  
   177  	SetTlsEstablished()
   178  
   179  	HandleHandshake(ctx context.Context, payload []byte) (bool, error)
   180  
   181  	Authenticate(ctx context.Context) error
   182  
   183  	SendPrepareResponse(ctx context.Context, stmt *PrepareStmt) error
   184  
   185  	Quit()
   186  
   187  	incDebugCount(int)
   188  
   189  	resetDebugCount() []uint64
   190  
   191  	UpdateCtx(context.Context)
   192  }
   193  
   194  type ProtocolImpl struct {
   195  	m sync.Mutex
   196  
   197  	io IOPackage
   198  
   199  	tcpConn goetty.IOSession
   200  
   201  	quit atomic.Bool
   202  
   203  	//random bytes
   204  	salt []byte
   205  
   206  	//the id of the connection
   207  	connectionID uint32
   208  
   209  	// whether the handshake succeeded
   210  	established atomic.Bool
   211  
   212  	// whether the tls handshake succeeded
   213  	tlsEstablished atomic.Bool
   214  
   215  	//The sequence-id is incremented with each packet and may wrap around.
   216  	//It starts at 0 and is reset to 0 when a new command begins in the Command Phase.
   217  	sequenceId atomic.Uint32
   218  
   219  	//for debug
   220  	debugCount [16]uint64
   221  
   222  	ctx context.Context
   223  }
   224  
   225  func (pi *ProtocolImpl) UpdateCtx(ctx context.Context) {
   226  	pi.ctx = ctx
   227  }
   228  
   229  func (pi *ProtocolImpl) incDebugCount(i int) {
   230  	if i >= 0 && i < len(pi.debugCount) {
   231  		atomic.AddUint64(&pi.debugCount[i], 1)
   232  	}
   233  }
   234  
   235  func (pi *ProtocolImpl) resetDebugCount() []uint64 {
   236  	ret := make([]uint64, len(pi.debugCount))
   237  	for i := 0; i < len(pi.debugCount); i++ {
   238  		ret[i] = atomic.LoadUint64(&pi.debugCount[i])
   239  	}
   240  	return ret
   241  }
   242  
   243  func (pi *ProtocolImpl) setQuit(b bool) bool {
   244  	return pi.quit.Swap(b)
   245  }
   246  
   247  func (pi *ProtocolImpl) GetSequenceId() uint8 {
   248  	return uint8(pi.sequenceId.Load())
   249  }
   250  
   251  func (pi *ProtocolImpl) getDebugStringUnsafe() string {
   252  	if pi.tcpConn != nil {
   253  		return fmt.Sprintf("connectionId %d|%s", pi.connectionID, pi.tcpConn.RemoteAddress())
   254  	}
   255  	return ""
   256  }
   257  
   258  func (pi *ProtocolImpl) GetDebugString() string {
   259  	pi.m.Lock()
   260  	defer pi.m.Unlock()
   261  	return pi.getDebugStringUnsafe()
   262  }
   263  
   264  func (pi *ProtocolImpl) GetSalt() []byte {
   265  	pi.m.Lock()
   266  	defer pi.m.Unlock()
   267  	return pi.salt
   268  }
   269  
   270  // SetSalt updates the salt value. This happens with proxy mode enabled.
   271  func (pi *ProtocolImpl) SetSalt(s []byte) {
   272  	pi.m.Lock()
   273  	defer pi.m.Unlock()
   274  	pi.salt = s
   275  }
   276  
   277  func (pi *ProtocolImpl) IsEstablished() bool {
   278  	return pi.established.Load()
   279  }
   280  
   281  func (pi *ProtocolImpl) SetEstablished() {
   282  	logDebugf(pi.GetDebugString(), "SWITCH ESTABLISHED to true")
   283  	pi.established.Store(true)
   284  }
   285  
   286  func (pi *ProtocolImpl) IsTlsEstablished() bool {
   287  	return pi.tlsEstablished.Load()
   288  }
   289  
   290  func (pi *ProtocolImpl) SetTlsEstablished() {
   291  	logutil.Debugf("SWITCH TLS_ESTABLISHED to true")
   292  	pi.tlsEstablished.Store(true)
   293  }
   294  
   295  func (pi *ProtocolImpl) ConnectionID() uint32 {
   296  	return pi.connectionID
   297  }
   298  
   299  // Quit kill tcpConn still connected.
   300  // before calling NewMysqlClientProtocol, tcpConn.Connected() must be true
   301  // please check goetty/application.go::doStart() and goetty/application.go::NewIOSession(...) for details
   302  func (pi *ProtocolImpl) Quit() {
   303  	//if it was quit, do nothing
   304  	if pi.setQuit(true) {
   305  		return
   306  	}
   307  	if pi.tcpConn != nil {
   308  		if err := pi.tcpConn.Disconnect(); err != nil {
   309  			return
   310  		}
   311  	}
   312  	//release salt
   313  	if pi.salt != nil {
   314  		pi.salt = nil
   315  	}
   316  }
   317  
   318  func (pi *ProtocolImpl) GetTcpConnection() goetty.IOSession {
   319  	return pi.tcpConn
   320  }
   321  
   322  func (pi *ProtocolImpl) Peer() string {
   323  	tcp := pi.GetTcpConnection()
   324  	if tcp == nil {
   325  		return ""
   326  	}
   327  	return tcp.RemoteAddress()
   328  }
   329  
   330  func (mp *MysqlProtocolImpl) GetRequest(payload []byte) *Request {
   331  	req := &Request{
   332  		cmd:  CommandType(payload[0]),
   333  		data: payload[1:],
   334  	}
   335  
   336  	return req
   337  }
   338  
   339  func (mp *MysqlProtocolImpl) SendResponse(ctx context.Context, resp *Response) error {
   340  	//move here to prohibit potential recursive lock
   341  	var attachAbort string
   342  
   343  	mp.m.Lock()
   344  	defer mp.m.Unlock()
   345  
   346  	switch resp.category {
   347  	case OkResponse:
   348  		s, ok := resp.data.(string)
   349  		if !ok {
   350  			return mp.sendOKPacket(resp.affectedRows, resp.lastInsertId, uint16(resp.status), resp.warnings, "")
   351  		}
   352  		return mp.sendOKPacket(resp.affectedRows, resp.lastInsertId, uint16(resp.status), resp.warnings, s)
   353  	case EoFResponse:
   354  		return mp.sendEOFPacket(0, uint16(resp.status))
   355  	case ErrorResponse:
   356  		err := resp.data.(error)
   357  		if err == nil {
   358  			return mp.sendOKPacket(0, 0, uint16(resp.status), 0, "")
   359  		}
   360  		switch myerr := err.(type) {
   361  		case *moerr.Error:
   362  			var code uint16
   363  			if myerr.MySQLCode() != moerr.ER_UNKNOWN_ERROR {
   364  				code = myerr.MySQLCode()
   365  			} else {
   366  				code = myerr.ErrorCode()
   367  			}
   368  			errMsg := myerr.Error()
   369  			if attachAbort != "" {
   370  				errMsg = fmt.Sprintf("%s\n%s", myerr.Error(), attachAbort)
   371  			}
   372  			return mp.sendErrPacket(code, myerr.SqlState(), errMsg)
   373  		}
   374  		errMsg := ""
   375  		if attachAbort != "" {
   376  			errMsg = fmt.Sprintf("%s\n%s", err, attachAbort)
   377  		} else {
   378  			errMsg = fmt.Sprintf("%v", err)
   379  		}
   380  		return mp.sendErrPacket(moerr.ER_UNKNOWN_ERROR, DefaultMySQLState, errMsg)
   381  	case ResultResponse:
   382  		mer := resp.data.(*MysqlExecutionResult)
   383  		if mer == nil {
   384  			return mp.sendOKPacket(0, 0, uint16(resp.status), 0, "")
   385  		}
   386  		if mer.Mrs() == nil {
   387  			return mp.sendOKPacket(mer.AffectedRows(), mer.InsertID(), uint16(resp.status), mer.Warnings(), "")
   388  		}
   389  		return mp.sendResultSet(ctx, mer.Mrs(), resp.cmd, mer.Warnings(), uint16(resp.status))
   390  	case LocalInfileRequest:
   391  		s, _ := resp.data.(string)
   392  		return mp.sendLocalInfileRequest(s)
   393  	default:
   394  		return moerr.NewInternalError(ctx, "unsupported response:%d ", resp.category)
   395  	}
   396  }
   397  
   398  func (mp *MysqlProtocolImpl) DisableAutoFlush() {
   399  	mp.disableAutoFlush = true
   400  }
   401  
   402  func (mp *MysqlProtocolImpl) EnableAutoFlush() {
   403  	mp.disableAutoFlush = false
   404  }
   405  
   406  func (mp *MysqlProtocolImpl) Flush() error {
   407  	return nil
   408  }
   409  
   410  var _ MysqlProtocol = &FakeProtocol{}
   411  
   412  const (
   413  	fakeConnectionID uint32 = math.MaxUint32
   414  )
   415  
   416  // FakeProtocol works for the background transaction that does not use the network protocol.
   417  type FakeProtocol struct {
   418  	username string
   419  	database string
   420  	ioses    goetty.IOSession
   421  }
   422  
   423  func (fp *FakeProtocol) UpdateCtx(ctx context.Context) {
   424  
   425  }
   426  
   427  func (fp *FakeProtocol) GetCapability() uint32 {
   428  	return DefaultCapability
   429  }
   430  
   431  func (fp *FakeProtocol) SetCapability(uint32) {
   432  
   433  }
   434  
   435  func (fp *FakeProtocol) IsTlsEstablished() bool {
   436  	return true
   437  }
   438  
   439  func (fp *FakeProtocol) SetTlsEstablished() {
   440  
   441  }
   442  
   443  func (fp *FakeProtocol) HandleHandshake(ctx context.Context, payload []byte) (bool, error) {
   444  	return false, nil
   445  }
   446  
   447  func (fp *FakeProtocol) Authenticate(ctx context.Context) error {
   448  	return nil
   449  }
   450  
   451  func (fp *FakeProtocol) GetTcpConnection() goetty.IOSession {
   452  	return fp.ioses
   453  }
   454  
   455  func (fp *FakeProtocol) GetDebugString() string {
   456  	return "fake protocol"
   457  }
   458  
   459  func (fp *FakeProtocol) GetSequenceId() uint8 {
   460  	return 0
   461  }
   462  
   463  func (fp *FakeProtocol) SetSequenceID(value uint8) {
   464  }
   465  
   466  func (fp *FakeProtocol) GetConnectAttrs() map[string]string {
   467  	return nil
   468  }
   469  
   470  func (fp *FakeProtocol) SendPrepareResponse(ctx context.Context, stmt *PrepareStmt) error {
   471  	return nil
   472  }
   473  
   474  func (fp *FakeProtocol) ParseSendLongData(ctx context.Context, proc *process.Process, stmt *PrepareStmt, data []byte, pos int) error {
   475  	return nil
   476  }
   477  
   478  func (fp *FakeProtocol) ParseExecuteData(ctx context.Context, proc *process.Process, stmt *PrepareStmt, data []byte, pos int) error {
   479  	return nil
   480  }
   481  
   482  func (fp *FakeProtocol) SendResultSetTextBatchRow(mrs *MysqlResultSet, cnt uint64) error {
   483  	return nil
   484  }
   485  
   486  func (fp *FakeProtocol) SendResultSetTextBatchRowSpeedup(mrs *MysqlResultSet, cnt uint64) error {
   487  	return nil
   488  }
   489  
   490  func (fp *FakeProtocol) SendColumnDefinitionPacket(ctx context.Context, column Column, cmd int) error {
   491  	return nil
   492  }
   493  
   494  func (fp *FakeProtocol) SendColumnCountPacket(count uint64) error {
   495  	return nil
   496  }
   497  
   498  func (fp *FakeProtocol) SendEOFPacketIf(warnings uint16, status uint16) error {
   499  	return nil
   500  }
   501  
   502  func (fp *FakeProtocol) sendOKPacket(affectedRows uint64, lastInsertId uint64, status uint16, warnings uint16, message string) error {
   503  	return nil
   504  }
   505  
   506  func (fp *FakeProtocol) sendEOFOrOkPacket(warnings uint16, status uint16) error {
   507  	return nil
   508  }
   509  
   510  func (fp *FakeProtocol) ResetStatistics() {}
   511  
   512  func (fp *FakeProtocol) GetStats() string {
   513  	return ""
   514  }
   515  
   516  func (fp *FakeProtocol) CalculateOutTrafficBytes(reset bool) (int64, int64) { return 0, 0 }
   517  
   518  func (fp *FakeProtocol) IsEstablished() bool {
   519  	return true
   520  }
   521  
   522  func (fp *FakeProtocol) SetEstablished() {}
   523  
   524  func (fp *FakeProtocol) GetRequest(payload []byte) *Request {
   525  	return nil
   526  }
   527  
   528  func (fp *FakeProtocol) SendResponse(ctx context.Context, resp *Response) error {
   529  	return nil
   530  }
   531  
   532  func (fp *FakeProtocol) ConnectionID() uint32 {
   533  	return fakeConnectionID
   534  }
   535  
   536  func (fp *FakeProtocol) Peer() string {
   537  	return "0.0.0.0:0"
   538  }
   539  
   540  func (fp *FakeProtocol) GetDatabaseName() string {
   541  	return fp.database
   542  }
   543  
   544  func (fp *FakeProtocol) SetDatabaseName(s string) {
   545  	fp.database = s
   546  }
   547  
   548  func (fp *FakeProtocol) GetUserName() string {
   549  	return fp.username
   550  }
   551  
   552  func (fp *FakeProtocol) SetUserName(s string) {
   553  	fp.username = s
   554  }
   555  
   556  func (fp *FakeProtocol) Quit() {}
   557  
   558  func (fp *FakeProtocol) sendLocalInfileRequest(filename string) error {
   559  	return nil
   560  }
   561  
   562  func (fp *FakeProtocol) incDebugCount(int) {}
   563  
   564  func (fp *FakeProtocol) resetDebugCount() []uint64 {
   565  	return nil
   566  }
   567  
   568  func (fp *FakeProtocol) DisableAutoFlush() {
   569  }
   570  
   571  func (fp *FakeProtocol) EnableAutoFlush() {
   572  }
   573  
   574  func (fp *FakeProtocol) Flush() error {
   575  	return nil
   576  }