github.com/dolthub/go-mysql-server@v0.18.0/server/context.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package server 16 17 import ( 18 "context" 19 "sync" 20 "time" 21 22 "github.com/dolthub/vitess/go/mysql" 23 "github.com/sirupsen/logrus" 24 "go.opentelemetry.io/otel/trace" 25 26 "github.com/dolthub/go-mysql-server/sql" 27 "github.com/dolthub/go-mysql-server/sql/mysql_db" 28 ) 29 30 // SessionBuilder creates sessions given a MySQL connection and a server address. 31 type SessionBuilder func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) 32 33 // DoneFunc is a function that must be executed when the session is used and 34 // it can be disposed. 35 type DoneFunc func() 36 37 // DefaultSessionBuilder is a SessionBuilder that returns a base session. 38 func DefaultSessionBuilder(ctx context.Context, c *mysql.Conn, addr string) (sql.Session, error) { 39 host := "" 40 user := "" 41 mysqlConnectionUser, ok := c.UserData.(mysql_db.MysqlConnectionUser) 42 if ok { 43 host = mysqlConnectionUser.Host 44 user = mysqlConnectionUser.User 45 } 46 client := sql.Client{Address: host, User: user, Capabilities: c.Capabilities} 47 return sql.NewBaseSessionWithClientServer(addr, client, c.ConnectionID), nil 48 } 49 50 // SessionManager is in charge of creating new sessions for the given 51 // connections and keep track of which sessions are in each connection, so 52 // they can be cancelled if the connection is closed. 53 type SessionManager struct { 54 addr string 55 tracer trace.Tracer 56 getDbFunc func(ctx *sql.Context, db string) (sql.Database, error) 57 memory *sql.MemoryManager 58 processlist sql.ProcessList 59 mu *sync.Mutex 60 builder SessionBuilder 61 sessions map[uint32]sql.Session 62 connections map[uint32]*mysql.Conn 63 lastPid uint64 64 } 65 66 // NewSessionManager creates a SessionManager with the given SessionBuilder. 67 func NewSessionManager( 68 builder SessionBuilder, 69 tracer trace.Tracer, 70 getDbFunc func(ctx *sql.Context, db string) (sql.Database, error), 71 memory *sql.MemoryManager, 72 processlist sql.ProcessList, 73 addr string, 74 ) *SessionManager { 75 return &SessionManager{ 76 addr: addr, 77 tracer: tracer, 78 getDbFunc: getDbFunc, 79 memory: memory, 80 processlist: processlist, 81 mu: new(sync.Mutex), 82 builder: builder, 83 sessions: make(map[uint32]sql.Session), 84 connections: make(map[uint32]*mysql.Conn), 85 } 86 } 87 88 func (s *SessionManager) nextPid() uint64 { 89 s.mu.Lock() 90 defer s.mu.Unlock() 91 s.lastPid++ 92 return s.lastPid 93 } 94 95 // Add a connection to be tracked by the SessionManager. Should be called as 96 // soon as possible after the server has accepted the connection. Results in 97 // the connection being tracked by ProcessList and being available through 98 // KillConnection. The connection will be tracked until RemoveConn is called, 99 // so clients should ensure a call to AddConn is always paired up with a call 100 // to RemoveConn. 101 func (s *SessionManager) AddConn(conn *mysql.Conn) { 102 s.mu.Lock() 103 defer s.mu.Unlock() 104 s.connections[conn.ConnectionID] = conn 105 s.processlist.AddConnection(conn.ConnectionID, conn.RemoteAddr().String()) 106 } 107 108 // NewSession creates a Session for the given connection and saves it to the session pool. 109 func (s *SessionManager) NewSession(ctx context.Context, conn *mysql.Conn) error { 110 s.mu.Lock() 111 defer s.mu.Unlock() 112 session, err := s.builder(ctx, conn, s.addr) 113 if err != nil { 114 return err 115 } 116 117 session.SetConnectionId(conn.ConnectionID) 118 119 s.sessions[conn.ConnectionID] = session 120 121 logger := session.GetLogger() 122 if logger == nil { 123 log := logrus.StandardLogger() 124 logger = logrus.NewEntry(log) 125 } 126 127 session.SetLogger( 128 logger.WithField(sql.ConnectionIdLogField, conn.ConnectionID). 129 WithField(sql.ConnectTimeLogKey, time.Now()), 130 ) 131 132 return err 133 } 134 135 func (s *SessionManager) SetDB(conn *mysql.Conn, dbName string) error { 136 sess, err := s.getOrCreateSession(context.Background(), conn) 137 if err != nil { 138 return err 139 } 140 141 ctx := sql.NewContext(context.Background(), sql.WithSession(sess)) 142 var db sql.Database 143 if dbName != "" { 144 db, err = s.getDbFunc(ctx, dbName) 145 if err != nil { 146 return err 147 } 148 } 149 150 sess.SetCurrentDatabase(dbName) 151 if dbName != "" { 152 if pdb, ok := db.(mysql_db.PrivilegedDatabase); ok { 153 db = pdb.Unwrap() 154 } 155 err = sess.UseDatabase(ctx, db) 156 if err != nil { 157 return err 158 } 159 } 160 161 s.processlist.ConnectionReady(sess) 162 return nil 163 } 164 165 // Iter iterates over the active sessions and executes the specified callback function on each one. 166 func (s *SessionManager) Iter(f func(session sql.Session) (stop bool, err error)) error { 167 // Lock the mutex guarding the sessions map while we make a copy of it to prevent errors from 168 // mutating a map while iterating over it. Making a copy of the map also allows us to guard 169 // against long running callback functions being passed in that could cause long mutex blocking. 170 s.mu.Lock() 171 sessions := make([]sql.Session, 0, len(s.sessions)) 172 for _, value := range s.sessions { 173 sessions = append(sessions, value) 174 } 175 s.mu.Unlock() 176 177 for _, sess := range sessions { 178 stop, err := f(sess) 179 if stop == true || err != nil { 180 return err 181 } 182 } 183 return nil 184 } 185 186 func (s *SessionManager) session(conn *mysql.Conn) sql.Session { 187 s.mu.Lock() 188 defer s.mu.Unlock() 189 return s.sessions[conn.ConnectionID] 190 } 191 192 // NewContext creates a new context for the session at the given conn. 193 func (s *SessionManager) NewContext(conn *mysql.Conn) (*sql.Context, error) { 194 return s.NewContextWithQuery(conn, "") 195 } 196 197 func (s *SessionManager) getOrCreateSession(ctx context.Context, conn *mysql.Conn) (sql.Session, error) { 198 s.mu.Lock() 199 sess, ok := s.sessions[conn.ConnectionID] 200 // Release this lock immediately. If we call NewSession below, we 201 // cannot hold the lock. We will relock if we need to. 202 s.mu.Unlock() 203 204 if !ok { 205 err := s.NewSession(ctx, conn) 206 if err != nil { 207 return nil, err 208 } 209 210 s.mu.Lock() 211 sess = s.sessions[conn.ConnectionID] 212 s.mu.Unlock() 213 } 214 215 return sess, nil 216 } 217 218 // NewContextWithQuery creates a new context for the session at the given conn. 219 func (s *SessionManager) NewContextWithQuery(conn *mysql.Conn, query string) (*sql.Context, error) { 220 ctx := context.Background() 221 sess, err := s.getOrCreateSession(ctx, conn) 222 223 if err != nil { 224 return nil, err 225 } 226 227 ctx, span := s.tracer.Start(ctx, "query") 228 229 context := sql.NewContext( 230 ctx, 231 sql.WithSession(sess), 232 sql.WithTracer(s.tracer), 233 sql.WithPid(s.nextPid()), 234 sql.WithQuery(query), 235 sql.WithMemoryManager(s.memory), 236 sql.WithProcessList(s.processlist), 237 sql.WithRootSpan(span), 238 sql.WithServices(sql.Services{ 239 KillConnection: s.KillConnection, 240 LoadInfile: conn.LoadInfile, 241 }), 242 ) 243 244 return context, nil 245 } 246 247 // Exposed through (*sql.Context).Services.KillConnection. Calls Close on the 248 // tracked connection with |connID|. The full teardown of the connection is 249 // asychronous, similar to how |Process.Kill| for tearing down an inflight 250 // query is asynchronous. The connection and any running query will remain in 251 // the ProcessList and in the SessionManager until it has been torn down by the 252 // server handler. 253 func (s *SessionManager) KillConnection(connID uint32) error { 254 s.mu.Lock() 255 defer s.mu.Unlock() 256 if conn, ok := s.connections[connID]; ok { 257 conn.Close() 258 } 259 return nil 260 } 261 262 // Remove the session assosiated with |conn| from the session manager. 263 func (s *SessionManager) RemoveConn(conn *mysql.Conn) { 264 s.mu.Lock() 265 defer s.mu.Unlock() 266 delete(s.sessions, conn.ConnectionID) 267 delete(s.connections, conn.ConnectionID) 268 s.processlist.RemoveConnection(conn.ConnectionID) 269 }