github.com/dolthub/go-mysql-server@v0.18.0/server/server.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package server
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"net"
    21  	"time"
    22  
    23  	"github.com/dolthub/vitess/go/mysql"
    24  	"github.com/sirupsen/logrus"
    25  	"go.opentelemetry.io/otel/trace"
    26  
    27  	sqle "github.com/dolthub/go-mysql-server"
    28  	"github.com/dolthub/go-mysql-server/server/golden"
    29  	"github.com/dolthub/go-mysql-server/sql"
    30  )
    31  
    32  // ProtocolListener handles connections based on the configuration it was given. These listeners also implement
    33  // their own protocol, which by default will be the MySQL wire protocol, but another protocol may be provided.
    34  type ProtocolListener interface {
    35  	Addr() net.Addr
    36  	Accept()
    37  	Close()
    38  }
    39  
    40  // ProtocolListenerFunc returns a ProtocolListener based on the configuration it was given.
    41  type ProtocolListenerFunc func(cfg mysql.ListenerConfig) (ProtocolListener, error)
    42  
    43  // DefaultProtocolListenerFunc is the protocol listener, which defaults to Vitess' protocol listener. Changing
    44  // this function will change the protocol listener used when creating all servers. If multiple servers are needed
    45  // with different protocols, then create each server after changing this function. Servers retain the protocol that
    46  // they were created with.
    47  var DefaultProtocolListenerFunc ProtocolListenerFunc = func(cfg mysql.ListenerConfig) (ProtocolListener, error) {
    48  	return mysql.NewListenerWithConfig(cfg)
    49  }
    50  
    51  type ServerEventListener interface {
    52  	ClientConnected()
    53  	ClientDisconnected()
    54  	QueryStarted()
    55  	QueryCompleted(success bool, duration time.Duration)
    56  }
    57  
    58  // NewDefaultServer creates a Server with the default session builder.
    59  func NewDefaultServer(cfg Config, e *sqle.Engine) (*Server, error) {
    60  	return NewServer(cfg, e, DefaultSessionBuilder, nil)
    61  }
    62  
    63  // NewServer creates a server with the given protocol, address, authentication
    64  // details given a SQLe engine and a session builder.
    65  func NewServer(cfg Config, e *sqle.Engine, sb SessionBuilder, listener ServerEventListener) (*Server, error) {
    66  	var tracer trace.Tracer
    67  	if cfg.Tracer != nil {
    68  		tracer = cfg.Tracer
    69  	} else {
    70  		tracer = sql.NoopTracer
    71  	}
    72  
    73  	sm := NewSessionManager(sb, tracer, e.Analyzer.Catalog.Database, e.MemoryManager, e.ProcessList, cfg.Address)
    74  	handler := &Handler{
    75  		e:                 e,
    76  		sm:                sm,
    77  		readTimeout:       cfg.ConnReadTimeout,
    78  		disableMultiStmts: cfg.DisableClientMultiStatements,
    79  		maxLoggedQueryLen: cfg.MaxLoggedQueryLen,
    80  		encodeLoggedQuery: cfg.EncodeLoggedQuery,
    81  		sel:               listener,
    82  	}
    83  	//handler = NewHandler_(e, sm, cfg.ConnReadTimeout, cfg.DisableClientMultiStatements, cfg.MaxLoggedQueryLen, cfg.EncodeLoggedQuery, listener)
    84  	return newServerFromHandler(cfg, e, sm, handler)
    85  }
    86  
    87  // NewValidatingServer creates a Server that validates its query results using a MySQL connection
    88  // as a source of golden-value query result sets.
    89  func NewValidatingServer(
    90  	cfg Config,
    91  	e *sqle.Engine,
    92  	sb SessionBuilder,
    93  	listener ServerEventListener,
    94  	mySqlConn string,
    95  ) (*Server, error) {
    96  	var tracer trace.Tracer
    97  	if cfg.Tracer != nil {
    98  		tracer = cfg.Tracer
    99  	} else {
   100  		tracer = sql.NoopTracer
   101  	}
   102  
   103  	sm := NewSessionManager(sb, tracer, e.Analyzer.Catalog.Database, e.MemoryManager, e.ProcessList, cfg.Address)
   104  	h := &Handler{
   105  		e:                 e,
   106  		sm:                sm,
   107  		readTimeout:       cfg.ConnReadTimeout,
   108  		disableMultiStmts: cfg.DisableClientMultiStatements,
   109  		maxLoggedQueryLen: cfg.MaxLoggedQueryLen,
   110  		encodeLoggedQuery: cfg.EncodeLoggedQuery,
   111  		sel:               listener,
   112  	}
   113  
   114  	handler, err := golden.NewValidatingHandler(h, mySqlConn, logrus.StandardLogger())
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  	return newServerFromHandler(cfg, e, sm, handler)
   119  }
   120  
   121  func portInUse(hostPort string) bool {
   122  	timeout := time.Second
   123  	conn, _ := net.DialTimeout("tcp", hostPort, timeout)
   124  	if conn != nil {
   125  		defer conn.Close()
   126  		return true
   127  	}
   128  	return false
   129  }
   130  
   131  func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler) (*Server, error) {
   132  	for _, option := range cfg.Options {
   133  		option(e, sm, handler)
   134  	}
   135  
   136  	if cfg.ConnReadTimeout < 0 {
   137  		cfg.ConnReadTimeout = 0
   138  	}
   139  	if cfg.ConnWriteTimeout < 0 {
   140  		cfg.ConnWriteTimeout = 0
   141  	}
   142  	if cfg.MaxConnections < 0 {
   143  		cfg.MaxConnections = 0
   144  	}
   145  
   146  	l := cfg.Listener
   147  	var unixSocketInUse error
   148  	if l == nil {
   149  		if portInUse(cfg.Address) {
   150  			unixSocketInUse = fmt.Errorf("Port %s already in use.", cfg.Address)
   151  		}
   152  
   153  		var err error
   154  		l, err = NewListener(cfg.Protocol, cfg.Address, cfg.Socket)
   155  		if err != nil {
   156  			if errors.Is(err, UnixSocketInUseError) {
   157  				unixSocketInUse = err
   158  			} else {
   159  				return nil, err
   160  			}
   161  		}
   162  	}
   163  
   164  	listenerCfg := mysql.ListenerConfig{
   165  		Listener:                 l,
   166  		AuthServer:               e.Analyzer.Catalog.MySQLDb,
   167  		Handler:                  handler,
   168  		ConnReadTimeout:          cfg.ConnReadTimeout,
   169  		ConnWriteTimeout:         cfg.ConnWriteTimeout,
   170  		MaxConns:                 cfg.MaxConnections,
   171  		ConnReadBufferSize:       mysql.DefaultConnBufferSize,
   172  		AllowClearTextWithoutTLS: cfg.AllowClearTextWithoutTLS,
   173  	}
   174  	protocolListener, err := DefaultProtocolListenerFunc(listenerCfg)
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  
   179  	if vtListener, ok := protocolListener.(*mysql.Listener); ok {
   180  		if cfg.Version != "" {
   181  			vtListener.ServerVersion = cfg.Version
   182  		}
   183  		vtListener.TLSConfig = cfg.TLSConfig
   184  		vtListener.RequireSecureTransport = cfg.RequireSecureTransport
   185  	}
   186  
   187  	return &Server{
   188  		Listener:   protocolListener,
   189  		handler:    handler,
   190  		sessionMgr: sm,
   191  		Engine:     e,
   192  	}, unixSocketInUse
   193  }
   194  
   195  // Start starts accepting connections on the server.
   196  func (s *Server) Start() error {
   197  	logrus.Infof("Server ready. Accepting connections.")
   198  	s.WarnIfLoadFileInsecure()
   199  	s.Listener.Accept()
   200  	return nil
   201  }
   202  
   203  func (s *Server) WarnIfLoadFileInsecure() {
   204  	_, v, ok := sql.SystemVariables.GetGlobal("secure_file_priv")
   205  	if ok {
   206  		if v == "" {
   207  			logrus.Warn("secure_file_priv is set to \"\", which is insecure.")
   208  			logrus.Warn("Any user with GRANT FILE privileges will be able to read any file which the sql-server process can read.")
   209  			logrus.Warn("Please consider restarting the server with secure_file_priv set to a safe (or non-existent) directory.")
   210  		}
   211  	}
   212  }
   213  
   214  // Close closes the server connection.
   215  func (s *Server) Close() error {
   216  	logrus.Infof("Server closing listener. No longer accepting connections.")
   217  	s.Listener.Close()
   218  	return nil
   219  }
   220  
   221  // SessionManager returns the session manager for this server.
   222  func (s *Server) SessionManager() *SessionManager {
   223  	return s.sessionMgr
   224  }