github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/tests/integration_tests/cdc/dailytest/parser.go (about) 1 // Copyright 2020 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package dailytest 15 16 import ( 17 "fmt" 18 "strconv" 19 "strings" 20 21 "github.com/pingcap/errors" 22 "github.com/pingcap/log" 23 "github.com/pingcap/tidb/pkg/parser" 24 "github.com/pingcap/tidb/pkg/parser/ast" 25 "github.com/pingcap/tidb/pkg/types" 26 _ "github.com/pingcap/tidb/pkg/types/parser_driver" // import parser_driver to avoid panic 27 ) 28 29 type column struct { 30 idx int 31 name string 32 data *datum 33 tp *types.FieldType 34 comment string 35 min string 36 max string 37 step int64 38 set []string 39 40 table *table 41 } 42 43 func (col *column) String() string { 44 if col == nil { 45 return "<nil>" 46 } 47 48 return fmt.Sprintf("[column]idx: %d, name: %s, tp: %v, min: %s, max: %s, step: %d, set: %v\n", 49 col.idx, col.name, col.tp, col.min, col.max, col.step, col.set) 50 } 51 52 func (col *column) parseRule(kvs []string) { 53 if len(kvs) != 2 { 54 return 55 } 56 57 key := strings.TrimSpace(kvs[0]) 58 value := strings.TrimSpace(kvs[1]) 59 if key == "range" { 60 fields := strings.Split(value, ",") 61 if len(fields) == 1 { 62 col.min = strings.TrimSpace(fields[0]) 63 } else if len(fields) == 2 { 64 col.min = strings.TrimSpace(fields[0]) 65 col.max = strings.TrimSpace(fields[1]) 66 } 67 } else if key == "step" { 68 var err error 69 col.step, err = strconv.ParseInt(value, 10, 64) 70 if err != nil { 71 log.S().Fatal(err) 72 } 73 } else if key == "set" { 74 fields := strings.Split(value, ",") 75 for _, field := range fields { 76 col.set = append(col.set, strings.TrimSpace(field)) 77 } 78 } 79 } 80 81 // parse the data rules. 82 // rules like `a int unique comment '[[range=1,10;step=1]]'`, 83 // then we will get value from 1,2...10 84 func (col *column) parseColumnComment() { 85 comment := strings.TrimSpace(col.comment) 86 start := strings.Index(comment, "[[") 87 end := strings.Index(comment, "]]") 88 var content string 89 if start < end { 90 content = comment[start+2 : end] 91 } 92 93 fields := strings.Split(content, ";") 94 for _, field := range fields { 95 field = strings.TrimSpace(field) 96 kvs := strings.Split(field, "=") 97 col.parseRule(kvs) 98 } 99 } 100 101 func (col *column) parseColumn(cd *ast.ColumnDef) { 102 col.name = cd.Name.Name.L 103 col.tp = cd.Tp 104 col.parseColumnOptions(cd.Options) 105 col.parseColumnComment() 106 col.table.columns = append(col.table.columns, col) 107 } 108 109 func (col *column) parseColumnOptions(ops []*ast.ColumnOption) { 110 for _, op := range ops { 111 switch op.Tp { 112 case ast.ColumnOptionPrimaryKey, ast.ColumnOptionAutoIncrement, ast.ColumnOptionUniqKey: 113 col.table.uniqIndices[col.name] = col 114 case ast.ColumnOptionComment: 115 col.comment = op.Expr.(ast.ValueExpr).GetDatumString() 116 } 117 } 118 } 119 120 type table struct { 121 name string 122 columns []*column 123 columnList string 124 indices map[string]*column 125 uniqIndices map[string]*column 126 unsignedCols map[string]*column 127 } 128 129 func (t *table) printColumns() string { 130 ret := "" 131 for _, col := range t.columns { 132 ret += fmt.Sprintf("%v", col) 133 } 134 135 return ret 136 } 137 138 func (t *table) String() string { 139 if t == nil { 140 return "<nil>" 141 } 142 143 ret := fmt.Sprintf("[table]name: %s\n", t.name) 144 ret += "[table]columns:\n" 145 ret += t.printColumns() 146 147 ret += fmt.Sprintf("[table]column list: %s\n", t.columnList) 148 149 ret += "[table]indices:\n" 150 for k, v := range t.indices { 151 ret += fmt.Sprintf("key->%s, value->%v", k, v) 152 } 153 154 ret += "[table]unique indices:\n" 155 for k, v := range t.uniqIndices { 156 ret += fmt.Sprintf("key->%s, value->%v", k, v) 157 } 158 159 return ret 160 } 161 162 func newTable() *table { 163 return &table{ 164 indices: make(map[string]*column), 165 uniqIndices: make(map[string]*column), 166 unsignedCols: make(map[string]*column), 167 } 168 } 169 170 func (t *table) findCol(cols []*column, name string) *column { 171 for _, col := range cols { 172 if col.name == name { 173 return col 174 } 175 } 176 return nil 177 } 178 179 func (t *table) parseTableConstraint(cons *ast.Constraint) { 180 switch cons.Tp { 181 case ast.ConstraintPrimaryKey, ast.ConstraintUniq, 182 ast.ConstraintUniqKey, ast.ConstraintUniqIndex: 183 for _, indexCol := range cons.Keys { 184 name := indexCol.Column.Name.L 185 t.uniqIndices[name] = t.findCol(t.columns, name) 186 } 187 case ast.ConstraintIndex, ast.ConstraintKey: 188 for _, indexCol := range cons.Keys { 189 name := indexCol.Column.Name.L 190 t.indices[name] = t.findCol(t.columns, name) 191 } 192 } 193 } 194 195 func (t *table) buildColumnList() { 196 columns := make([]string, 0, len(t.columns)) 197 for _, column := range t.columns { 198 columns = append(columns, column.name) 199 } 200 201 t.columnList = strings.Join(columns, ",") 202 } 203 204 func parseTable(t *table, stmt *ast.CreateTableStmt) { 205 t.name = stmt.Table.Name.L 206 t.columns = make([]*column, 0, len(stmt.Cols)) 207 208 for i, col := range stmt.Cols { 209 column := &column{idx: i + 1, table: t, step: defaultStep, data: newDatum()} 210 column.parseColumn(col) 211 } 212 213 for _, cons := range stmt.Constraints { 214 t.parseTableConstraint(cons) 215 } 216 217 t.buildColumnList() 218 } 219 220 func parseTableSQL(table *table, sql string) error { 221 stmt, err := parser.New().ParseOneStmt(sql, "", "") 222 if err != nil { 223 return errors.Trace(err) 224 } 225 226 switch node := stmt.(type) { 227 case *ast.CreateTableStmt: 228 parseTable(table, node) 229 default: 230 err = errors.Errorf("invalid statement - %v", stmt.Text()) 231 } 232 233 return errors.Trace(err) 234 }