github.com/dolthub/go-mysql-server@v0.18.0/memory/session.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package memory 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 ) 23 24 type GlobalsMap = map[string]interface{} 25 type Session struct { 26 *sql.BaseSession 27 dbProvider sql.DatabaseProvider 28 tables map[tableKey]*TableData 29 editAccumulators map[tableKey]tableEditAccumulator 30 persistedGlobals GlobalsMap 31 validateCallback func() 32 } 33 34 var _ sql.Session = (*Session)(nil) 35 var _ sql.TransactionSession = (*Session)(nil) 36 var _ sql.Transaction = (*Transaction)(nil) 37 var _ sql.PersistableSession = (*Session)(nil) 38 39 // NewSession returns the new session for this object 40 func NewSession(baseSession *sql.BaseSession, provider sql.DatabaseProvider) *Session { 41 return &Session{ 42 BaseSession: baseSession, 43 dbProvider: provider, 44 tables: make(map[tableKey]*TableData), 45 editAccumulators: make(map[tableKey]tableEditAccumulator), 46 } 47 } 48 49 func SessionFromContext(ctx *sql.Context) *Session { 50 return ctx.Session.(*Session) 51 } 52 53 type Transaction struct { 54 readOnly bool 55 } 56 57 var _ sql.Transaction = (*Transaction)(nil) 58 59 func (s *Transaction) String() string { 60 return "in-memory transaction" 61 } 62 63 func (s *Transaction) IsReadOnly() bool { 64 return s.readOnly 65 } 66 67 type tableKey struct { 68 db string 69 table string 70 } 71 72 func key(t *TableData) tableKey { 73 return tableKey{strings.ToLower(t.dbName), strings.ToLower(t.tableName)} 74 } 75 76 // editAccumulator returns the edit accumulator for this session for the table provided. Some statement types, like 77 // updates with an on duplicate key clause, require an accumulator to be shared among all table editors 78 func (s *Session) editAccumulator(t *Table) tableEditAccumulator { 79 ea, ok := s.editAccumulators[key(t.data)] 80 if !ok { 81 ea = newTableEditAccumulator(t.data) 82 s.editAccumulators[key(t.data)] = ea 83 } 84 return ea 85 } 86 87 func (s *Session) clearEditAccumulator(t *Table) { 88 delete(s.editAccumulators, key(t.data)) 89 } 90 91 func keyFromNames(dbName, tableName string) tableKey { 92 return tableKey{strings.ToLower(dbName), strings.ToLower(tableName)} 93 } 94 95 // tableData returns the table data for this session for the table provided 96 func (s *Session) tableData(t *Table) *TableData { 97 td, ok := s.tables[key(t.data)] 98 if !ok { 99 s.tables[key(t.data)] = t.data 100 return t.data 101 } 102 103 return td 104 } 105 106 // putTable stores the table data for this session for the table provided 107 func (s *Session) putTable(d *TableData) { 108 s.tables[key(d)] = d 109 delete(s.editAccumulators, key(d)) 110 } 111 112 // dropTable clears the table data for the session 113 func (s *Session) dropTable(d *TableData) { 114 delete(s.tables, key(d)) 115 } 116 117 // StartTransaction clears session state and returns a new transaction object. 118 // Because we don't support concurrency, we store table data changes in the session, rather than the transaction itself. 119 func (s *Session) StartTransaction(ctx *sql.Context, tCharacteristic sql.TransactionCharacteristic) (sql.Transaction, error) { 120 s.tables = make(map[tableKey]*TableData) 121 s.editAccumulators = make(map[tableKey]tableEditAccumulator) 122 return &Transaction{tCharacteristic == sql.ReadOnly}, nil 123 } 124 125 func (s *Session) CommitTransaction(ctx *sql.Context, tx sql.Transaction) error { 126 for key := range s.tables { 127 if key.db == "" && key.table == "" { 128 // dual table 129 continue 130 } 131 db, err := s.dbProvider.Database(ctx, key.db) 132 if err != nil { 133 return err 134 } 135 136 var baseDb *BaseDatabase 137 switch db := db.(type) { 138 case *BaseDatabase: 139 baseDb = db 140 case *Database: 141 baseDb = db.BaseDatabase 142 case *HistoryDatabase: 143 baseDb = db.BaseDatabase 144 default: 145 return fmt.Errorf("unknown database type %T", db) 146 } 147 baseDb.putTable(s.tables[key].Table(baseDb)) 148 } 149 150 return nil 151 } 152 153 func (s *Session) Rollback(ctx *sql.Context, transaction sql.Transaction) error { 154 s.tables = make(map[tableKey]*TableData) 155 s.editAccumulators = make(map[tableKey]tableEditAccumulator) 156 return nil 157 } 158 159 func (s *Session) CreateSavepoint(ctx *sql.Context, transaction sql.Transaction, name string) error { 160 return fmt.Errorf("savepoints are not supported in memory sessions") 161 } 162 163 func (s *Session) RollbackToSavepoint(ctx *sql.Context, transaction sql.Transaction, name string) error { 164 return fmt.Errorf("savepoints are not supported in memory sessions") 165 } 166 167 func (s *Session) ReleaseSavepoint(ctx *sql.Context, transaction sql.Transaction, name string) error { 168 return fmt.Errorf("savepoints are not supported in memory sessions") 169 } 170 171 // PersistGlobal implements sql.PersistableSession 172 func (s *Session) PersistGlobal(sysVarName string, value interface{}) error { 173 sysVar, _, ok := sql.SystemVariables.GetGlobal(sysVarName) 174 if !ok { 175 return sql.ErrUnknownSystemVariable.New(sysVarName) 176 } 177 val, _, err := sysVar.Type.Convert(value) 178 if err != nil { 179 return err 180 } 181 s.persistedGlobals[sysVarName] = val 182 return nil 183 } 184 185 func (s *Session) SetGlobals(globals map[string]interface{}) *Session { 186 s.persistedGlobals = globals 187 return s 188 } 189 190 func (s *Session) SetValidationCallback(validationCallback func()) *Session { 191 s.validateCallback = validationCallback 192 return s 193 } 194 195 // RemovePersistedGlobal implements sql.PersistableSession 196 func (s *Session) RemovePersistedGlobal(sysVarName string) error { 197 if _, _, ok := sql.SystemVariables.GetGlobal(sysVarName); !ok { 198 return sql.ErrUnknownSystemVariable.New(sysVarName) 199 } 200 delete(s.persistedGlobals, sysVarName) 201 return nil 202 } 203 204 // RemoveAllPersistedGlobals implements sql.PersistableSession 205 func (s *Session) RemoveAllPersistedGlobals() error { 206 s.persistedGlobals = GlobalsMap{} 207 return nil 208 } 209 210 // GetPersistedValue implements sql.PersistableSession 211 func (s *Session) GetPersistedValue(k string) (interface{}, error) { 212 return s.persistedGlobals[k], nil 213 } 214 215 // ValidateSession counts the number of times this method is called. 216 func (s *Session) ValidateSession(ctx *sql.Context) error { 217 if s.validateCallback != nil { 218 s.validateCallback() 219 } 220 return s.BaseSession.ValidateSession(ctx) 221 }