github.com/dolthub/go-mysql-server@v0.18.0/sql/base_session.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sql 16 17 import ( 18 "strings" 19 "sync" 20 "sync/atomic" 21 22 "github.com/sirupsen/logrus" 23 ) 24 25 // BaseSession is the basic session implementation. Integrators should typically embed this type into their custom 26 // session implementations to get base functionality. 27 type BaseSession struct { 28 id uint32 29 addr string 30 client Client 31 32 // TODO(andy): in principle, we shouldn't 33 // have concurrent access to the session. 34 // Needs investigation. 35 mu sync.RWMutex 36 37 // |mu| protects the following state 38 logger *logrus.Entry 39 currentDB string 40 transactionDb string 41 systemVars map[string]SystemVarValue 42 userVars SessionUserVariables 43 idxReg *IndexRegistry 44 viewReg *ViewRegistry 45 warnings []*Warning 46 warncnt uint16 47 locks map[string]bool 48 queriedDb string 49 lastQueryInfo map[string]int64 50 tx Transaction 51 ignoreAutocommit bool 52 53 // When the MySQL database updates any tables related to privileges, it increments its counter. We then update our 54 // privilege set if our counter doesn't equal the database's counter. 55 privSetCounter uint64 56 privilegeSet PrivilegeSet 57 } 58 59 func (s *BaseSession) GetLogger() *logrus.Entry { 60 s.mu.Lock() 61 defer s.mu.Unlock() 62 63 if s.logger == nil { 64 s.logger = s.newLogger() 65 } 66 return s.logger 67 } 68 69 func (s *BaseSession) newLogger() *logrus.Entry { 70 log := logrus.StandardLogger() 71 return logrus.NewEntry(log) 72 } 73 74 func (s *BaseSession) SetLogger(logger *logrus.Entry) { 75 s.mu.Lock() 76 defer s.mu.Unlock() 77 s.logger = logger 78 } 79 80 func (s *BaseSession) SetIgnoreAutoCommit(ignore bool) { 81 s.mu.Lock() 82 defer s.mu.Unlock() 83 s.ignoreAutocommit = ignore 84 } 85 86 func (s *BaseSession) GetIgnoreAutoCommit() bool { 87 s.mu.RLock() 88 defer s.mu.RUnlock() 89 return s.ignoreAutocommit 90 } 91 92 var _ Session = (*BaseSession)(nil) 93 94 func (s *BaseSession) SetTransactionDatabase(dbName string) { 95 s.mu.Lock() 96 defer s.mu.Unlock() 97 s.transactionDb = dbName 98 } 99 100 func (s *BaseSession) GetTransactionDatabase() string { 101 s.mu.RLock() 102 defer s.mu.RUnlock() 103 return s.transactionDb 104 } 105 106 // Address returns the server address. 107 func (s *BaseSession) Address() string { return s.addr } 108 109 // Client returns session's client information. 110 func (s *BaseSession) Client() Client { return s.client } 111 112 // SetClient implements the Session interface. 113 func (s *BaseSession) SetClient(c Client) { 114 s.client = c 115 return 116 } 117 118 // GetAllSessionVariables implements the Session interface. 119 func (s *BaseSession) GetAllSessionVariables() map[string]interface{} { 120 m := make(map[string]interface{}) 121 s.mu.RLock() 122 defer s.mu.RUnlock() 123 124 for k, v := range s.systemVars { 125 if sysType, ok := v.Var.Type.(SetType); ok { 126 if sv, ok := v.Val.(uint64); ok { 127 if svStr, err := sysType.BitsToString(sv); err == nil { 128 m[k] = svStr 129 } 130 continue 131 } 132 } 133 m[k] = v.Val 134 } 135 return m 136 } 137 138 // SetSessionVariable implements the Session interface. 139 func (s *BaseSession) SetSessionVariable(ctx *Context, sysVarName string, value interface{}) error { 140 sysVarName = strings.ToLower(sysVarName) 141 sysVar, ok := s.systemVars[sysVarName] 142 143 // Since we initialized the system variables in this session at session start time, any variables that were added since that time 144 // will need to be added dynamically here. 145 // TODO: fix this with proper session lifecycle management 146 if !ok { 147 if SystemVariables != nil { 148 sv, _, ok := SystemVariables.GetGlobal(sysVarName) 149 if !ok { 150 return ErrUnknownSystemVariable.New(sysVarName) 151 } 152 return s.setSessVar(ctx, sv, value) 153 } else { 154 return ErrUnknownSystemVariable.New(sysVarName) 155 } 156 } 157 158 if !sysVar.Var.Dynamic || sysVar.Var.ValueFunction != nil { 159 return ErrSystemVariableReadOnly.New(sysVarName) 160 } 161 return s.setSessVar(ctx, sysVar.Var, value) 162 } 163 164 // InitSessionVariable implements the Session interface and is used to initialize variables (Including read-only variables) 165 func (s *BaseSession) InitSessionVariable(ctx *Context, sysVarName string, value interface{}) error { 166 sysVar, _, ok := SystemVariables.GetGlobal(sysVarName) 167 if !ok { 168 return ErrUnknownSystemVariable.New(sysVarName) 169 } 170 171 val, ok := s.systemVars[sysVar.Name] 172 if ok && val.Val != sysVar.Default { 173 return ErrSystemVariableReinitialized.New(sysVarName) 174 } 175 176 return s.setSessVar(ctx, sysVar, value) 177 } 178 179 func (s *BaseSession) setSessVar(ctx *Context, sysVar SystemVariable, value interface{}) error { 180 if sysVar.Scope == SystemVariableScope_Global { 181 return ErrSystemVariableGlobalOnly.New(sysVar.Name) 182 } 183 convertedVal, _, err := sysVar.Type.Convert(value) 184 if err != nil { 185 return err 186 } 187 s.mu.Lock() 188 defer s.mu.Unlock() 189 svv := SystemVarValue{ 190 Var: sysVar, 191 Val: convertedVal, 192 } 193 194 if sysVar.NotifyChanged != nil { 195 err := sysVar.NotifyChanged(SystemVariableScope_Session, svv) 196 if err != nil { 197 return err 198 } 199 } 200 s.systemVars[sysVar.Name] = svv 201 return nil 202 } 203 204 // SetUserVariable implements the Session interface. 205 func (s *BaseSession) SetUserVariable(ctx *Context, varName string, value interface{}, typ Type) error { 206 return s.userVars.SetUserVariable(ctx, varName, value, typ) 207 } 208 209 // GetSessionVariable implements the Session interface. 210 func (s *BaseSession) GetSessionVariable(ctx *Context, sysVarName string) (interface{}, error) { 211 s.mu.Lock() 212 defer s.mu.Unlock() 213 214 sysVarName = strings.ToLower(sysVarName) 215 sysVar, ok := s.systemVars[sysVarName] 216 if !ok { 217 return nil, ErrUnknownSystemVariable.New(sysVarName) 218 } 219 // TODO: this is duplicated from within variables.globalSystemVariables, suggesting the need for an interface 220 if sysType, ok := sysVar.Var.Type.(SetType); ok { 221 if sv, ok := sysVar.Val.(uint64); ok { 222 return sysType.BitsToString(sv) 223 } 224 } 225 return sysVar.Val, nil 226 } 227 228 // GetUserVariable implements the Session interface. 229 func (s *BaseSession) GetUserVariable(ctx *Context, varName string) (Type, interface{}, error) { 230 return s.userVars.GetUserVariable(ctx, varName) 231 } 232 233 // GetCharacterSet returns the character set for this session (defined by the system variable `character_set_connection`). 234 func (s *BaseSession) GetCharacterSet() CharacterSetID { 235 s.mu.RLock() 236 defer s.mu.RUnlock() 237 sysVar, _ := s.systemVars[characterSetConnectionSysVarName] 238 if sysVar.Val == nil { 239 return CharacterSet_Unspecified 240 } 241 charSet, err := ParseCharacterSet(sysVar.Val.(string)) 242 if err != nil { 243 panic(err) // shouldn't happen 244 } 245 return charSet 246 } 247 248 // GetCharacterSetResults returns the result character set for this session (defined by the system variable `character_set_results`). 249 func (s *BaseSession) GetCharacterSetResults() CharacterSetID { 250 s.mu.RLock() 251 defer s.mu.RUnlock() 252 sysVar, _ := s.systemVars[characterSetResultsSysVarName] 253 if sysVar.Val == nil { 254 return CharacterSet_Unspecified 255 } 256 charSet, err := ParseCharacterSet(sysVar.Val.(string)) 257 if err != nil { 258 panic(err) // shouldn't happen 259 } 260 return charSet 261 } 262 263 // GetCollation returns the collation for this session (defined by the system variable `collation_connection`). 264 func (s *BaseSession) GetCollation() CollationID { 265 s.mu.Lock() 266 defer s.mu.Unlock() 267 sysVar, ok := s.systemVars[collationConnectionSysVarName] 268 269 // In tests, the collation may not be set because the sys vars haven't been initialized 270 if !ok { 271 return Collation_Default 272 } 273 if sysVar.Val == nil { 274 return Collation_Unspecified 275 } 276 valStr := sysVar.Val.(string) 277 collation, err := ParseCollation(nil, &valStr, false) 278 if err != nil { 279 panic(err) // shouldn't happen 280 } 281 return collation 282 } 283 284 // ValidateSession provides integrators a chance to do any custom validation of this session before any query is executed in it. 285 func (s *BaseSession) ValidateSession(ctx *Context) error { 286 return nil 287 } 288 289 // GetCurrentDatabase gets the current database for this session 290 func (s *BaseSession) GetCurrentDatabase() string { 291 s.mu.RLock() 292 defer s.mu.RUnlock() 293 return s.currentDB 294 } 295 296 // SetCurrentDatabase sets the current database for this session 297 func (s *BaseSession) SetCurrentDatabase(dbName string) { 298 s.mu.Lock() 299 defer s.mu.Unlock() 300 s.currentDB = dbName 301 logger := s.logger 302 if logger == nil { 303 logger = s.newLogger() 304 } 305 s.logger = logger.WithField(ConnectionDbLogField, dbName) 306 } 307 308 func (s *BaseSession) UseDatabase(ctx *Context, db Database) error { 309 // Nothing to do for default implementation 310 // Integrators should override this method on custom session implementations as necessary 311 return nil 312 } 313 314 // ID implements the Session interface. 315 func (s *BaseSession) ID() uint32 { return s.id } 316 317 // SetConnectionId sets the [id] for this session 318 func (s *BaseSession) SetConnectionId(id uint32) { 319 s.id = id 320 return 321 } 322 323 // Warn stores the warning in the session. 324 func (s *BaseSession) Warn(warn *Warning) { 325 s.mu.Lock() 326 defer s.mu.Unlock() 327 s.warnings = append(s.warnings, warn) 328 } 329 330 // Warnings returns a copy of session warnings (from the most recent - the last one) 331 // The function implements sql.Session interface 332 func (s *BaseSession) Warnings() []*Warning { 333 s.mu.RLock() 334 defer s.mu.RUnlock() 335 336 n := len(s.warnings) 337 warns := make([]*Warning, n) 338 for i := 0; i < n; i++ { 339 warns[i] = s.warnings[n-i-1] 340 } 341 342 return warns 343 } 344 345 // ClearWarnings cleans up session warnings 346 func (s *BaseSession) ClearWarnings() { 347 s.mu.Lock() 348 defer s.mu.Unlock() 349 350 cnt := uint16(len(s.warnings)) 351 if s.warncnt == cnt { 352 if s.warnings != nil { 353 s.warnings = s.warnings[:0] 354 } 355 s.warncnt = 0 356 } else { 357 s.warncnt = cnt 358 } 359 } 360 361 // WarningCount returns a number of session warnings 362 func (s *BaseSession) WarningCount() uint16 { 363 s.mu.RLock() 364 defer s.mu.RUnlock() 365 return uint16(len(s.warnings)) 366 } 367 368 // AddLock adds a lock to the set of locks owned by this user which will need to be released if this session terminates 369 func (s *BaseSession) AddLock(lockName string) error { 370 s.mu.Lock() 371 defer s.mu.Unlock() 372 373 s.locks[lockName] = true 374 return nil 375 } 376 377 // DelLock removes a lock from the set of locks owned by this user 378 func (s *BaseSession) DelLock(lockName string) error { 379 s.mu.Lock() 380 defer s.mu.Unlock() 381 382 delete(s.locks, lockName) 383 return nil 384 } 385 386 // IterLocks iterates through all locks owned by this user 387 func (s *BaseSession) IterLocks(cb func(name string) error) error { 388 s.mu.RLock() 389 defer s.mu.RUnlock() 390 391 for name := range s.locks { 392 err := cb(name) 393 394 if err != nil { 395 return err 396 } 397 } 398 399 return nil 400 } 401 402 // GetQueriedDatabase implements the Session interface. 403 func (s *BaseSession) GetQueriedDatabase() string { 404 s.mu.RLock() 405 defer s.mu.RUnlock() 406 return s.queriedDb 407 } 408 409 // SetQueriedDatabase implements the Session interface. 410 func (s *BaseSession) SetQueriedDatabase(dbName string) { 411 s.mu.Lock() 412 defer s.mu.Unlock() 413 s.queriedDb = dbName 414 } 415 416 func (s *BaseSession) GetIndexRegistry() *IndexRegistry { 417 s.mu.Lock() 418 defer s.mu.Unlock() 419 return s.idxReg 420 } 421 422 func (s *BaseSession) GetViewRegistry() *ViewRegistry { 423 s.mu.Lock() 424 defer s.mu.Unlock() 425 return s.viewReg 426 } 427 428 func (s *BaseSession) SetIndexRegistry(reg *IndexRegistry) { 429 s.mu.Lock() 430 defer s.mu.Unlock() 431 s.idxReg = reg 432 } 433 434 func (s *BaseSession) SetViewRegistry(reg *ViewRegistry) { 435 s.mu.Lock() 436 defer s.mu.Unlock() 437 s.viewReg = reg 438 } 439 440 func (s *BaseSession) SetLastQueryInfo(key string, value int64) { 441 s.mu.Lock() 442 defer s.mu.Unlock() 443 s.lastQueryInfo[key] = value 444 } 445 446 func (s *BaseSession) GetLastQueryInfo(key string) int64 { 447 s.mu.RLock() 448 defer s.mu.RUnlock() 449 return s.lastQueryInfo[key] 450 } 451 452 func (s *BaseSession) GetTransaction() Transaction { 453 s.mu.RLock() 454 defer s.mu.RUnlock() 455 return s.tx 456 } 457 458 func (s *BaseSession) SetTransaction(tx Transaction) { 459 s.mu.Lock() 460 defer s.mu.Unlock() 461 s.tx = tx 462 } 463 464 func (s *BaseSession) GetPrivilegeSet() (PrivilegeSet, uint64) { 465 return s.privilegeSet, s.privSetCounter 466 } 467 468 func (s *BaseSession) SetPrivilegeSet(newPs PrivilegeSet, counter uint64) { 469 s.privSetCounter = counter 470 s.privilegeSet = newPs 471 } 472 473 // NewBaseSessionWithClientServer creates a new session with data. 474 func NewBaseSessionWithClientServer(server string, client Client, id uint32) *BaseSession { 475 // TODO: if system variable "activate_all_roles_on_login" if set, activate all roles 476 var sessionVars map[string]SystemVarValue 477 if SystemVariables != nil { 478 sessionVars = SystemVariables.NewSessionMap() 479 } else { 480 sessionVars = make(map[string]SystemVarValue) 481 } 482 return &BaseSession{ 483 addr: server, 484 client: client, 485 id: id, 486 systemVars: sessionVars, 487 userVars: NewUserVars(), 488 idxReg: NewIndexRegistry(), 489 viewReg: NewViewRegistry(), 490 mu: sync.RWMutex{}, 491 locks: make(map[string]bool), 492 lastQueryInfo: defaultLastQueryInfo(), 493 privSetCounter: 0, 494 } 495 } 496 497 // NewBaseSession creates a new empty session. 498 func NewBaseSession() *BaseSession { 499 // TODO: if system variable "activate_all_roles_on_login" if set, activate all roles 500 var sessionVars map[string]SystemVarValue 501 if SystemVariables != nil { 502 sessionVars = SystemVariables.NewSessionMap() 503 } else { 504 sessionVars = make(map[string]SystemVarValue) 505 } 506 return &BaseSession{ 507 id: atomic.AddUint32(&autoSessionIDs, 1), 508 systemVars: sessionVars, 509 userVars: NewUserVars(), 510 idxReg: NewIndexRegistry(), 511 viewReg: NewViewRegistry(), 512 mu: sync.RWMutex{}, 513 locks: make(map[string]bool), 514 lastQueryInfo: defaultLastQueryInfo(), 515 privSetCounter: 0, 516 } 517 }