github.com/dolthub/go-mysql-server@v0.18.0/sql/transform/expr.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package transform 16 17 import ( 18 "errors" 19 "fmt" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 "github.com/dolthub/go-mysql-server/sql/expression" 23 ) 24 25 // Expr applies a transformation function to the given expression 26 // tree from the bottom up. Each callback [f] returns a TreeIdentity 27 // that is aggregated into a final output indicating whether the 28 // expression tree was changed. 29 func Expr(e sql.Expression, f ExprFunc) (sql.Expression, TreeIdentity, error) { 30 children := e.Children() 31 if len(children) == 0 { 32 return f(e) 33 } 34 35 var ( 36 newChildren []sql.Expression 37 err error 38 ) 39 40 for i := 0; i < len(children); i++ { 41 c := children[i] 42 c, same, err := Expr(c, f) 43 if err != nil { 44 return nil, SameTree, err 45 } 46 if !same { 47 if newChildren == nil { 48 newChildren = make([]sql.Expression, len(children)) 49 copy(newChildren, children) 50 } 51 newChildren[i] = c 52 } 53 } 54 55 sameC := SameTree 56 if len(newChildren) > 0 { 57 sameC = NewTree 58 e, err = e.WithChildren(newChildren...) 59 if err != nil { 60 return nil, SameTree, err 61 } 62 } 63 64 e, sameN, err := f(e) 65 if err != nil { 66 return nil, SameTree, err 67 } 68 return e, sameC && sameN, nil 69 } 70 71 // Exprs applies a transformation function to the given set of expressions and returns the result. 72 func Exprs(e []sql.Expression, f ExprFunc) ([]sql.Expression, TreeIdentity, error) { 73 var ( 74 newExprs []sql.Expression 75 ) 76 77 for i := 0; i < len(e); i++ { 78 c := e[i] 79 c, same, err := Expr(c, f) 80 if err != nil { 81 return nil, SameTree, err 82 } 83 if !same { 84 if newExprs == nil { 85 newExprs = make([]sql.Expression, len(e)) 86 copy(newExprs, e) 87 } 88 newExprs[i] = c 89 } 90 } 91 92 if len(newExprs) == 0 { 93 return e, SameTree, nil 94 } 95 96 return newExprs, NewTree, nil 97 } 98 99 var stopInspect = errors.New("stop") 100 101 // InspectExpr traverses the given expression tree from the bottom up, breaking if 102 // stop = true. Returns a bool indicating whether traversal was interrupted. 103 func InspectExpr(node sql.Expression, f func(sql.Expression) bool) bool { 104 _, _, err := Expr(node, func(e sql.Expression) (sql.Expression, TreeIdentity, error) { 105 ok := f(e) 106 if ok { 107 return nil, SameTree, stopInspect 108 } 109 return e, SameTree, nil 110 }) 111 return errors.Is(err, stopInspect) 112 } 113 114 // InspectUp traverses the given node tree from the bottom up, breaking if 115 // stop = true. Returns a bool indicating whether traversal was interrupted. 116 func InspectUp(node sql.Node, f func(sql.Node) bool) bool { 117 stop := errors.New("stop") 118 _, _, err := Node(node, func(e sql.Node) (sql.Node, TreeIdentity, error) { 119 ok := f(e) 120 if ok { 121 return nil, SameTree, stop 122 } 123 return e, SameTree, nil 124 }) 125 return errors.Is(err, stop) 126 } 127 128 // Clone duplicates an existing sql.Expression, returning new nodes with the 129 // same structure and internal values. It can be useful when dealing with 130 // stateful expression nodes where an evaluation needs to create multiple 131 // independent histories of the internal state of the expression nodes. 132 func Clone(expr sql.Expression) (sql.Expression, error) { 133 expr, _, err := Expr(expr, func(e sql.Expression) (sql.Expression, TreeIdentity, error) { 134 return e, NewTree, nil 135 }) 136 return expr, err 137 } 138 139 // ExprWithNode applies a transformation function to the given expression from the bottom up. 140 func ExprWithNode(n sql.Node, e sql.Expression, f ExprWithNodeFunc) (sql.Expression, TreeIdentity, error) { 141 children := e.Children() 142 if len(children) == 0 { 143 return f(n, e) 144 } 145 146 var ( 147 newChildren []sql.Expression 148 err error 149 ) 150 151 for i := 0; i < len(children); i++ { 152 c := children[i] 153 c, sameC, err := ExprWithNode(n, c, f) 154 if err != nil { 155 return nil, SameTree, err 156 } 157 if !sameC { 158 if newChildren == nil { 159 newChildren = make([]sql.Expression, len(children)) 160 copy(newChildren, children) 161 } 162 newChildren[i] = c 163 } 164 } 165 166 sameC := SameTree 167 if len(newChildren) > 0 { 168 sameC = NewTree 169 e, err = e.WithChildren(newChildren...) 170 if err != nil { 171 return nil, SameTree, err 172 } 173 } 174 175 e, sameN, err := f(n, e) 176 if err != nil { 177 return nil, SameTree, err 178 } 179 return e, sameC && sameN, nil 180 } 181 182 // ExpressionToColumn converts the expression to the form that should be used in a Schema. Expressions that have Name() 183 // and Table() methods will use these; otherwise, String() and "" are used, respectively. The type and nullability are 184 // taken from the expression directly. 185 func ExpressionToColumn(e sql.Expression, name string) *sql.Column { 186 if n, ok := e.(sql.Nameable); ok { 187 name = n.Name() 188 } 189 190 var table string 191 if t, ok := e.(sql.Tableable); ok { 192 table = t.Table() 193 } 194 195 var db string 196 if t, ok := e.(sql.Databaseable); ok { 197 db = t.Database() 198 } 199 200 // TODO: Is this still necessary? 201 if e.Resolved() { 202 return &sql.Column{ 203 Name: name, 204 Source: table, 205 DatabaseSource: db, 206 Type: e.Type(), 207 Nullable: e.IsNullable(), 208 } 209 } else { 210 return &sql.Column{ 211 Name: name, 212 Source: table, 213 DatabaseSource: db, 214 } 215 } 216 } 217 218 // SchemaWithDefaults returns a copy of the schema given with the defaults provided. Default expressions must be 219 // wrapped with expression.Wrapper. 220 func SchemaWithDefaults(schema sql.Schema, defaultExprs []sql.Expression) (sql.Schema, error) { 221 if len(schema) != len(defaultExprs) { 222 return nil, fmt.Errorf("expected %d default expressions, got %d", len(schema), len(defaultExprs)) 223 } 224 225 sch := schema.Copy() 226 for i, col := range sch { 227 wrapper, ok := defaultExprs[i].(*expression.Wrapper) 228 if !ok { 229 return nil, fmt.Errorf("expected expression.Wrapper, got %T", defaultExprs[i]) 230 } 231 wrappedExpr := wrapper.Unwrap() 232 if wrappedExpr == nil { 233 continue 234 } 235 236 defaultExpr, ok := wrappedExpr.(*sql.ColumnDefaultValue) 237 if !ok { 238 return nil, fmt.Errorf("expected *sql.ColumnDefaultValue, got %T", wrappedExpr) 239 } 240 if col.Default != nil { 241 col.Default = defaultExpr 242 } else { 243 col.Generated = defaultExpr 244 } 245 } 246 247 return sch, nil 248 } 249 250 // WrappedColumnDefaults returns the column defaults / generated expressions for the schema given, 251 // wrapped with expression.Wrapper 252 func WrappedColumnDefaults(schema sql.Schema) []sql.Expression { 253 defs := make([]sql.Expression, len(schema)) 254 for i, col := range schema { 255 defaultVal := col.Default 256 if defaultVal == nil && col.Generated != nil { 257 defaultVal = col.Generated 258 } 259 defs[i] = expression.WrapExpression(defaultVal) 260 } 261 return defs 262 }