github.com/dolthub/go-mysql-server@v0.18.0/sql/base_session.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sql
    16  
    17  import (
    18  	"strings"
    19  	"sync"
    20  	"sync/atomic"
    21  
    22  	"github.com/sirupsen/logrus"
    23  )
    24  
    25  // BaseSession is the basic session implementation. Integrators should typically embed this type into their custom
    26  // session implementations to get base functionality.
    27  type BaseSession struct {
    28  	id     uint32
    29  	addr   string
    30  	client Client
    31  
    32  	// TODO(andy): in principle, we shouldn't
    33  	//   have concurrent access to the session.
    34  	//   Needs investigation.
    35  	mu sync.RWMutex
    36  
    37  	// |mu| protects the following state
    38  	logger           *logrus.Entry
    39  	currentDB        string
    40  	transactionDb    string
    41  	systemVars       map[string]SystemVarValue
    42  	userVars         SessionUserVariables
    43  	idxReg           *IndexRegistry
    44  	viewReg          *ViewRegistry
    45  	warnings         []*Warning
    46  	warncnt          uint16
    47  	locks            map[string]bool
    48  	queriedDb        string
    49  	lastQueryInfo    map[string]int64
    50  	tx               Transaction
    51  	ignoreAutocommit bool
    52  
    53  	// When the MySQL database updates any tables related to privileges, it increments its counter. We then update our
    54  	// privilege set if our counter doesn't equal the database's counter.
    55  	privSetCounter uint64
    56  	privilegeSet   PrivilegeSet
    57  }
    58  
    59  func (s *BaseSession) GetLogger() *logrus.Entry {
    60  	s.mu.Lock()
    61  	defer s.mu.Unlock()
    62  
    63  	if s.logger == nil {
    64  		s.logger = s.newLogger()
    65  	}
    66  	return s.logger
    67  }
    68  
    69  func (s *BaseSession) newLogger() *logrus.Entry {
    70  	log := logrus.StandardLogger()
    71  	return logrus.NewEntry(log)
    72  }
    73  
    74  func (s *BaseSession) SetLogger(logger *logrus.Entry) {
    75  	s.mu.Lock()
    76  	defer s.mu.Unlock()
    77  	s.logger = logger
    78  }
    79  
    80  func (s *BaseSession) SetIgnoreAutoCommit(ignore bool) {
    81  	s.mu.Lock()
    82  	defer s.mu.Unlock()
    83  	s.ignoreAutocommit = ignore
    84  }
    85  
    86  func (s *BaseSession) GetIgnoreAutoCommit() bool {
    87  	s.mu.RLock()
    88  	defer s.mu.RUnlock()
    89  	return s.ignoreAutocommit
    90  }
    91  
    92  var _ Session = (*BaseSession)(nil)
    93  
    94  func (s *BaseSession) SetTransactionDatabase(dbName string) {
    95  	s.mu.Lock()
    96  	defer s.mu.Unlock()
    97  	s.transactionDb = dbName
    98  }
    99  
   100  func (s *BaseSession) GetTransactionDatabase() string {
   101  	s.mu.RLock()
   102  	defer s.mu.RUnlock()
   103  	return s.transactionDb
   104  }
   105  
   106  // Address returns the server address.
   107  func (s *BaseSession) Address() string { return s.addr }
   108  
   109  // Client returns session's client information.
   110  func (s *BaseSession) Client() Client { return s.client }
   111  
   112  // SetClient implements the Session interface.
   113  func (s *BaseSession) SetClient(c Client) {
   114  	s.client = c
   115  	return
   116  }
   117  
   118  // GetAllSessionVariables implements the Session interface.
   119  func (s *BaseSession) GetAllSessionVariables() map[string]interface{} {
   120  	m := make(map[string]interface{})
   121  	s.mu.RLock()
   122  	defer s.mu.RUnlock()
   123  
   124  	for k, v := range s.systemVars {
   125  		if sysType, ok := v.Var.Type.(SetType); ok {
   126  			if sv, ok := v.Val.(uint64); ok {
   127  				if svStr, err := sysType.BitsToString(sv); err == nil {
   128  					m[k] = svStr
   129  				}
   130  				continue
   131  			}
   132  		}
   133  		m[k] = v.Val
   134  	}
   135  	return m
   136  }
   137  
   138  // SetSessionVariable implements the Session interface.
   139  func (s *BaseSession) SetSessionVariable(ctx *Context, sysVarName string, value interface{}) error {
   140  	sysVarName = strings.ToLower(sysVarName)
   141  	sysVar, ok := s.systemVars[sysVarName]
   142  
   143  	// Since we initialized the system variables in this session at session start time, any variables that were added since that time
   144  	// will need to be added dynamically here.
   145  	// TODO: fix this with proper session lifecycle management
   146  	if !ok {
   147  		if SystemVariables != nil {
   148  			sv, _, ok := SystemVariables.GetGlobal(sysVarName)
   149  			if !ok {
   150  				return ErrUnknownSystemVariable.New(sysVarName)
   151  			}
   152  			return s.setSessVar(ctx, sv, value)
   153  		} else {
   154  			return ErrUnknownSystemVariable.New(sysVarName)
   155  		}
   156  	}
   157  
   158  	if !sysVar.Var.Dynamic || sysVar.Var.ValueFunction != nil {
   159  		return ErrSystemVariableReadOnly.New(sysVarName)
   160  	}
   161  	return s.setSessVar(ctx, sysVar.Var, value)
   162  }
   163  
   164  // InitSessionVariable implements the Session interface and is used to initialize variables (Including read-only variables)
   165  func (s *BaseSession) InitSessionVariable(ctx *Context, sysVarName string, value interface{}) error {
   166  	sysVar, _, ok := SystemVariables.GetGlobal(sysVarName)
   167  	if !ok {
   168  		return ErrUnknownSystemVariable.New(sysVarName)
   169  	}
   170  
   171  	val, ok := s.systemVars[sysVar.Name]
   172  	if ok && val.Val != sysVar.Default {
   173  		return ErrSystemVariableReinitialized.New(sysVarName)
   174  	}
   175  
   176  	return s.setSessVar(ctx, sysVar, value)
   177  }
   178  
   179  func (s *BaseSession) setSessVar(ctx *Context, sysVar SystemVariable, value interface{}) error {
   180  	if sysVar.Scope == SystemVariableScope_Global {
   181  		return ErrSystemVariableGlobalOnly.New(sysVar.Name)
   182  	}
   183  	convertedVal, _, err := sysVar.Type.Convert(value)
   184  	if err != nil {
   185  		return err
   186  	}
   187  	s.mu.Lock()
   188  	defer s.mu.Unlock()
   189  	svv := SystemVarValue{
   190  		Var: sysVar,
   191  		Val: convertedVal,
   192  	}
   193  
   194  	if sysVar.NotifyChanged != nil {
   195  		err := sysVar.NotifyChanged(SystemVariableScope_Session, svv)
   196  		if err != nil {
   197  			return err
   198  		}
   199  	}
   200  	s.systemVars[sysVar.Name] = svv
   201  	return nil
   202  }
   203  
   204  // SetUserVariable implements the Session interface.
   205  func (s *BaseSession) SetUserVariable(ctx *Context, varName string, value interface{}, typ Type) error {
   206  	return s.userVars.SetUserVariable(ctx, varName, value, typ)
   207  }
   208  
   209  // GetSessionVariable implements the Session interface.
   210  func (s *BaseSession) GetSessionVariable(ctx *Context, sysVarName string) (interface{}, error) {
   211  	s.mu.Lock()
   212  	defer s.mu.Unlock()
   213  
   214  	sysVarName = strings.ToLower(sysVarName)
   215  	sysVar, ok := s.systemVars[sysVarName]
   216  	if !ok {
   217  		return nil, ErrUnknownSystemVariable.New(sysVarName)
   218  	}
   219  	// TODO: this is duplicated from within variables.globalSystemVariables, suggesting the need for an interface
   220  	if sysType, ok := sysVar.Var.Type.(SetType); ok {
   221  		if sv, ok := sysVar.Val.(uint64); ok {
   222  			return sysType.BitsToString(sv)
   223  		}
   224  	}
   225  	return sysVar.Val, nil
   226  }
   227  
   228  // GetUserVariable implements the Session interface.
   229  func (s *BaseSession) GetUserVariable(ctx *Context, varName string) (Type, interface{}, error) {
   230  	return s.userVars.GetUserVariable(ctx, varName)
   231  }
   232  
   233  // GetCharacterSet returns the character set for this session (defined by the system variable `character_set_connection`).
   234  func (s *BaseSession) GetCharacterSet() CharacterSetID {
   235  	s.mu.RLock()
   236  	defer s.mu.RUnlock()
   237  	sysVar, _ := s.systemVars[characterSetConnectionSysVarName]
   238  	if sysVar.Val == nil {
   239  		return CharacterSet_Unspecified
   240  	}
   241  	charSet, err := ParseCharacterSet(sysVar.Val.(string))
   242  	if err != nil {
   243  		panic(err) // shouldn't happen
   244  	}
   245  	return charSet
   246  }
   247  
   248  // GetCharacterSetResults returns the result character set for this session (defined by the system variable `character_set_results`).
   249  func (s *BaseSession) GetCharacterSetResults() CharacterSetID {
   250  	s.mu.RLock()
   251  	defer s.mu.RUnlock()
   252  	sysVar, _ := s.systemVars[characterSetResultsSysVarName]
   253  	if sysVar.Val == nil {
   254  		return CharacterSet_Unspecified
   255  	}
   256  	charSet, err := ParseCharacterSet(sysVar.Val.(string))
   257  	if err != nil {
   258  		panic(err) // shouldn't happen
   259  	}
   260  	return charSet
   261  }
   262  
   263  // GetCollation returns the collation for this session (defined by the system variable `collation_connection`).
   264  func (s *BaseSession) GetCollation() CollationID {
   265  	s.mu.Lock()
   266  	defer s.mu.Unlock()
   267  	sysVar, ok := s.systemVars[collationConnectionSysVarName]
   268  
   269  	// In tests, the collation may not be set because the sys vars haven't been initialized
   270  	if !ok {
   271  		return Collation_Default
   272  	}
   273  	if sysVar.Val == nil {
   274  		return Collation_Unspecified
   275  	}
   276  	valStr := sysVar.Val.(string)
   277  	collation, err := ParseCollation(nil, &valStr, false)
   278  	if err != nil {
   279  		panic(err) // shouldn't happen
   280  	}
   281  	return collation
   282  }
   283  
   284  // ValidateSession provides integrators a chance to do any custom validation of this session before any query is executed in it.
   285  func (s *BaseSession) ValidateSession(ctx *Context) error {
   286  	return nil
   287  }
   288  
   289  // GetCurrentDatabase gets the current database for this session
   290  func (s *BaseSession) GetCurrentDatabase() string {
   291  	s.mu.RLock()
   292  	defer s.mu.RUnlock()
   293  	return s.currentDB
   294  }
   295  
   296  // SetCurrentDatabase sets the current database for this session
   297  func (s *BaseSession) SetCurrentDatabase(dbName string) {
   298  	s.mu.Lock()
   299  	defer s.mu.Unlock()
   300  	s.currentDB = dbName
   301  	logger := s.logger
   302  	if logger == nil {
   303  		logger = s.newLogger()
   304  	}
   305  	s.logger = logger.WithField(ConnectionDbLogField, dbName)
   306  }
   307  
   308  func (s *BaseSession) UseDatabase(ctx *Context, db Database) error {
   309  	// Nothing to do for default implementation
   310  	// Integrators should override this method on custom session implementations as necessary
   311  	return nil
   312  }
   313  
   314  // ID implements the Session interface.
   315  func (s *BaseSession) ID() uint32 { return s.id }
   316  
   317  // SetConnectionId sets the [id] for this session
   318  func (s *BaseSession) SetConnectionId(id uint32) {
   319  	s.id = id
   320  	return
   321  }
   322  
   323  // Warn stores the warning in the session.
   324  func (s *BaseSession) Warn(warn *Warning) {
   325  	s.mu.Lock()
   326  	defer s.mu.Unlock()
   327  	s.warnings = append(s.warnings, warn)
   328  }
   329  
   330  // Warnings returns a copy of session warnings (from the most recent - the last one)
   331  // The function implements sql.Session interface
   332  func (s *BaseSession) Warnings() []*Warning {
   333  	s.mu.RLock()
   334  	defer s.mu.RUnlock()
   335  
   336  	n := len(s.warnings)
   337  	warns := make([]*Warning, n)
   338  	for i := 0; i < n; i++ {
   339  		warns[i] = s.warnings[n-i-1]
   340  	}
   341  
   342  	return warns
   343  }
   344  
   345  // ClearWarnings cleans up session warnings
   346  func (s *BaseSession) ClearWarnings() {
   347  	s.mu.Lock()
   348  	defer s.mu.Unlock()
   349  
   350  	cnt := uint16(len(s.warnings))
   351  	if s.warncnt == cnt {
   352  		if s.warnings != nil {
   353  			s.warnings = s.warnings[:0]
   354  		}
   355  		s.warncnt = 0
   356  	} else {
   357  		s.warncnt = cnt
   358  	}
   359  }
   360  
   361  // WarningCount returns a number of session warnings
   362  func (s *BaseSession) WarningCount() uint16 {
   363  	s.mu.RLock()
   364  	defer s.mu.RUnlock()
   365  	return uint16(len(s.warnings))
   366  }
   367  
   368  // AddLock adds a lock to the set of locks owned by this user which will need to be released if this session terminates
   369  func (s *BaseSession) AddLock(lockName string) error {
   370  	s.mu.Lock()
   371  	defer s.mu.Unlock()
   372  
   373  	s.locks[lockName] = true
   374  	return nil
   375  }
   376  
   377  // DelLock removes a lock from the set of locks owned by this user
   378  func (s *BaseSession) DelLock(lockName string) error {
   379  	s.mu.Lock()
   380  	defer s.mu.Unlock()
   381  
   382  	delete(s.locks, lockName)
   383  	return nil
   384  }
   385  
   386  // IterLocks iterates through all locks owned by this user
   387  func (s *BaseSession) IterLocks(cb func(name string) error) error {
   388  	s.mu.RLock()
   389  	defer s.mu.RUnlock()
   390  
   391  	for name := range s.locks {
   392  		err := cb(name)
   393  
   394  		if err != nil {
   395  			return err
   396  		}
   397  	}
   398  
   399  	return nil
   400  }
   401  
   402  // GetQueriedDatabase implements the Session interface.
   403  func (s *BaseSession) GetQueriedDatabase() string {
   404  	s.mu.RLock()
   405  	defer s.mu.RUnlock()
   406  	return s.queriedDb
   407  }
   408  
   409  // SetQueriedDatabase implements the Session interface.
   410  func (s *BaseSession) SetQueriedDatabase(dbName string) {
   411  	s.mu.Lock()
   412  	defer s.mu.Unlock()
   413  	s.queriedDb = dbName
   414  }
   415  
   416  func (s *BaseSession) GetIndexRegistry() *IndexRegistry {
   417  	s.mu.Lock()
   418  	defer s.mu.Unlock()
   419  	return s.idxReg
   420  }
   421  
   422  func (s *BaseSession) GetViewRegistry() *ViewRegistry {
   423  	s.mu.Lock()
   424  	defer s.mu.Unlock()
   425  	return s.viewReg
   426  }
   427  
   428  func (s *BaseSession) SetIndexRegistry(reg *IndexRegistry) {
   429  	s.mu.Lock()
   430  	defer s.mu.Unlock()
   431  	s.idxReg = reg
   432  }
   433  
   434  func (s *BaseSession) SetViewRegistry(reg *ViewRegistry) {
   435  	s.mu.Lock()
   436  	defer s.mu.Unlock()
   437  	s.viewReg = reg
   438  }
   439  
   440  func (s *BaseSession) SetLastQueryInfo(key string, value int64) {
   441  	s.mu.Lock()
   442  	defer s.mu.Unlock()
   443  	s.lastQueryInfo[key] = value
   444  }
   445  
   446  func (s *BaseSession) GetLastQueryInfo(key string) int64 {
   447  	s.mu.RLock()
   448  	defer s.mu.RUnlock()
   449  	return s.lastQueryInfo[key]
   450  }
   451  
   452  func (s *BaseSession) GetTransaction() Transaction {
   453  	s.mu.RLock()
   454  	defer s.mu.RUnlock()
   455  	return s.tx
   456  }
   457  
   458  func (s *BaseSession) SetTransaction(tx Transaction) {
   459  	s.mu.Lock()
   460  	defer s.mu.Unlock()
   461  	s.tx = tx
   462  }
   463  
   464  func (s *BaseSession) GetPrivilegeSet() (PrivilegeSet, uint64) {
   465  	return s.privilegeSet, s.privSetCounter
   466  }
   467  
   468  func (s *BaseSession) SetPrivilegeSet(newPs PrivilegeSet, counter uint64) {
   469  	s.privSetCounter = counter
   470  	s.privilegeSet = newPs
   471  }
   472  
   473  // NewBaseSessionWithClientServer creates a new session with data.
   474  func NewBaseSessionWithClientServer(server string, client Client, id uint32) *BaseSession {
   475  	// TODO: if system variable "activate_all_roles_on_login" if set, activate all roles
   476  	var sessionVars map[string]SystemVarValue
   477  	if SystemVariables != nil {
   478  		sessionVars = SystemVariables.NewSessionMap()
   479  	} else {
   480  		sessionVars = make(map[string]SystemVarValue)
   481  	}
   482  	return &BaseSession{
   483  		addr:           server,
   484  		client:         client,
   485  		id:             id,
   486  		systemVars:     sessionVars,
   487  		userVars:       NewUserVars(),
   488  		idxReg:         NewIndexRegistry(),
   489  		viewReg:        NewViewRegistry(),
   490  		mu:             sync.RWMutex{},
   491  		locks:          make(map[string]bool),
   492  		lastQueryInfo:  defaultLastQueryInfo(),
   493  		privSetCounter: 0,
   494  	}
   495  }
   496  
   497  // NewBaseSession creates a new empty session.
   498  func NewBaseSession() *BaseSession {
   499  	// TODO: if system variable "activate_all_roles_on_login" if set, activate all roles
   500  	var sessionVars map[string]SystemVarValue
   501  	if SystemVariables != nil {
   502  		sessionVars = SystemVariables.NewSessionMap()
   503  	} else {
   504  		sessionVars = make(map[string]SystemVarValue)
   505  	}
   506  	return &BaseSession{
   507  		id:             atomic.AddUint32(&autoSessionIDs, 1),
   508  		systemVars:     sessionVars,
   509  		userVars:       NewUserVars(),
   510  		idxReg:         NewIndexRegistry(),
   511  		viewReg:        NewViewRegistry(),
   512  		mu:             sync.RWMutex{},
   513  		locks:          make(map[string]bool),
   514  		lastQueryInfo:  defaultLastQueryInfo(),
   515  		privSetCounter: 0,
   516  	}
   517  }