github.com/vedadiyan/sqlparser@v1.0.0/pkg/sqlparser/normalizer.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"fmt"
    21  	"strconv"
    22  
    23  	"github.com/vedadiyan/sqlparser/pkg/sqltypes"
    24  
    25  	querypb "github.com/vedadiyan/sqlparser/pkg/query"
    26  )
    27  
    28  // BindVars is a set of reserved bind variables from a SQL statement
    29  type BindVars map[string]struct{}
    30  
    31  // Normalize changes the statement to use bind values, and
    32  // updates the bind vars to those values. The supplied prefix
    33  // is used to generate the bind var names. The function ensures
    34  // that there are no collisions with existing bind vars.
    35  // Within Select constructs, bind vars are deduped. This allows
    36  // us to identify vindex equality. Otherwise, every value is
    37  // treated as distinct.
    38  func Normalize(stmt Statement, reserved *ReservedVars, bindVars map[string]*querypb.BindVariable) error {
    39  	nz := newNormalizer(reserved, bindVars)
    40  	_ = SafeRewrite(stmt, nz.walkStatementDown, nz.walkStatementUp)
    41  	return nz.err
    42  }
    43  
    44  type normalizer struct {
    45  	bindVars  map[string]*querypb.BindVariable
    46  	reserved  *ReservedVars
    47  	vals      map[string]string
    48  	err       error
    49  	inDerived bool
    50  }
    51  
    52  func newNormalizer(reserved *ReservedVars, bindVars map[string]*querypb.BindVariable) *normalizer {
    53  	return &normalizer{
    54  		bindVars: bindVars,
    55  		reserved: reserved,
    56  		vals:     make(map[string]string),
    57  	}
    58  }
    59  
    60  // walkStatementUp is one half of the top level walk function.
    61  func (nz *normalizer) walkStatementUp(cursor *Cursor) bool {
    62  	if nz.err != nil {
    63  		return false
    64  	}
    65  	node, isLiteral := cursor.Node().(*Literal)
    66  	if !isLiteral {
    67  		return true
    68  	}
    69  	nz.convertLiteral(node, cursor)
    70  	return nz.err == nil // only continue if we haven't found any errors
    71  }
    72  
    73  // walkStatementDown is the top level walk function.
    74  // If it encounters a Select, it switches to a mode
    75  // where variables are deduped.
    76  func (nz *normalizer) walkStatementDown(node, parent SQLNode) bool {
    77  	switch node := node.(type) {
    78  	// no need to normalize the statement types
    79  	case *Set, *Show, *Begin, *Commit, *Rollback, *Savepoint, DDLStatement, *SRollback, *Release, *OtherAdmin, *OtherRead:
    80  		return false
    81  	case *Select:
    82  		_, isDerived := parent.(*DerivedTable)
    83  		var tmp bool
    84  		tmp, nz.inDerived = nz.inDerived, isDerived
    85  		_ = SafeRewrite(node, nz.walkDownSelect, nz.walkUpSelect)
    86  		// Don't continue
    87  		nz.inDerived = tmp
    88  		return false
    89  	case *ComparisonExpr:
    90  		nz.convertComparison(node)
    91  	case *UpdateExpr:
    92  		nz.convertUpdateExpr(node)
    93  	case *ColName, TableName:
    94  		// Common node types that never contain Literal or ListArgs but create a lot of object
    95  		// allocations.
    96  		return false
    97  	case *ConvertType: // we should not rewrite the type description
    98  		return false
    99  	}
   100  	return nz.err == nil // only continue if we haven't found any errors
   101  }
   102  
   103  // walkDownSelect normalizes the AST in Select mode.
   104  func (nz *normalizer) walkDownSelect(node, parent SQLNode) bool {
   105  	switch node := node.(type) {
   106  	case *Select:
   107  		_, isDerived := parent.(*DerivedTable)
   108  		if !isDerived {
   109  			return true
   110  		}
   111  		var tmp bool
   112  		tmp, nz.inDerived = nz.inDerived, isDerived
   113  		// initiating a new AST walk here means that we might change something while walking down on the tree,
   114  		// but since we are only changing literals, we can be safe that we are not changing the SELECT struct,
   115  		// only something much further down, and that should be safe
   116  		_ = SafeRewrite(node, nz.walkDownSelect, nz.walkUpSelect)
   117  		// Don't continue
   118  		nz.inDerived = tmp
   119  		return false
   120  	case SelectExprs:
   121  		return !nz.inDerived
   122  	case *ComparisonExpr:
   123  		nz.convertComparison(node)
   124  	case *FramePoint:
   125  		// do not make a bind var for rows and range
   126  		return false
   127  	case *ColName, TableName:
   128  		// Common node types that never contain Literals or ListArgs but create a lot of object
   129  		// allocations.
   130  		return false
   131  	case *ConvertType:
   132  		// we should not rewrite the type description
   133  		return false
   134  	}
   135  	return nz.err == nil // only continue if we haven't found any errors
   136  }
   137  
   138  // walkUpSelect normalizes the Literals in Select mode.
   139  func (nz *normalizer) walkUpSelect(cursor *Cursor) bool {
   140  	if nz.err != nil {
   141  		return false
   142  	}
   143  	node, isLiteral := cursor.Node().(*Literal)
   144  	if !isLiteral {
   145  		return true
   146  	}
   147  	parent := cursor.Parent()
   148  	switch parent.(type) {
   149  	case *Order, GroupBy:
   150  		return false
   151  	case *Limit:
   152  		nz.convertLiteral(node, cursor)
   153  	default:
   154  		nz.convertLiteralDedup(node, cursor)
   155  	}
   156  	return nz.err == nil // only continue if we haven't found any errors
   157  }
   158  
   159  func validateLiteral(node *Literal) (err error) {
   160  	switch node.Type {
   161  	case DateVal:
   162  		_, err = ParseDate(node.Val)
   163  	case TimeVal:
   164  		_, err = ParseTime(node.Val)
   165  	case TimestampVal:
   166  		_, err = ParseDateTime(node.Val)
   167  	}
   168  	return err
   169  }
   170  
   171  func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) {
   172  	err := validateLiteral(node)
   173  	if err != nil {
   174  		nz.err = err
   175  	}
   176  
   177  	// If value is too long, don't dedup.
   178  	// Such values are most likely not for vindexes.
   179  	// We save a lot of CPU because we avoid building
   180  	// the key for them.
   181  	if len(node.Val) > 256 {
   182  		nz.convertLiteral(node, cursor)
   183  		return
   184  	}
   185  
   186  	// Make the bindvar
   187  	bval := SQLToBindvar(node)
   188  	if bval == nil {
   189  		return
   190  	}
   191  
   192  	// Check if there's a bindvar for that value already.
   193  	key := keyFor(bval, node)
   194  	bvname, ok := nz.vals[key]
   195  	if !ok {
   196  		// If there's no such bindvar, make a new one.
   197  		bvname = nz.reserved.nextUnusedVar()
   198  		nz.vals[key] = bvname
   199  		nz.bindVars[bvname] = bval
   200  	}
   201  
   202  	// Modify the AST node to a bindvar.
   203  	cursor.Replace(NewArgument(bvname))
   204  }
   205  
   206  func keyFor(bval *querypb.BindVariable, lit *Literal) string {
   207  	if bval.Type != sqltypes.VarBinary && bval.Type != sqltypes.VarChar {
   208  		return lit.Val
   209  	}
   210  
   211  	// Prefixing strings with "'" ensures that a string
   212  	// and number that have the same representation don't
   213  	// collide.
   214  	return "'" + lit.Val
   215  
   216  }
   217  
   218  // convertLiteral converts an Literal without the dedup.
   219  func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) {
   220  	err := validateLiteral(node)
   221  	if err != nil {
   222  		nz.err = err
   223  	}
   224  
   225  	bval := SQLToBindvar(node)
   226  	if bval == nil {
   227  		return
   228  	}
   229  
   230  	bvname := nz.reserved.nextUnusedVar()
   231  	nz.bindVars[bvname] = bval
   232  
   233  	cursor.Replace(NewArgument(bvname))
   234  }
   235  
   236  // convertComparison attempts to convert IN clauses to
   237  // use the list bind var construct. If it fails, it returns
   238  // with no change made. The walk function will then continue
   239  // and iterate on converting each individual value into separate
   240  // bind vars.
   241  func (nz *normalizer) convertComparison(node *ComparisonExpr) {
   242  	switch node.Operator {
   243  	case InOp, NotInOp:
   244  		nz.rewriteInComparisons(node)
   245  	default:
   246  		nz.rewriteOtherComparisons(node)
   247  	}
   248  }
   249  
   250  func (nz *normalizer) rewriteOtherComparisons(node *ComparisonExpr) {
   251  	newR := nz.parameterize(node.Left, node.Right)
   252  	if newR != nil {
   253  		node.Right = newR
   254  	}
   255  }
   256  
   257  func (nz *normalizer) parameterize(left, right Expr) Expr {
   258  	col, ok := left.(*ColName)
   259  	if !ok {
   260  		return nil
   261  	}
   262  	lit, ok := right.(*Literal)
   263  	if !ok {
   264  		return nil
   265  	}
   266  	err := validateLiteral(lit)
   267  	if err != nil {
   268  		nz.err = err
   269  		return nil
   270  	}
   271  
   272  	bval := SQLToBindvar(lit)
   273  	if bval == nil {
   274  		return nil
   275  	}
   276  	key := keyFor(bval, lit)
   277  	bvname := nz.decideBindVarName(key, lit, col, bval)
   278  	return Argument(bvname)
   279  }
   280  
   281  func (nz *normalizer) decideBindVarName(key string, lit *Literal, col *ColName, bval *querypb.BindVariable) string {
   282  	if len(lit.Val) <= 256 {
   283  		// first we check if we already have a bindvar for this value. if we do, we re-use that bindvar name
   284  		bvname, ok := nz.vals[key]
   285  		if ok {
   286  			return bvname
   287  		}
   288  	}
   289  
   290  	// If there's no such bindvar, or we have a big value, make a new one.
   291  	// Big values are most likely not for vindexes.
   292  	// We save a lot of CPU because we avoid building
   293  	bvname := nz.reserved.ReserveColName(col)
   294  	nz.vals[key] = bvname
   295  	nz.bindVars[bvname] = bval
   296  
   297  	return bvname
   298  }
   299  
   300  func (nz *normalizer) rewriteInComparisons(node *ComparisonExpr) {
   301  	tupleVals, ok := node.Right.(ValTuple)
   302  	if !ok {
   303  		return
   304  	}
   305  
   306  	// The RHS is a tuple of values.
   307  	// Make a list bindvar.
   308  	bvals := &querypb.BindVariable{
   309  		Type: querypb.Type_TUPLE,
   310  	}
   311  	for _, val := range tupleVals {
   312  		bval := SQLToBindvar(val)
   313  		if bval == nil {
   314  			return
   315  		}
   316  		bvals.Values = append(bvals.Values, &querypb.Value{
   317  			Type:  bval.Type,
   318  			Value: bval.Value,
   319  		})
   320  	}
   321  	bvname := nz.reserved.nextUnusedVar()
   322  	nz.bindVars[bvname] = bvals
   323  	// Modify RHS to be a list bindvar.
   324  	node.Right = ListArg(bvname)
   325  }
   326  
   327  func (nz *normalizer) convertUpdateExpr(node *UpdateExpr) {
   328  	newR := nz.parameterize(node.Name, node.Expr)
   329  	if newR != nil {
   330  		node.Expr = newR
   331  	}
   332  }
   333  
   334  func SQLToBindvar(node SQLNode) *querypb.BindVariable {
   335  	if node, ok := node.(*Literal); ok {
   336  		var v sqltypes.Value
   337  		var err error
   338  		switch node.Type {
   339  		case StrVal:
   340  			v, err = sqltypes.NewValue(sqltypes.VarChar, node.Bytes())
   341  		case IntVal:
   342  			v, err = sqltypes.NewValue(sqltypes.Int64, node.Bytes())
   343  		case FloatVal:
   344  			v, err = sqltypes.NewValue(sqltypes.Float64, node.Bytes())
   345  		case DecimalVal:
   346  			v, err = sqltypes.NewValue(sqltypes.Decimal, node.Bytes())
   347  		case HexNum:
   348  			v, err = sqltypes.NewValue(sqltypes.HexNum, node.Bytes())
   349  		case HexVal:
   350  			// We parse the `x'7b7d'` string literal into a hex encoded string of `7b7d` in the parser
   351  			// We need to re-encode it back to the original MySQL query format before passing it on as a bindvar value to MySQL
   352  			var vbytes []byte
   353  			vbytes, err = node.encodeHexOrBitValToMySQLQueryFormat()
   354  			if err != nil {
   355  				return nil
   356  			}
   357  			v, err = sqltypes.NewValue(sqltypes.HexVal, vbytes)
   358  		case BitVal:
   359  			// Convert bit value to hex number in parameterized query format
   360  			var ui uint64
   361  			ui, err = strconv.ParseUint(string(node.Bytes()), 2, 64)
   362  			if err != nil {
   363  				return nil
   364  			}
   365  			v, err = sqltypes.NewValue(sqltypes.HexNum, []byte(fmt.Sprintf("0x%x", ui)))
   366  		case DateVal:
   367  			v, err = sqltypes.NewValue(sqltypes.Date, node.Bytes())
   368  		case TimeVal:
   369  			v, err = sqltypes.NewValue(sqltypes.Time, node.Bytes())
   370  		case TimestampVal:
   371  			// This is actually a DATETIME MySQL type. The timestamp literal
   372  			// syntax is part of the SQL standard and MySQL DATETIME matches
   373  			// the type best.
   374  			v, err = sqltypes.NewValue(sqltypes.Datetime, node.Bytes())
   375  		default:
   376  			return nil
   377  		}
   378  		if err != nil {
   379  			return nil
   380  		}
   381  		return sqltypes.ValueBindVariable(v)
   382  	}
   383  	return nil
   384  }
   385  
   386  // GetBindvars returns a map of the bind vars referenced in the statement.
   387  func GetBindvars(stmt Statement) map[string]struct{} {
   388  	bindvars := make(map[string]struct{})
   389  	_ = Walk(func(node SQLNode) (kontinue bool, err error) {
   390  		switch node := node.(type) {
   391  		case *ColName, TableName:
   392  			// Common node types that never contain expressions but create a lot of object
   393  			// allocations.
   394  			return false, nil
   395  		case Argument:
   396  			bindvars[string(node)] = struct{}{}
   397  		case ListArg:
   398  			bindvars[string(node)] = struct{}{}
   399  		}
   400  		return true, nil
   401  	}, stmt)
   402  	return bindvars
   403  }