github.com/dolthub/go-mysql-server@v0.18.0/sql/transform/expr.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transform
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	"github.com/dolthub/go-mysql-server/sql/expression"
    23  )
    24  
    25  // Expr applies a transformation function to the given expression
    26  // tree from the bottom up. Each callback [f] returns a TreeIdentity
    27  // that is aggregated into a final output indicating whether the
    28  // expression tree was changed.
    29  func Expr(e sql.Expression, f ExprFunc) (sql.Expression, TreeIdentity, error) {
    30  	children := e.Children()
    31  	if len(children) == 0 {
    32  		return f(e)
    33  	}
    34  
    35  	var (
    36  		newChildren []sql.Expression
    37  		err         error
    38  	)
    39  
    40  	for i := 0; i < len(children); i++ {
    41  		c := children[i]
    42  		c, same, err := Expr(c, f)
    43  		if err != nil {
    44  			return nil, SameTree, err
    45  		}
    46  		if !same {
    47  			if newChildren == nil {
    48  				newChildren = make([]sql.Expression, len(children))
    49  				copy(newChildren, children)
    50  			}
    51  			newChildren[i] = c
    52  		}
    53  	}
    54  
    55  	sameC := SameTree
    56  	if len(newChildren) > 0 {
    57  		sameC = NewTree
    58  		e, err = e.WithChildren(newChildren...)
    59  		if err != nil {
    60  			return nil, SameTree, err
    61  		}
    62  	}
    63  
    64  	e, sameN, err := f(e)
    65  	if err != nil {
    66  		return nil, SameTree, err
    67  	}
    68  	return e, sameC && sameN, nil
    69  }
    70  
    71  // Exprs applies a transformation function to the given set of expressions and returns the result.
    72  func Exprs(e []sql.Expression, f ExprFunc) ([]sql.Expression, TreeIdentity, error) {
    73  	var (
    74  		newExprs []sql.Expression
    75  	)
    76  
    77  	for i := 0; i < len(e); i++ {
    78  		c := e[i]
    79  		c, same, err := Expr(c, f)
    80  		if err != nil {
    81  			return nil, SameTree, err
    82  		}
    83  		if !same {
    84  			if newExprs == nil {
    85  				newExprs = make([]sql.Expression, len(e))
    86  				copy(newExprs, e)
    87  			}
    88  			newExprs[i] = c
    89  		}
    90  	}
    91  
    92  	if len(newExprs) == 0 {
    93  		return e, SameTree, nil
    94  	}
    95  
    96  	return newExprs, NewTree, nil
    97  }
    98  
    99  var stopInspect = errors.New("stop")
   100  
   101  // InspectExpr traverses the given expression tree from the bottom up, breaking if
   102  // stop = true. Returns a bool indicating whether traversal was interrupted.
   103  func InspectExpr(node sql.Expression, f func(sql.Expression) bool) bool {
   104  	_, _, err := Expr(node, func(e sql.Expression) (sql.Expression, TreeIdentity, error) {
   105  		ok := f(e)
   106  		if ok {
   107  			return nil, SameTree, stopInspect
   108  		}
   109  		return e, SameTree, nil
   110  	})
   111  	return errors.Is(err, stopInspect)
   112  }
   113  
   114  // InspectUp traverses the given node tree from the bottom up, breaking if
   115  // stop = true. Returns a bool indicating whether traversal was interrupted.
   116  func InspectUp(node sql.Node, f func(sql.Node) bool) bool {
   117  	stop := errors.New("stop")
   118  	_, _, err := Node(node, func(e sql.Node) (sql.Node, TreeIdentity, error) {
   119  		ok := f(e)
   120  		if ok {
   121  			return nil, SameTree, stop
   122  		}
   123  		return e, SameTree, nil
   124  	})
   125  	return errors.Is(err, stop)
   126  }
   127  
   128  // Clone duplicates an existing sql.Expression, returning new nodes with the
   129  // same structure and internal values. It can be useful when dealing with
   130  // stateful expression nodes where an evaluation needs to create multiple
   131  // independent histories of the internal state of the expression nodes.
   132  func Clone(expr sql.Expression) (sql.Expression, error) {
   133  	expr, _, err := Expr(expr, func(e sql.Expression) (sql.Expression, TreeIdentity, error) {
   134  		return e, NewTree, nil
   135  	})
   136  	return expr, err
   137  }
   138  
   139  // ExprWithNode applies a transformation function to the given expression from the bottom up.
   140  func ExprWithNode(n sql.Node, e sql.Expression, f ExprWithNodeFunc) (sql.Expression, TreeIdentity, error) {
   141  	children := e.Children()
   142  	if len(children) == 0 {
   143  		return f(n, e)
   144  	}
   145  
   146  	var (
   147  		newChildren []sql.Expression
   148  		err         error
   149  	)
   150  
   151  	for i := 0; i < len(children); i++ {
   152  		c := children[i]
   153  		c, sameC, err := ExprWithNode(n, c, f)
   154  		if err != nil {
   155  			return nil, SameTree, err
   156  		}
   157  		if !sameC {
   158  			if newChildren == nil {
   159  				newChildren = make([]sql.Expression, len(children))
   160  				copy(newChildren, children)
   161  			}
   162  			newChildren[i] = c
   163  		}
   164  	}
   165  
   166  	sameC := SameTree
   167  	if len(newChildren) > 0 {
   168  		sameC = NewTree
   169  		e, err = e.WithChildren(newChildren...)
   170  		if err != nil {
   171  			return nil, SameTree, err
   172  		}
   173  	}
   174  
   175  	e, sameN, err := f(n, e)
   176  	if err != nil {
   177  		return nil, SameTree, err
   178  	}
   179  	return e, sameC && sameN, nil
   180  }
   181  
   182  // ExpressionToColumn converts the expression to the form that should be used in a Schema. Expressions that have Name()
   183  // and Table() methods will use these; otherwise, String() and "" are used, respectively. The type and nullability are
   184  // taken from the expression directly.
   185  func ExpressionToColumn(e sql.Expression, name string) *sql.Column {
   186  	if n, ok := e.(sql.Nameable); ok {
   187  		name = n.Name()
   188  	}
   189  
   190  	var table string
   191  	if t, ok := e.(sql.Tableable); ok {
   192  		table = t.Table()
   193  	}
   194  
   195  	var db string
   196  	if t, ok := e.(sql.Databaseable); ok {
   197  		db = t.Database()
   198  	}
   199  
   200  	// TODO: Is this still necessary?
   201  	if e.Resolved() {
   202  		return &sql.Column{
   203  			Name:           name,
   204  			Source:         table,
   205  			DatabaseSource: db,
   206  			Type:           e.Type(),
   207  			Nullable:       e.IsNullable(),
   208  		}
   209  	} else {
   210  		return &sql.Column{
   211  			Name:           name,
   212  			Source:         table,
   213  			DatabaseSource: db,
   214  		}
   215  	}
   216  }
   217  
   218  // SchemaWithDefaults returns a copy of the schema given with the defaults provided. Default expressions must be
   219  // wrapped with expression.Wrapper.
   220  func SchemaWithDefaults(schema sql.Schema, defaultExprs []sql.Expression) (sql.Schema, error) {
   221  	if len(schema) != len(defaultExprs) {
   222  		return nil, fmt.Errorf("expected %d default expressions, got %d", len(schema), len(defaultExprs))
   223  	}
   224  
   225  	sch := schema.Copy()
   226  	for i, col := range sch {
   227  		wrapper, ok := defaultExprs[i].(*expression.Wrapper)
   228  		if !ok {
   229  			return nil, fmt.Errorf("expected expression.Wrapper, got %T", defaultExprs[i])
   230  		}
   231  		wrappedExpr := wrapper.Unwrap()
   232  		if wrappedExpr == nil {
   233  			continue
   234  		}
   235  
   236  		defaultExpr, ok := wrappedExpr.(*sql.ColumnDefaultValue)
   237  		if !ok {
   238  			return nil, fmt.Errorf("expected *sql.ColumnDefaultValue, got %T", wrappedExpr)
   239  		}
   240  		if col.Default != nil {
   241  			col.Default = defaultExpr
   242  		} else {
   243  			col.Generated = defaultExpr
   244  		}
   245  	}
   246  
   247  	return sch, nil
   248  }
   249  
   250  // WrappedColumnDefaults returns the column defaults / generated expressions for the schema given,
   251  // wrapped with expression.Wrapper
   252  func WrappedColumnDefaults(schema sql.Schema) []sql.Expression {
   253  	defs := make([]sql.Expression, len(schema))
   254  	for i, col := range schema {
   255  		defaultVal := col.Default
   256  		if defaultVal == nil && col.Generated != nil {
   257  			defaultVal = col.Generated
   258  		}
   259  		defs[i] = expression.WrapExpression(defaultVal)
   260  	}
   261  	return defs
   262  }