vitess.io/vitess@v0.16.2/go/vt/vtgate/simplifier/simplifier_test.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package simplifier
    18  
    19  import (
    20  	"fmt"
    21  	"testing"
    22  
    23  	"vitess.io/vitess/go/vt/log"
    24  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    25  
    26  	"github.com/stretchr/testify/require"
    27  
    28  	"vitess.io/vitess/go/vt/sqlparser"
    29  )
    30  
    31  func TestFindAllExpressions(t *testing.T) {
    32  	query := `
    33  select 
    34  	user.selectExpr1, 
    35  	unsharded.selectExpr2,
    36  	count(*) as leCount
    37  from 
    38  	user join 
    39  	unsharded on 
    40  		user.joinCond = unsharded.joinCond 
    41  where
    42  	unsharded.wherePred = 42 and
    43  	wherePred = 'foo' and 
    44  	user.id = unsharded.id
    45  group by 
    46  	user.groupByExpr1 + unsharded.groupByExpr2
    47  order by 
    48  	user.orderByExpr1 desc, 
    49  	unsharded.orderByExpr2 asc
    50  limit 123 offset 456
    51  `
    52  	ast, err := sqlparser.Parse(query)
    53  	require.NoError(t, err)
    54  	visitAllExpressionsInAST(ast.(sqlparser.SelectStatement), func(cursor expressionCursor) bool {
    55  		fmt.Printf(">> found expression: %s\n", sqlparser.String(cursor.expr))
    56  		cursor.replace(sqlparser.NewIntLiteral("1"))
    57  		fmt.Printf("remove: %s\n", sqlparser.String(ast))
    58  		cursor.restore()
    59  		fmt.Printf("restore: %s\n", sqlparser.String(ast))
    60  		cursor.remove()
    61  		fmt.Printf("replace it with literal: %s\n", sqlparser.String(ast))
    62  		cursor.restore()
    63  		fmt.Printf("restore: %s\n", sqlparser.String(ast))
    64  		return true
    65  	})
    66  }
    67  
    68  func TestAbortExpressionCursor(t *testing.T) {
    69  	query := "select user.id, count(*), unsharded.name from user join unsharded on 13 = 14 where unsharded.id = 42 and name = 'foo' and user.id = unsharded.id"
    70  	ast, err := sqlparser.Parse(query)
    71  	require.NoError(t, err)
    72  	visitAllExpressionsInAST(ast.(sqlparser.SelectStatement), func(cursor expressionCursor) bool {
    73  		fmt.Println(sqlparser.String(cursor.expr))
    74  		cursor.replace(sqlparser.NewIntLiteral("1"))
    75  		fmt.Println(sqlparser.String(ast))
    76  		cursor.replace(cursor.expr)
    77  		_, isFunc := cursor.expr.(sqlparser.AggrFunc)
    78  		return !isFunc
    79  	})
    80  }
    81  
    82  func TestSimplifyEvalEngineExpr(t *testing.T) {
    83  	// ast struct for L0         +
    84  	// L1                +            +
    85  	// L2             +     +      +     +
    86  	// L3            1 2   3 4    5 6   7 8
    87  
    88  	// L3
    89  	i1, i2, i3, i4, i5, i6, i7, i8 :=
    90  		sqlparser.NewIntLiteral("1"),
    91  		sqlparser.NewIntLiteral("2"),
    92  		sqlparser.NewIntLiteral("3"),
    93  		sqlparser.NewIntLiteral("4"),
    94  		sqlparser.NewIntLiteral("5"),
    95  		sqlparser.NewIntLiteral("6"),
    96  		sqlparser.NewIntLiteral("7"),
    97  		sqlparser.NewIntLiteral("8")
    98  	// L2
    99  	p21, p22, p23, p24 :=
   100  		plus(i1, i2),
   101  		plus(i3, i4),
   102  		plus(i5, i6),
   103  		plus(i7, i8)
   104  
   105  	// L1
   106  	p11, p12 :=
   107  		plus(p21, p22),
   108  		plus(p23, p24)
   109  
   110  	// L0
   111  	p0 := plus(p11, p12)
   112  
   113  	expr := SimplifyExpr(p0, func(expr sqlparser.Expr) bool {
   114  		local, err := evalengine.TranslateEx(expr, nil, true)
   115  		if err != nil {
   116  			return false
   117  		}
   118  		res, err := evalengine.EmptyExpressionEnv().Evaluate(local)
   119  		if err != nil {
   120  			return false
   121  		}
   122  		toInt64, err := res.Value().ToInt64()
   123  		if err != nil {
   124  			return false
   125  		}
   126  		return toInt64 >= 8
   127  	})
   128  	log.Infof("simplest expr to evaluate to >= 8: [%s], started from: [%s]", sqlparser.String(expr), sqlparser.String(p0))
   129  }
   130  
   131  func plus(a, b sqlparser.Expr) sqlparser.Expr {
   132  	return &sqlparser.BinaryExpr{
   133  		Operator: sqlparser.PlusOp,
   134  		Left:     a,
   135  		Right:    b,
   136  	}
   137  }