github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/column-mapping/column.go (about)

     1  // Copyright 2018 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package column
    15  
    16  import (
    17  	"fmt"
    18  	"strconv"
    19  	"strings"
    20  	"sync"
    21  
    22  	"github.com/pingcap/errors"
    23  	selector "github.com/pingcap/tidb/pkg/util/table-rule-selector"
    24  )
    25  
    26  var (
    27  	// for partition ID, ref definition of partitionID
    28  	instanceIDBitSize       = 4
    29  	schemaIDBitSize         = 7
    30  	tableIDBitSize          = 8
    31  	maxOriginID       int64 = 17592186044416
    32  )
    33  
    34  // SetPartitionRule sets bit size of schema ID and table ID
    35  func SetPartitionRule(instanceIDSize, schemaIDSize, tableIDSize int) {
    36  	instanceIDBitSize = instanceIDSize
    37  	schemaIDBitSize = schemaIDSize
    38  	tableIDBitSize = tableIDSize
    39  	maxOriginID = 1 << uint(64-instanceIDSize-schemaIDSize-tableIDSize-1)
    40  }
    41  
    42  // Expr indicates how to handle column mapping
    43  type Expr string
    44  
    45  // poor Expr
    46  const (
    47  	AddPrefix   Expr = "add prefix"
    48  	AddSuffix   Expr = "add suffix"
    49  	PartitionID Expr = "partition id"
    50  )
    51  
    52  // Exprs is some built-in expression for column mapping
    53  // only supports some poor expressions now,
    54  // we would unify tableInfo later and support more
    55  var Exprs = map[Expr]func(*mappingInfo, []interface{}) ([]interface{}, error){
    56  	AddPrefix: addPrefix, // arguments contains prefix
    57  	AddSuffix: addSuffix, // arguments contains suffix
    58  	// arguments contains [instance_id, prefix of schema, prefix of table]
    59  	// we would compute a ID like
    60  	// [1:1 bit][2:9 bits][3:10 bits][4:44 bits] int64  (using default bits length)
    61  	// # 1 useless, no reason
    62  	// # 2 schema ID (schema suffix)
    63  	// # 3 table ID (table suffix)
    64  	// # 4 origin ID (>= 0, <= 17592186044415)
    65  	//
    66  	// others: schema = arguments[1]  or  arguments[1] + arguments[3] + schema suffix
    67  	//         table  = arguments[2]  or  arguments[2] + arguments[3] + table suffix
    68  	//  example: schema = schema_1 table = t_1
    69  	//        => arguments[1] = "schema", arguments[2] = "t", arguments[3] = "_"
    70  	//  if arguments[1]/arguments[2] == "", it means we don't use schemaID/tableID to compute partition ID
    71  	//  if length of arguments is < 4, arguments[3] is set to "" (empty string)
    72  	PartitionID: partitionID,
    73  }
    74  
    75  // Rule is a rule to map column
    76  // TODO: we will do it later, if we need to implement a real column mapping, we need table structure of source and target system
    77  type Rule struct {
    78  	PatternSchema    string   `yaml:"schema-pattern" json:"schema-pattern" toml:"schema-pattern"`
    79  	PatternTable     string   `yaml:"table-pattern" json:"table-pattern" toml:"table-pattern"`
    80  	SourceColumn     string   `yaml:"source-column" json:"source-column" toml:"source-column"` // modify, add refer column, ignore
    81  	TargetColumn     string   `yaml:"target-column" json:"target-column" toml:"target-column"` // add column, modify
    82  	Expression       Expr     `yaml:"expression" json:"expression" toml:"expression"`
    83  	Arguments        []string `yaml:"arguments" json:"arguments" toml:"arguments"`
    84  	CreateTableQuery string   `yaml:"create-table-query" json:"create-table-query" toml:"create-table-query"`
    85  }
    86  
    87  // ToLower covert schema/table parttern to lower case
    88  func (r *Rule) ToLower() {
    89  	r.PatternSchema = strings.ToLower(r.PatternSchema)
    90  	r.PatternTable = strings.ToLower(r.PatternTable)
    91  }
    92  
    93  // Valid checks validity of rule.
    94  // add prefix/suffix: it should have target column and one argument
    95  // partition id: it should have 3 to 4 arguments
    96  func (r *Rule) Valid() error {
    97  	if _, ok := Exprs[r.Expression]; !ok {
    98  		return errors.NotFoundf("expression %s", r.Expression)
    99  	}
   100  
   101  	if r.TargetColumn == "" {
   102  		return errors.NotValidf("rule need to be applied a target column")
   103  	}
   104  
   105  	if r.Expression == AddPrefix || r.Expression == AddSuffix {
   106  		if len(r.Arguments) != 1 {
   107  			return errors.NotValidf("arguments %v for add prefix/suffix", r.Arguments)
   108  		}
   109  	}
   110  
   111  	if r.Expression == PartitionID {
   112  		switch len(r.Arguments) {
   113  		case 3, 4:
   114  			break
   115  		default:
   116  			return errors.NotValidf("arguments %v for patition id", r.Arguments)
   117  		}
   118  	}
   119  
   120  	return nil
   121  }
   122  
   123  // Adjust normalizes the rule into an easier-to-process form, e.g. filling in
   124  // optional arguments with the default values.
   125  func (r *Rule) Adjust() {
   126  	if r.Expression == PartitionID && len(r.Arguments) == 3 {
   127  		r.Arguments = append(r.Arguments, "")
   128  	}
   129  }
   130  
   131  // check source and target position
   132  func (r *Rule) adjustColumnPosition(source, target int) (int, int, error) {
   133  	// if not found target, ignore it
   134  	if target == -1 {
   135  		return source, target, errors.NotFoundf("target column %s", r.TargetColumn)
   136  	}
   137  
   138  	return source, target, nil
   139  }
   140  
   141  type mappingInfo struct {
   142  	ignore         bool
   143  	sourcePosition int
   144  	targetPosition int
   145  	rule           *Rule
   146  
   147  	instanceID int64
   148  	schemaID   int64
   149  	tableID    int64
   150  }
   151  
   152  // Mapping maps column to something by rules
   153  type Mapping struct {
   154  	selector.Selector
   155  
   156  	caseSensitive bool
   157  
   158  	cache struct {
   159  		sync.RWMutex
   160  		infos map[string]*mappingInfo
   161  	}
   162  }
   163  
   164  // NewMapping returns a column mapping
   165  func NewMapping(caseSensitive bool, rules []*Rule) (*Mapping, error) {
   166  	m := &Mapping{
   167  		Selector:      selector.NewTrieSelector(),
   168  		caseSensitive: caseSensitive,
   169  	}
   170  	m.resetCache()
   171  
   172  	for _, rule := range rules {
   173  		if err := m.AddRule(rule); err != nil {
   174  			return nil, errors.Annotatef(err, "initial rule %+v in mapping", rule)
   175  		}
   176  	}
   177  
   178  	return m, nil
   179  }
   180  
   181  func (m *Mapping) addOrUpdateRule(rule *Rule, isUpdate bool) error {
   182  	if m == nil || rule == nil {
   183  		return nil
   184  	}
   185  
   186  	err := rule.Valid()
   187  	if err != nil {
   188  		return errors.Trace(err)
   189  	}
   190  	if !m.caseSensitive {
   191  		rule.ToLower()
   192  	}
   193  	rule.Adjust()
   194  
   195  	m.resetCache()
   196  	if isUpdate {
   197  		err = m.Insert(rule.PatternSchema, rule.PatternTable, rule, selector.Replace)
   198  	} else {
   199  		err = m.Insert(rule.PatternSchema, rule.PatternTable, rule, selector.Insert)
   200  	}
   201  	if err != nil {
   202  		var method string
   203  		if isUpdate {
   204  			method = "update"
   205  		} else {
   206  			method = "add"
   207  		}
   208  		return errors.Annotatef(err, "%s rule %+v into mapping", method, rule)
   209  	}
   210  
   211  	return nil
   212  }
   213  
   214  // AddRule adds a rule into mapping
   215  func (m *Mapping) AddRule(rule *Rule) error {
   216  	return m.addOrUpdateRule(rule, false)
   217  }
   218  
   219  // UpdateRule updates mapping rule
   220  func (m *Mapping) UpdateRule(rule *Rule) error {
   221  	return m.addOrUpdateRule(rule, true)
   222  }
   223  
   224  // RemoveRule removes a rule from mapping
   225  func (m *Mapping) RemoveRule(rule *Rule) error {
   226  	if m == nil || rule == nil {
   227  		return nil
   228  	}
   229  	if !m.caseSensitive {
   230  		rule.ToLower()
   231  	}
   232  
   233  	m.resetCache()
   234  	err := m.Remove(rule.PatternSchema, rule.PatternTable)
   235  	if err != nil {
   236  		return errors.Annotatef(err, "remove rule %+v from mapping", rule)
   237  	}
   238  
   239  	return nil
   240  }
   241  
   242  // HandleRowValue handles row value
   243  func (m *Mapping) HandleRowValue(schema, table string, columns []string, vals []interface{}) ([]interface{}, []int, error) {
   244  	if m == nil {
   245  		return vals, nil, nil
   246  	}
   247  
   248  	schemaL, tableL := schema, table
   249  	if !m.caseSensitive {
   250  		schemaL, tableL = strings.ToLower(schema), strings.ToLower(table)
   251  	}
   252  
   253  	info, err := m.queryColumnInfo(schemaL, tableL, columns)
   254  	if err != nil {
   255  		return nil, nil, errors.Trace(err)
   256  	}
   257  	if info.ignore {
   258  		return vals, nil, nil
   259  	}
   260  
   261  	exp, ok := Exprs[info.rule.Expression]
   262  	if !ok {
   263  		return nil, nil, errors.NotFoundf("column mapping expression %s", info.rule.Expression)
   264  	}
   265  
   266  	vals, err = exp(info, vals)
   267  	if err != nil {
   268  		return nil, nil, errors.Trace(err)
   269  	}
   270  
   271  	return vals, []int{info.sourcePosition, info.targetPosition}, nil
   272  }
   273  
   274  // HandleDDL handles ddl
   275  func (m *Mapping) HandleDDL(schema, table string, columns []string, statement string) (string, []int, error) {
   276  	if m == nil {
   277  		return statement, nil, nil
   278  	}
   279  
   280  	schemaL, tableL := schema, table
   281  	if !m.caseSensitive {
   282  		schemaL, tableL = strings.ToLower(schema), strings.ToLower(table)
   283  	}
   284  
   285  	info, err := m.queryColumnInfo(schemaL, tableL, columns)
   286  	if err != nil {
   287  		return statement, nil, errors.Trace(err)
   288  	}
   289  
   290  	if info.ignore {
   291  		return statement, nil, nil
   292  	}
   293  
   294  	m.resetCache()
   295  	// only output erro now, wait fix it manually
   296  	return statement, nil, errors.Errorf("ddl %s @ column mapping rule %s/%s:%+v not implemented", statement, schema, table, info.rule)
   297  }
   298  
   299  func (m *Mapping) queryColumnInfo(schema, table string, columns []string) (*mappingInfo, error) {
   300  	m.cache.RLock()
   301  	ci, ok := m.cache.infos[tableName(schema, table)]
   302  	m.cache.RUnlock()
   303  	if ok {
   304  		return ci, nil
   305  	}
   306  
   307  	info := &mappingInfo{
   308  		ignore: true,
   309  	}
   310  	rules := m.Match(schema, table)
   311  	if len(rules) == 0 {
   312  		m.cache.Lock()
   313  		m.cache.infos[tableName(schema, table)] = info
   314  		m.cache.Unlock()
   315  
   316  		return info, nil
   317  	}
   318  
   319  	var (
   320  		schemaRules []*Rule
   321  		tableRules  = make([]*Rule, 0, 1)
   322  	)
   323  	// classify rules into schema level rules and table level
   324  	// table level rules have highest priority
   325  	for i := range rules {
   326  		rule, ok := rules[i].(*Rule)
   327  		if !ok {
   328  			return nil, errors.NotValidf("column mapping rule %+v", rules[i])
   329  		}
   330  
   331  		if len(rule.PatternTable) == 0 {
   332  			schemaRules = append(schemaRules, rule)
   333  		} else {
   334  			tableRules = append(tableRules, rule)
   335  		}
   336  	}
   337  
   338  	// only support one expression for one table now, refine it later
   339  	var rule *Rule
   340  	if len(table) == 0 || len(tableRules) == 0 {
   341  		if len(schemaRules) != 1 {
   342  			return nil, errors.NotSupportedf("`%s`.`%s` matches %d schema column mapping rules which should be one. It's", schema, table, len(schemaRules))
   343  		}
   344  
   345  		rule = schemaRules[0]
   346  	} else {
   347  		if len(tableRules) != 1 {
   348  			return nil, errors.NotSupportedf("`%s`.`%s` matches %d table column mapping rules which should be one. It's", schema, table, len(tableRules))
   349  		}
   350  
   351  		rule = tableRules[0]
   352  	}
   353  	if rule == nil {
   354  		m.cache.Lock()
   355  		m.cache.infos[tableName(schema, table)] = info
   356  		m.cache.Unlock()
   357  
   358  		return info, nil
   359  	}
   360  
   361  	// compute source and target column position
   362  	sourcePosition := findColumnPosition(columns, rule.SourceColumn)
   363  	targetPosition := findColumnPosition(columns, rule.TargetColumn)
   364  
   365  	sourcePosition, targetPosition, err := rule.adjustColumnPosition(sourcePosition, targetPosition)
   366  	if err != nil {
   367  		return nil, errors.Trace(err)
   368  	}
   369  
   370  	info = &mappingInfo{
   371  		sourcePosition: sourcePosition,
   372  		targetPosition: targetPosition,
   373  		rule:           rule,
   374  	}
   375  
   376  	// if expr is partition ID, compute schema and table ID
   377  	if rule.Expression == PartitionID {
   378  		info.instanceID, info.schemaID, info.tableID, err = computePartitionID(schema, table, rule)
   379  		if err != nil {
   380  			return nil, errors.Trace(err)
   381  		}
   382  	}
   383  
   384  	m.cache.Lock()
   385  	m.cache.infos[tableName(schema, table)] = info
   386  	m.cache.Unlock()
   387  
   388  	return info, nil
   389  }
   390  
   391  func (m *Mapping) resetCache() {
   392  	m.cache.Lock()
   393  	m.cache.infos = make(map[string]*mappingInfo)
   394  	m.cache.Unlock()
   395  }
   396  
   397  func findColumnPosition(cols []string, col string) int {
   398  	for i := range cols {
   399  		if cols[i] == col {
   400  			return i
   401  		}
   402  	}
   403  
   404  	return -1
   405  }
   406  
   407  func tableName(schema, table string) string {
   408  	return fmt.Sprintf("`%s`.`%s`", schema, table)
   409  }
   410  
   411  func addPrefix(info *mappingInfo, vals []interface{}) ([]interface{}, error) {
   412  	prefix := info.rule.Arguments[0]
   413  	originStr, ok := vals[info.targetPosition].(string)
   414  	if !ok {
   415  		return nil, errors.NotValidf("column %d value is not string, but %v, which is", info.targetPosition, vals[info.targetPosition])
   416  	}
   417  
   418  	// fast to concatenated string
   419  	rawByte := make([]byte, 0, len(prefix)+len(originStr))
   420  	rawByte = append(rawByte, prefix...)
   421  	rawByte = append(rawByte, originStr...)
   422  
   423  	vals[info.targetPosition] = string(rawByte)
   424  	return vals, nil
   425  }
   426  
   427  func addSuffix(info *mappingInfo, vals []interface{}) ([]interface{}, error) {
   428  	suffix := info.rule.Arguments[0]
   429  	originStr, ok := vals[info.targetPosition].(string)
   430  	if !ok {
   431  		return nil, errors.NotValidf("column %d value is not string, but %v, which is", info.targetPosition, vals[info.targetPosition])
   432  	}
   433  
   434  	rawByte := make([]byte, 0, len(suffix)+len(originStr))
   435  	rawByte = append(rawByte, originStr...)
   436  	rawByte = append(rawByte, suffix...)
   437  
   438  	vals[info.targetPosition] = string(rawByte)
   439  	return vals, nil
   440  }
   441  
   442  func partitionID(info *mappingInfo, vals []interface{}) ([]interface{}, error) {
   443  	// only int64 now
   444  	var (
   445  		originID int64
   446  		err      error
   447  		isChars  bool
   448  	)
   449  
   450  	switch rawID := vals[info.targetPosition].(type) {
   451  	case int:
   452  		originID = int64(rawID)
   453  	case int8:
   454  		originID = int64(rawID)
   455  	case int32:
   456  		originID = int64(rawID)
   457  	case int64:
   458  		originID = rawID
   459  	case uint:
   460  		originID = int64(rawID)
   461  	case uint16:
   462  		originID = int64(rawID)
   463  	case uint32:
   464  		originID = int64(rawID)
   465  	case uint64:
   466  		originID = int64(rawID)
   467  	case string:
   468  		originID, err = strconv.ParseInt(rawID, 10, 64)
   469  		if err != nil {
   470  			return nil, errors.NotValidf("column %d value is not int, but %v, which is", info.targetPosition, vals[info.targetPosition])
   471  		}
   472  		isChars = true
   473  	default:
   474  		return nil, errors.NotValidf("type %T(%v)", vals[info.targetPosition], vals[info.targetPosition])
   475  	}
   476  
   477  	if originID >= maxOriginID || originID < 0 {
   478  		return nil, errors.NotValidf("id must less than %d, greater than or equal to 0, but got %d, which is", maxOriginID, originID)
   479  	}
   480  
   481  	originID = info.instanceID | info.schemaID | info.tableID | originID
   482  	if isChars {
   483  		vals[info.targetPosition] = strconv.FormatInt(originID, 10)
   484  	} else {
   485  		vals[info.targetPosition] = originID
   486  	}
   487  
   488  	return vals, nil
   489  }
   490  
   491  func computePartitionID(schema, table string, rule *Rule) (instanceID int64, schemaID int64, tableID int64, err error) {
   492  	shiftCnt := uint(63)
   493  	if instanceIDBitSize > 0 && len(rule.Arguments[0]) > 0 {
   494  		var instanceIDUnsign uint64
   495  		shiftCnt = shiftCnt - uint(instanceIDBitSize)
   496  		instanceIDUnsign, err = strconv.ParseUint(rule.Arguments[0], 10, instanceIDBitSize)
   497  		if err != nil {
   498  			return
   499  		}
   500  		instanceID = int64(instanceIDUnsign << shiftCnt)
   501  	}
   502  
   503  	sep := rule.Arguments[3]
   504  
   505  	if schemaIDBitSize > 0 && len(rule.Arguments[1]) > 0 {
   506  		shiftCnt = shiftCnt - uint(schemaIDBitSize)
   507  		schemaID, err = computeID(schema, rule.Arguments[1], sep, schemaIDBitSize, shiftCnt)
   508  		if err != nil {
   509  			return
   510  		}
   511  	}
   512  
   513  	if tableIDBitSize > 0 && len(rule.Arguments[2]) > 0 {
   514  		shiftCnt = shiftCnt - uint(tableIDBitSize)
   515  		tableID, err = computeID(table, rule.Arguments[2], sep, tableIDBitSize, shiftCnt)
   516  	}
   517  
   518  	return
   519  }
   520  
   521  func computeID(name string, prefix, sep string, bitSize int, shiftCount uint) (int64, error) {
   522  	if name == prefix {
   523  		return 0, nil
   524  	}
   525  
   526  	prefix += sep
   527  	if len(prefix) >= len(name) || prefix != name[:len(prefix)] {
   528  		return 0, errors.NotValidf("%s is not the prefix of %s", prefix, name)
   529  	}
   530  
   531  	idStr := name[len(prefix):]
   532  	id, err := strconv.ParseUint(idStr, 10, bitSize)
   533  	if err != nil {
   534  		return 0, errors.NotValidf("the suffix of %s can't be converted to int64", idStr)
   535  	}
   536  
   537  	return int64(id << shiftCount), nil
   538  }