github.com/team-ide/go-dialect@v1.9.20/vitess/sqlparser/normalizer.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"github.com/team-ide/go-dialect/vitess/sqltypes"
    21  
    22  	querypb "github.com/team-ide/go-dialect/vitess/query"
    23  )
    24  
    25  // BindVars is a set of reserved bind variables from a SQL statement
    26  type BindVars map[string]struct{}
    27  
    28  // Normalize changes the statement to use bind values, and
    29  // updates the bind vars to those values. The supplied prefix
    30  // is used to generate the bind var names. The function ensures
    31  // that there are no collisions with existing bind vars.
    32  // Within Select constructs, bind vars are deduped. This allows
    33  // us to identify vindex equality. Otherwise, every value is
    34  // treated as distinct.
    35  func Normalize(stmt Statement, reserved *ReservedVars, bindVars map[string]*querypb.BindVariable) error {
    36  	nz := newNormalizer(reserved, bindVars)
    37  	_ = Rewrite(stmt, nz.WalkStatement, nil)
    38  	return nz.err
    39  }
    40  
    41  type normalizer struct {
    42  	bindVars map[string]*querypb.BindVariable
    43  	reserved *ReservedVars
    44  	vals     map[string]string
    45  	err      error
    46  }
    47  
    48  func newNormalizer(reserved *ReservedVars, bindVars map[string]*querypb.BindVariable) *normalizer {
    49  	return &normalizer{
    50  		bindVars: bindVars,
    51  		reserved: reserved,
    52  		vals:     make(map[string]string),
    53  	}
    54  }
    55  
    56  // WalkStatement is the top level walk function.
    57  // If it encounters a Select, it switches to a mode
    58  // where variables are deduped.
    59  func (nz *normalizer) WalkStatement(cursor *Cursor) bool {
    60  	switch node := cursor.Node().(type) {
    61  	// no need to normalize the statement types
    62  	case *Set, *Show, *Begin, *Commit, *Rollback, *Savepoint, *SetTransaction, DDLStatement, *SRollback, *Release, *OtherAdmin, *OtherRead:
    63  		return false
    64  	case *Select:
    65  		_ = Rewrite(node, nz.WalkSelect, nil)
    66  		// Don't continue
    67  		return false
    68  	case *Literal:
    69  		nz.convertLiteral(node, cursor)
    70  	case *ComparisonExpr:
    71  		nz.convertComparison(node)
    72  	case *ColName, TableName:
    73  		// Common node types that never contain Literal or ListArgs but create a lot of object
    74  		// allocations.
    75  		return false
    76  	case *ConvertType: // we should not rewrite the type description
    77  		return false
    78  	}
    79  	return nz.err == nil // only continue if we haven't found any errors
    80  }
    81  
    82  // WalkSelect normalizes the AST in Select mode.
    83  func (nz *normalizer) WalkSelect(cursor *Cursor) bool {
    84  	switch node := cursor.Node().(type) {
    85  	case *Literal:
    86  		nz.convertLiteralDedup(node, cursor)
    87  	case *ComparisonExpr:
    88  		nz.convertComparison(node)
    89  	case *ColName, TableName:
    90  		// Common node types that never contain Literals or ListArgs but create a lot of object
    91  		// allocations.
    92  		return false
    93  	case OrderBy, GroupBy:
    94  		// do not make a bind var for order by column_position
    95  		return false
    96  	case *ConvertType:
    97  		// we should not rewrite the type description
    98  		return false
    99  	}
   100  	return nz.err == nil // only continue if we haven't found any errors
   101  }
   102  
   103  func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) {
   104  	// If value is too long, don't dedup.
   105  	// Such values are most likely not for vindexes.
   106  	// We save a lot of CPU because we avoid building
   107  	// the key for them.
   108  	if len(node.Val) > 256 {
   109  		nz.convertLiteral(node, cursor)
   110  		return
   111  	}
   112  
   113  	// Make the bindvar
   114  	bval := nz.sqlToBindvar(node)
   115  	if bval == nil {
   116  		return
   117  	}
   118  
   119  	// Check if there's a bindvar for that value already.
   120  	var key string
   121  	if bval.Type == sqltypes.VarBinary || bval.Type == sqltypes.VarChar {
   122  		// Prefixing strings with "'" ensures that a string
   123  		// and number that have the same representation don't
   124  		// collide.
   125  		key = "'" + node.Val
   126  	} else {
   127  		key = node.Val
   128  	}
   129  	bvname, ok := nz.vals[key]
   130  	if !ok {
   131  		// If there's no such bindvar, make a new one.
   132  		bvname = nz.reserved.nextUnusedVar()
   133  		nz.vals[key] = bvname
   134  		nz.bindVars[bvname] = bval
   135  	}
   136  
   137  	// Modify the AST node to a bindvar.
   138  	cursor.Replace(NewArgument(bvname))
   139  }
   140  
   141  // convertLiteral converts an Literal without the dedup.
   142  func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) {
   143  	bval := nz.sqlToBindvar(node)
   144  	if bval == nil {
   145  		return
   146  	}
   147  
   148  	bvname := nz.reserved.nextUnusedVar()
   149  	nz.bindVars[bvname] = bval
   150  
   151  	cursor.Replace(NewArgument(bvname))
   152  }
   153  
   154  // convertComparison attempts to convert IN clauses to
   155  // use the list bind var construct. If it fails, it returns
   156  // with no change made. The walk function will then continue
   157  // and iterate on converting each individual value into separate
   158  // bind vars.
   159  func (nz *normalizer) convertComparison(node *ComparisonExpr) {
   160  	if node.Operator != InOp && node.Operator != NotInOp {
   161  		return
   162  	}
   163  	tupleVals, ok := node.Right.(ValTuple)
   164  	if !ok {
   165  		return
   166  	}
   167  	// The RHS is a tuple of values.
   168  	// Make a list bindvar.
   169  	bvals := &querypb.BindVariable{
   170  		Type: querypb.Type_TUPLE,
   171  	}
   172  	for _, val := range tupleVals {
   173  		bval := nz.sqlToBindvar(val)
   174  		if bval == nil {
   175  			return
   176  		}
   177  		bvals.Values = append(bvals.Values, &querypb.Value{
   178  			Type:  bval.Type,
   179  			Value: bval.Value,
   180  		})
   181  	}
   182  	bvname := nz.reserved.nextUnusedVar()
   183  	nz.bindVars[bvname] = bvals
   184  	// Modify RHS to be a list bindvar.
   185  	node.Right = ListArg(bvname)
   186  }
   187  
   188  func (nz *normalizer) sqlToBindvar(node SQLNode) *querypb.BindVariable {
   189  	if node, ok := node.(*Literal); ok {
   190  		var v sqltypes.Value
   191  		var err error
   192  		switch node.Type {
   193  		case StrVal:
   194  			v, err = sqltypes.NewValue(sqltypes.VarChar, node.Bytes())
   195  		case IntVal:
   196  			v, err = sqltypes.NewValue(sqltypes.Int64, node.Bytes())
   197  		case FloatVal:
   198  			v, err = sqltypes.NewValue(sqltypes.Float64, node.Bytes())
   199  		case DecimalVal:
   200  			v, err = sqltypes.NewValue(sqltypes.Decimal, node.Bytes())
   201  		case HexNum:
   202  			v, err = sqltypes.NewValue(sqltypes.HexNum, node.Bytes())
   203  		case HexVal:
   204  			// We parse the `x'7b7d'` string literal into a hex encoded string of `7b7d` in the parser
   205  			// We need to re-encode it back to the original MySQL query format before passing it on as a bindvar value to MySQL
   206  			var vbytes []byte
   207  			vbytes, err = node.encodeHexValToMySQLQueryFormat()
   208  			if err != nil {
   209  				return nil
   210  			}
   211  			v, err = sqltypes.NewValue(sqltypes.HexVal, vbytes)
   212  		default:
   213  			return nil
   214  		}
   215  		if err != nil {
   216  			return nil
   217  		}
   218  		return sqltypes.ValueBindVariable(v)
   219  	}
   220  	return nil
   221  }
   222  
   223  // GetBindvars returns a map of the bind vars referenced in the statement.
   224  func GetBindvars(stmt Statement) map[string]struct{} {
   225  	bindvars := make(map[string]struct{})
   226  	_ = Walk(func(node SQLNode) (kontinue bool, err error) {
   227  		switch node := node.(type) {
   228  		case *ColName, TableName:
   229  			// Common node types that never contain expressions but create a lot of object
   230  			// allocations.
   231  			return false, nil
   232  		case Argument:
   233  			bindvars[string(node)] = struct{}{}
   234  		case ListArg:
   235  			bindvars[string(node)] = struct{}{}
   236  		}
   237  		return true, nil
   238  	}, stmt)
   239  	return bindvars
   240  }