github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/tables.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"github.com/dolthub/go-mysql-server/sql"
    19  	"github.com/dolthub/go-mysql-server/sql/expression"
    20  	"github.com/dolthub/go-mysql-server/sql/plan"
    21  	"github.com/dolthub/go-mysql-server/sql/transform"
    22  )
    23  
    24  // Returns the underlying table name for the node given
    25  func getTableName(node sql.Node) string {
    26  	var tableName string
    27  	transform.Inspect(node, func(node sql.Node) bool {
    28  		switch node := node.(type) {
    29  		case *plan.TableAlias:
    30  			tableName = node.Name()
    31  			return false
    32  		case *plan.ResolvedTable:
    33  			tableName = node.Name()
    34  			return false
    35  		case *plan.UnresolvedTable:
    36  			tableName = node.Name()
    37  			return false
    38  		case *plan.IndexedTableAccess:
    39  			tableName = node.Name()
    40  			return false
    41  		}
    42  		return true
    43  	})
    44  
    45  	return tableName
    46  }
    47  
    48  // Returns the underlying table name for the node given, ignoring table aliases
    49  func getUnaliasedTableName(node sql.Node) string {
    50  	var tableName string
    51  	transform.Inspect(node, func(node sql.Node) bool {
    52  		switch node := node.(type) {
    53  		case *plan.ResolvedTable:
    54  			tableName = node.Name()
    55  			return false
    56  		case *plan.UnresolvedTable:
    57  			tableName = node.Name()
    58  			return false
    59  		case *plan.IndexedTableAccess:
    60  			tableName = node.Name()
    61  			return false
    62  		}
    63  		return true
    64  	})
    65  
    66  	return tableName
    67  }
    68  
    69  // Finds first table node that is a descendant of the node given
    70  func getTable(node sql.Node) sql.Table {
    71  	var table sql.Table
    72  	transform.Inspect(node, func(node sql.Node) bool {
    73  		if table != nil {
    74  			return false
    75  		}
    76  
    77  		switch n := node.(type) {
    78  		case sql.TableNode:
    79  			table = n.UnderlyingTable()
    80  			// TODO unwinding a table wrapper here causes infinite analyzer recursion
    81  			return false
    82  		case *plan.IndexedTableAccess:
    83  			table = n.TableNode.UnderlyingTable()
    84  			return false
    85  		}
    86  		return true
    87  	})
    88  	return table
    89  }
    90  
    91  // Finds first ResolvedTable node that is a descendant of the node given
    92  func getResolvedTable(node sql.Node) *plan.ResolvedTable {
    93  	var table *plan.ResolvedTable
    94  	transform.Inspect(node, func(node sql.Node) bool {
    95  		// plan.Inspect will get called on all children of a node even if one of the children's calls returns false. We
    96  		// only want the first TableNode match.
    97  		if table != nil {
    98  			return false
    99  		}
   100  
   101  		switch n := node.(type) {
   102  		case *plan.ResolvedTable:
   103  			if !plan.IsDualTable(n) {
   104  				table = n
   105  				return false
   106  			}
   107  		case *plan.IndexedTableAccess:
   108  			rt, ok := n.TableNode.(*plan.ResolvedTable)
   109  			if ok {
   110  				table = rt
   111  				return false
   112  			}
   113  		}
   114  		return true
   115  	})
   116  	return table
   117  }
   118  
   119  // getTablesByName takes a node and returns all found resolved tables in a map.
   120  func getTablesByName(node sql.Node) map[string]*plan.ResolvedTable {
   121  	ret := make(map[string]*plan.ResolvedTable)
   122  
   123  	transform.Inspect(node, func(node sql.Node) bool {
   124  		switch n := node.(type) {
   125  		case *plan.ResolvedTable:
   126  			ret[n.Table.Name()] = n
   127  		case *plan.IndexedTableAccess:
   128  			rt, ok := n.TableNode.(*plan.ResolvedTable)
   129  			if ok {
   130  				ret[rt.Name()] = rt
   131  				return false
   132  			}
   133  		case *plan.TableAlias:
   134  			rt := getResolvedTable(n)
   135  			if rt != nil {
   136  				ret[n.Name()] = rt
   137  			}
   138  		default:
   139  		}
   140  		return true
   141  	})
   142  
   143  	return ret
   144  }
   145  
   146  // Returns the tables used in the expressions given
   147  func findTables(exprs ...sql.Expression) []string {
   148  	tables := make(map[string]bool)
   149  	for _, e := range exprs {
   150  		sql.Inspect(e, func(e sql.Expression) bool {
   151  			switch e := e.(type) {
   152  			case *expression.GetField:
   153  				tables[e.Table()] = true
   154  				return false
   155  			default:
   156  				return true
   157  			}
   158  		})
   159  	}
   160  
   161  	var names []string
   162  	for table := range tables {
   163  		names = append(names, table)
   164  	}
   165  
   166  	return names
   167  }