github.com/dolthub/go-mysql-server@v0.18.0/sql/session.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sql 16 17 import ( 18 "context" 19 "fmt" 20 "io" 21 "os" 22 "sync" 23 "time" 24 25 "github.com/sirupsen/logrus" 26 "go.opentelemetry.io/otel/attribute" 27 "go.opentelemetry.io/otel/trace" 28 "golang.org/x/sync/errgroup" 29 ) 30 31 type key uint 32 33 const ( 34 // QueryKey to access query in the context. 35 QueryKey key = iota 36 ) 37 38 const ( 39 CurrentDBSessionVar = "current_database" 40 AutoCommitSessionVar = "autocommit" 41 characterSetConnectionSysVarName = "character_set_connection" 42 characterSetResultsSysVarName = "character_set_results" 43 collationConnectionSysVarName = "collation_connection" 44 ) 45 46 var NoopTracer = trace.NewNoopTracerProvider().Tracer("github.com/dolthub/go-mysql-server/sql") 47 var _, noopSpan = NoopTracer.Start(context.Background(), "noop") 48 49 // Client holds session user information. 50 type Client struct { 51 // User of the session. 52 User string 53 // Address of the client. 54 Address string 55 // Capabilities of the client 56 Capabilities uint32 57 } 58 59 // Session holds the session data. 60 type Session interface { 61 // Address of the server. 62 Address() string 63 // Client returns the user of the session. 64 Client() Client 65 // SetClient returns a new session with the given client. 66 SetClient(Client) 67 // SetSessionVariable sets the given system variable to the value given for this session. 68 SetSessionVariable(ctx *Context, sysVarName string, value interface{}) error 69 // InitSessionVariable sets the given system variable to the value given for this session and will allow for 70 // initialization of readonly variables. 71 InitSessionVariable(ctx *Context, sysVarName string, value interface{}) error 72 // SetUserVariable sets the given user variable to the value given for this session, or creates it for this session. 73 SetUserVariable(ctx *Context, varName string, value interface{}, typ Type) error 74 // GetSessionVariable returns this session's value of the system variable with the given name. 75 GetSessionVariable(ctx *Context, sysVarName string) (interface{}, error) 76 // GetUserVariable returns this session's value of the user variable with the given name, along with its most 77 // appropriate type. 78 GetUserVariable(ctx *Context, varName string) (Type, interface{}, error) 79 // GetAllSessionVariables returns a copy of all session variable values. 80 GetAllSessionVariables() map[string]interface{} 81 // GetCurrentDatabase gets the current database for this session 82 GetCurrentDatabase() string 83 // SetCurrentDatabase sets the current database for this session 84 SetCurrentDatabase(dbName string) 85 // UseDatabase notifies sessions that a particular database is now the default DB namespace 86 UseDatabase(ctx *Context, db Database) error 87 // ID returns the unique ID of the connection. 88 ID() uint32 89 // Warn stores the warning in the session. 90 Warn(warn *Warning) 91 // Warnings returns a copy of session warnings (from the most recent). 92 Warnings() []*Warning 93 // ClearWarnings cleans up session warnings. 94 ClearWarnings() 95 // WarningCount returns a number of session warnings 96 WarningCount() uint16 97 // AddLock adds a lock to the set of locks owned by this user which will need to be released if this session terminates 98 AddLock(lockName string) error 99 // DelLock removes a lock from the set of locks owned by this user 100 DelLock(lockName string) error 101 // IterLocks iterates through all locks owned by this user 102 IterLocks(cb func(name string) error) error 103 // SetLastQueryInfo sets session-level query info for the key given, applying to the query just executed. 104 SetLastQueryInfo(key string, value int64) 105 // GetLastQueryInfo returns the session-level query info for the key given, for the query most recently executed. 106 GetLastQueryInfo(key string) int64 107 // GetTransaction returns the active transaction, if any 108 GetTransaction() Transaction 109 // SetTransaction sets the session's transaction 110 SetTransaction(tx Transaction) 111 // SetIgnoreAutoCommit instructs the session to ignore the value of the @@autocommit variable, or consider it again 112 SetIgnoreAutoCommit(ignore bool) 113 // GetIgnoreAutoCommit returns whether this session should ignore the @@autocommit variable 114 GetIgnoreAutoCommit() bool 115 // GetLogger returns the logger for this session, useful if clients want to log messages with the same format / output 116 // as the running server. Clients should instantiate their own global logger with formatting options, and session 117 // implementations should return the logger to be used for the running server. 118 GetLogger() *logrus.Entry 119 // SetLogger sets the logger to use for this session, which will always be an extension of the one returned by 120 // GetLogger, extended with session information 121 SetLogger(*logrus.Entry) 122 // GetIndexRegistry returns the index registry for this session 123 GetIndexRegistry() *IndexRegistry 124 // GetViewRegistry returns the view registry for this session 125 GetViewRegistry() *ViewRegistry 126 // SetIndexRegistry sets the index registry for this session. Integrators should set an index registry in the event 127 // they are using an index driver. 128 SetIndexRegistry(*IndexRegistry) 129 // SetViewRegistry sets the view registry for this session. Integrators should set a view registry if their database 130 // doesn't implement ViewDatabase and they want views created to persist across sessions. 131 SetViewRegistry(*ViewRegistry) 132 // SetConnectionId sets this sessions unique ID 133 SetConnectionId(connId uint32) 134 // GetCharacterSet returns the character set for this session (defined by the system variable `character_set_connection`). 135 GetCharacterSet() CharacterSetID 136 // GetCharacterSetResults returns the result character set for this session (defined by the system variable `character_set_results`). 137 GetCharacterSetResults() CharacterSetID 138 // GetCollation returns the collation for this session (defined by the system variable `collation_connection`). 139 GetCollation() CollationID 140 // GetPrivilegeSet returns the cached privilege set associated with this session, along with its counter. The 141 // PrivilegeSet is only valid when the counter is greater than zero. 142 GetPrivilegeSet() (PrivilegeSet, uint64) 143 // SetPrivilegeSet updates this session's cache with the given counter and privilege set. Setting the counter to a 144 // value of zero will force the cache to reload. This is an internal function and is not intended to be used by 145 // integrators. 146 SetPrivilegeSet(newPs PrivilegeSet, counter uint64) 147 // ValidateSession provides integrators a chance to do any custom validation of this session before any query is 148 // executed in it. For example, Dolt uses this hook to validate that the session's working set is valid. 149 ValidateSession(ctx *Context) error 150 } 151 152 // PersistableSession supports serializing/deserializing global system variables/ 153 type PersistableSession interface { 154 Session 155 // PersistGlobal writes to the persisted global system variables file 156 PersistGlobal(sysVarName string, value interface{}) error 157 // RemovePersistedGlobal deletes a variable from the persisted globals file 158 RemovePersistedGlobal(sysVarName string) error 159 // RemoveAllPersistedGlobals clears the contents of the persisted globals file 160 RemoveAllPersistedGlobals() error 161 // GetPersistedValue returns persisted value for a global system variable 162 GetPersistedValue(k string) (interface{}, error) 163 } 164 165 // TransactionSession can BEGIN, ROLLBACK and COMMIT transactions, as well as create SAVEPOINTS and restore to them. 166 // Transactions can span multiple databases, and integrators must do their own error handling to prevent this if they 167 // cannot support multiple databases in a single transaction. Such integrators can use Session.GetTransactionDatabase 168 // to determine the database that was considered in scope when a transaction began. 169 type TransactionSession interface { 170 Session 171 // StartTransaction starts a new transaction and returns it 172 StartTransaction(ctx *Context, tCharacteristic TransactionCharacteristic) (Transaction, error) 173 // CommitTransaction commits the transaction given 174 CommitTransaction(ctx *Context, tx Transaction) error 175 // Rollback restores the database to the state recorded in the transaction given 176 Rollback(ctx *Context, transaction Transaction) error 177 // CreateSavepoint records a savepoint for the transaction given with the name given. If the name is already in use 178 // for this transaction, the new savepoint replaces the old one. 179 CreateSavepoint(ctx *Context, transaction Transaction, name string) error 180 // RollbackToSavepoint restores the database to the state named by the savepoint 181 RollbackToSavepoint(ctx *Context, transaction Transaction, name string) error 182 // ReleaseSavepoint removes the savepoint named from the transaction given 183 ReleaseSavepoint(ctx *Context, transaction Transaction, name string) error 184 } 185 186 type ( 187 // TypedValue is a value along with its type. 188 TypedValue struct { 189 Typ Type 190 Value interface{} 191 } 192 193 // Warning stands for mySQL warning record. 194 Warning struct { 195 Level string 196 Message string 197 Code int 198 } 199 ) 200 201 const ( 202 RowCount = "row_count" 203 FoundRows = "found_rows" 204 LastInsertId = "last_insert_id" 205 ) 206 207 // Session ID 0 used as invalid SessionID 208 var autoSessionIDs uint32 = 1 209 210 // Context of the query execution. 211 type Context struct { 212 context.Context 213 Session 214 Memory *MemoryManager 215 ProcessList ProcessList 216 services Services 217 pid uint64 218 query string 219 queryTime time.Time 220 tracer trace.Tracer 221 rootSpan trace.Span 222 Version AnalyzerVersion 223 } 224 225 // ContextOption is a function to configure the context. 226 type ContextOption func(*Context) 227 228 // WithSession adds the given session to the context. 229 func WithSession(s Session) ContextOption { 230 return func(ctx *Context) { 231 ctx.Session = s 232 } 233 } 234 235 // WithTracer adds the given tracer to the context. 236 func WithTracer(t trace.Tracer) ContextOption { 237 return func(ctx *Context) { 238 ctx.tracer = t 239 } 240 } 241 242 // WithPid adds the given pid to the context. 243 func WithPid(pid uint64) ContextOption { 244 return func(ctx *Context) { 245 ctx.pid = pid 246 } 247 } 248 249 // WithQuery adds the given query to the context. 250 func WithQuery(q string) ContextOption { 251 return func(ctx *Context) { 252 ctx.query = q 253 } 254 } 255 256 // WithMemoryManager adds the given memory manager to the context. 257 func WithMemoryManager(m *MemoryManager) ContextOption { 258 return func(ctx *Context) { 259 ctx.Memory = m 260 } 261 } 262 263 // WithRootSpan sets the root span of the context. 264 func WithRootSpan(s trace.Span) ContextOption { 265 return func(ctx *Context) { 266 ctx.rootSpan = s 267 } 268 } 269 270 func WithProcessList(p ProcessList) ContextOption { 271 return func(ctx *Context) { 272 ctx.ProcessList = p 273 } 274 } 275 276 // WithServices sets the services for the Context 277 func WithServices(services Services) ContextOption { 278 return func(ctx *Context) { 279 ctx.services = services 280 } 281 } 282 283 var ctxNowFunc = time.Now 284 var ctxNowFuncMutex = &sync.Mutex{} 285 286 func RunWithNowFunc(nowFunc func() time.Time, fn func() error) error { 287 oldNowFunc := swapNowFunc(nowFunc) 288 defer func() { 289 swapNowFunc(oldNowFunc) 290 }() 291 292 return fn() 293 } 294 295 func swapNowFunc(newNowFunc func() time.Time) func() time.Time { 296 ctxNowFuncMutex.Lock() 297 defer ctxNowFuncMutex.Unlock() 298 299 oldNowFunc := ctxNowFunc 300 ctxNowFunc = newNowFunc 301 return oldNowFunc 302 } 303 304 func Now() time.Time { 305 ctxNowFuncMutex.Lock() 306 defer ctxNowFuncMutex.Unlock() 307 308 return ctxNowFunc() 309 } 310 311 // NewContext creates a new query context. Options can be passed to configure 312 // the context. If some aspect of the context is not configure, the default 313 // value will be used. 314 // By default, the context will have an empty base session, a noop tracer and 315 // a memory manager using the process reporter. 316 func NewContext( 317 ctx context.Context, 318 opts ...ContextOption, 319 ) *Context { 320 c := &Context{ 321 Context: ctx, 322 Session: nil, 323 queryTime: ctxNowFunc(), 324 tracer: NoopTracer, 325 } 326 for _, opt := range opts { 327 opt(c) 328 } 329 330 if c.Memory == nil { 331 c.Memory = NewMemoryManager(ProcessMemory) 332 } 333 if c.ProcessList == nil { 334 c.ProcessList = EmptyProcessList{} 335 } 336 if c.Session == nil { 337 c.Session = NewBaseSession() 338 } 339 340 return c 341 } 342 343 // ApplyOpts the options given to the context. Mostly for tests, not safe for use after construction of the context. 344 func (c *Context) ApplyOpts(opts ...ContextOption) { 345 for _, opt := range opts { 346 opt(c) 347 } 348 } 349 350 // NewEmptyContext returns a default context with default values. 351 func NewEmptyContext() *Context { return NewContext(context.TODO()) } 352 353 // Pid returns the process id associated with this context. 354 func (c *Context) Pid() uint64 { 355 if c == nil { 356 return 0 357 } 358 return c.pid 359 } 360 361 // Query returns the query string associated with this context. 362 func (c *Context) Query() string { 363 if c == nil { 364 return "" 365 } 366 return c.query 367 } 368 369 func (c *Context) WithQuery(q string) *Context { 370 if c == nil { 371 return nil 372 } 373 374 nc := *c 375 nc.query = q 376 return &nc 377 } 378 379 // QueryTime returns the time.Time when the context associated with this query was created 380 func (c *Context) QueryTime() time.Time { 381 if c == nil { 382 return time.Time{} 383 } 384 return c.queryTime 385 } 386 387 // SetQueryTime updates the queryTime to the given time 388 func (c *Context) SetQueryTime(t time.Time) { 389 if c == nil { 390 return 391 } 392 c.queryTime = t 393 } 394 395 // Span creates a new tracing span with the given context. 396 // It will return the span and a new context that should be passed to all 397 // children of this span. 398 func (c *Context) Span( 399 opName string, 400 opts ...trace.SpanStartOption, 401 ) (trace.Span, *Context) { 402 if c == nil { 403 return noopSpan, nil 404 } 405 406 if c.tracer == nil || c.tracer == NoopTracer { 407 return noopSpan, c 408 } 409 410 ctx, span := c.tracer.Start(c.Context, opName, opts...) 411 return span, c.WithContext(ctx) 412 } 413 414 // NewSubContext creates a new sub-context with the current context as parent. Returns the resulting context.CancelFunc 415 // as well as the new *sql.Context, which be used to cancel the new context before the parent is finished. 416 func (c *Context) NewSubContext() (*Context, context.CancelFunc) { 417 if c == nil { 418 return nil, nil 419 } 420 421 ctx, cancelFunc := context.WithCancel(c.Context) 422 423 return c.WithContext(ctx), cancelFunc 424 } 425 426 // WithContext returns a new context with the given underlying context. 427 func (c *Context) WithContext(ctx context.Context) *Context { 428 if c == nil { 429 return nil 430 } 431 432 nc := *c 433 nc.Context = ctx 434 return &nc 435 } 436 437 // RootSpan returns the root span, if any. 438 func (c *Context) RootSpan() trace.Span { 439 if c == nil { 440 return noopSpan 441 } 442 return c.rootSpan 443 } 444 445 // Error adds an error as warning to the session. 446 func (c *Context) Error(code int, msg string, args ...interface{}) { 447 if c == nil || c.Session == nil { 448 return 449 } 450 451 c.Session.Warn(&Warning{ 452 Level: "Error", 453 Code: code, 454 Message: fmt.Sprintf(msg, args...), 455 }) 456 } 457 458 // Warn adds a warning to the session. 459 func (c *Context) Warn(code int, msg string, args ...interface{}) { 460 if c == nil || c.Session == nil { 461 return 462 } 463 c.Session.Warn(&Warning{ 464 Level: "Warning", 465 Code: code, 466 Message: fmt.Sprintf(msg, args...), 467 }) 468 } 469 470 // KillConnection terminates the connection associated with |connID|. 471 func (c *Context) KillConnection(connID uint32) error { 472 if c == nil || c.services.KillConnection == nil { 473 return nil 474 } 475 476 if c.services.KillConnection != nil { 477 return c.services.KillConnection(connID) 478 } 479 return nil 480 } 481 482 // LoadInfile loads the remote file |filename| from the client. Returns a |ReadCloser| for 483 // the file's contents. Returns an error if this functionality is not supported. 484 func (c *Context) LoadInfile(filename string) (io.ReadCloser, error) { 485 if c == nil || c.services.LoadInfile == nil { 486 return nil, ErrUnsupportedFeature.New("LOAD DATA LOCAL INFILE ...") 487 } 488 489 if c.services.LoadInfile != nil { 490 return c.services.LoadInfile(filename) 491 } 492 return nil, ErrUnsupportedFeature.New("LOAD DATA LOCAL INFILE ...") 493 } 494 495 func (c *Context) NewErrgroup() (*errgroup.Group, *Context) { 496 if c == nil { 497 return nil, nil 498 } 499 500 eg, egCtx := errgroup.WithContext(c.Context) 501 return eg, c.WithContext(egCtx) 502 } 503 504 // NewCtxWithClient returns a new Context with the given [client] 505 func (c *Context) NewCtxWithClient(client Client) *Context { 506 if c == nil { 507 return nil 508 } 509 510 nc := *c 511 nc.Session.SetClient(client) 512 nc.Session.SetPrivilegeSet(nil, 0) 513 return &nc 514 } 515 516 // Services are handles to optional or plugin functionality that can be 517 // used by the SQL implementation in certain situations. An integrator can set 518 // methods on Services for a given *Context and different parts of go-mysql-server 519 // will inspect it in order to fulfill their implementations. Currently, the 520 // KillConnection service is available. Set these with |WithServices|; the 521 // implementation will access them through the corresponding methods on 522 // *Context, such as |KillConnection|. 523 type Services struct { 524 KillConnection func(connID uint32) error 525 LoadInfile func(filename string) (io.ReadCloser, error) 526 } 527 528 // NewSpanIter creates a RowIter executed in the given span. 529 // Currently inactive, returns the iter returned unaltered. 530 func NewSpanIter(span trace.Span, iter RowIter) RowIter { 531 // In the default, non traced case, we should not bother with 532 // collecting the timings below. 533 if !span.IsRecording() { 534 return iter 535 } else { 536 return &spanIter{ 537 span: span, 538 iter: iter, 539 } 540 } 541 } 542 543 type spanIter struct { 544 span trace.Span 545 iter RowIter 546 count int 547 max time.Duration 548 min time.Duration 549 total time.Duration 550 done bool 551 } 552 553 var _ RowIter = (*spanIter)(nil) 554 555 func (i *spanIter) updateTimings(start time.Time) { 556 elapsed := time.Since(start) 557 if i.max < elapsed { 558 i.max = elapsed 559 } 560 561 if i.min > elapsed || i.min == 0 { 562 i.min = elapsed 563 } 564 565 i.total += elapsed 566 } 567 568 func (i *spanIter) Next(ctx *Context) (Row, error) { 569 start := time.Now() 570 571 row, err := i.iter.Next(ctx) 572 if err == io.EOF { 573 i.finish() 574 return nil, err 575 } 576 577 if err != nil { 578 i.finishWithError(err) 579 return nil, err 580 } 581 582 i.count++ 583 i.updateTimings(start) 584 return row, nil 585 } 586 587 func (i *spanIter) finish() { 588 var avg time.Duration 589 if i.count > 0 { 590 avg = i.total / time.Duration(i.count) 591 } 592 593 i.span.AddEvent("finish", trace.WithAttributes( 594 attribute.Int("rows", i.count), 595 attribute.Stringer("total_time", i.total), 596 attribute.Stringer("max_time", i.max), 597 attribute.Stringer("min_time", i.min), 598 attribute.Stringer("avg_time", avg), 599 )) 600 i.span.End() 601 i.done = true 602 } 603 604 func (i *spanIter) finishWithError(err error) { 605 var avg time.Duration 606 if i.count > 0 { 607 avg = i.total / time.Duration(i.count) 608 } 609 610 i.span.RecordError(err) 611 i.span.AddEvent("finish", trace.WithAttributes( 612 attribute.Int("rows", i.count), 613 attribute.Stringer("total_time", i.total), 614 attribute.Stringer("max_time", i.max), 615 attribute.Stringer("min_time", i.min), 616 attribute.Stringer("avg_time", avg), 617 )) 618 i.span.End() 619 i.done = true 620 } 621 622 func (i *spanIter) Close(ctx *Context) error { 623 if !i.done { 624 i.finish() 625 } 626 return i.iter.Close(ctx) 627 } 628 629 func defaultLastQueryInfo() map[string]int64 { 630 return map[string]int64{ 631 RowCount: 0, 632 FoundRows: 1, // this is kind of a hack -- it handles the case of `select found_rows()` before any select statement is issued 633 LastInsertId: 0, 634 } 635 } 636 637 // cc: https://dev.mysql.com/doc/refman/8.0/en/temporary-files.html 638 func GetTmpdirSessionVar() string { 639 ret := os.Getenv("TMPDIR") 640 if ret != "" { 641 return ret 642 } 643 644 ret = os.Getenv("TEMP") 645 if ret != "" { 646 return ret 647 } 648 649 ret = os.Getenv("TMP") 650 if ret != "" { 651 return ret 652 } 653 654 return "" 655 } 656 657 // HasDefaultValue checks if session variable value is the default one. 658 func HasDefaultValue(ctx *Context, s Session, key string) (bool, interface{}) { 659 val, err := s.GetSessionVariable(ctx, key) 660 if err == nil { 661 sysVar, _, ok := SystemVariables.GetGlobal(key) 662 if ok { 663 return sysVar.Default == val, val 664 } 665 } 666 return true, nil 667 } 668 669 type AnalyzerVersion uint8 670 671 const ( 672 VersionUnknown AnalyzerVersion = iota 673 VersionStable 674 VersionExperimental 675 )