github.com/dolthub/go-mysql-server@v0.18.0/memory/session.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package memory
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  )
    23  
    24  type GlobalsMap = map[string]interface{}
    25  type Session struct {
    26  	*sql.BaseSession
    27  	dbProvider       sql.DatabaseProvider
    28  	tables           map[tableKey]*TableData
    29  	editAccumulators map[tableKey]tableEditAccumulator
    30  	persistedGlobals GlobalsMap
    31  	validateCallback func()
    32  }
    33  
    34  var _ sql.Session = (*Session)(nil)
    35  var _ sql.TransactionSession = (*Session)(nil)
    36  var _ sql.Transaction = (*Transaction)(nil)
    37  var _ sql.PersistableSession = (*Session)(nil)
    38  
    39  // NewSession returns the new session for this object
    40  func NewSession(baseSession *sql.BaseSession, provider sql.DatabaseProvider) *Session {
    41  	return &Session{
    42  		BaseSession:      baseSession,
    43  		dbProvider:       provider,
    44  		tables:           make(map[tableKey]*TableData),
    45  		editAccumulators: make(map[tableKey]tableEditAccumulator),
    46  	}
    47  }
    48  
    49  func SessionFromContext(ctx *sql.Context) *Session {
    50  	return ctx.Session.(*Session)
    51  }
    52  
    53  type Transaction struct {
    54  	readOnly bool
    55  }
    56  
    57  var _ sql.Transaction = (*Transaction)(nil)
    58  
    59  func (s *Transaction) String() string {
    60  	return "in-memory transaction"
    61  }
    62  
    63  func (s *Transaction) IsReadOnly() bool {
    64  	return s.readOnly
    65  }
    66  
    67  type tableKey struct {
    68  	db    string
    69  	table string
    70  }
    71  
    72  func key(t *TableData) tableKey {
    73  	return tableKey{strings.ToLower(t.dbName), strings.ToLower(t.tableName)}
    74  }
    75  
    76  // editAccumulator returns the edit accumulator for this session for the table provided. Some statement types, like
    77  // updates with an on duplicate key clause, require an accumulator to be shared among all table editors
    78  func (s *Session) editAccumulator(t *Table) tableEditAccumulator {
    79  	ea, ok := s.editAccumulators[key(t.data)]
    80  	if !ok {
    81  		ea = newTableEditAccumulator(t.data)
    82  		s.editAccumulators[key(t.data)] = ea
    83  	}
    84  	return ea
    85  }
    86  
    87  func (s *Session) clearEditAccumulator(t *Table) {
    88  	delete(s.editAccumulators, key(t.data))
    89  }
    90  
    91  func keyFromNames(dbName, tableName string) tableKey {
    92  	return tableKey{strings.ToLower(dbName), strings.ToLower(tableName)}
    93  }
    94  
    95  // tableData returns the table data for this session for the table provided
    96  func (s *Session) tableData(t *Table) *TableData {
    97  	td, ok := s.tables[key(t.data)]
    98  	if !ok {
    99  		s.tables[key(t.data)] = t.data
   100  		return t.data
   101  	}
   102  
   103  	return td
   104  }
   105  
   106  // putTable stores the table data for this session for the table provided
   107  func (s *Session) putTable(d *TableData) {
   108  	s.tables[key(d)] = d
   109  	delete(s.editAccumulators, key(d))
   110  }
   111  
   112  // dropTable clears the table data for the session
   113  func (s *Session) dropTable(d *TableData) {
   114  	delete(s.tables, key(d))
   115  }
   116  
   117  // StartTransaction clears session state and returns a new transaction object.
   118  // Because we don't support concurrency, we store table data changes in the session, rather than the transaction itself.
   119  func (s *Session) StartTransaction(ctx *sql.Context, tCharacteristic sql.TransactionCharacteristic) (sql.Transaction, error) {
   120  	s.tables = make(map[tableKey]*TableData)
   121  	s.editAccumulators = make(map[tableKey]tableEditAccumulator)
   122  	return &Transaction{tCharacteristic == sql.ReadOnly}, nil
   123  }
   124  
   125  func (s *Session) CommitTransaction(ctx *sql.Context, tx sql.Transaction) error {
   126  	for key := range s.tables {
   127  		if key.db == "" && key.table == "" {
   128  			// dual table
   129  			continue
   130  		}
   131  		db, err := s.dbProvider.Database(ctx, key.db)
   132  		if err != nil {
   133  			return err
   134  		}
   135  
   136  		var baseDb *BaseDatabase
   137  		switch db := db.(type) {
   138  		case *BaseDatabase:
   139  			baseDb = db
   140  		case *Database:
   141  			baseDb = db.BaseDatabase
   142  		case *HistoryDatabase:
   143  			baseDb = db.BaseDatabase
   144  		default:
   145  			return fmt.Errorf("unknown database type %T", db)
   146  		}
   147  		baseDb.putTable(s.tables[key].Table(baseDb))
   148  	}
   149  
   150  	return nil
   151  }
   152  
   153  func (s *Session) Rollback(ctx *sql.Context, transaction sql.Transaction) error {
   154  	s.tables = make(map[tableKey]*TableData)
   155  	s.editAccumulators = make(map[tableKey]tableEditAccumulator)
   156  	return nil
   157  }
   158  
   159  func (s *Session) CreateSavepoint(ctx *sql.Context, transaction sql.Transaction, name string) error {
   160  	return fmt.Errorf("savepoints are not supported in memory sessions")
   161  }
   162  
   163  func (s *Session) RollbackToSavepoint(ctx *sql.Context, transaction sql.Transaction, name string) error {
   164  	return fmt.Errorf("savepoints are not supported in memory sessions")
   165  }
   166  
   167  func (s *Session) ReleaseSavepoint(ctx *sql.Context, transaction sql.Transaction, name string) error {
   168  	return fmt.Errorf("savepoints are not supported in memory sessions")
   169  }
   170  
   171  // PersistGlobal implements sql.PersistableSession
   172  func (s *Session) PersistGlobal(sysVarName string, value interface{}) error {
   173  	sysVar, _, ok := sql.SystemVariables.GetGlobal(sysVarName)
   174  	if !ok {
   175  		return sql.ErrUnknownSystemVariable.New(sysVarName)
   176  	}
   177  	val, _, err := sysVar.Type.Convert(value)
   178  	if err != nil {
   179  		return err
   180  	}
   181  	s.persistedGlobals[sysVarName] = val
   182  	return nil
   183  }
   184  
   185  func (s *Session) SetGlobals(globals map[string]interface{}) *Session {
   186  	s.persistedGlobals = globals
   187  	return s
   188  }
   189  
   190  func (s *Session) SetValidationCallback(validationCallback func()) *Session {
   191  	s.validateCallback = validationCallback
   192  	return s
   193  }
   194  
   195  // RemovePersistedGlobal implements sql.PersistableSession
   196  func (s *Session) RemovePersistedGlobal(sysVarName string) error {
   197  	if _, _, ok := sql.SystemVariables.GetGlobal(sysVarName); !ok {
   198  		return sql.ErrUnknownSystemVariable.New(sysVarName)
   199  	}
   200  	delete(s.persistedGlobals, sysVarName)
   201  	return nil
   202  }
   203  
   204  // RemoveAllPersistedGlobals implements sql.PersistableSession
   205  func (s *Session) RemoveAllPersistedGlobals() error {
   206  	s.persistedGlobals = GlobalsMap{}
   207  	return nil
   208  }
   209  
   210  // GetPersistedValue implements sql.PersistableSession
   211  func (s *Session) GetPersistedValue(k string) (interface{}, error) {
   212  	return s.persistedGlobals[k], nil
   213  }
   214  
   215  // ValidateSession counts the number of times this method is called.
   216  func (s *Session) ValidateSession(ctx *sql.Context) error {
   217  	if s.validateCallback != nil {
   218  		s.validateCallback()
   219  	}
   220  	return s.BaseSession.ValidateSession(ctx)
   221  }