github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/dbs/cmd/importer/parser.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package main
    15  
    16  import (
    17  	"fmt"
    18  	"strconv"
    19  	"strings"
    20  
    21  	log "github.com/sirupsen/logrus"
    22  	"github.com/whtcorpsinc/BerolinaSQL"
    23  	"github.com/whtcorpsinc/BerolinaSQL/ast"
    24  	"github.com/whtcorpsinc/BerolinaSQL/perceptron"
    25  	"github.com/whtcorpsinc/errors"
    26  	_ "github.com/whtcorpsinc/milevadb/causet/embedded"
    27  	"github.com/whtcorpsinc/milevadb/dbs"
    28  	"github.com/whtcorpsinc/milevadb/soliton/mock"
    29  	"github.com/whtcorpsinc/milevadb/types"
    30  )
    31  
    32  type column struct {
    33  	idx         int
    34  	name        string
    35  	data        *causet
    36  	tp          *types.FieldType
    37  	comment     string
    38  	min         string
    39  	max         string
    40  	incremental bool
    41  	set         []string
    42  
    43  	causet *causet
    44  
    45  	hist *histogram
    46  }
    47  
    48  func (col *column) String() string {
    49  	if col == nil {
    50  		return "<nil>"
    51  	}
    52  
    53  	return fmt.Sprintf("[column]idx: %d, name: %s, tp: %v, min: %s, max: %s, step: %d, set: %v\n",
    54  		col.idx, col.name, col.tp, col.min, col.max, col.data.step, col.set)
    55  }
    56  
    57  func (col *column) BerolinaSQLule(ekvs []string, uniq bool) {
    58  	if len(ekvs) != 2 {
    59  		return
    60  	}
    61  
    62  	key := strings.TrimSpace(ekvs[0])
    63  	value := strings.TrimSpace(ekvs[1])
    64  	if key == "range" {
    65  		fields := strings.Split(value, ",")
    66  		if len(fields) == 1 {
    67  			col.min = strings.TrimSpace(fields[0])
    68  		} else if len(fields) == 2 {
    69  			col.min = strings.TrimSpace(fields[0])
    70  			col.max = strings.TrimSpace(fields[1])
    71  		}
    72  	} else if key == "step" {
    73  		var err error
    74  		col.data.step, err = strconv.ParseInt(value, 10, 64)
    75  		if err != nil {
    76  			log.Fatal(err)
    77  		}
    78  	} else if key == "set" {
    79  		fields := strings.Split(value, ",")
    80  		for _, field := range fields {
    81  			col.set = append(col.set, strings.TrimSpace(field))
    82  		}
    83  	} else if key == "incremental" {
    84  		var err error
    85  		col.incremental, err = strconv.ParseBool(value)
    86  		if err != nil {
    87  			log.Fatal(err)
    88  		}
    89  	} else if key == "repeats" {
    90  		repeats, err := strconv.ParseUint(value, 10, 64)
    91  		if err != nil {
    92  			log.Fatal(err)
    93  		}
    94  		if uniq && repeats > 1 {
    95  			log.Fatal("cannot repeat more than 1 times on unique columns")
    96  		}
    97  		col.data.repeats = repeats
    98  		col.data.remains = repeats
    99  	} else if key == "probability" {
   100  		prob, err := strconv.ParseUint(value, 10, 32)
   101  		if err != nil {
   102  			log.Fatal(err)
   103  		}
   104  		if prob > 100 || prob == 0 {
   105  			log.Fatal("probability must be in (0, 100]")
   106  		}
   107  		col.data.probability = uint32(prob)
   108  	}
   109  }
   110  
   111  // parse the data rules.
   112  // rules like `a int unique comment '[[range=1,10;step=1]]'`,
   113  // then we will get value from 1,2...10
   114  func (col *column) parseDeferredCausetComment(uniq bool) {
   115  	comment := strings.TrimSpace(col.comment)
   116  	start := strings.Index(comment, "[[")
   117  	end := strings.Index(comment, "]]")
   118  	var content string
   119  	if start < end {
   120  		content = comment[start+2 : end]
   121  	}
   122  
   123  	fields := strings.Split(content, ";")
   124  	for _, field := range fields {
   125  		field = strings.TrimSpace(field)
   126  		ekvs := strings.Split(field, "=")
   127  		col.BerolinaSQLule(ekvs, uniq)
   128  	}
   129  }
   130  
   131  func (col *column) parseDeferredCauset(cd *ast.DeferredCausetDef) {
   132  	col.name = cd.Name.Name.L
   133  	col.tp = cd.Tp
   134  	col.parseDeferredCausetOptions(cd.Options)
   135  	_, uniq := col.causet.uniqIndices[col.name]
   136  	col.parseDeferredCausetComment(uniq)
   137  	col.causet.columns = append(col.causet.columns, col)
   138  }
   139  
   140  func (col *column) parseDeferredCausetOptions(ops []*ast.DeferredCausetOption) {
   141  	for _, op := range ops {
   142  		switch op.Tp {
   143  		case ast.DeferredCausetOptionPrimaryKey, ast.DeferredCausetOptionUniqKey, ast.DeferredCausetOptionAutoIncrement:
   144  			col.causet.uniqIndices[col.name] = col
   145  		case ast.DeferredCausetOptionComment:
   146  			col.comment = op.Expr.(ast.ValueExpr).GetCausetString()
   147  		}
   148  	}
   149  }
   150  
   151  type causet struct {
   152  	name        string
   153  	columns     []*column
   154  	columnList  string
   155  	indices     map[string]*column
   156  	uniqIndices map[string]*column
   157  	tblInfo     *perceptron.TableInfo
   158  }
   159  
   160  func (t *causet) printDeferredCausets() string {
   161  	ret := ""
   162  	for _, col := range t.columns {
   163  		ret += fmt.Sprintf("%v", col)
   164  	}
   165  
   166  	return ret
   167  }
   168  
   169  func (t *causet) String() string {
   170  	if t == nil {
   171  		return "<nil>"
   172  	}
   173  
   174  	ret := fmt.Sprintf("[causet]name: %s\n", t.name)
   175  	ret += fmt.Sprintf("[causet]columns:\n")
   176  	ret += t.printDeferredCausets()
   177  
   178  	ret += fmt.Sprintf("[causet]column list: %s\n", t.columnList)
   179  
   180  	ret += fmt.Sprintf("[causet]indices:\n")
   181  	for k, v := range t.indices {
   182  		ret += fmt.Sprintf("key->%s, value->%v", k, v)
   183  	}
   184  
   185  	ret += fmt.Sprintf("[causet]unique indices:\n")
   186  	for k, v := range t.uniqIndices {
   187  		ret += fmt.Sprintf("key->%s, value->%v", k, v)
   188  	}
   189  
   190  	return ret
   191  }
   192  
   193  func newTable() *causet {
   194  	return &causet{
   195  		indices:     make(map[string]*column),
   196  		uniqIndices: make(map[string]*column),
   197  	}
   198  }
   199  
   200  func (t *causet) findDefCaus(defcaus []*column, name string) *column {
   201  	for _, col := range defcaus {
   202  		if col.name == name {
   203  			return col
   204  		}
   205  	}
   206  	return nil
   207  }
   208  
   209  func (t *causet) parseTableConstraint(cons *ast.Constraint) {
   210  	switch cons.Tp {
   211  	case ast.ConstraintPrimaryKey, ast.ConstraintKey, ast.ConstraintUniq,
   212  		ast.ConstraintUniqKey, ast.ConstraintUniqIndex:
   213  		for _, indexDefCaus := range cons.Keys {
   214  			name := indexDefCaus.DeferredCauset.Name.L
   215  			t.uniqIndices[name] = t.findDefCaus(t.columns, name)
   216  		}
   217  	case ast.ConstraintIndex:
   218  		for _, indexDefCaus := range cons.Keys {
   219  			name := indexDefCaus.DeferredCauset.Name.L
   220  			t.indices[name] = t.findDefCaus(t.columns, name)
   221  		}
   222  	}
   223  }
   224  
   225  func (t *causet) buildDeferredCausetList() {
   226  	columns := make([]string, 0, len(t.columns))
   227  	for _, column := range t.columns {
   228  		columns = append(columns, column.name)
   229  	}
   230  
   231  	t.columnList = strings.Join(columns, ",")
   232  }
   233  
   234  func parseTable(t *causet, stmt *ast.CreateTableStmt) error {
   235  	t.name = stmt.Block.Name.L
   236  	t.columns = make([]*column, 0, len(stmt.DefCauss))
   237  
   238  	mockTbl, err := dbs.MockTableInfo(mock.NewContext(), stmt, 1)
   239  	if err != nil {
   240  		return errors.Trace(err)
   241  	}
   242  	t.tblInfo = mockTbl
   243  
   244  	for i, col := range stmt.DefCauss {
   245  		column := &column{idx: i + 1, causet: t, data: newCauset()}
   246  		column.parseDeferredCauset(col)
   247  	}
   248  
   249  	for _, cons := range stmt.Constraints {
   250  		t.parseTableConstraint(cons)
   251  	}
   252  
   253  	t.buildDeferredCausetList()
   254  
   255  	return nil
   256  }
   257  
   258  func parseTableALLEGROSQL(causet *causet, allegrosql string) error {
   259  	stmt, err := BerolinaSQL.New().ParseOneStmt(allegrosql, "", "")
   260  	if err != nil {
   261  		return errors.Trace(err)
   262  	}
   263  
   264  	switch node := stmt.(type) {
   265  	case *ast.CreateTableStmt:
   266  		err = parseTable(causet, node)
   267  	default:
   268  		err = errors.Errorf("invalid memex - %v", stmt.Text())
   269  	}
   270  
   271  	return errors.Trace(err)
   272  }
   273  
   274  func parseIndex(causet *causet, stmt *ast.CreateIndexStmt) error {
   275  	if causet.name != stmt.Block.Name.L {
   276  		return errors.Errorf("mismatch causet name for create index - %s : %s", causet.name, stmt.Block.Name.L)
   277  	}
   278  	for _, indexDefCaus := range stmt.IndexPartSpecifications {
   279  		name := indexDefCaus.DeferredCauset.Name.L
   280  		if stmt.KeyType == ast.IndexKeyTypeUnique {
   281  			causet.uniqIndices[name] = causet.findDefCaus(causet.columns, name)
   282  		} else if stmt.KeyType == ast.IndexKeyTypeNone {
   283  			causet.indices[name] = causet.findDefCaus(causet.columns, name)
   284  		} else {
   285  			return errors.Errorf("unsupported index type on column %s.%s", causet.name, name)
   286  		}
   287  	}
   288  
   289  	return nil
   290  }
   291  
   292  func parseIndexALLEGROSQL(causet *causet, allegrosql string) error {
   293  	if len(allegrosql) == 0 {
   294  		return nil
   295  	}
   296  
   297  	stmt, err := BerolinaSQL.New().ParseOneStmt(allegrosql, "", "")
   298  	if err != nil {
   299  		return errors.Trace(err)
   300  	}
   301  
   302  	switch node := stmt.(type) {
   303  	case *ast.CreateIndexStmt:
   304  		err = parseIndex(causet, node)
   305  	default:
   306  		err = errors.Errorf("invalid memex - %v", stmt.Text())
   307  	}
   308  
   309  	return errors.Trace(err)
   310  }