github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/builder.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package planbuilder 16 17 import ( 18 "strings" 19 "sync" 20 21 querypb "github.com/dolthub/vitess/go/vt/proto/query" 22 ast "github.com/dolthub/vitess/go/vt/sqlparser" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/binlogreplication" 26 "github.com/dolthub/go-mysql-server/sql/expression" 27 "github.com/dolthub/go-mysql-server/sql/plan" 28 "github.com/dolthub/go-mysql-server/sql/transform" 29 ) 30 31 var BinderFactory = &sync.Pool{New: func() interface{} { 32 return &Builder{f: &factory{}} 33 }} 34 35 type Builder struct { 36 ctx *sql.Context 37 cat sql.Catalog 38 parserOpts ast.ParserOptions 39 f *factory 40 currentDatabase sql.Database 41 colId columnId 42 tabId sql.TableId 43 multiDDL bool 44 viewCtx *ViewContext 45 procCtx *ProcContext 46 triggerCtx *TriggerContext 47 bindCtx *BindvarContext 48 insertActive bool 49 nesting int 50 } 51 52 // BindvarContext holds bind variable replacement literals. 53 type BindvarContext struct { 54 Bindings map[string]*querypb.BindVariable 55 used map[string]struct{} 56 // resolveOnly indicates that we are resolving plan names, 57 // but will not error for missing bindvar replacements. 58 resolveOnly bool 59 } 60 61 func (bv *BindvarContext) GetSubstitute(s string) (*querypb.BindVariable, bool) { 62 if bv.Bindings != nil { 63 ret, ok := bv.Bindings[s] 64 bv.used[s] = struct{}{} 65 return ret, ok 66 } 67 return nil, false 68 } 69 70 func (bv *BindvarContext) UnusedBindings() []string { 71 if len(bv.used) == len(bv.Bindings) { 72 return nil 73 } 74 var unused []string 75 for k, _ := range bv.Bindings { 76 if _, ok := bv.used[k]; !ok { 77 unused = append(unused, k) 78 } 79 } 80 return unused 81 } 82 83 // ViewContext overwrites database root source of nested 84 // calls. 85 type ViewContext struct { 86 AsOf interface{} 87 DbName string 88 } 89 90 type TriggerContext struct { 91 Active bool 92 Call bool 93 UnresolvedTables []string 94 ResolveErr error 95 } 96 97 // ProcContext allows nested CALLs to use the same database for resolving 98 // procedure definitions without changing the underlying database roots. 99 type ProcContext struct { 100 AsOf interface{} 101 DbName string 102 } 103 104 func New(ctx *sql.Context, cat sql.Catalog) *Builder { 105 sqlMode := sql.LoadSqlMode(ctx) 106 return &Builder{ctx: ctx, cat: cat, parserOpts: sqlMode.ParserOptions(), f: &factory{}} 107 } 108 109 func (b *Builder) Initialize(ctx *sql.Context, cat sql.Catalog, opts ast.ParserOptions) { 110 b.ctx = ctx 111 b.cat = cat 112 b.f.ctx = ctx 113 b.parserOpts = opts 114 } 115 116 func (b *Builder) SetDebug(val bool) { 117 b.f.debug = val 118 } 119 120 func (b *Builder) SetBindings(bindings map[string]*querypb.BindVariable) { 121 b.bindCtx = &BindvarContext{ 122 Bindings: bindings, 123 used: make(map[string]struct{}), 124 } 125 } 126 127 func (b *Builder) SetParserOptions(opts ast.ParserOptions) { 128 b.parserOpts = opts 129 } 130 131 func (b *Builder) BindCtx() *BindvarContext { 132 return b.bindCtx 133 } 134 135 func (b *Builder) ViewCtx() *ViewContext { 136 if b.viewCtx == nil { 137 b.viewCtx = &ViewContext{} 138 } 139 return b.viewCtx 140 } 141 142 func (b *Builder) ProcCtx() *ProcContext { 143 if b.procCtx == nil { 144 b.procCtx = &ProcContext{} 145 } 146 return b.procCtx 147 } 148 149 func (b *Builder) TriggerCtx() *TriggerContext { 150 if b.triggerCtx == nil { 151 b.triggerCtx = &TriggerContext{} 152 } 153 return b.triggerCtx 154 } 155 156 func (b *Builder) newScope() *scope { 157 return &scope{b: b} 158 } 159 160 func (b *Builder) Reset() { 161 b.colId = 0 162 b.tabId = 0 163 b.bindCtx = nil 164 b.currentDatabase = nil 165 b.procCtx = nil 166 b.multiDDL = false 167 b.insertActive = false 168 b.triggerCtx = nil 169 b.viewCtx = nil 170 b.nesting = 0 171 } 172 173 type parseErr struct { 174 err error 175 } 176 177 func (b *Builder) handleErr(err error) { 178 panic(parseErr{err}) 179 } 180 181 func (b *Builder) build(inScope *scope, stmt ast.Statement, query string) (outScope *scope) { 182 if inScope == nil { 183 inScope = b.newScope() 184 } 185 switch n := stmt.(type) { 186 default: 187 b.handleErr(sql.ErrUnsupportedSyntax.New(ast.String(n))) 188 case ast.SelectStatement: 189 outScope = b.buildSelectStmt(inScope, n) 190 if into := n.GetInto(); into != nil { 191 b.buildInto(outScope, into) 192 } 193 return outScope 194 case *ast.Analyze: 195 return b.buildAnalyze(inScope, n, query) 196 case *ast.CreateSpatialRefSys: 197 return b.buildCreateSpatialRefSys(inScope, n) 198 case *ast.Show: 199 // When a query is empty it means it comes from a subquery, as we don't 200 // have the query itself in a subquery. Hence, a SHOW could not be 201 // parsed. 202 if query == "" { 203 b.handleErr(sql.ErrUnsupportedFeature.New("SHOW in subquery")) 204 } 205 return b.buildShow(inScope, n) 206 case *ast.DDL: 207 return b.buildDDL(inScope, query, n) 208 case *ast.AlterTable: 209 return b.buildAlterTable(inScope, query, n) 210 case *ast.DBDDL: 211 return b.buildDBDDL(inScope, n) 212 case *ast.Explain: 213 return b.buildExplain(inScope, n) 214 case *ast.Insert: 215 if n.With != nil { 216 cteScope := b.buildWith(inScope, n.With) 217 return b.buildInsert(cteScope, n) 218 } 219 return b.buildInsert(inScope, n) 220 case *ast.Delete: 221 if n.With != nil { 222 cteScope := b.buildWith(inScope, n.With) 223 return b.buildDelete(cteScope, n) 224 } 225 return b.buildDelete(inScope, n) 226 case *ast.Update: 227 if n.With != nil { 228 cteScope := b.buildWith(inScope, n.With) 229 return b.buildUpdate(cteScope, n) 230 } 231 return b.buildUpdate(inScope, n) 232 case *ast.Load: 233 return b.buildLoad(inScope, n) 234 case *ast.Set: 235 return b.buildSet(inScope, n) 236 case *ast.Use: 237 return b.buildUse(inScope, n) 238 case *ast.Begin: 239 outScope = inScope.push() 240 transChar := sql.ReadWrite 241 if n.TransactionCharacteristic == ast.TxReadOnly { 242 transChar = sql.ReadOnly 243 } 244 245 outScope.node = plan.NewStartTransaction(transChar) 246 case *ast.Commit: 247 outScope = inScope.push() 248 outScope.node = plan.NewCommit() 249 case *ast.Rollback: 250 outScope = inScope.push() 251 outScope.node = plan.NewRollback() 252 case *ast.Savepoint: 253 outScope = inScope.push() 254 outScope.node = plan.NewCreateSavepoint(n.Identifier) 255 case *ast.RollbackSavepoint: 256 outScope = inScope.push() 257 outScope.node = plan.NewRollbackSavepoint(n.Identifier) 258 case *ast.ReleaseSavepoint: 259 outScope = inScope.push() 260 outScope.node = plan.NewReleaseSavepoint(n.Identifier) 261 case *ast.ChangeReplicationSource: 262 return b.buildChangeReplicationSource(inScope, n) 263 case *ast.ChangeReplicationFilter: 264 return b.buildChangeReplicationFilter(inScope, n) 265 case *ast.StartReplica: 266 outScope = inScope.push() 267 startRep := plan.NewStartReplica() 268 if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.IsBinlogReplicaCatalog() { 269 startRep.ReplicaController = binCat.GetBinlogReplicaController() 270 } 271 outScope.node = startRep 272 case *ast.StopReplica: 273 outScope = inScope.push() 274 stopRep := plan.NewStopReplica() 275 if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.IsBinlogReplicaCatalog() { 276 stopRep.ReplicaController = binCat.GetBinlogReplicaController() 277 } 278 outScope.node = stopRep 279 case *ast.ResetReplica: 280 outScope = inScope.push() 281 resetRep := plan.NewResetReplica(n.All) 282 if binCat, ok := b.cat.(binlogreplication.BinlogReplicaCatalog); ok && binCat.IsBinlogReplicaCatalog() { 283 resetRep.ReplicaController = binCat.GetBinlogReplicaController() 284 } 285 outScope.node = resetRep 286 case *ast.BeginEndBlock: 287 return b.buildBeginEndBlock(inScope, n) 288 case *ast.IfStatement: 289 return b.buildIfBlock(inScope, n) 290 case *ast.CaseStatement: 291 return b.buildCaseStatement(inScope, n) 292 case *ast.Call: 293 return b.buildCall(inScope, n) 294 case *ast.Declare: 295 return b.buildDeclare(inScope, n, query) 296 case *ast.FetchCursor: 297 return b.buildFetchCursor(inScope, n) 298 case *ast.OpenCursor: 299 return b.buildOpenCursor(inScope, n) 300 case *ast.CloseCursor: 301 return b.buildCloseCursor(inScope, n) 302 case *ast.Loop: 303 return b.buildLoop(inScope, n) 304 case *ast.Repeat: 305 return b.buildRepeat(inScope, n) 306 case *ast.While: 307 return b.buildWhile(inScope, n) 308 case *ast.Leave: 309 return b.buildLeave(inScope, n) 310 case *ast.Iterate: 311 return b.buildIterate(inScope, n) 312 case *ast.Kill: 313 return b.buildKill(inScope, n) 314 case *ast.Signal: 315 return b.buildSignal(inScope, n) 316 case *ast.LockTables: 317 return b.buildLockTables(inScope, n) 318 case *ast.UnlockTables: 319 return b.buildUnlockTables(inScope, n) 320 case *ast.CreateUser: 321 return b.buildCreateUser(inScope, n) 322 case *ast.RenameUser: 323 return b.buildRenameUser(inScope, n) 324 case *ast.DropUser: 325 return b.buildDropUser(inScope, n) 326 case *ast.CreateRole: 327 return b.buildCreateRole(inScope, n) 328 case *ast.DropRole: 329 return b.buildDropRole(inScope, n) 330 case *ast.GrantPrivilege: 331 return b.buildGrantPrivilege(inScope, n) 332 case *ast.GrantRole: 333 return b.buildGrantRole(inScope, n) 334 case *ast.GrantProxy: 335 return b.buildGrantProxy(inScope, n) 336 case *ast.RevokePrivilege: 337 return b.buildRevokePrivilege(inScope, n) 338 case *ast.RevokeAllPrivileges: 339 return b.buildRevokeAllPrivileges(inScope, n) 340 case *ast.RevokeRole: 341 return b.buildRevokeRole(inScope, n) 342 case *ast.RevokeProxy: 343 return b.buildRevokeProxy(inScope, n) 344 case *ast.ShowGrants: 345 return b.buildShowGrants(inScope, n) 346 case *ast.ShowPrivileges: 347 return b.buildShowPrivileges(inScope, n) 348 case *ast.Flush: 349 return b.buildFlush(inScope, n) 350 case *ast.Prepare: 351 return b.buildPrepare(inScope, n) 352 case *ast.Execute: 353 return b.buildExecute(inScope, n) 354 case *ast.Deallocate: 355 return b.buildDeallocate(inScope, n) 356 } 357 return 358 } 359 360 // buildVirtualTableScan returns a ProjectNode for a table that has virtual columns, projecting the values of any 361 // generated columns 362 func (b *Builder) buildVirtualTableScan(db string, tab sql.Table) *plan.VirtualColumnTable { 363 tableScope := b.newScope() 364 schema := tab.Schema() 365 for _, c := range schema { 366 tableScope.newColumn(scopeColumn{ 367 table: strings.ToLower(tab.Name()), 368 db: strings.ToLower(db), 369 col: strings.ToLower(c.Name), 370 originalCol: c.Name, 371 typ: c.Type, 372 nullable: c.Nullable, 373 }) 374 } 375 376 tableId := tableScope.tables[strings.ToLower(tab.Name())] 377 projections := make([]sql.Expression, len(schema)) 378 for i, c := range schema { 379 if !c.Virtual { 380 projections[i] = expression.NewGetFieldWithTable(i+1, int(tableId), c.Type, db, tab.Name(), c.Name, c.Nullable) 381 } else { 382 projections[i] = b.resolveColumnDefaultExpression(tableScope, c, c.Generated) 383 } 384 } 385 386 // Unlike other kinds of nodes, the projection on this table wrapper is invisible to the analyzer, so we need to 387 // get the column indexes correct here, they won't be fixed later like other kinds of expressions. 388 for i, p := range projections { 389 projections[i] = assignColumnIndexes(p, schema) 390 } 391 392 return plan.NewVirtualColumnTable(tab, projections) 393 } 394 395 // assignColumnIndexes fixes the column indexes in the expression to match the schema given 396 func assignColumnIndexes(e sql.Expression, schema sql.Schema) sql.Expression { 397 e, _, _ = transform.Expr(e, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { 398 if gf, ok := e.(*expression.GetField); ok { 399 idx := schema.IndexOfColName(gf.Name()) 400 return gf.WithIndex(idx), transform.NewTree, nil 401 } 402 return e, transform.SameTree, nil 403 }) 404 return e 405 }