github.com/dolthub/go-mysql-server@v0.18.0/server/server.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package server 16 17 import ( 18 "errors" 19 "fmt" 20 "net" 21 "time" 22 23 "github.com/dolthub/vitess/go/mysql" 24 "github.com/sirupsen/logrus" 25 "go.opentelemetry.io/otel/trace" 26 27 sqle "github.com/dolthub/go-mysql-server" 28 "github.com/dolthub/go-mysql-server/server/golden" 29 "github.com/dolthub/go-mysql-server/sql" 30 ) 31 32 // ProtocolListener handles connections based on the configuration it was given. These listeners also implement 33 // their own protocol, which by default will be the MySQL wire protocol, but another protocol may be provided. 34 type ProtocolListener interface { 35 Addr() net.Addr 36 Accept() 37 Close() 38 } 39 40 // ProtocolListenerFunc returns a ProtocolListener based on the configuration it was given. 41 type ProtocolListenerFunc func(cfg mysql.ListenerConfig) (ProtocolListener, error) 42 43 // DefaultProtocolListenerFunc is the protocol listener, which defaults to Vitess' protocol listener. Changing 44 // this function will change the protocol listener used when creating all servers. If multiple servers are needed 45 // with different protocols, then create each server after changing this function. Servers retain the protocol that 46 // they were created with. 47 var DefaultProtocolListenerFunc ProtocolListenerFunc = func(cfg mysql.ListenerConfig) (ProtocolListener, error) { 48 return mysql.NewListenerWithConfig(cfg) 49 } 50 51 type ServerEventListener interface { 52 ClientConnected() 53 ClientDisconnected() 54 QueryStarted() 55 QueryCompleted(success bool, duration time.Duration) 56 } 57 58 // NewDefaultServer creates a Server with the default session builder. 59 func NewDefaultServer(cfg Config, e *sqle.Engine) (*Server, error) { 60 return NewServer(cfg, e, DefaultSessionBuilder, nil) 61 } 62 63 // NewServer creates a server with the given protocol, address, authentication 64 // details given a SQLe engine and a session builder. 65 func NewServer(cfg Config, e *sqle.Engine, sb SessionBuilder, listener ServerEventListener) (*Server, error) { 66 var tracer trace.Tracer 67 if cfg.Tracer != nil { 68 tracer = cfg.Tracer 69 } else { 70 tracer = sql.NoopTracer 71 } 72 73 sm := NewSessionManager(sb, tracer, e.Analyzer.Catalog.Database, e.MemoryManager, e.ProcessList, cfg.Address) 74 handler := &Handler{ 75 e: e, 76 sm: sm, 77 readTimeout: cfg.ConnReadTimeout, 78 disableMultiStmts: cfg.DisableClientMultiStatements, 79 maxLoggedQueryLen: cfg.MaxLoggedQueryLen, 80 encodeLoggedQuery: cfg.EncodeLoggedQuery, 81 sel: listener, 82 } 83 //handler = NewHandler_(e, sm, cfg.ConnReadTimeout, cfg.DisableClientMultiStatements, cfg.MaxLoggedQueryLen, cfg.EncodeLoggedQuery, listener) 84 return newServerFromHandler(cfg, e, sm, handler) 85 } 86 87 // NewValidatingServer creates a Server that validates its query results using a MySQL connection 88 // as a source of golden-value query result sets. 89 func NewValidatingServer( 90 cfg Config, 91 e *sqle.Engine, 92 sb SessionBuilder, 93 listener ServerEventListener, 94 mySqlConn string, 95 ) (*Server, error) { 96 var tracer trace.Tracer 97 if cfg.Tracer != nil { 98 tracer = cfg.Tracer 99 } else { 100 tracer = sql.NoopTracer 101 } 102 103 sm := NewSessionManager(sb, tracer, e.Analyzer.Catalog.Database, e.MemoryManager, e.ProcessList, cfg.Address) 104 h := &Handler{ 105 e: e, 106 sm: sm, 107 readTimeout: cfg.ConnReadTimeout, 108 disableMultiStmts: cfg.DisableClientMultiStatements, 109 maxLoggedQueryLen: cfg.MaxLoggedQueryLen, 110 encodeLoggedQuery: cfg.EncodeLoggedQuery, 111 sel: listener, 112 } 113 114 handler, err := golden.NewValidatingHandler(h, mySqlConn, logrus.StandardLogger()) 115 if err != nil { 116 return nil, err 117 } 118 return newServerFromHandler(cfg, e, sm, handler) 119 } 120 121 func portInUse(hostPort string) bool { 122 timeout := time.Second 123 conn, _ := net.DialTimeout("tcp", hostPort, timeout) 124 if conn != nil { 125 defer conn.Close() 126 return true 127 } 128 return false 129 } 130 131 func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler) (*Server, error) { 132 for _, option := range cfg.Options { 133 option(e, sm, handler) 134 } 135 136 if cfg.ConnReadTimeout < 0 { 137 cfg.ConnReadTimeout = 0 138 } 139 if cfg.ConnWriteTimeout < 0 { 140 cfg.ConnWriteTimeout = 0 141 } 142 if cfg.MaxConnections < 0 { 143 cfg.MaxConnections = 0 144 } 145 146 l := cfg.Listener 147 var unixSocketInUse error 148 if l == nil { 149 if portInUse(cfg.Address) { 150 unixSocketInUse = fmt.Errorf("Port %s already in use.", cfg.Address) 151 } 152 153 var err error 154 l, err = NewListener(cfg.Protocol, cfg.Address, cfg.Socket) 155 if err != nil { 156 if errors.Is(err, UnixSocketInUseError) { 157 unixSocketInUse = err 158 } else { 159 return nil, err 160 } 161 } 162 } 163 164 listenerCfg := mysql.ListenerConfig{ 165 Listener: l, 166 AuthServer: e.Analyzer.Catalog.MySQLDb, 167 Handler: handler, 168 ConnReadTimeout: cfg.ConnReadTimeout, 169 ConnWriteTimeout: cfg.ConnWriteTimeout, 170 MaxConns: cfg.MaxConnections, 171 ConnReadBufferSize: mysql.DefaultConnBufferSize, 172 AllowClearTextWithoutTLS: cfg.AllowClearTextWithoutTLS, 173 } 174 protocolListener, err := DefaultProtocolListenerFunc(listenerCfg) 175 if err != nil { 176 return nil, err 177 } 178 179 if vtListener, ok := protocolListener.(*mysql.Listener); ok { 180 if cfg.Version != "" { 181 vtListener.ServerVersion = cfg.Version 182 } 183 vtListener.TLSConfig = cfg.TLSConfig 184 vtListener.RequireSecureTransport = cfg.RequireSecureTransport 185 } 186 187 return &Server{ 188 Listener: protocolListener, 189 handler: handler, 190 sessionMgr: sm, 191 Engine: e, 192 }, unixSocketInUse 193 } 194 195 // Start starts accepting connections on the server. 196 func (s *Server) Start() error { 197 logrus.Infof("Server ready. Accepting connections.") 198 s.WarnIfLoadFileInsecure() 199 s.Listener.Accept() 200 return nil 201 } 202 203 func (s *Server) WarnIfLoadFileInsecure() { 204 _, v, ok := sql.SystemVariables.GetGlobal("secure_file_priv") 205 if ok { 206 if v == "" { 207 logrus.Warn("secure_file_priv is set to \"\", which is insecure.") 208 logrus.Warn("Any user with GRANT FILE privileges will be able to read any file which the sql-server process can read.") 209 logrus.Warn("Please consider restarting the server with secure_file_priv set to a safe (or non-existent) directory.") 210 } 211 } 212 } 213 214 // Close closes the server connection. 215 func (s *Server) Close() error { 216 logrus.Infof("Server closing listener. No longer accepting connections.") 217 s.Listener.Close() 218 return nil 219 } 220 221 // SessionManager returns the session manager for this server. 222 func (s *Server) SessionManager() *SessionManager { 223 return s.sessionMgr 224 }