github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/builder.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package planbuilder
    16  
    17  import (
    18  	"strings"
    19  	"sync"
    20  
    21  	querypb "github.com/dolthub/vitess/go/vt/proto/query"
    22  	ast "github.com/dolthub/vitess/go/vt/sqlparser"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/binlogreplication"
    26  	"github.com/dolthub/go-mysql-server/sql/expression"
    27  	"github.com/dolthub/go-mysql-server/sql/plan"
    28  	"github.com/dolthub/go-mysql-server/sql/transform"
    29  )
    30  
    31  var BinderFactory = &sync.Pool{New: func() interface{} {
    32  	return &Builder{f: &factory{}}
    33  }}
    34  
    35  type Builder struct {
    36  	ctx             *sql.Context
    37  	cat             sql.Catalog
    38  	parserOpts      ast.ParserOptions
    39  	f               *factory
    40  	currentDatabase sql.Database
    41  	colId           columnId
    42  	tabId           sql.TableId
    43  	multiDDL        bool
    44  	viewCtx         *ViewContext
    45  	procCtx         *ProcContext
    46  	triggerCtx      *TriggerContext
    47  	bindCtx         *BindvarContext
    48  	insertActive    bool
    49  	nesting         int
    50  }
    51  
    52  // BindvarContext holds bind variable replacement literals.
    53  type BindvarContext struct {
    54  	Bindings map[string]*querypb.BindVariable
    55  	used     map[string]struct{}
    56  	// resolveOnly indicates that we are resolving plan names,
    57  	// but will not error for missing bindvar replacements.
    58  	resolveOnly bool
    59  }
    60  
    61  func (bv *BindvarContext) GetSubstitute(s string) (*querypb.BindVariable, bool) {
    62  	if bv.Bindings != nil {
    63  		ret, ok := bv.Bindings[s]
    64  		bv.used[s] = struct{}{}
    65  		return ret, ok
    66  	}
    67  	return nil, false
    68  }
    69  
    70  func (bv *BindvarContext) UnusedBindings() []string {
    71  	if len(bv.used) == len(bv.Bindings) {
    72  		return nil
    73  	}
    74  	var unused []string
    75  	for k, _ := range bv.Bindings {
    76  		if _, ok := bv.used[k]; !ok {
    77  			unused = append(unused, k)
    78  		}
    79  	}
    80  	return unused
    81  }
    82  
    83  // ViewContext overwrites database root source of nested
    84  // calls.
    85  type ViewContext struct {
    86  	AsOf   interface{}
    87  	DbName string
    88  }
    89  
    90  type TriggerContext struct {
    91  	Active           bool
    92  	Call             bool
    93  	UnresolvedTables []string
    94  	ResolveErr       error
    95  }
    96  
    97  // ProcContext allows nested CALLs to use the same database for resolving
    98  // procedure definitions without changing the underlying database roots.
    99  type ProcContext struct {
   100  	AsOf   interface{}
   101  	DbName string
   102  }
   103  
   104  func New(ctx *sql.Context, cat sql.Catalog) *Builder {
   105  	sqlMode := sql.LoadSqlMode(ctx)
   106  	return &Builder{ctx: ctx, cat: cat, parserOpts: sqlMode.ParserOptions(), f: &factory{}}
   107  }
   108  
   109  func (b *Builder) Initialize(ctx *sql.Context, cat sql.Catalog, opts ast.ParserOptions) {
   110  	b.ctx = ctx
   111  	b.cat = cat
   112  	b.f.ctx = ctx
   113  	b.parserOpts = opts
   114  }
   115  
   116  func (b *Builder) SetDebug(val bool) {
   117  	b.f.debug = val
   118  }
   119  
   120  func (b *Builder) SetBindings(bindings map[string]*querypb.BindVariable) {
   121  	b.bindCtx = &BindvarContext{
   122  		Bindings: bindings,
   123  		used:     make(map[string]struct{}),
   124  	}
   125  }
   126  
   127  func (b *Builder) SetParserOptions(opts ast.ParserOptions) {
   128  	b.parserOpts = opts
   129  }
   130  
   131  func (b *Builder) BindCtx() *BindvarContext {
   132  	return b.bindCtx
   133  }
   134  
   135  func (b *Builder) ViewCtx() *ViewContext {
   136  	if b.viewCtx == nil {
   137  		b.viewCtx = &ViewContext{}
   138  	}
   139  	return b.viewCtx
   140  }
   141  
   142  func (b *Builder) ProcCtx() *ProcContext {
   143  	if b.procCtx == nil {
   144  		b.procCtx = &ProcContext{}
   145  	}
   146  	return b.procCtx
   147  }
   148  
   149  func (b *Builder) TriggerCtx() *TriggerContext {
   150  	if b.triggerCtx == nil {
   151  		b.triggerCtx = &TriggerContext{}
   152  	}
   153  	return b.triggerCtx
   154  }
   155  
   156  func (b *Builder) newScope() *scope {
   157  	return &scope{b: b}
   158  }
   159  
   160  func (b *Builder) Reset() {
   161  	b.colId = 0
   162  	b.tabId = 0
   163  	b.bindCtx = nil
   164  	b.currentDatabase = nil
   165  	b.procCtx = nil
   166  	b.multiDDL = false
   167  	b.insertActive = false
   168  	b.triggerCtx = nil
   169  	b.viewCtx = nil
   170  	b.nesting = 0
   171  }
   172  
   173  type parseErr struct {
   174  	err error
   175  }
   176  
   177  func (b *Builder) handleErr(err error) {
   178  	panic(parseErr{err})
   179  }
   180  
   181  func (b *Builder) build(inScope *scope, stmt ast.Statement, query string) (outScope *scope) {
   182  	if inScope == nil {
   183  		inScope = b.newScope()
   184  	}
   185  	switch n := stmt.(type) {
   186  	default:
   187  		b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(n)))
   188  	case ast.SelectStatement:
   189  		outScope = b.buildSelectStmt(inScope, n)
   190  		if into := n.GetInto(); into != nil {
   191  			b.buildInto(outScope, into)
   192  		}
   193  		return outScope
   194  	case *ast.Analyze:
   195  		return b.buildAnalyze(inScope, n, query)
   196  	case *ast.CreateSpatialRefSys:
   197  		return b.buildCreateSpatialRefSys(inScope, n)
   198  	case *ast.Show:
   199  		// When a query is empty it means it comes from a subquery, as we don't
   200  		// have the query itself in a subquery. Hence, a SHOW could not be
   201  		// parsed.
   202  		if query == "" {
   203  			b.handleErr(sql.ErrUnsupportedFeature.New("SHOW in subquery"))
   204  		}
   205  		return b.buildShow(inScope, n)
   206  	case *ast.DDL:
   207  		return b.buildDDL(inScope, query, n)
   208  	case *ast.AlterTable:
   209  		return b.buildAlterTable(inScope, query, n)
   210  	case *ast.DBDDL:
   211  		return b.buildDBDDL(inScope, n)
   212  	case *ast.Explain:
   213  		return b.buildExplain(inScope, n)
   214  	case *ast.Insert:
   215  		if n.With != nil {
   216  			cteScope := b.buildWith(inScope, n.With)
   217  			return b.buildInsert(cteScope, n)
   218  		}
   219  		return b.buildInsert(inScope, n)
   220  	case *ast.Delete:
   221  		if n.With != nil {
   222  			cteScope := b.buildWith(inScope, n.With)
   223  			return b.buildDelete(cteScope, n)
   224  		}
   225  		return b.buildDelete(inScope, n)
   226  	case *ast.Update:
   227  		if n.With != nil {
   228  			cteScope := b.buildWith(inScope, n.With)
   229  			return b.buildUpdate(cteScope, n)
   230  		}
   231  		return b.buildUpdate(inScope, n)
   232  	case *ast.Load:
   233  		return b.buildLoad(inScope, n)
   234  	case *ast.Set:
   235  		return b.buildSet(inScope, n)
   236  	case *ast.Use:
   237  		return b.buildUse(inScope, n)
   238  	case *ast.Begin:
   239  		outScope = inScope.push()
   240  		transChar := sql.ReadWrite
   241  		if n.TransactionCharacteristic == ast.TxReadOnly {
   242  			transChar = sql.ReadOnly
   243  		}
   244  
   245  		outScope.node = plan.NewStartTransaction(transChar)
   246  	case *ast.Commit:
   247  		outScope = inScope.push()
   248  		outScope.node = plan.NewCommit()
   249  	case *ast.Rollback:
   250  		outScope = inScope.push()
   251  		outScope.node = plan.NewRollback()
   252  	case *ast.Savepoint:
   253  		outScope = inScope.push()
   254  		outScope.node = plan.NewCreateSavepoint(n.Identifier)
   255  	case *ast.RollbackSavepoint:
   256  		outScope = inScope.push()
   257  		outScope.node = plan.NewRollbackSavepoint(n.Identifier)
   258  	case *ast.ReleaseSavepoint:
   259  		outScope = inScope.push()
   260  		outScope.node = plan.NewReleaseSavepoint(n.Identifier)
   261  	case *ast.ChangeReplicationSource:
   262  		return b.buildChangeReplicationSource(inScope, n)
   263  	case *ast.ChangeReplicationFilter:
   264  		return b.buildChangeReplicationFilter(inScope, n)
   265  	case *ast.StartReplica:
   266  		outScope = inScope.push()
   267  		startRep := plan.NewStartReplica()
   268  		if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.IsBinlogReplicaCatalog() {
   269  			startRep.ReplicaController = binCat.GetBinlogReplicaController()
   270  		}
   271  		outScope.node = startRep
   272  	case *ast.StopReplica:
   273  		outScope = inScope.push()
   274  		stopRep := plan.NewStopReplica()
   275  		if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.IsBinlogReplicaCatalog() {
   276  			stopRep.ReplicaController = binCat.GetBinlogReplicaController()
   277  		}
   278  		outScope.node = stopRep
   279  	case *ast.ResetReplica:
   280  		outScope = inScope.push()
   281  		resetRep := plan.NewResetReplica(n.All)
   282  		if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.IsBinlogReplicaCatalog() {
   283  			resetRep.ReplicaController = binCat.GetBinlogReplicaController()
   284  		}
   285  		outScope.node = resetRep
   286  	case *ast.BeginEndBlock:
   287  		return b.buildBeginEndBlock(inScope, n)
   288  	case *ast.IfStatement:
   289  		return b.buildIfBlock(inScope, n)
   290  	case *ast.CaseStatement:
   291  		return b.buildCaseStatement(inScope, n)
   292  	case *ast.Call:
   293  		return b.buildCall(inScope, n)
   294  	case *ast.Declare:
   295  		return b.buildDeclare(inScope, n, query)
   296  	case *ast.FetchCursor:
   297  		return b.buildFetchCursor(inScope, n)
   298  	case *ast.OpenCursor:
   299  		return b.buildOpenCursor(inScope, n)
   300  	case *ast.CloseCursor:
   301  		return b.buildCloseCursor(inScope, n)
   302  	case *ast.Loop:
   303  		return b.buildLoop(inScope, n)
   304  	case *ast.Repeat:
   305  		return b.buildRepeat(inScope, n)
   306  	case *ast.While:
   307  		return b.buildWhile(inScope, n)
   308  	case *ast.Leave:
   309  		return b.buildLeave(inScope, n)
   310  	case *ast.Iterate:
   311  		return b.buildIterate(inScope, n)
   312  	case *ast.Kill:
   313  		return b.buildKill(inScope, n)
   314  	case *ast.Signal:
   315  		return b.buildSignal(inScope, n)
   316  	case *ast.LockTables:
   317  		return b.buildLockTables(inScope, n)
   318  	case *ast.UnlockTables:
   319  		return b.buildUnlockTables(inScope, n)
   320  	case *ast.CreateUser:
   321  		return b.buildCreateUser(inScope, n)
   322  	case *ast.RenameUser:
   323  		return b.buildRenameUser(inScope, n)
   324  	case *ast.DropUser:
   325  		return b.buildDropUser(inScope, n)
   326  	case *ast.CreateRole:
   327  		return b.buildCreateRole(inScope, n)
   328  	case *ast.DropRole:
   329  		return b.buildDropRole(inScope, n)
   330  	case *ast.GrantPrivilege:
   331  		return b.buildGrantPrivilege(inScope, n)
   332  	case *ast.GrantRole:
   333  		return b.buildGrantRole(inScope, n)
   334  	case *ast.GrantProxy:
   335  		return b.buildGrantProxy(inScope, n)
   336  	case *ast.RevokePrivilege:
   337  		return b.buildRevokePrivilege(inScope, n)
   338  	case *ast.RevokeAllPrivileges:
   339  		return b.buildRevokeAllPrivileges(inScope, n)
   340  	case *ast.RevokeRole:
   341  		return b.buildRevokeRole(inScope, n)
   342  	case *ast.RevokeProxy:
   343  		return b.buildRevokeProxy(inScope, n)
   344  	case *ast.ShowGrants:
   345  		return b.buildShowGrants(inScope, n)
   346  	case *ast.ShowPrivileges:
   347  		return b.buildShowPrivileges(inScope, n)
   348  	case *ast.Flush:
   349  		return b.buildFlush(inScope, n)
   350  	case *ast.Prepare:
   351  		return b.buildPrepare(inScope, n)
   352  	case *ast.Execute:
   353  		return b.buildExecute(inScope, n)
   354  	case *ast.Deallocate:
   355  		return b.buildDeallocate(inScope, n)
   356  	}
   357  	return
   358  }
   359  
   360  // buildVirtualTableScan returns a ProjectNode for a table that has virtual columns, projecting the values of any
   361  // generated columns
   362  func (b *Builder) buildVirtualTableScan(db string, tab sql.Table) *plan.VirtualColumnTable {
   363  	tableScope := b.newScope()
   364  	schema := tab.Schema()
   365  	for _, c := range schema {
   366  		tableScope.newColumn(scopeColumn{
   367  			table:       strings.ToLower(tab.Name()),
   368  			db:          strings.ToLower(db),
   369  			col:         strings.ToLower(c.Name),
   370  			originalCol: c.Name,
   371  			typ:         c.Type,
   372  			nullable:    c.Nullable,
   373  		})
   374  	}
   375  
   376  	tableId := tableScope.tables[strings.ToLower(tab.Name())]
   377  	projections := make([]sql.Expression, len(schema))
   378  	for i, c := range schema {
   379  		if !c.Virtual {
   380  			projections[i] = expression.NewGetFieldWithTable(i+1, int(tableId), c.Type, db, tab.Name(), c.Name, c.Nullable)
   381  		} else {
   382  			projections[i] = b.resolveColumnDefaultExpression(tableScope, c, c.Generated)
   383  		}
   384  	}
   385  
   386  	// Unlike other kinds of nodes, the projection on this table wrapper is invisible to the analyzer, so we need to
   387  	// get the column indexes correct here, they won't be fixed later like other kinds of expressions.
   388  	for i, p := range projections {
   389  		projections[i] = assignColumnIndexes(p, schema)
   390  	}
   391  
   392  	return plan.NewVirtualColumnTable(tab, projections)
   393  }
   394  
   395  // assignColumnIndexes fixes the column indexes in the expression to match the schema given
   396  func assignColumnIndexes(e sql.Expression, schema sql.Schema) sql.Expression {
   397  	e, _, _ = transform.Expr(e, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   398  		if gf, ok := e.(*expression.GetField); ok {
   399  			idx := schema.IndexOfColName(gf.Name())
   400  			return gf.WithIndex(idx), transform.NewTree, nil
   401  		}
   402  		return e, transform.SameTree, nil
   403  	})
   404  	return e
   405  }