github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/plan/build_replace.go (about)

     1  // Copyright 2021 - 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  	"time"
    21  
    22  	"github.com/matrixorigin/matrixone/pkg/catalog"
    23  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    24  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    25  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect"
    26  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/tree"
    27  	v2 "github.com/matrixorigin/matrixone/pkg/util/metric/v2"
    28  )
    29  
    30  func buildReplace(stmt *tree.Replace, ctx CompilerContext, isPrepareStmt bool) (p *Plan, err error) {
    31  	start := time.Now()
    32  	defer func() {
    33  		v2.TxnStatementBuildReplaceHistogram.Observe(time.Since(start).Seconds())
    34  	}()
    35  	tblInfo, err := getDmlTableInfo(ctx, tree.TableExprs{stmt.Table}, nil, nil, "replace")
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  	if len(tblInfo.tableDefs) != 1 {
    40  		return nil, moerr.NewInvalidInput(ctx.GetContext(), "replace does not support multi-table")
    41  	}
    42  	tableDef := tblInfo.tableDefs[0]
    43  	keys := getAllKeys(tableDef)
    44  	deleteCond := ""
    45  
    46  	if keys != nil {
    47  		// table has keys
    48  		if len(stmt.Columns) != 0 {
    49  			// replace into table set col1 = val1, col2 = val2, ...
    50  			row := stmt.Rows.Select.(*tree.ValuesClause).Rows[0]
    51  			keyToRow := getKeyToRowMatch(stmt.Columns)
    52  			keepKeys := filterKeys(keys, stmt.Columns)
    53  			disjunction := make([]string, 0, len(keepKeys))
    54  			for _, key := range keepKeys {
    55  				disjunction = append(disjunction, buildConjunction(key, row, keyToRow))
    56  			}
    57  			deleteCond = strings.Join(disjunction, " or ")
    58  		} else {
    59  			// replace into table values (...);
    60  			keyToRow := make(map[string]int, len(tableDef.Cols))
    61  			for i, col := range tableDef.Cols {
    62  				keyToRow[col.Name] = i
    63  			}
    64  
    65  			rows := stmt.Rows.Select.(*tree.ValuesClause).Rows
    66  			disjunction := make([]string, 0, len(rows)*len(keys))
    67  			for _, row := range rows {
    68  				for _, key := range keys {
    69  					disjunction = append(disjunction, buildConjunction(key, row, keyToRow))
    70  				}
    71  			}
    72  			deleteCond = strings.Join(disjunction, " or ")
    73  		}
    74  	}
    75  
    76  	return &Plan{
    77  		Plan: &plan.Plan_Query{
    78  			Query: &plan.Query{
    79  				StmtType: plan.Query_REPLACE,
    80  				Nodes: []*plan.Node{
    81  					{NodeType: plan.Node_REPLACE, ReplaceCtx: &plan.ReplaceCtx{TableDef: tableDef, DeleteCond: deleteCond}},
    82  				},
    83  			},
    84  		},
    85  	}, nil
    86  }
    87  
    88  func isMapSubset(m, sub map[string]struct{}) bool {
    89  	if len(sub) > len(m) {
    90  		return false
    91  	}
    92  	for k := range sub {
    93  		if _, ok := m[k]; !ok {
    94  			return false
    95  		}
    96  	}
    97  	return true
    98  }
    99  
   100  func getAllKeys(tableDef *plan.TableDef) []map[string]struct{} {
   101  	n := 0
   102  	for _, index := range tableDef.Indexes {
   103  		if index.Unique {
   104  			n++
   105  		}
   106  	}
   107  	if tableDef.Pkey.PkeyColName != catalog.FakePrimaryKeyColName {
   108  		n++
   109  	}
   110  	if n == 0 {
   111  		return nil
   112  	}
   113  
   114  	keys := make([]map[string]struct{}, 0, n)
   115  	if tableDef.Pkey.PkeyColName != catalog.FakePrimaryKeyColName {
   116  		keys = append(keys, make(map[string]struct{}))
   117  		for _, part := range tableDef.Pkey.Names {
   118  			keys[0][part] = struct{}{}
   119  		}
   120  	}
   121  	for _, index := range tableDef.Indexes {
   122  		if index.Unique {
   123  			keys = append(keys, make(map[string]struct{}))
   124  			for _, key := range index.Parts {
   125  				keys[len(keys)-1][key] = struct{}{}
   126  			}
   127  		}
   128  	}
   129  	return keys
   130  }
   131  
   132  func getInsertedCol(cols tree.IdentifierList) map[string]struct{} {
   133  	insertedCol := make(map[string]struct{}, len(cols))
   134  	for _, col := range cols {
   135  		insertedCol[string(col)] = struct{}{}
   136  	}
   137  	return insertedCol
   138  }
   139  
   140  func filterKeys(keys []map[string]struct{}, cols tree.IdentifierList) []map[string]struct{} {
   141  	keepKeys := keys[:0]
   142  	insertedCol := getInsertedCol(cols)
   143  	for _, key := range keys {
   144  		if isMapSubset(insertedCol, key) {
   145  			keepKeys = append(keepKeys, key)
   146  		}
   147  	}
   148  	for i := len(keepKeys); i < len(keys); i++ {
   149  		keys[i] = nil // or the zero value of T
   150  	}
   151  	return keepKeys
   152  }
   153  
   154  func getKeyToRowMatch(columns tree.IdentifierList) map[string]int {
   155  	keyToRow := make(map[string]int, len(columns))
   156  	for i, col := range columns {
   157  		keyToRow[string(col)] = i
   158  	}
   159  	return keyToRow
   160  }
   161  
   162  func buildConjunction(key map[string]struct{}, row tree.Exprs, keyToRow map[string]int) string {
   163  	conjunctions := make([]string, 0, len(key))
   164  	for k := range key {
   165  		fmtctx := tree.NewFmtCtx(dialect.MYSQL, tree.WithQuoteString(true))
   166  		row[keyToRow[k]].Format(fmtctx)
   167  		conjunctions = append(conjunctions, fmt.Sprintf("%s in (select %s)", k, fmtctx.String()))
   168  	}
   169  	return "(" + strings.Join(conjunctions, " and ") + ")"
   170  }