github.com/matrixorigin/matrixone@v0.7.0/pkg/frontend/protocol.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package frontend
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"fmt"
    21  	"math"
    22  	"net"
    23  	"sync"
    24  	"sync/atomic"
    25  
    26  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    27  
    28  	"github.com/fagongzi/goetty/v2"
    29  	"github.com/matrixorigin/matrixone/pkg/logutil"
    30  )
    31  
    32  // Response Categories
    33  const (
    34  	// OkResponse OK message
    35  	OkResponse = iota
    36  	// ErrorResponse Error message
    37  	ErrorResponse
    38  	// EoFResponse EOF message
    39  	EoFResponse
    40  	// ResultResponse result message
    41  	ResultResponse
    42  	// LocalInfileRequest local infile message
    43  	LocalInfileRequest
    44  )
    45  
    46  type Request struct {
    47  	//the command type from the client
    48  	cmd CommandType
    49  	// sequence num
    50  	seq uint8
    51  	//the data from the client
    52  	data interface{}
    53  }
    54  
    55  func (req *Request) GetData() interface{} {
    56  	return req.data
    57  }
    58  
    59  func (req *Request) SetData(data interface{}) {
    60  	req.data = data
    61  }
    62  
    63  func (req *Request) GetCmd() CommandType {
    64  	return req.cmd
    65  }
    66  
    67  func (req *Request) SetCmd(cmd CommandType) {
    68  	req.cmd = cmd
    69  }
    70  
    71  type Response struct {
    72  	//the category of the response
    73  	category int
    74  	//the status of executing the peer request
    75  	status uint16
    76  	//the command type which generates the response
    77  	cmd int
    78  	//the data of the response
    79  	data interface{}
    80  
    81  	/*
    82  		ok response
    83  	*/
    84  	affectedRows, lastInsertId uint64
    85  	warnings                   uint16
    86  }
    87  
    88  func NewResponse(category int, status uint16, cmd int, d interface{}) *Response {
    89  	return &Response{
    90  		category: category,
    91  		status:   status,
    92  		cmd:      cmd,
    93  		data:     d,
    94  	}
    95  }
    96  
    97  func NewGeneralErrorResponse(cmd CommandType, err error) *Response {
    98  	return NewResponse(ErrorResponse, 0, int(cmd), err)
    99  }
   100  
   101  func NewGeneralOkResponse(cmd CommandType) *Response {
   102  	return NewResponse(OkResponse, 0, int(cmd), nil)
   103  }
   104  
   105  func NewOkResponse(affectedRows, lastInsertId uint64, warnings, status uint16, cmd int, d interface{}) *Response {
   106  	resp := &Response{
   107  		category:     OkResponse,
   108  		status:       status,
   109  		cmd:          cmd,
   110  		data:         d,
   111  		affectedRows: affectedRows,
   112  		lastInsertId: lastInsertId,
   113  		warnings:     warnings,
   114  	}
   115  
   116  	return resp
   117  }
   118  
   119  func (resp *Response) GetData() interface{} {
   120  	return resp.data
   121  }
   122  
   123  func (resp *Response) SetData(data interface{}) {
   124  	resp.data = data
   125  }
   126  
   127  func (resp *Response) GetStatus() uint16 {
   128  	return resp.status
   129  }
   130  
   131  func (resp *Response) SetStatus(status uint16) {
   132  	resp.status = status
   133  }
   134  
   135  func (resp *Response) GetCategory() int {
   136  	return resp.category
   137  }
   138  
   139  func (resp *Response) SetCategory(category int) {
   140  	resp.category = category
   141  }
   142  
   143  type Protocol interface {
   144  	profile
   145  	IsEstablished() bool
   146  
   147  	SetEstablished()
   148  
   149  	// GetRequest gets Request from Packet
   150  	GetRequest(payload []byte) *Request
   151  
   152  	// SendResponse sends a response to the client for the application request
   153  	SendResponse(context.Context, *Response) error
   154  
   155  	// ConnectionID the identity of the client
   156  	ConnectionID() uint32
   157  
   158  	// Peer gets the address [Host:Port,Host:Port] of the client and the server
   159  	Peer() (string, string, string, string)
   160  
   161  	GetDatabaseName() string
   162  
   163  	SetDatabaseName(string)
   164  
   165  	GetUserName() string
   166  
   167  	SetUserName(string)
   168  
   169  	GetSequenceId() uint8
   170  
   171  	SetSequenceID(value uint8)
   172  
   173  	GetConciseProfile() string
   174  
   175  	GetTcpConnection() goetty.IOSession
   176  
   177  	GetCapability() uint32
   178  
   179  	IsTlsEstablished() bool
   180  
   181  	SetTlsEstablished()
   182  
   183  	HandleHandshake(ctx context.Context, payload []byte) (bool, error)
   184  
   185  	SendPrepareResponse(ctx context.Context, stmt *PrepareStmt) error
   186  
   187  	Quit()
   188  }
   189  
   190  type ProtocolImpl struct {
   191  	m sync.Mutex
   192  
   193  	io IOPackage
   194  
   195  	tcpConn goetty.IOSession
   196  
   197  	quit atomic.Bool
   198  
   199  	//random bytes
   200  	salt []byte
   201  
   202  	//the id of the connection
   203  	connectionID uint32
   204  
   205  	// whether the handshake succeeded
   206  	established atomic.Bool
   207  
   208  	// whether the tls handshake succeeded
   209  	tlsEstablished atomic.Bool
   210  
   211  	//The sequence-id is incremented with each packet and may wrap around.
   212  	//It starts at 0 and is reset to 0 when a new command begins in the Command Phase.
   213  	sequenceId atomic.Uint32
   214  
   215  	profiles [8]string
   216  }
   217  
   218  func (pi *ProtocolImpl) setQuit(b bool) bool {
   219  	return pi.quit.Swap(b)
   220  }
   221  
   222  func (pi *ProtocolImpl) GetSequenceId() uint8 {
   223  	return uint8(pi.sequenceId.Load())
   224  }
   225  
   226  func (pi *ProtocolImpl) SetSequenceID(value uint8) {
   227  	pi.sequenceId.Store(uint32(value))
   228  }
   229  
   230  func (pi *ProtocolImpl) makeProfile(profileTyp profileType) {
   231  	var mask profileType
   232  	var profile string
   233  	for i := uint8(0); i < 8; i++ {
   234  		mask = 1 << i
   235  		switch mask & profileTyp {
   236  		case profileTypeConnectionWithId:
   237  			if pi.tcpConn != nil {
   238  				profile = fmt.Sprintf("connectionId %d", pi.connectionID)
   239  			}
   240  		case profileTypeConnectionWithIp:
   241  			if pi.tcpConn != nil {
   242  				client := pi.tcpConn.RemoteAddress()
   243  				profile = "client " + client
   244  			}
   245  		default:
   246  			profile = ""
   247  		}
   248  		pi.profiles[i] = profile
   249  	}
   250  }
   251  
   252  func (pi *ProtocolImpl) getProfile(profileTyp profileType) string {
   253  	var mask profileType
   254  	sb := bytes.Buffer{}
   255  	for i := uint8(0); i < 8; i++ {
   256  		mask = 1 << i
   257  		if mask&profileTyp != 0 {
   258  			if sb.Len() != 0 {
   259  				sb.WriteByte(' ')
   260  			}
   261  			sb.WriteString(pi.profiles[i])
   262  		}
   263  	}
   264  	return sb.String()
   265  }
   266  
   267  func (pi *ProtocolImpl) MakeProfile() {
   268  	pi.m.Lock()
   269  	defer pi.m.Unlock()
   270  	pi.makeProfile(profileTypeAll)
   271  }
   272  
   273  func (pi *ProtocolImpl) GetConciseProfile() string {
   274  	pi.m.Lock()
   275  	defer pi.m.Unlock()
   276  	return pi.getProfile(profileTypeConcise)
   277  }
   278  
   279  func (pi *ProtocolImpl) GetSalt() []byte {
   280  	pi.m.Lock()
   281  	defer pi.m.Unlock()
   282  	return pi.salt
   283  }
   284  
   285  func (pi *ProtocolImpl) IsEstablished() bool {
   286  	return pi.established.Load()
   287  }
   288  
   289  func (pi *ProtocolImpl) SetEstablished() {
   290  	logDebugf(pi.GetConciseProfile(), "SWITCH ESTABLISHED to true")
   291  	pi.established.Store(true)
   292  }
   293  
   294  func (pi *ProtocolImpl) IsTlsEstablished() bool {
   295  	return pi.tlsEstablished.Load()
   296  }
   297  
   298  func (pi *ProtocolImpl) SetTlsEstablished() {
   299  	logutil.Debugf("SWITCH TLS_ESTABLISHED to true")
   300  	pi.tlsEstablished.Store(true)
   301  }
   302  
   303  func (pi *ProtocolImpl) ConnectionID() uint32 {
   304  	pi.m.Lock()
   305  	defer pi.m.Unlock()
   306  	return pi.connectionID
   307  }
   308  
   309  // Quit kill tcpConn still connected.
   310  // before calling NewMysqlClientProtocol, tcpConn.Connected() must be true
   311  // please check goetty/application.go::doStart() and goetty/application.go::NewIOSession(...) for details
   312  func (pi *ProtocolImpl) Quit() {
   313  	pi.m.Lock()
   314  	defer pi.m.Unlock()
   315  	//if it was quit, do nothing
   316  	if pi.setQuit(true) {
   317  		return
   318  	}
   319  	if pi.tcpConn != nil {
   320  		if err := pi.tcpConn.Disconnect(); err != nil {
   321  			return
   322  		}
   323  	}
   324  	//release salt
   325  	if pi.salt != nil {
   326  		pi.salt = nil
   327  	}
   328  }
   329  
   330  func (pi *ProtocolImpl) GetLock() sync.Locker {
   331  	return &pi.m
   332  }
   333  
   334  func (pi *ProtocolImpl) GetTcpConnection() goetty.IOSession {
   335  	pi.m.Lock()
   336  	defer pi.m.Unlock()
   337  	return pi.tcpConn
   338  }
   339  
   340  func (pi *ProtocolImpl) Peer() (string, string, string, string) {
   341  	tcp := pi.GetTcpConnection()
   342  	if tcp == nil {
   343  		return "", "", "", ""
   344  	}
   345  	addr := tcp.RemoteAddress()
   346  	rawConn := tcp.RawConn()
   347  	var local net.Addr
   348  	if rawConn != nil {
   349  		local = rawConn.LocalAddr()
   350  	}
   351  	host, port, err := net.SplitHostPort(addr)
   352  	if err != nil {
   353  		logutil.Errorf("get peer host:port failed. error:%v ", err)
   354  		return "failed", "0", "", ""
   355  	}
   356  	localHost, localPort, err := net.SplitHostPort(local.String())
   357  	if err != nil {
   358  		logutil.Errorf("get peer host:port failed. error:%v ", err)
   359  		return "failed", "0", "failed", "0"
   360  	}
   361  	return host, port, localHost, localPort
   362  }
   363  
   364  func (mp *MysqlProtocolImpl) GetRequest(payload []byte) *Request {
   365  	req := &Request{
   366  		cmd:  CommandType(payload[0]),
   367  		data: payload[1:],
   368  	}
   369  
   370  	return req
   371  }
   372  
   373  func (mp *MysqlProtocolImpl) getAbortTransactionErrorInfo() string {
   374  	ses := mp.GetSession()
   375  	//update error message in Case1,Case3,Case4.
   376  	if ses != nil && ses.OptionBitsIsSet(OPTION_ATTACH_ABORT_TRANSACTION_ERROR) {
   377  		ses.ClearOptionBits(OPTION_ATTACH_ABORT_TRANSACTION_ERROR)
   378  		return abortTransactionErrorInfo()
   379  	}
   380  	return ""
   381  }
   382  
   383  func (mp *MysqlProtocolImpl) SendResponse(ctx context.Context, resp *Response) error {
   384  	mp.GetLock().Lock()
   385  	defer mp.GetLock().Unlock()
   386  
   387  	switch resp.category {
   388  	case OkResponse:
   389  		s, ok := resp.data.(string)
   390  		if !ok {
   391  			return mp.sendOKPacket(resp.affectedRows, resp.lastInsertId, uint16(resp.status), resp.warnings, "")
   392  		}
   393  		return mp.sendOKPacket(resp.affectedRows, resp.lastInsertId, uint16(resp.status), resp.warnings, s)
   394  	case EoFResponse:
   395  		return mp.sendEOFPacket(0, uint16(resp.status))
   396  	case ErrorResponse:
   397  		err := resp.data.(error)
   398  		if err == nil {
   399  			return mp.sendOKPacket(0, 0, uint16(resp.status), 0, "")
   400  		}
   401  		attachAbort := mp.getAbortTransactionErrorInfo()
   402  		switch myerr := err.(type) {
   403  		case *moerr.Error:
   404  			var code uint16
   405  			if myerr.MySQLCode() != moerr.ER_UNKNOWN_ERROR {
   406  				code = myerr.MySQLCode()
   407  			} else {
   408  				code = myerr.ErrorCode()
   409  			}
   410  			errMsg := myerr.Error()
   411  			if attachAbort != "" {
   412  				errMsg = fmt.Sprintf("%s\n%s", myerr.Error(), attachAbort)
   413  			}
   414  			return mp.sendErrPacket(code, myerr.SqlState(), errMsg)
   415  		}
   416  		errMsg := ""
   417  		if attachAbort != "" {
   418  			errMsg = fmt.Sprintf("%s\n%s", err, attachAbort)
   419  		} else {
   420  			errMsg = fmt.Sprintf("%v", err)
   421  		}
   422  		return mp.sendErrPacket(moerr.ER_UNKNOWN_ERROR, DefaultMySQLState, errMsg)
   423  	case ResultResponse:
   424  		mer := resp.data.(*MysqlExecutionResult)
   425  		if mer == nil {
   426  			return mp.sendOKPacket(0, 0, uint16(resp.status), 0, "")
   427  		}
   428  		if mer.Mrs() == nil {
   429  			return mp.sendOKPacket(mer.AffectedRows(), mer.InsertID(), uint16(resp.status), mer.Warnings(), "")
   430  		}
   431  		return mp.sendResultSet(ctx, mer.Mrs(), resp.cmd, mer.Warnings(), uint16(resp.status))
   432  	case LocalInfileRequest:
   433  		s, _ := resp.data.(string)
   434  		return mp.sendLocalInfileRequest(s)
   435  	default:
   436  		return moerr.NewInternalError(ctx, "unsupported response:%d ", resp.category)
   437  	}
   438  }
   439  
   440  var _ MysqlProtocol = &FakeProtocol{}
   441  
   442  const (
   443  	fakeConnectionID uint32 = math.MaxUint32
   444  )
   445  
   446  // FakeProtocol works for the background transaction that does not use the network protocol.
   447  type FakeProtocol struct {
   448  	username string
   449  	database string
   450  	ioses    goetty.IOSession
   451  }
   452  
   453  func (fp *FakeProtocol) GetCapability() uint32 {
   454  	return DefaultCapability
   455  }
   456  
   457  func (fp *FakeProtocol) IsTlsEstablished() bool {
   458  	return true
   459  }
   460  
   461  func (fp *FakeProtocol) SetTlsEstablished() {
   462  
   463  }
   464  
   465  func (fp *FakeProtocol) HandleHandshake(ctx context.Context, payload []byte) (bool, error) {
   466  	return false, nil
   467  }
   468  
   469  func (fp *FakeProtocol) GetTcpConnection() goetty.IOSession {
   470  	return fp.ioses
   471  }
   472  
   473  func (fp *FakeProtocol) GetConciseProfile() string {
   474  	return "fake protocol"
   475  }
   476  
   477  func (fp *FakeProtocol) GetSequenceId() uint8 {
   478  	return 0
   479  }
   480  
   481  func (fp *FakeProtocol) SetSequenceID(value uint8) {
   482  }
   483  
   484  func (fp *FakeProtocol) makeProfile(profileTyp profileType) {
   485  }
   486  
   487  func (fp *FakeProtocol) getProfile(profileTyp profileType) string {
   488  	return ""
   489  }
   490  
   491  func (fp *FakeProtocol) SendPrepareResponse(ctx context.Context, stmt *PrepareStmt) error {
   492  	return nil
   493  }
   494  
   495  func (fp *FakeProtocol) ParseExecuteData(ctx context.Context, stmt *PrepareStmt, data []byte, pos int) (names []string, vars []any, err error) {
   496  	return nil, nil, nil
   497  }
   498  
   499  func (fp *FakeProtocol) SendResultSetTextBatchRow(mrs *MysqlResultSet, cnt uint64) error {
   500  	return nil
   501  }
   502  
   503  func (fp *FakeProtocol) SendResultSetTextBatchRowSpeedup(mrs *MysqlResultSet, cnt uint64) error {
   504  	return nil
   505  }
   506  
   507  func (fp *FakeProtocol) SendColumnDefinitionPacket(ctx context.Context, column Column, cmd int) error {
   508  	return nil
   509  }
   510  
   511  func (fp *FakeProtocol) SendColumnCountPacket(count uint64) error {
   512  	return nil
   513  }
   514  
   515  func (fp *FakeProtocol) SendEOFPacketIf(warnings uint16, status uint16) error {
   516  	return nil
   517  }
   518  
   519  func (fp *FakeProtocol) sendOKPacket(affectedRows uint64, lastInsertId uint64, status uint16, warnings uint16, message string) error {
   520  	return nil
   521  }
   522  
   523  func (fp *FakeProtocol) sendEOFOrOkPacket(warnings uint16, status uint16) error {
   524  	return nil
   525  }
   526  
   527  func (fp *FakeProtocol) ResetStatistics() {}
   528  
   529  func (fp *FakeProtocol) GetStats() string {
   530  	return ""
   531  }
   532  
   533  func (fp *FakeProtocol) IsEstablished() bool {
   534  	return true
   535  }
   536  
   537  func (fp *FakeProtocol) SetEstablished() {}
   538  
   539  func (fp *FakeProtocol) GetRequest(payload []byte) *Request {
   540  	return nil
   541  }
   542  
   543  func (fp *FakeProtocol) SendResponse(ctx context.Context, resp *Response) error {
   544  	return nil
   545  }
   546  
   547  func (fp *FakeProtocol) ConnectionID() uint32 {
   548  	return fakeConnectionID
   549  }
   550  
   551  func (fp *FakeProtocol) Peer() (string, string, string, string) {
   552  	return "", "", "", ""
   553  }
   554  
   555  func (fp *FakeProtocol) GetDatabaseName() string {
   556  	return fp.database
   557  }
   558  
   559  func (fp *FakeProtocol) SetDatabaseName(s string) {
   560  	fp.database = s
   561  }
   562  
   563  func (fp *FakeProtocol) GetUserName() string {
   564  	return fp.username
   565  }
   566  
   567  func (fp *FakeProtocol) SetUserName(s string) {
   568  	fp.username = s
   569  }
   570  
   571  func (fp *FakeProtocol) Quit() {}
   572  
   573  func (fp *FakeProtocol) sendLocalInfileRequest(filename string) error {
   574  	return nil
   575  }