github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/tables.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package analyzer 16 17 import ( 18 "github.com/dolthub/go-mysql-server/sql" 19 "github.com/dolthub/go-mysql-server/sql/expression" 20 "github.com/dolthub/go-mysql-server/sql/plan" 21 "github.com/dolthub/go-mysql-server/sql/transform" 22 ) 23 24 // Returns the underlying table name for the node given 25 func getTableName(node sql.Node) string { 26 var tableName string 27 transform.Inspect(node, func(node sql.Node) bool { 28 switch node := node.(type) { 29 case *plan.TableAlias: 30 tableName = node.Name() 31 return false 32 case *plan.ResolvedTable: 33 tableName = node.Name() 34 return false 35 case *plan.UnresolvedTable: 36 tableName = node.Name() 37 return false 38 case *plan.IndexedTableAccess: 39 tableName = node.Name() 40 return false 41 } 42 return true 43 }) 44 45 return tableName 46 } 47 48 // Returns the underlying table name for the node given, ignoring table aliases 49 func getUnaliasedTableName(node sql.Node) string { 50 var tableName string 51 transform.Inspect(node, func(node sql.Node) bool { 52 switch node := node.(type) { 53 case *plan.ResolvedTable: 54 tableName = node.Name() 55 return false 56 case *plan.UnresolvedTable: 57 tableName = node.Name() 58 return false 59 case *plan.IndexedTableAccess: 60 tableName = node.Name() 61 return false 62 } 63 return true 64 }) 65 66 return tableName 67 } 68 69 // Finds first table node that is a descendant of the node given 70 func getTable(node sql.Node) sql.Table { 71 var table sql.Table 72 transform.Inspect(node, func(node sql.Node) bool { 73 if table != nil { 74 return false 75 } 76 77 switch n := node.(type) { 78 case sql.TableNode: 79 table = n.UnderlyingTable() 80 // TODO unwinding a table wrapper here causes infinite analyzer recursion 81 return false 82 case *plan.IndexedTableAccess: 83 table = n.TableNode.UnderlyingTable() 84 return false 85 } 86 return true 87 }) 88 return table 89 } 90 91 // Finds first ResolvedTable node that is a descendant of the node given 92 func getResolvedTable(node sql.Node) *plan.ResolvedTable { 93 var table *plan.ResolvedTable 94 transform.Inspect(node, func(node sql.Node) bool { 95 // plan.Inspect will get called on all children of a node even if one of the children's calls returns false. We 96 // only want the first TableNode match. 97 if table != nil { 98 return false 99 } 100 101 switch n := node.(type) { 102 case *plan.ResolvedTable: 103 if !plan.IsDualTable(n) { 104 table = n 105 return false 106 } 107 case *plan.IndexedTableAccess: 108 rt, ok := n.TableNode.(*plan.ResolvedTable) 109 if ok { 110 table = rt 111 return false 112 } 113 } 114 return true 115 }) 116 return table 117 } 118 119 // getTablesByName takes a node and returns all found resolved tables in a map. 120 func getTablesByName(node sql.Node) map[string]*plan.ResolvedTable { 121 ret := make(map[string]*plan.ResolvedTable) 122 123 transform.Inspect(node, func(node sql.Node) bool { 124 switch n := node.(type) { 125 case *plan.ResolvedTable: 126 ret[n.Table.Name()] = n 127 case *plan.IndexedTableAccess: 128 rt, ok := n.TableNode.(*plan.ResolvedTable) 129 if ok { 130 ret[rt.Name()] = rt 131 return false 132 } 133 case *plan.TableAlias: 134 rt := getResolvedTable(n) 135 if rt != nil { 136 ret[n.Name()] = rt 137 } 138 default: 139 } 140 return true 141 }) 142 143 return ret 144 } 145 146 // Returns the tables used in the expressions given 147 func findTables(exprs ...sql.Expression) []string { 148 tables := make(map[string]bool) 149 for _, e := range exprs { 150 sql.Inspect(e, func(e sql.Expression) bool { 151 switch e := e.(type) { 152 case *expression.GetField: 153 tables[e.Table()] = true 154 return false 155 default: 156 return true 157 } 158 }) 159 } 160 161 var names []string 162 for table := range tables { 163 names = append(names, table) 164 } 165 166 return names 167 }