github.com/pingcap/ticdc@v0.0.0-20220526033649-485a10ef2652/tests/dailytest/parser.go (about)

     1  // Copyright 2020 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package dailytest
    15  
    16  import (
    17  	"fmt"
    18  	"strconv"
    19  	"strings"
    20  
    21  	"github.com/pingcap/errors"
    22  	"github.com/pingcap/log"
    23  	"github.com/pingcap/parser"
    24  	"github.com/pingcap/parser/ast"
    25  	"github.com/pingcap/tidb/types"
    26  
    27  	// import parser_drive to avoid panic
    28  	_ "github.com/pingcap/tidb/types/parser_driver"
    29  )
    30  
    31  type column struct {
    32  	idx     int
    33  	name    string
    34  	data    *datum
    35  	tp      *types.FieldType
    36  	comment string
    37  	min     string
    38  	max     string
    39  	step    int64
    40  	set     []string
    41  
    42  	table *table
    43  }
    44  
    45  func (col *column) String() string {
    46  	if col == nil {
    47  		return "<nil>"
    48  	}
    49  
    50  	return fmt.Sprintf("[column]idx: %d, name: %s, tp: %v, min: %s, max: %s, step: %d, set: %v\n",
    51  		col.idx, col.name, col.tp, col.min, col.max, col.step, col.set)
    52  }
    53  
    54  func (col *column) parseRule(kvs []string) {
    55  	if len(kvs) != 2 {
    56  		return
    57  	}
    58  
    59  	key := strings.TrimSpace(kvs[0])
    60  	value := strings.TrimSpace(kvs[1])
    61  	if key == "range" {
    62  		fields := strings.Split(value, ",")
    63  		if len(fields) == 1 {
    64  			col.min = strings.TrimSpace(fields[0])
    65  		} else if len(fields) == 2 {
    66  			col.min = strings.TrimSpace(fields[0])
    67  			col.max = strings.TrimSpace(fields[1])
    68  		}
    69  	} else if key == "step" {
    70  		var err error
    71  		col.step, err = strconv.ParseInt(value, 10, 64)
    72  		if err != nil {
    73  			log.S().Fatal(err)
    74  		}
    75  	} else if key == "set" {
    76  		fields := strings.Split(value, ",")
    77  		for _, field := range fields {
    78  			col.set = append(col.set, strings.TrimSpace(field))
    79  		}
    80  	}
    81  }
    82  
    83  // parse the data rules.
    84  // rules like `a int unique comment '[[range=1,10;step=1]]'`,
    85  // then we will get value from 1,2...10
    86  func (col *column) parseColumnComment() {
    87  	comment := strings.TrimSpace(col.comment)
    88  	start := strings.Index(comment, "[[")
    89  	end := strings.Index(comment, "]]")
    90  	var content string
    91  	if start < end {
    92  		content = comment[start+2 : end]
    93  	}
    94  
    95  	fields := strings.Split(content, ";")
    96  	for _, field := range fields {
    97  		field = strings.TrimSpace(field)
    98  		kvs := strings.Split(field, "=")
    99  		col.parseRule(kvs)
   100  	}
   101  }
   102  
   103  func (col *column) parseColumn(cd *ast.ColumnDef) {
   104  	col.name = cd.Name.Name.L
   105  	col.tp = cd.Tp
   106  	col.parseColumnOptions(cd.Options)
   107  	col.parseColumnComment()
   108  	col.table.columns = append(col.table.columns, col)
   109  }
   110  
   111  func (col *column) parseColumnOptions(ops []*ast.ColumnOption) {
   112  	for _, op := range ops {
   113  		switch op.Tp {
   114  		case ast.ColumnOptionPrimaryKey, ast.ColumnOptionAutoIncrement, ast.ColumnOptionUniqKey:
   115  			col.table.uniqIndices[col.name] = col
   116  		case ast.ColumnOptionComment:
   117  			col.comment = op.Expr.(ast.ValueExpr).GetDatumString()
   118  		}
   119  	}
   120  }
   121  
   122  type table struct {
   123  	name         string
   124  	columns      []*column
   125  	columnList   string
   126  	indices      map[string]*column
   127  	uniqIndices  map[string]*column
   128  	unsignedCols map[string]*column
   129  }
   130  
   131  func (t *table) printColumns() string {
   132  	ret := ""
   133  	for _, col := range t.columns {
   134  		ret += fmt.Sprintf("%v", col)
   135  	}
   136  
   137  	return ret
   138  }
   139  
   140  func (t *table) String() string {
   141  	if t == nil {
   142  		return "<nil>"
   143  	}
   144  
   145  	ret := fmt.Sprintf("[table]name: %s\n", t.name)
   146  	ret += "[table]columns:\n"
   147  	ret += t.printColumns()
   148  
   149  	ret += fmt.Sprintf("[table]column list: %s\n", t.columnList)
   150  
   151  	ret += "[table]indices:\n"
   152  	for k, v := range t.indices {
   153  		ret += fmt.Sprintf("key->%s, value->%v", k, v)
   154  	}
   155  
   156  	ret += "[table]unique indices:\n"
   157  	for k, v := range t.uniqIndices {
   158  		ret += fmt.Sprintf("key->%s, value->%v", k, v)
   159  	}
   160  
   161  	return ret
   162  }
   163  
   164  func newTable() *table {
   165  	return &table{
   166  		indices:      make(map[string]*column),
   167  		uniqIndices:  make(map[string]*column),
   168  		unsignedCols: make(map[string]*column),
   169  	}
   170  }
   171  
   172  func (t *table) findCol(cols []*column, name string) *column {
   173  	for _, col := range cols {
   174  		if col.name == name {
   175  			return col
   176  		}
   177  	}
   178  	return nil
   179  }
   180  
   181  func (t *table) parseTableConstraint(cons *ast.Constraint) {
   182  	switch cons.Tp {
   183  	case ast.ConstraintPrimaryKey, ast.ConstraintUniq,
   184  		ast.ConstraintUniqKey, ast.ConstraintUniqIndex:
   185  		for _, indexCol := range cons.Keys {
   186  			name := indexCol.Column.Name.L
   187  			t.uniqIndices[name] = t.findCol(t.columns, name)
   188  		}
   189  	case ast.ConstraintIndex, ast.ConstraintKey:
   190  		for _, indexCol := range cons.Keys {
   191  			name := indexCol.Column.Name.L
   192  			t.indices[name] = t.findCol(t.columns, name)
   193  		}
   194  	}
   195  }
   196  
   197  func (t *table) buildColumnList() {
   198  	columns := make([]string, 0, len(t.columns))
   199  	for _, column := range t.columns {
   200  		columns = append(columns, column.name)
   201  	}
   202  
   203  	t.columnList = strings.Join(columns, ",")
   204  }
   205  
   206  func parseTable(t *table, stmt *ast.CreateTableStmt) {
   207  	t.name = stmt.Table.Name.L
   208  	t.columns = make([]*column, 0, len(stmt.Cols))
   209  
   210  	for i, col := range stmt.Cols {
   211  		column := &column{idx: i + 1, table: t, step: defaultStep, data: newDatum()}
   212  		column.parseColumn(col)
   213  	}
   214  
   215  	for _, cons := range stmt.Constraints {
   216  		t.parseTableConstraint(cons)
   217  	}
   218  
   219  	t.buildColumnList()
   220  }
   221  
   222  func parseTableSQL(table *table, sql string) error {
   223  	stmt, err := parser.New().ParseOneStmt(sql, "", "")
   224  	if err != nil {
   225  		return errors.Trace(err)
   226  	}
   227  
   228  	switch node := stmt.(type) {
   229  	case *ast.CreateTableStmt:
   230  		parseTable(table, node)
   231  	default:
   232  		err = errors.Errorf("invalid statement - %v", stmt.Text())
   233  	}
   234  
   235  	return errors.Trace(err)
   236  }