github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/client_conn.go (about)

     1  // Copyright 2021 - 2023 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxy
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"errors"
    21  	"fmt"
    22  	"net"
    23  	"strconv"
    24  	"strings"
    25  	"sync/atomic"
    26  	"time"
    27  
    28  	"github.com/fagongzi/goetty/v2"
    29  	"go.uber.org/zap"
    30  
    31  	"github.com/matrixorigin/matrixone/pkg/clusterservice"
    32  	"github.com/matrixorigin/matrixone/pkg/common/log"
    33  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    34  	"github.com/matrixorigin/matrixone/pkg/common/morpc"
    35  	"github.com/matrixorigin/matrixone/pkg/config"
    36  	"github.com/matrixorigin/matrixone/pkg/frontend"
    37  	"github.com/matrixorigin/matrixone/pkg/logservice"
    38  	qclient "github.com/matrixorigin/matrixone/pkg/queryservice/client"
    39  	v2 "github.com/matrixorigin/matrixone/pkg/util/metric/v2"
    40  )
    41  
    42  // clientBaseConnID is the base connection ID for client.
    43  var clientBaseConnID uint32 = 1000
    44  
    45  // parse parses the account information from whole username.
    46  // The whole username parameter is like: tenant1:user1:role1?key1:value1,key2:value2
    47  func (c *clientInfo) parse(full string) error {
    48  	var labelPart string
    49  	labelDelPos := strings.IndexByte(full, '?')
    50  	userPart := full[:]
    51  	if labelDelPos >= 0 {
    52  		userPart = full[:labelDelPos]
    53  		if len(full) > labelDelPos+1 {
    54  			labelPart = full[labelDelPos+1:]
    55  		}
    56  	}
    57  	tenant, err := frontend.GetTenantInfo(context.Background(), userPart)
    58  	if err != nil {
    59  		return err
    60  	}
    61  	c.labelInfo.Tenant = Tenant(tenant.GetTenant())
    62  	c.username = tenant.GetUser()
    63  
    64  	// For label part.
    65  	if len(labelPart) > 0 {
    66  		labels, err := frontend.ParseLabel(strings.TrimSpace(labelPart))
    67  		if err != nil {
    68  			return err
    69  		}
    70  		c.labelInfo.Labels = labels
    71  	}
    72  	return nil
    73  }
    74  
    75  // ClientConn is the connection to the client.
    76  type ClientConn interface {
    77  	// ConnID returns the connection ID.
    78  	ConnID() uint32
    79  	// GetSalt returns the salt value of this connection.
    80  	GetSalt() []byte
    81  	// GetHandshakePack returns the handshake response packet
    82  	// which is received from client.
    83  	GetHandshakePack() *frontend.Packet
    84  	// RawConn return the raw connection.
    85  	RawConn() net.Conn
    86  	// GetTenant returns the tenant which this connection belongs to.
    87  	GetTenant() Tenant
    88  	// SendErrToClient sends access error to MySQL client.
    89  	SendErrToClient(err error)
    90  	// BuildConnWithServer selects a CN server and connects to it, then
    91  	// returns the connection. If sendToClient is true, means that the
    92  	// packet received from CN server should be sent to client because
    93  	// it is the first time to build connection and login. And if it is
    94  	// false, means that the packet received from CN server should NOT
    95  	// be sent to client because we are transferring CN server connection,
    96  	// and it is not the first time to build connection and login has been
    97  	// finished already.
    98  	// prevAddr is empty if it is the first time to build connection with
    99  	// a cn server; otherwise, it is the address of the previous cn node
   100  	// when it is transferring connection and the handshake phase is ignored.
   101  	BuildConnWithServer(prevAddr string) (ServerConn, error)
   102  	// HandleEvent handles event that comes from tunnel data flow.
   103  	HandleEvent(ctx context.Context, e IEvent, resp chan<- []byte) error
   104  	// Close closes the client connection.
   105  	Close() error
   106  }
   107  
   108  type migration struct {
   109  	setVarStmts []string
   110  }
   111  
   112  // clientConn is the connection between proxy and client.
   113  type clientConn struct {
   114  	ctx context.Context
   115  	log *log.MOLogger
   116  	// counterSet counts the events in proxy.
   117  	counterSet *counterSet
   118  	// conn is the raw TCP connection between proxy and client.
   119  	conn goetty.IOSession
   120  	// mysqlProto is mainly used to build handshake.
   121  	mysqlProto *frontend.MysqlProtocolImpl
   122  	// handshakePack is a cached info, used in connection migration.
   123  	// When connection is transferred, we use it to rebuild handshake.
   124  	handshakePack *frontend.Packet
   125  	// connID records the connection ID.
   126  	connID uint32
   127  	// clientInfo is the information of the client.
   128  	clientInfo clientInfo
   129  	// haKeeperClient is the client of HAKeeper.
   130  	haKeeperClient logservice.ClusterHAKeeperClient
   131  	// moCluster is the CN server cache, which used to filter CN servers
   132  	// by CN labels.
   133  	moCluster clusterservice.MOCluster
   134  	// Router is used to select and connect to a best CN server.
   135  	router Router
   136  	// tun is the tunnel which this client connection belongs to.
   137  	tun *tunnel
   138  	// tlsConfig is the config of TLS.
   139  	tlsConfig *tls.Config
   140  	// tlsConnectTimeout is the TLS connect timeout value.
   141  	tlsConnectTimeout time.Duration
   142  	// ipNetList is the list of ip net, which is parsed from CIDRs.
   143  	ipNetList []*net.IPNet
   144  	// queryClient is used to send query request to CN servers.
   145  	queryClient qclient.QueryClient
   146  	// testHelper is used for testing.
   147  	testHelper struct {
   148  		connectToBackend func() (ServerConn, error)
   149  	}
   150  	migration migration
   151  }
   152  
   153  // internalStmt is used internally in proxy, which indicates the stmt
   154  // need to execute.
   155  type internalStmt struct {
   156  	cmdType MySQLCmd
   157  	s       string
   158  }
   159  
   160  var _ ClientConn = (*clientConn)(nil)
   161  
   162  // newClientConn creates a new client connection.
   163  func newClientConn(
   164  	ctx context.Context,
   165  	cfg *Config,
   166  	logger *log.MOLogger,
   167  	cs *counterSet,
   168  	conn goetty.IOSession,
   169  	haKeeperClient logservice.ClusterHAKeeperClient,
   170  	mc clusterservice.MOCluster,
   171  	router Router,
   172  	tun *tunnel,
   173  	ipNetList []*net.IPNet,
   174  ) (ClientConn, error) {
   175  	var originIP net.IP
   176  	var port int
   177  	host, portStr, err := net.SplitHostPort(conn.RemoteAddress())
   178  	if err == nil {
   179  		originIP = net.ParseIP(host)
   180  		port, _ = strconv.Atoi(portStr)
   181  	}
   182  	qc, err := qclient.NewQueryClient(cfg.UUID, morpc.Config{})
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  	c := &clientConn{
   187  		ctx:            ctx,
   188  		counterSet:     cs,
   189  		conn:           conn,
   190  		haKeeperClient: haKeeperClient,
   191  		moCluster:      mc,
   192  		router:         router,
   193  		tun:            tun,
   194  		clientInfo: clientInfo{
   195  			originIP:   originIP,
   196  			originPort: uint16(port),
   197  		},
   198  		ipNetList: ipNetList,
   199  		// set the connection timeout value.
   200  		tlsConnectTimeout: cfg.TLSConnectTimeout.Duration,
   201  		queryClient:       qc,
   202  	}
   203  	c.connID, err = c.genConnID()
   204  	if err != nil {
   205  		return nil, err
   206  	}
   207  	c.log = logger.With(zap.Uint32("ConnID", c.connID))
   208  	fp := config.FrontendParameters{
   209  		EnableTls: cfg.TLSEnabled,
   210  	}
   211  	fp.SetDefaultValues()
   212  	c.mysqlProto = frontend.NewMysqlClientProtocol(c.connID, c.conn, 0, &fp)
   213  	if cfg.TLSEnabled {
   214  		tlsConfig, err := frontend.ConstructTLSConfig(
   215  			ctx, cfg.TLSCAFile, cfg.TLSCertFile, cfg.TLSKeyFile)
   216  		if err != nil {
   217  			return nil, err
   218  		}
   219  		c.tlsConfig = tlsConfig
   220  	}
   221  	return c, nil
   222  }
   223  
   224  // ConnID implements the ClientConn interface.
   225  func (c *clientConn) ConnID() uint32 {
   226  	return c.connID
   227  }
   228  
   229  // GetSalt implements the ClientConn interface.
   230  func (c *clientConn) GetSalt() []byte {
   231  	return c.mysqlProto.GetSalt()
   232  }
   233  
   234  // GetHandshakePack implements the ClientConn interface.
   235  func (c *clientConn) GetHandshakePack() *frontend.Packet {
   236  	return c.handshakePack
   237  }
   238  
   239  // RawConn implements the ClientConn interface.
   240  func (c *clientConn) RawConn() net.Conn {
   241  	if c != nil {
   242  		return c.conn.RawConn()
   243  	}
   244  	return nil
   245  }
   246  
   247  // GetTenant implements the ClientConn interface.
   248  func (c *clientConn) GetTenant() Tenant {
   249  	if c != nil {
   250  		return c.clientInfo.Tenant
   251  	}
   252  	return EmptyTenant
   253  }
   254  
   255  // SendErrToClient implements the ClientConn interface.
   256  func (c *clientConn) SendErrToClient(err error) {
   257  	errorCode, sqlState, msg := frontend.RewriteError(err, "")
   258  	p := c.mysqlProto.MakeErrPayload(errorCode, sqlState, msg)
   259  	if err := c.mysqlProto.WritePacket(p); err != nil {
   260  		c.log.Error("failed to send access error to client", zap.Error(err))
   261  	}
   262  }
   263  
   264  // BuildConnWithServer implements the ClientConn interface.
   265  func (c *clientConn) BuildConnWithServer(prevAddr string) (ServerConn, error) {
   266  	if prevAddr == "" {
   267  		// Step 1, proxy write initial handshake to client.
   268  		if err := c.writeInitialHandshake(); err != nil {
   269  			c.log.Debug("failed to write Handshake packet", zap.Error(err))
   270  			return nil, err
   271  		}
   272  		// Step 2, client send handshake response, which is auth request,
   273  		// to proxy.
   274  		if err := c.handleHandshakeResp(); err != nil {
   275  			c.log.Error("failed to handle Handshake response", zap.Error(err))
   276  			return nil, err
   277  		}
   278  	}
   279  	// Step 3, proxy connects to a CN server to build connection.
   280  	conn, err := c.connectToBackend(prevAddr)
   281  	if err != nil {
   282  		c.log.Error("failed to connect to backend", zap.Error(err))
   283  		return nil, err
   284  	}
   285  	return conn, nil
   286  }
   287  
   288  // HandleEvent implements the ClientConn interface.
   289  func (c *clientConn) HandleEvent(ctx context.Context, e IEvent, resp chan<- []byte) error {
   290  	switch ev := e.(type) {
   291  	case *killQueryEvent:
   292  		return c.handleKillQuery(ev, resp)
   293  	case *setVarEvent:
   294  		return c.handleSetVar(ev)
   295  	default:
   296  	}
   297  	return nil
   298  }
   299  
   300  func (c *clientConn) sendErr(err error, resp chan<- []byte) {
   301  	errCode, sqlState, errMsg := frontend.RewriteError(err, "")
   302  	payload := c.mysqlProto.MakeErrPayload(
   303  		errCode, sqlState, errMsg)
   304  	r := &frontend.Packet{
   305  		Length:     0,
   306  		SequenceID: 1,
   307  		Payload:    payload,
   308  	}
   309  	sendResp(packetToBytes(r), resp)
   310  }
   311  
   312  func (c *clientConn) connAndExec(cn *CNServer, stmt string, resp chan<- []byte) error {
   313  	sc, r, err := c.router.Connect(cn, c.handshakePack, c.tun)
   314  	if err != nil {
   315  		c.log.Error("failed to connect to backend server", zap.Error(err))
   316  		if resp != nil {
   317  			c.sendErr(err, resp)
   318  		}
   319  		return err
   320  	}
   321  	defer func() { _ = sc.Close() }()
   322  
   323  	if !isOKPacket(r) {
   324  		c.log.Error("failed to connect to cn to handle event",
   325  			zap.String("query", stmt), zap.String("error", string(r)))
   326  		if resp != nil {
   327  			sendResp(r, resp)
   328  		}
   329  		return moerr.NewInternalErrorNoCtx("access error")
   330  	}
   331  
   332  	ok, err := sc.ExecStmt(internalStmt{cmdType: cmdQuery, s: stmt}, resp)
   333  	if err != nil {
   334  		c.log.Error("failed to send query to server",
   335  			zap.String("query", stmt), zap.Error(err))
   336  		return err
   337  	}
   338  	if !ok {
   339  		return moerr.NewInternalErrorNoCtx("exec error")
   340  	}
   341  	return nil
   342  }
   343  
   344  // handleKillQuery handles the kill query event.
   345  func (c *clientConn) handleKillQuery(e *killQueryEvent, resp chan<- []byte) error {
   346  	cn, err := c.router.SelectByConnID(e.connID)
   347  	if err != nil {
   348  		// If no server found, means that the query has been terminated.
   349  		if errors.Is(err, noCNServerErr) {
   350  			sendResp(makeOKPacket(8), resp)
   351  			return nil
   352  		}
   353  		c.log.Error("failed to select CN server", zap.Error(err))
   354  		c.sendErr(err, resp)
   355  		return err
   356  	}
   357  	// Before connect to backend server, update the salt.
   358  	cn.salt = c.mysqlProto.GetSalt()
   359  
   360  	return c.connAndExec(cn, fmt.Sprintf("KILL QUERY %d", cn.connID), resp)
   361  }
   362  
   363  // handleSetVar handles the set variable event.
   364  func (c *clientConn) handleSetVar(e *setVarEvent) error {
   365  	c.migration.setVarStmts = append(c.migration.setVarStmts, e.stmt)
   366  	return nil
   367  }
   368  
   369  // Close implements the ClientConn interface.
   370  func (c *clientConn) Close() error {
   371  	return c.queryClient.Close()
   372  }
   373  
   374  // connectToBackend connect to the real CN server.
   375  func (c *clientConn) connectToBackend(prevAdd string) (ServerConn, error) {
   376  	// Testing path.
   377  	if c.testHelper.connectToBackend != nil {
   378  		return c.testHelper.connectToBackend()
   379  	}
   380  
   381  	if c.router == nil {
   382  		v2.ProxyConnectCommonFailCounter.Inc()
   383  		return nil, moerr.NewInternalErrorNoCtx("no router available")
   384  	}
   385  
   386  	badCNServers := make(map[string]struct{})
   387  	if prevAdd != "" {
   388  		badCNServers[prevAdd] = struct{}{}
   389  	}
   390  	filterFn := func(str string) bool {
   391  		if _, ok := badCNServers[str]; ok {
   392  			return true
   393  		}
   394  		return false
   395  	}
   396  
   397  	var err error
   398  	var cn *CNServer
   399  	var sc ServerConn
   400  	var r []byte
   401  	for {
   402  		// Select the best CN server from backend.
   403  		//
   404  		// NB: The selected CNServer must have label hash in it.
   405  		cn, err = c.router.Route(c.ctx, c.clientInfo, filterFn)
   406  		if err != nil {
   407  			v2.ProxyConnectRouteFailCounter.Inc()
   408  			c.log.Error("route failed", zap.Error(err))
   409  			return nil, err
   410  		}
   411  		// We have to set connection ID after cn is returned.
   412  		cn.connID = c.connID
   413  
   414  		// Set the salt value of cn server.
   415  		cn.salt = c.mysqlProto.GetSalt()
   416  
   417  		// Update the internal connection.
   418  		cn.internalConn = containIP(c.ipNetList, c.clientInfo.originIP)
   419  		cn.clientAddr = fmt.Sprintf("%s:%d", c.clientInfo.originIP.String(), c.clientInfo.originPort)
   420  
   421  		// After select a CN server, we try to connect to it. If connect
   422  		// fails, and it is a retryable error, we reselect another CN server.
   423  		sc, r, err = c.router.Connect(cn, c.handshakePack, c.tun)
   424  		if err != nil {
   425  			if isRetryableErr(err) {
   426  				v2.ProxyConnectRetryCounter.Inc()
   427  				badCNServers[cn.addr] = struct{}{}
   428  				c.log.Warn("failed to connect to CN server, will retry",
   429  					zap.String("current server uuid", cn.uuid),
   430  					zap.String("current server address", cn.addr),
   431  					zap.Any("bad backend servers", badCNServers),
   432  					zap.String("client->proxy",
   433  						fmt.Sprintf("%s -> %s", c.RawConn().RemoteAddr(),
   434  							c.RawConn().LocalAddr())),
   435  					zap.Error(err),
   436  				)
   437  				continue
   438  			} else {
   439  				v2.ProxyConnectCommonFailCounter.Inc()
   440  				c.log.Error("failed to connect to CN server, cannot retry", zap.Error(err))
   441  				return nil, err
   442  			}
   443  		}
   444  
   445  		if prevAdd == "" {
   446  			// r is the packet received from CN server, send r to client.
   447  			if err := c.mysqlProto.WritePacket(r[4:]); err != nil {
   448  				c.log.Error("failed to write packet to client", zap.Error(err))
   449  				v2.ProxyConnectCommonFailCounter.Inc()
   450  				closeErr := sc.Close()
   451  				if closeErr != nil {
   452  					c.log.Error("failed to close server connection", zap.Error(closeErr))
   453  				}
   454  				return nil, err
   455  			}
   456  		} else {
   457  			// The connection has been transferred to a new server, but migration fails,
   458  			// but we don't return error, which will cause unknown issue.
   459  			if err := c.migrateConn(prevAdd, sc); err != nil {
   460  				closeErr := sc.Close()
   461  				if closeErr != nil {
   462  					c.log.Error("failed to close server connection", zap.Error(closeErr))
   463  				}
   464  				c.log.Error("failed to migrate connection to cn, will retry",
   465  					zap.Uint32("conn ID", c.connID),
   466  					zap.String("current uuid", cn.uuid),
   467  					zap.String("current addr", cn.addr),
   468  					zap.Any("bad backend servers", badCNServers),
   469  					zap.Error(err),
   470  				)
   471  				badCNServers[cn.addr] = struct{}{}
   472  				continue
   473  			}
   474  		}
   475  
   476  		// connection to cn server successfully.
   477  		break
   478  	}
   479  	if !isOKPacket(r) {
   480  		c.log.Error("response is not OK", zap.Any("packet", err))
   481  		// If we do not close here, there will be a lot of unused connections
   482  		// in connManager.
   483  		if sc != nil {
   484  			if closeErr := sc.Close(); closeErr != nil {
   485  				c.log.Error("failed to close server connection", zap.Error(closeErr))
   486  			}
   487  		}
   488  		v2.ProxyConnectCommonFailCounter.Inc()
   489  		return nil, withCode(moerr.NewInternalErrorNoCtx("access error"),
   490  			codeAuthFailed)
   491  	}
   492  	v2.ProxyConnectSuccessCounter.Inc()
   493  	return sc, nil
   494  }
   495  
   496  // readPacket reads MySQL packets from clients. It is mainly used in
   497  // handshake phase.
   498  func (c *clientConn) readPacket() (*frontend.Packet, error) {
   499  	msg, err := c.conn.Read(goetty.ReadOptions{})
   500  	if err != nil {
   501  		return nil, err
   502  	}
   503  	if proxyAddr, ok := msg.(*ProxyAddr); ok {
   504  		if proxyAddr.SourceAddress != nil {
   505  			c.clientInfo.originIP = proxyAddr.SourceAddress
   506  			c.clientInfo.originPort = proxyAddr.SourcePort
   507  		}
   508  		return c.readPacket()
   509  	}
   510  	packet, ok := msg.(*frontend.Packet)
   511  	if !ok {
   512  		return nil, moerr.NewInternalError(c.ctx, "message is not a Packet")
   513  	}
   514  	return packet, nil
   515  }
   516  
   517  // nextClientConnID increases baseConnID by 1 and returns the result.
   518  func nextClientConnID() uint32 {
   519  	return atomic.AddUint32(&clientBaseConnID, 1)
   520  }
   521  
   522  // genConnID is used to generate globally unique connection ID.
   523  func (c *clientConn) genConnID() (uint32, error) {
   524  	if c.haKeeperClient == nil {
   525  		return nextClientConnID(), nil
   526  	}
   527  	ctx, cancel := context.WithTimeout(c.ctx, time.Second*3)
   528  	defer cancel()
   529  	// Use the same key with frontend module to make sure the connection ID
   530  	// is unique globally.
   531  	connID, err := c.haKeeperClient.AllocateIDByKey(ctx, frontend.ConnIDAllocKey)
   532  	if err != nil {
   533  		return 0, err
   534  	}
   535  	// Convert uint64 to uint32 to adapt MySQL protocol.
   536  	return uint32(connID), nil
   537  }