github.com/pingcap/ticdc@v0.0.0-20220526033649-485a10ef2652/tests/dailytest/parser.go (about) 1 // Copyright 2020 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package dailytest 15 16 import ( 17 "fmt" 18 "strconv" 19 "strings" 20 21 "github.com/pingcap/errors" 22 "github.com/pingcap/log" 23 "github.com/pingcap/parser" 24 "github.com/pingcap/parser/ast" 25 "github.com/pingcap/tidb/types" 26 27 // import parser_drive to avoid panic 28 _ "github.com/pingcap/tidb/types/parser_driver" 29 ) 30 31 type column struct { 32 idx int 33 name string 34 data *datum 35 tp *types.FieldType 36 comment string 37 min string 38 max string 39 step int64 40 set []string 41 42 table *table 43 } 44 45 func (col *column) String() string { 46 if col == nil { 47 return "<nil>" 48 } 49 50 return fmt.Sprintf("[column]idx: %d, name: %s, tp: %v, min: %s, max: %s, step: %d, set: %v\n", 51 col.idx, col.name, col.tp, col.min, col.max, col.step, col.set) 52 } 53 54 func (col *column) parseRule(kvs []string) { 55 if len(kvs) != 2 { 56 return 57 } 58 59 key := strings.TrimSpace(kvs[0]) 60 value := strings.TrimSpace(kvs[1]) 61 if key == "range" { 62 fields := strings.Split(value, ",") 63 if len(fields) == 1 { 64 col.min = strings.TrimSpace(fields[0]) 65 } else if len(fields) == 2 { 66 col.min = strings.TrimSpace(fields[0]) 67 col.max = strings.TrimSpace(fields[1]) 68 } 69 } else if key == "step" { 70 var err error 71 col.step, err = strconv.ParseInt(value, 10, 64) 72 if err != nil { 73 log.S().Fatal(err) 74 } 75 } else if key == "set" { 76 fields := strings.Split(value, ",") 77 for _, field := range fields { 78 col.set = append(col.set, strings.TrimSpace(field)) 79 } 80 } 81 } 82 83 // parse the data rules. 84 // rules like `a int unique comment '[[range=1,10;step=1]]'`, 85 // then we will get value from 1,2...10 86 func (col *column) parseColumnComment() { 87 comment := strings.TrimSpace(col.comment) 88 start := strings.Index(comment, "[[") 89 end := strings.Index(comment, "]]") 90 var content string 91 if start < end { 92 content = comment[start+2 : end] 93 } 94 95 fields := strings.Split(content, ";") 96 for _, field := range fields { 97 field = strings.TrimSpace(field) 98 kvs := strings.Split(field, "=") 99 col.parseRule(kvs) 100 } 101 } 102 103 func (col *column) parseColumn(cd *ast.ColumnDef) { 104 col.name = cd.Name.Name.L 105 col.tp = cd.Tp 106 col.parseColumnOptions(cd.Options) 107 col.parseColumnComment() 108 col.table.columns = append(col.table.columns, col) 109 } 110 111 func (col *column) parseColumnOptions(ops []*ast.ColumnOption) { 112 for _, op := range ops { 113 switch op.Tp { 114 case ast.ColumnOptionPrimaryKey, ast.ColumnOptionAutoIncrement, ast.ColumnOptionUniqKey: 115 col.table.uniqIndices[col.name] = col 116 case ast.ColumnOptionComment: 117 col.comment = op.Expr.(ast.ValueExpr).GetDatumString() 118 } 119 } 120 } 121 122 type table struct { 123 name string 124 columns []*column 125 columnList string 126 indices map[string]*column 127 uniqIndices map[string]*column 128 unsignedCols map[string]*column 129 } 130 131 func (t *table) printColumns() string { 132 ret := "" 133 for _, col := range t.columns { 134 ret += fmt.Sprintf("%v", col) 135 } 136 137 return ret 138 } 139 140 func (t *table) String() string { 141 if t == nil { 142 return "<nil>" 143 } 144 145 ret := fmt.Sprintf("[table]name: %s\n", t.name) 146 ret += "[table]columns:\n" 147 ret += t.printColumns() 148 149 ret += fmt.Sprintf("[table]column list: %s\n", t.columnList) 150 151 ret += "[table]indices:\n" 152 for k, v := range t.indices { 153 ret += fmt.Sprintf("key->%s, value->%v", k, v) 154 } 155 156 ret += "[table]unique indices:\n" 157 for k, v := range t.uniqIndices { 158 ret += fmt.Sprintf("key->%s, value->%v", k, v) 159 } 160 161 return ret 162 } 163 164 func newTable() *table { 165 return &table{ 166 indices: make(map[string]*column), 167 uniqIndices: make(map[string]*column), 168 unsignedCols: make(map[string]*column), 169 } 170 } 171 172 func (t *table) findCol(cols []*column, name string) *column { 173 for _, col := range cols { 174 if col.name == name { 175 return col 176 } 177 } 178 return nil 179 } 180 181 func (t *table) parseTableConstraint(cons *ast.Constraint) { 182 switch cons.Tp { 183 case ast.ConstraintPrimaryKey, ast.ConstraintUniq, 184 ast.ConstraintUniqKey, ast.ConstraintUniqIndex: 185 for _, indexCol := range cons.Keys { 186 name := indexCol.Column.Name.L 187 t.uniqIndices[name] = t.findCol(t.columns, name) 188 } 189 case ast.ConstraintIndex, ast.ConstraintKey: 190 for _, indexCol := range cons.Keys { 191 name := indexCol.Column.Name.L 192 t.indices[name] = t.findCol(t.columns, name) 193 } 194 } 195 } 196 197 func (t *table) buildColumnList() { 198 columns := make([]string, 0, len(t.columns)) 199 for _, column := range t.columns { 200 columns = append(columns, column.name) 201 } 202 203 t.columnList = strings.Join(columns, ",") 204 } 205 206 func parseTable(t *table, stmt *ast.CreateTableStmt) { 207 t.name = stmt.Table.Name.L 208 t.columns = make([]*column, 0, len(stmt.Cols)) 209 210 for i, col := range stmt.Cols { 211 column := &column{idx: i + 1, table: t, step: defaultStep, data: newDatum()} 212 column.parseColumn(col) 213 } 214 215 for _, cons := range stmt.Constraints { 216 t.parseTableConstraint(cons) 217 } 218 219 t.buildColumnList() 220 } 221 222 func parseTableSQL(table *table, sql string) error { 223 stmt, err := parser.New().ParseOneStmt(sql, "", "") 224 if err != nil { 225 return errors.Trace(err) 226 } 227 228 switch node := stmt.(type) { 229 case *ast.CreateTableStmt: 230 parseTable(table, node) 231 default: 232 err = errors.Errorf("invalid statement - %v", stmt.Text()) 233 } 234 235 return errors.Trace(err) 236 }