github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/assign_update_join.go (about)

     1  package analyzer
     2  
     3  import (
     4  	"github.com/dolthub/go-mysql-server/sql"
     5  	"github.com/dolthub/go-mysql-server/sql/expression"
     6  	"github.com/dolthub/go-mysql-server/sql/plan"
     7  	"github.com/dolthub/go-mysql-server/sql/transform"
     8  )
     9  
    10  // modifyUpdateExpressionsForJoin searches for a JOIN for UPDATE query and updates the child of the original update
    11  // node to use a plan.UpdateJoin node as a child.
    12  func modifyUpdateExpressionsForJoin(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    13  	switch n := n.(type) {
    14  	case *plan.Update:
    15  		us, ok := n.Child.(*plan.UpdateSource)
    16  		if !ok {
    17  			return n, transform.SameTree, nil
    18  		}
    19  
    20  		var jn sql.Node
    21  		transform.Inspect(us, func(node sql.Node) bool {
    22  			switch node.(type) {
    23  			case *plan.JoinNode:
    24  				jn = node
    25  				return false
    26  			default:
    27  				return true
    28  			}
    29  		})
    30  
    31  		if jn == nil {
    32  			return n, transform.SameTree, nil
    33  		}
    34  
    35  		updaters, err := rowUpdatersByTable(ctx, us, jn)
    36  		if err != nil {
    37  			return nil, transform.SameTree, err
    38  		}
    39  
    40  		uj := plan.NewUpdateJoin(updaters, us)
    41  		ret, err := n.WithChildren(uj)
    42  		if err != nil {
    43  			return nil, transform.SameTree, err
    44  		}
    45  
    46  		return ret, transform.NewTree, nil
    47  	}
    48  
    49  	return n, transform.SameTree, nil
    50  }
    51  
    52  // rowUpdatersByTable maps a set of tables to their RowUpdater objects.
    53  func rowUpdatersByTable(ctx *sql.Context, node sql.Node, ij sql.Node) (map[string]sql.RowUpdater, error) {
    54  	namesOfTableToBeUpdated := getTablesToBeUpdated(node)
    55  	resolvedTables := getTablesByName(ij)
    56  
    57  	rowUpdatersByTable := make(map[string]sql.RowUpdater)
    58  	for tableToBeUpdated, _ := range namesOfTableToBeUpdated {
    59  		resolvedTable, ok := resolvedTables[tableToBeUpdated]
    60  		if !ok {
    61  			return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
    62  		}
    63  
    64  		var table = resolvedTable.UnderlyingTable()
    65  
    66  		// If there is no UpdatableTable for a table being updated, error out
    67  		updatable, ok := table.(sql.UpdatableTable)
    68  		if !ok && updatable == nil {
    69  			return nil, plan.ErrUpdateForTableNotSupported.New(tableToBeUpdated)
    70  		}
    71  
    72  		keyless := sql.IsKeyless(updatable.Schema())
    73  		if keyless {
    74  			return nil, sql.ErrUnsupportedFeature.New("error: keyless tables unsupported for UPDATE JOIN")
    75  		}
    76  
    77  		rowUpdatersByTable[tableToBeUpdated] = updatable.Updater(ctx)
    78  	}
    79  
    80  	return rowUpdatersByTable, nil
    81  }
    82  
    83  // getTablesToBeUpdated takes a node and looks for the tables to modified by a SetField.
    84  func getTablesToBeUpdated(node sql.Node) map[string]struct{} {
    85  	ret := make(map[string]struct{})
    86  
    87  	transform.InspectExpressions(node, func(e sql.Expression) bool {
    88  		switch e := e.(type) {
    89  		case *expression.SetField:
    90  			gf := e.LeftChild.(*expression.GetField)
    91  			ret[gf.Table()] = struct{}{}
    92  			return false
    93  		}
    94  
    95  		return true
    96  	})
    97  
    98  	return ret
    99  }