github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/common.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"github.com/dolthub/go-mysql-server/sql"
    19  	"github.com/dolthub/go-mysql-server/sql/mysql_db"
    20  	"github.com/dolthub/go-mysql-server/sql/transform"
    21  )
    22  
    23  // IsUnary returns whether the node is unary or not.
    24  func IsUnary(node sql.Node) bool {
    25  	return len(node.Children()) == 1
    26  }
    27  
    28  // IsBinary returns whether the node is binary or not.
    29  func IsBinary(node sql.Node) bool {
    30  	return len(node.Children()) == 2
    31  }
    32  
    33  // NillaryNode is a node with no children. This is a common WithChildren implementation for all nodes that have none.
    34  func NillaryWithChildren(node sql.Node, children ...sql.Node) (sql.Node, error) {
    35  	if len(children) != 0 {
    36  		return nil, sql.ErrInvalidChildrenNumber.New(node, len(children), 0)
    37  	}
    38  	return node, nil
    39  }
    40  
    41  // UnaryNode is a node that has only one child.
    42  type UnaryNode struct {
    43  	Child sql.Node
    44  }
    45  
    46  // Schema implements the Node interface.
    47  func (n *UnaryNode) Schema() sql.Schema {
    48  	return n.Child.Schema()
    49  }
    50  
    51  // Resolved implements the Resolvable interface.
    52  func (n UnaryNode) Resolved() bool {
    53  	return n.Child.Resolved()
    54  }
    55  
    56  // Children implements the Node interface.
    57  func (n UnaryNode) Children() []sql.Node {
    58  	return []sql.Node{n.Child}
    59  }
    60  
    61  // BinaryNode is a node with two children.
    62  type BinaryNode struct {
    63  	left  sql.Node
    64  	right sql.Node
    65  }
    66  
    67  func (n BinaryNode) Left() sql.Node {
    68  	return n.left
    69  }
    70  
    71  func (n BinaryNode) Right() sql.Node {
    72  	return n.right
    73  }
    74  
    75  // Children implements the Node interface.
    76  func (n BinaryNode) Children() []sql.Node {
    77  	return []sql.Node{n.left, n.right}
    78  }
    79  
    80  // Resolved implements the Resolvable interface.
    81  func (n BinaryNode) Resolved() bool {
    82  	return n.left.Resolved() && n.right.Resolved()
    83  }
    84  
    85  // BlockRowIter is an iterator that produces rows. It is an extended interface over RowIter. This is primarily used
    86  // by block statements. In order to track the schema of a sql.RowIter from nested blocks, this extended row iter returns
    87  // the relevant information inside of the iter itself. In addition, the most specific top-level Node for that iter is
    88  // returned, as stored procedures use that Node to determine whether the iter represents a SELECT statement.
    89  type BlockRowIter interface {
    90  	sql.RowIter
    91  	// RepresentingNode returns the Node that most directly represents this RowIter. For example, in the case of
    92  	// an IF/ELSE block, the RowIter represents the Node where the condition evaluated to true.
    93  	RepresentingNode() sql.Node
    94  	// Schema returns the schema of this RowIter.
    95  	Schema() sql.Schema
    96  }
    97  
    98  // NodeRepresentsSelect attempts to walk a sql.Node to determine if it represents a SELECT statement.
    99  func NodeRepresentsSelect(s sql.Node) bool {
   100  	if s == nil {
   101  		return false
   102  	}
   103  	isSelect := false
   104  	// All SELECT statements, including those that do not specify a table (using "dual"), have a TableNode.
   105  	transform.Inspect(s, func(node sql.Node) bool {
   106  		switch node.(type) {
   107  		case *AlterAutoIncrement, *AlterIndex, *CreateForeignKey, *CreateIndex, *CreateTable, *CreateTrigger,
   108  			*DeleteFrom, *DropForeignKey, *InsertInto, *ShowCreateTable, *ShowIndexes, *Truncate, *Update, *Into:
   109  			return false
   110  		case *ResolvedTable, *ProcedureResolvedTable:
   111  			isSelect = true
   112  			return false
   113  		default:
   114  			return true
   115  		}
   116  	})
   117  	return isSelect
   118  }
   119  
   120  // getTableName attempts to fetch the table name from the node. If not found directly on the node, searches the
   121  // children. Returns the first table name found, regardless of whether there are more, therefore this is only intended
   122  // to be used in situations where only a single table is expected to be found.
   123  func getTableName(nodeToSearch sql.Node) string {
   124  	nodeStack := []sql.Node{nodeToSearch}
   125  	for len(nodeStack) > 0 {
   126  		node := nodeStack[len(nodeStack)-1]
   127  		nodeStack = nodeStack[:len(nodeStack)-1]
   128  		switch n := node.(type) {
   129  		case *TableAlias:
   130  			if n.UnaryNode != nil {
   131  				nodeStack = append(nodeStack, n.UnaryNode.Child)
   132  				continue
   133  			}
   134  		case sql.TableNode:
   135  			return n.UnderlyingTable().Name()
   136  		case *UnresolvedTable:
   137  			return n.name
   138  		case *IndexedTableAccess:
   139  			return n.Name()
   140  		case sql.TableWrapper:
   141  			return n.Underlying().Name()
   142  		}
   143  		nodeStack = append(nodeStack, node.Children()...)
   144  	}
   145  	return ""
   146  }
   147  
   148  // GetDatabaseName attempts to fetch the database name from the node. If not found directly on the node, searches the
   149  // children. Returns the first database name found, regardless of whether there are more, therefore this is only
   150  // intended to be used in situations where only a single database is expected to be found. Unlike how tables are handled
   151  // in most nodes, databases may be stored as a string field therefore there will be situations where a database name
   152  // exists on a node, but cannot be found through inspection.
   153  func GetDatabaseName(nodeToSearch sql.Node) string {
   154  	nodeStack := []sql.Node{nodeToSearch}
   155  	for len(nodeStack) > 0 {
   156  		node := nodeStack[len(nodeStack)-1]
   157  		nodeStack = nodeStack[:len(nodeStack)-1]
   158  		switch n := node.(type) {
   159  		case sql.Databaser:
   160  			return n.Database().Name()
   161  		case *ResolvedTable:
   162  			return n.SqlDatabase.Name()
   163  		case *UnresolvedTable:
   164  			return n.Database().Name()
   165  		case *IndexedTableAccess:
   166  			return n.Database().Name()
   167  		}
   168  		nodeStack = append(nodeStack, node.Children()...)
   169  	}
   170  	return ""
   171  }
   172  
   173  // CheckPrivilegeNameForDatabase returns the name of the database to check privileges for, which may not be the result
   174  // of db.Name()
   175  func CheckPrivilegeNameForDatabase(db sql.Database) string {
   176  	if db == nil {
   177  		return ""
   178  	}
   179  
   180  	checkDbName := db.Name()
   181  	if pdb, ok := db.(mysql_db.PrivilegedDatabase); ok {
   182  		db = pdb.Unwrap()
   183  	}
   184  	if adb, ok := db.(sql.AliasedDatabase); ok {
   185  		checkDbName = adb.AliasedName()
   186  	}
   187  	return checkDbName
   188  }