github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/simulator/sqlgen/impl.go (about)

     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package sqlgen
    15  
    16  import (
    17  	"strings"
    18  
    19  	"github.com/chaos-mesh/go-sqlsmith/util"
    20  	"github.com/pingcap/errors"
    21  	"github.com/pingcap/tidb/pkg/parser/ast"
    22  	"github.com/pingcap/tidb/pkg/parser/format"
    23  	"github.com/pingcap/tidb/pkg/parser/model"
    24  	"github.com/pingcap/tidb/pkg/parser/opcode"
    25  	_ "github.com/pingcap/tidb/pkg/types/parser_driver" // import this to make the parser work
    26  	"github.com/pingcap/tiflow/dm/pkg/log"
    27  	"github.com/pingcap/tiflow/dm/simulator/config"
    28  	"github.com/pingcap/tiflow/dm/simulator/mcp"
    29  	"go.uber.org/zap"
    30  )
    31  
    32  type sqlGeneratorImpl struct {
    33  	tableConfig *config.TableConfig
    34  	columnMap   map[string]*config.ColumnDefinition
    35  	ukMap       map[string]struct{}
    36  }
    37  
    38  // NewSQLGeneratorImpl generates a new implementation object for SQL generator.
    39  func NewSQLGeneratorImpl(tableConfig *config.TableConfig) *sqlGeneratorImpl {
    40  	colDefMap := make(map[string]*config.ColumnDefinition)
    41  	for _, colDef := range tableConfig.Columns {
    42  		colDefMap[colDef.ColumnName] = colDef
    43  	}
    44  	ukMap := make(map[string]struct{})
    45  	for _, ukColName := range tableConfig.UniqueKeyColumnNames {
    46  		if _, ok := colDefMap[ukColName]; ok {
    47  			ukMap[ukColName] = struct{}{}
    48  		}
    49  	}
    50  	return &sqlGeneratorImpl{
    51  		tableConfig: tableConfig,
    52  		columnMap:   colDefMap,
    53  		ukMap:       ukMap,
    54  	}
    55  }
    56  
    57  // outputString parses an ast node to SQL string.
    58  func outputString(node ast.Node) (string, error) {
    59  	var sb strings.Builder
    60  	err := node.Restore(format.NewRestoreCtx(format.DefaultRestoreFlags, &sb))
    61  	if err != nil {
    62  		return "", errors.Annotate(err, "restore AST into SQL string error")
    63  	}
    64  	return sb.String(), nil
    65  }
    66  
    67  // GenTruncateTable generates a TRUNCATE TABLE SQL.
    68  // It implements the SQLGenerator interface.
    69  func (g *sqlGeneratorImpl) GenTruncateTable() (string, error) {
    70  	truncateTree := &ast.TruncateTableStmt{
    71  		Table: &ast.TableName{
    72  			Schema: model.NewCIStr(g.tableConfig.DatabaseName),
    73  			Name:   model.NewCIStr(g.tableConfig.TableName),
    74  		},
    75  	}
    76  	return outputString(truncateTree)
    77  }
    78  
    79  func (g *sqlGeneratorImpl) generateWhereClause(theUK map[string]interface{}) (ast.ExprNode, error) {
    80  	compareExprs := make([]ast.ExprNode, 0)
    81  	// iterate the existing UKs, to make sure all the uk columns has values
    82  	for ukColName := range g.ukMap {
    83  		val, ok := theUK[ukColName]
    84  		if !ok {
    85  			log.L().Error(ErrUKColValueNotProvided.Error(), zap.String("column_name", ukColName))
    86  			return nil, errors.Trace(ErrUKColValueNotProvided)
    87  		}
    88  		var compareExpr ast.ExprNode
    89  		if val == nil {
    90  			compareExpr = &ast.IsNullExpr{
    91  				Expr: &ast.ColumnNameExpr{
    92  					Name: &ast.ColumnName{
    93  						Name: model.NewCIStr(ukColName),
    94  					},
    95  				},
    96  			}
    97  		} else {
    98  			compareExpr = &ast.BinaryOperationExpr{
    99  				Op: opcode.EQ,
   100  				L: &ast.ColumnNameExpr{
   101  					Name: &ast.ColumnName{
   102  						Name: model.NewCIStr(ukColName),
   103  					},
   104  				},
   105  				R: ast.NewValueExpr(val, "", ""),
   106  			}
   107  		}
   108  		compareExprs = append(compareExprs, compareExpr)
   109  	}
   110  	resultExpr := generateCompoundBinaryOpExpr(compareExprs)
   111  	if resultExpr == nil {
   112  		return nil, ErrWhereFiltersEmpty
   113  	}
   114  	return resultExpr, nil
   115  }
   116  
   117  func generateCompoundBinaryOpExpr(compExprs []ast.ExprNode) ast.ExprNode {
   118  	switch len(compExprs) {
   119  	case 0:
   120  		return nil
   121  	case 1:
   122  		return compExprs[0]
   123  	default:
   124  		return &ast.BinaryOperationExpr{
   125  			Op: opcode.LogicAnd,
   126  			L:  compExprs[0],
   127  			R:  generateCompoundBinaryOpExpr(compExprs[1:]),
   128  		}
   129  	}
   130  }
   131  
   132  // GenUpdateRow generates an UPDATE SQL for the given unique key.
   133  // It implements the SQLGenerator interface.
   134  func (g *sqlGeneratorImpl) GenUpdateRow(theUK *mcp.UniqueKey) (string, error) {
   135  	if theUK == nil {
   136  		return "", errors.Trace(ErrMissingUKValue)
   137  	}
   138  	assignments := make([]*ast.Assignment, 0)
   139  	for _, colInfo := range g.columnMap {
   140  		if _, ok := g.ukMap[colInfo.ColumnName]; ok {
   141  			// this is a UK column, skip from modifying it
   142  			// TODO: support UK modification in the future
   143  			continue
   144  		}
   145  		assignments = append(assignments, &ast.Assignment{
   146  			Column: &ast.ColumnName{
   147  				Name: model.NewCIStr(colInfo.ColumnName),
   148  			},
   149  			Expr: ast.NewValueExpr(util.GenerateDataItem(colInfo.DataType), "", ""),
   150  		})
   151  	}
   152  	whereClause, err := g.generateWhereClause(theUK.GetValue())
   153  	if err != nil {
   154  		return "", errors.Annotate(err, "generate where clause error")
   155  	}
   156  	updateTree := &ast.UpdateStmt{
   157  		List: assignments,
   158  		TableRefs: &ast.TableRefsClause{
   159  			TableRefs: &ast.Join{
   160  				Left: &ast.TableName{
   161  					Schema: model.NewCIStr(g.tableConfig.DatabaseName),
   162  					Name:   model.NewCIStr(g.tableConfig.TableName),
   163  				},
   164  			},
   165  		},
   166  		Where: whereClause,
   167  	}
   168  	return outputString(updateTree)
   169  }
   170  
   171  // GenInsertRow generates an INSERT SQL.
   172  // It implements the SQLGenerator interface.
   173  // The new row's unique key is also provided,
   174  // so that it can be further added into an MCP.
   175  func (g *sqlGeneratorImpl) GenInsertRow() (string, *mcp.UniqueKey, error) {
   176  	ukValues := make(map[string]interface{})
   177  	columnNames := []*ast.ColumnName{}
   178  	values := []ast.ExprNode{}
   179  	for _, col := range g.columnMap {
   180  		columnNames = append(columnNames, &ast.ColumnName{
   181  			Name: model.NewCIStr(col.ColumnName),
   182  		})
   183  		newValue := util.GenerateDataItem(col.DataType)
   184  		values = append(values, ast.NewValueExpr(newValue, "", ""))
   185  		if _, ok := g.ukMap[col.ColumnName]; ok {
   186  			// add UK value
   187  			ukValues[col.ColumnName] = newValue
   188  		}
   189  	}
   190  	insertTree := &ast.InsertStmt{
   191  		Table: &ast.TableRefsClause{
   192  			TableRefs: &ast.Join{
   193  				Left: &ast.TableName{
   194  					Schema: model.NewCIStr(g.tableConfig.DatabaseName),
   195  					Name:   model.NewCIStr(g.tableConfig.TableName),
   196  				},
   197  			},
   198  		},
   199  		Lists:   [][]ast.ExprNode{values},
   200  		Columns: columnNames,
   201  	}
   202  	sql, err := outputString(insertTree)
   203  	if err != nil {
   204  		return "", nil, errors.Annotate(err, "output INSERT AST into SQL string error")
   205  	}
   206  	return sql, mcp.NewUniqueKey(-1, ukValues), nil
   207  }
   208  
   209  // GenDeleteRow generates a DELETE SQL for the given unique key.
   210  // It implements the SQLGenerator interface.
   211  func (g *sqlGeneratorImpl) GenDeleteRow(theUK *mcp.UniqueKey) (string, error) {
   212  	if theUK == nil {
   213  		return "", errors.Trace(ErrMissingUKValue)
   214  	}
   215  	whereClause, err := g.generateWhereClause(theUK.GetValue())
   216  	if err != nil {
   217  		return "", errors.Annotate(err, "generate where clause error")
   218  	}
   219  	updateTree := &ast.DeleteStmt{
   220  		TableRefs: &ast.TableRefsClause{
   221  			TableRefs: &ast.Join{
   222  				Left: &ast.TableName{
   223  					Schema: model.NewCIStr(g.tableConfig.DatabaseName),
   224  					Name:   model.NewCIStr(g.tableConfig.TableName),
   225  				},
   226  			},
   227  		},
   228  		Where: whereClause,
   229  	}
   230  	return outputString(updateTree)
   231  }
   232  
   233  // GenLoadUniqueKeySQL generates a SELECT SQL fetching all the uniques of a table.
   234  // It implements the SQLGenerator interface.
   235  // The column definitions of the returned data is also provided,
   236  // so that the values can be stored to different variables.
   237  func (g *sqlGeneratorImpl) GenLoadUniqueKeySQL() (string, []*config.ColumnDefinition, error) {
   238  	selectFields := make([]*ast.SelectField, 0)
   239  	cols := make([]*config.ColumnDefinition, 0)
   240  	for ukColName := range g.ukMap {
   241  		selectFields = append(selectFields, &ast.SelectField{
   242  			Expr: &ast.ColumnNameExpr{
   243  				Name: &ast.ColumnName{
   244  					Name: model.NewCIStr(ukColName),
   245  				},
   246  			},
   247  		})
   248  		cols = append(cols, g.columnMap[ukColName])
   249  	}
   250  	selectTree := &ast.SelectStmt{
   251  		SelectStmtOpts: &ast.SelectStmtOpts{
   252  			SQLCache: true,
   253  		},
   254  		Fields: &ast.FieldList{
   255  			Fields: selectFields,
   256  		},
   257  		From: &ast.TableRefsClause{
   258  			TableRefs: &ast.Join{
   259  				Left: &ast.TableName{
   260  					Schema: model.NewCIStr(g.tableConfig.DatabaseName),
   261  					Name:   model.NewCIStr(g.tableConfig.TableName),
   262  				},
   263  			},
   264  		},
   265  	}
   266  	sql, err := outputString(selectTree)
   267  	if err != nil {
   268  		return "", nil, errors.Annotate(err, "output SELECT AST into SQL string error")
   269  	}
   270  	return sql, cols, nil
   271  }
   272  
   273  func (g *sqlGeneratorImpl) GenCreateTable() string {
   274  	return g.tableConfig.GenCreateTable()
   275  }