github.com/dolthub/go-mysql-server@v0.18.0/server/context.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package server
    16  
    17  import (
    18  	"context"
    19  	"sync"
    20  	"time"
    21  
    22  	"github.com/dolthub/vitess/go/mysql"
    23  	"github.com/sirupsen/logrus"
    24  	"go.opentelemetry.io/otel/trace"
    25  
    26  	"github.com/dolthub/go-mysql-server/sql"
    27  	"github.com/dolthub/go-mysql-server/sql/mysql_db"
    28  )
    29  
    30  // SessionBuilder creates sessions given a MySQL connection and a server address.
    31  type SessionBuilder func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error)
    32  
    33  // DoneFunc is a function that must be executed when the session is used and
    34  // it can be disposed.
    35  type DoneFunc func()
    36  
    37  // DefaultSessionBuilder is a SessionBuilder that returns a base session.
    38  func DefaultSessionBuilder(ctx context.Context, c *mysql.Conn, addr string) (sql.Session, error) {
    39  	host := ""
    40  	user := ""
    41  	mysqlConnectionUser, ok := c.UserData.(mysql_db.MysqlConnectionUser)
    42  	if ok {
    43  		host = mysqlConnectionUser.Host
    44  		user = mysqlConnectionUser.User
    45  	}
    46  	client := sql.Client{Address: host, User: user, Capabilities: c.Capabilities}
    47  	return sql.NewBaseSessionWithClientServer(addr, client, c.ConnectionID), nil
    48  }
    49  
    50  // SessionManager is in charge of creating new sessions for the given
    51  // connections and keep track of which sessions are in each connection, so
    52  // they can be cancelled if the connection is closed.
    53  type SessionManager struct {
    54  	addr        string
    55  	tracer      trace.Tracer
    56  	getDbFunc   func(ctx *sql.Context, db string) (sql.Database, error)
    57  	memory      *sql.MemoryManager
    58  	processlist sql.ProcessList
    59  	mu          *sync.Mutex
    60  	builder     SessionBuilder
    61  	sessions    map[uint32]sql.Session
    62  	connections map[uint32]*mysql.Conn
    63  	lastPid     uint64
    64  }
    65  
    66  // NewSessionManager creates a SessionManager with the given SessionBuilder.
    67  func NewSessionManager(
    68  	builder SessionBuilder,
    69  	tracer trace.Tracer,
    70  	getDbFunc func(ctx *sql.Context, db string) (sql.Database, error),
    71  	memory *sql.MemoryManager,
    72  	processlist sql.ProcessList,
    73  	addr string,
    74  ) *SessionManager {
    75  	return &SessionManager{
    76  		addr:        addr,
    77  		tracer:      tracer,
    78  		getDbFunc:   getDbFunc,
    79  		memory:      memory,
    80  		processlist: processlist,
    81  		mu:          new(sync.Mutex),
    82  		builder:     builder,
    83  		sessions:    make(map[uint32]sql.Session),
    84  		connections: make(map[uint32]*mysql.Conn),
    85  	}
    86  }
    87  
    88  func (s *SessionManager) nextPid() uint64 {
    89  	s.mu.Lock()
    90  	defer s.mu.Unlock()
    91  	s.lastPid++
    92  	return s.lastPid
    93  }
    94  
    95  // Add a connection to be tracked by the SessionManager. Should be called as
    96  // soon as possible after the server has accepted the connection. Results in
    97  // the connection being tracked by ProcessList and being available through
    98  // KillConnection. The connection will be tracked until RemoveConn is called,
    99  // so clients should ensure a call to AddConn is always paired up with a call
   100  // to RemoveConn.
   101  func (s *SessionManager) AddConn(conn *mysql.Conn) {
   102  	s.mu.Lock()
   103  	defer s.mu.Unlock()
   104  	s.connections[conn.ConnectionID] = conn
   105  	s.processlist.AddConnection(conn.ConnectionID, conn.RemoteAddr().String())
   106  }
   107  
   108  // NewSession creates a Session for the given connection and saves it to the session pool.
   109  func (s *SessionManager) NewSession(ctx context.Context, conn *mysql.Conn) error {
   110  	s.mu.Lock()
   111  	defer s.mu.Unlock()
   112  	session, err := s.builder(ctx, conn, s.addr)
   113  	if err != nil {
   114  		return err
   115  	}
   116  
   117  	session.SetConnectionId(conn.ConnectionID)
   118  
   119  	s.sessions[conn.ConnectionID] = session
   120  
   121  	logger := session.GetLogger()
   122  	if logger == nil {
   123  		log := logrus.StandardLogger()
   124  		logger = logrus.NewEntry(log)
   125  	}
   126  
   127  	session.SetLogger(
   128  		logger.WithField(sql.ConnectionIdLogField, conn.ConnectionID).
   129  			WithField(sql.ConnectTimeLogKey, time.Now()),
   130  	)
   131  
   132  	return err
   133  }
   134  
   135  func (s *SessionManager) SetDB(conn *mysql.Conn, dbName string) error {
   136  	sess, err := s.getOrCreateSession(context.Background(), conn)
   137  	if err != nil {
   138  		return err
   139  	}
   140  
   141  	ctx := sql.NewContext(context.Background(), sql.WithSession(sess))
   142  	var db sql.Database
   143  	if dbName != "" {
   144  		db, err = s.getDbFunc(ctx, dbName)
   145  		if err != nil {
   146  			return err
   147  		}
   148  	}
   149  
   150  	sess.SetCurrentDatabase(dbName)
   151  	if dbName != "" {
   152  		if pdb, ok := db.(mysql_db.PrivilegedDatabase); ok {
   153  			db = pdb.Unwrap()
   154  		}
   155  		err = sess.UseDatabase(ctx, db)
   156  		if err != nil {
   157  			return err
   158  		}
   159  	}
   160  
   161  	s.processlist.ConnectionReady(sess)
   162  	return nil
   163  }
   164  
   165  // Iter iterates over the active sessions and executes the specified callback function on each one.
   166  func (s *SessionManager) Iter(f func(session sql.Session) (stop bool, err error)) error {
   167  	// Lock the mutex guarding the sessions map while we make a copy of it to prevent errors from
   168  	// mutating a map while iterating over it. Making a copy of the map also allows us to guard
   169  	// against long running callback functions being passed in that could cause long mutex blocking.
   170  	s.mu.Lock()
   171  	sessions := make([]sql.Session, 0, len(s.sessions))
   172  	for _, value := range s.sessions {
   173  		sessions = append(sessions, value)
   174  	}
   175  	s.mu.Unlock()
   176  
   177  	for _, sess := range sessions {
   178  		stop, err := f(sess)
   179  		if stop == true || err != nil {
   180  			return err
   181  		}
   182  	}
   183  	return nil
   184  }
   185  
   186  func (s *SessionManager) session(conn *mysql.Conn) sql.Session {
   187  	s.mu.Lock()
   188  	defer s.mu.Unlock()
   189  	return s.sessions[conn.ConnectionID]
   190  }
   191  
   192  // NewContext creates a new context for the session at the given conn.
   193  func (s *SessionManager) NewContext(conn *mysql.Conn) (*sql.Context, error) {
   194  	return s.NewContextWithQuery(conn, "")
   195  }
   196  
   197  func (s *SessionManager) getOrCreateSession(ctx context.Context, conn *mysql.Conn) (sql.Session, error) {
   198  	s.mu.Lock()
   199  	sess, ok := s.sessions[conn.ConnectionID]
   200  	// Release this lock immediately. If we call NewSession below, we
   201  	// cannot hold the lock. We will relock if we need to.
   202  	s.mu.Unlock()
   203  
   204  	if !ok {
   205  		err := s.NewSession(ctx, conn)
   206  		if err != nil {
   207  			return nil, err
   208  		}
   209  
   210  		s.mu.Lock()
   211  		sess = s.sessions[conn.ConnectionID]
   212  		s.mu.Unlock()
   213  	}
   214  
   215  	return sess, nil
   216  }
   217  
   218  // NewContextWithQuery creates a new context for the session at the given conn.
   219  func (s *SessionManager) NewContextWithQuery(conn *mysql.Conn, query string) (*sql.Context, error) {
   220  	ctx := context.Background()
   221  	sess, err := s.getOrCreateSession(ctx, conn)
   222  
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  
   227  	ctx, span := s.tracer.Start(ctx, "query")
   228  
   229  	context := sql.NewContext(
   230  		ctx,
   231  		sql.WithSession(sess),
   232  		sql.WithTracer(s.tracer),
   233  		sql.WithPid(s.nextPid()),
   234  		sql.WithQuery(query),
   235  		sql.WithMemoryManager(s.memory),
   236  		sql.WithProcessList(s.processlist),
   237  		sql.WithRootSpan(span),
   238  		sql.WithServices(sql.Services{
   239  			KillConnection: s.KillConnection,
   240  			LoadInfile:     conn.LoadInfile,
   241  		}),
   242  	)
   243  
   244  	return context, nil
   245  }
   246  
   247  // Exposed through (*sql.Context).Services.KillConnection. Calls Close on the
   248  // tracked connection with |connID|. The full teardown of the connection is
   249  // asychronous, similar to how |Process.Kill| for tearing down an inflight
   250  // query is asynchronous. The connection and any running query will remain in
   251  // the ProcessList and in the SessionManager until it has been torn down by the
   252  // server handler.
   253  func (s *SessionManager) KillConnection(connID uint32) error {
   254  	s.mu.Lock()
   255  	defer s.mu.Unlock()
   256  	if conn, ok := s.connections[connID]; ok {
   257  		conn.Close()
   258  	}
   259  	return nil
   260  }
   261  
   262  // Remove the session assosiated with |conn| from the session manager.
   263  func (s *SessionManager) RemoveConn(conn *mysql.Conn) {
   264  	s.mu.Lock()
   265  	defer s.mu.Unlock()
   266  	delete(s.sessions, conn.ConnectionID)
   267  	delete(s.connections, conn.ConnectionID)
   268  	s.processlist.RemoveConnection(conn.ConnectionID)
   269  }