github.com/expr-lang/expr@v1.16.9/optimizer/optimizer_test.go (about)

     1  package optimizer_test
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"strings"
     7  	"testing"
     8  
     9  	"github.com/expr-lang/expr/internal/testify/assert"
    10  	"github.com/expr-lang/expr/internal/testify/require"
    11  
    12  	"github.com/expr-lang/expr"
    13  	"github.com/expr-lang/expr/ast"
    14  	"github.com/expr-lang/expr/checker"
    15  	"github.com/expr-lang/expr/conf"
    16  	"github.com/expr-lang/expr/optimizer"
    17  	"github.com/expr-lang/expr/parser"
    18  )
    19  
    20  func TestOptimize(t *testing.T) {
    21  	env := map[string]any{
    22  		"a": 1,
    23  		"b": 2,
    24  		"c": 3,
    25  	}
    26  
    27  	tests := []struct {
    28  		expr string
    29  		want any
    30  	}{
    31  		{`1 + 2`, 3},
    32  		{`sum([])`, 0},
    33  		{`sum([a])`, 1},
    34  		{`sum([a, b])`, 3},
    35  		{`sum([a, b, c])`, 6},
    36  		{`sum([a, b, c, 4])`, 10},
    37  		{`sum(1..10, # * 1000)`, 55000},
    38  		{`sum(map(1..10, # * 1000), # / 1000)`, float64(55)},
    39  		{`all(1..3, {# > 0}) && all(1..3, {# < 4})`, true},
    40  		{`all(1..3, {# > 2}) && all(1..3, {# < 4})`, false},
    41  		{`all(1..3, {# > 0}) && all(1..3, {# < 2})`, false},
    42  		{`all(1..3, {# > 2}) && all(1..3, {# < 2})`, false},
    43  		{`all(1..3, {# > 0}) || all(1..3, {# < 4})`, true},
    44  		{`all(1..3, {# > 0}) || all(1..3, {# != 2})`, true},
    45  		{`all(1..3, {# != 3}) || all(1..3, {# < 4})`, true},
    46  		{`all(1..3, {# != 3}) || all(1..3, {# != 2})`, false},
    47  		{`none(1..3, {# == 0})`, true},
    48  		{`none(1..3, {# == 0}) && none(1..3, {# == 4})`, true},
    49  		{`none(1..3, {# == 0}) && none(1..3, {# == 3})`, false},
    50  		{`none(1..3, {# == 1}) && none(1..3, {# == 4})`, false},
    51  		{`none(1..3, {# == 1}) && none(1..3, {# == 3})`, false},
    52  		{`none(1..3, {# == 0}) || none(1..3, {# == 4})`, true},
    53  		{`none(1..3, {# == 0}) || none(1..3, {# == 3})`, true},
    54  		{`none(1..3, {# == 1}) || none(1..3, {# == 4})`, true},
    55  		{`none(1..3, {# == 1}) || none(1..3, {# == 3})`, false},
    56  		{`any([1, 1, 0, 1], {# == 0})`, true},
    57  		{`any(1..3, {# == 1}) && any(1..3, {# == 2})`, true},
    58  		{`any(1..3, {# == 0}) && any(1..3, {# == 2})`, false},
    59  		{`any(1..3, {# == 1}) && any(1..3, {# == 4})`, false},
    60  		{`any(1..3, {# == 0}) && any(1..3, {# == 4})`, false},
    61  		{`any(1..3, {# == 1}) || any(1..3, {# == 2})`, true},
    62  		{`any(1..3, {# == 0}) || any(1..3, {# == 2})`, true},
    63  		{`any(1..3, {# == 1}) || any(1..3, {# == 4})`, true},
    64  		{`any(1..3, {# == 0}) || any(1..3, {# == 4})`, false},
    65  		{`one([1, 1, 0, 1], {# == 0}) and not one([1, 0, 0, 1], {# == 0})`, true},
    66  		{`one(1..3, {# == 1}) and one(1..3, {# == 2})`, true},
    67  		{`one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2})`, false},
    68  		{`one(1..3, {# == 1}) and one(1..3, {# == 2 || # == 3})`, false},
    69  		{`one(1..3, {# == 1 || # == 2}) and one(1..3, {# == 2 || # == 3})`, false},
    70  		{`one(1..3, {# == 1}) or one(1..3, {# == 2})`, true},
    71  		{`one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2})`, true},
    72  		{`one(1..3, {# == 1}) or one(1..3, {# == 2 || # == 3})`, true},
    73  		{`one(1..3, {# == 1 || # == 2}) or one(1..3, {# == 2 || # == 3})`, false},
    74  	}
    75  
    76  	for _, tt := range tests {
    77  		t.Run(tt.expr, func(t *testing.T) {
    78  			program, err := expr.Compile(tt.expr, expr.Env(env))
    79  			require.NoError(t, err)
    80  
    81  			output, err := expr.Run(program, env)
    82  			require.NoError(t, err)
    83  			assert.Equal(t, tt.want, output)
    84  
    85  			unoptimizedProgram, err := expr.Compile(tt.expr, expr.Env(env), expr.Optimize(false))
    86  			require.NoError(t, err)
    87  
    88  			unoptimizedOutput, err := expr.Run(unoptimizedProgram, env)
    89  			require.NoError(t, err)
    90  			assert.Equal(t, tt.want, unoptimizedOutput)
    91  		})
    92  	}
    93  }
    94  
    95  func TestOptimize_constant_folding(t *testing.T) {
    96  	tree, err := parser.Parse(`[1,2,3][5*5-25]`)
    97  	require.NoError(t, err)
    98  
    99  	err = optimizer.Optimize(&tree.Node, nil)
   100  	require.NoError(t, err)
   101  
   102  	expected := &ast.MemberNode{
   103  		Node:     &ast.ConstantNode{Value: []any{1, 2, 3}},
   104  		Property: &ast.IntegerNode{Value: 0},
   105  	}
   106  
   107  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   108  }
   109  
   110  func TestOptimize_constant_folding_with_floats(t *testing.T) {
   111  	tree, err := parser.Parse(`1 + 2.0 * ((1.0 * 2) / 2) - 0`)
   112  	require.NoError(t, err)
   113  
   114  	err = optimizer.Optimize(&tree.Node, nil)
   115  	require.NoError(t, err)
   116  
   117  	expected := &ast.FloatNode{Value: 3.0}
   118  
   119  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   120  	assert.Equal(t, reflect.Float64, tree.Node.Type().Kind())
   121  }
   122  
   123  func TestOptimize_constant_folding_with_bools(t *testing.T) {
   124  	tree, err := parser.Parse(`(true and false) or (true or false) or (false and false) or (true and (true == false))`)
   125  	require.NoError(t, err)
   126  
   127  	err = optimizer.Optimize(&tree.Node, nil)
   128  	require.NoError(t, err)
   129  
   130  	expected := &ast.BoolNode{Value: true}
   131  
   132  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   133  }
   134  
   135  func TestOptimize_in_array(t *testing.T) {
   136  	config := conf.New(map[string]int{"v": 0})
   137  
   138  	tree, err := parser.Parse(`v in [1,2,3]`)
   139  	require.NoError(t, err)
   140  
   141  	_, err = checker.Check(tree, config)
   142  	require.NoError(t, err)
   143  
   144  	err = optimizer.Optimize(&tree.Node, nil)
   145  	require.NoError(t, err)
   146  
   147  	expected := &ast.BinaryNode{
   148  		Operator: "in",
   149  		Left:     &ast.IdentifierNode{Value: "v"},
   150  		Right:    &ast.ConstantNode{Value: map[int]struct{}{1: {}, 2: {}, 3: {}}},
   151  	}
   152  
   153  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   154  }
   155  
   156  func TestOptimize_in_range(t *testing.T) {
   157  	tree, err := parser.Parse(`age in 18..31`)
   158  	require.NoError(t, err)
   159  
   160  	config := conf.New(map[string]int{"age": 30})
   161  	_, err = checker.Check(tree, config)
   162  
   163  	err = optimizer.Optimize(&tree.Node, nil)
   164  	require.NoError(t, err)
   165  
   166  	left := &ast.IdentifierNode{
   167  		Value: "age",
   168  	}
   169  	expected := &ast.BinaryNode{
   170  		Operator: "and",
   171  		Left: &ast.BinaryNode{
   172  			Operator: ">=",
   173  			Left:     left,
   174  			Right: &ast.IntegerNode{
   175  				Value: 18,
   176  			},
   177  		},
   178  		Right: &ast.BinaryNode{
   179  			Operator: "<=",
   180  			Left:     left,
   181  			Right: &ast.IntegerNode{
   182  				Value: 31,
   183  			},
   184  		},
   185  	}
   186  
   187  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   188  }
   189  
   190  func TestOptimize_in_range_with_floats(t *testing.T) {
   191  	out, err := expr.Eval(`f in 1..3`, map[string]any{"f": 1.5})
   192  	require.NoError(t, err)
   193  	assert.Equal(t, false, out)
   194  }
   195  
   196  func TestOptimize_const_expr(t *testing.T) {
   197  	tree, err := parser.Parse(`toUpper("hello")`)
   198  	require.NoError(t, err)
   199  
   200  	env := map[string]any{
   201  		"toUpper": strings.ToUpper,
   202  	}
   203  
   204  	config := conf.New(env)
   205  	config.ConstExpr("toUpper")
   206  
   207  	err = optimizer.Optimize(&tree.Node, config)
   208  	require.NoError(t, err)
   209  
   210  	expected := &ast.ConstantNode{
   211  		Value: "HELLO",
   212  	}
   213  
   214  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   215  }
   216  
   217  func TestOptimize_filter_len(t *testing.T) {
   218  	tree, err := parser.Parse(`len(filter(users, .Name == "Bob"))`)
   219  	require.NoError(t, err)
   220  
   221  	err = optimizer.Optimize(&tree.Node, nil)
   222  	require.NoError(t, err)
   223  
   224  	expected := &ast.BuiltinNode{
   225  		Name: "count",
   226  		Arguments: []ast.Node{
   227  			&ast.IdentifierNode{Value: "users"},
   228  			&ast.ClosureNode{
   229  				Node: &ast.BinaryNode{
   230  					Operator: "==",
   231  					Left: &ast.MemberNode{
   232  						Node:     &ast.PointerNode{},
   233  						Property: &ast.StringNode{Value: "Name"},
   234  					},
   235  					Right: &ast.StringNode{Value: "Bob"},
   236  				},
   237  			},
   238  		},
   239  	}
   240  
   241  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   242  }
   243  
   244  func TestOptimize_filter_0(t *testing.T) {
   245  	tree, err := parser.Parse(`filter(users, .Name == "Bob")[0]`)
   246  	require.NoError(t, err)
   247  
   248  	err = optimizer.Optimize(&tree.Node, nil)
   249  	require.NoError(t, err)
   250  
   251  	expected := &ast.BuiltinNode{
   252  		Name: "find",
   253  		Arguments: []ast.Node{
   254  			&ast.IdentifierNode{Value: "users"},
   255  			&ast.ClosureNode{
   256  				Node: &ast.BinaryNode{
   257  					Operator: "==",
   258  					Left: &ast.MemberNode{
   259  						Node:     &ast.PointerNode{},
   260  						Property: &ast.StringNode{Value: "Name"},
   261  					},
   262  					Right: &ast.StringNode{Value: "Bob"},
   263  				},
   264  			},
   265  		},
   266  		Throws: true,
   267  	}
   268  
   269  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   270  }
   271  
   272  func TestOptimize_filter_first(t *testing.T) {
   273  	tree, err := parser.Parse(`first(filter(users, .Name == "Bob"))`)
   274  	require.NoError(t, err)
   275  
   276  	err = optimizer.Optimize(&tree.Node, nil)
   277  	require.NoError(t, err)
   278  
   279  	expected := &ast.BuiltinNode{
   280  		Name: "find",
   281  		Arguments: []ast.Node{
   282  			&ast.IdentifierNode{Value: "users"},
   283  			&ast.ClosureNode{
   284  				Node: &ast.BinaryNode{
   285  					Operator: "==",
   286  					Left: &ast.MemberNode{
   287  						Node:     &ast.PointerNode{},
   288  						Property: &ast.StringNode{Value: "Name"},
   289  					},
   290  					Right: &ast.StringNode{Value: "Bob"},
   291  				},
   292  			},
   293  		},
   294  		Throws: false,
   295  	}
   296  
   297  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   298  }
   299  
   300  func TestOptimize_filter_minus_1(t *testing.T) {
   301  	tree, err := parser.Parse(`filter(users, .Name == "Bob")[-1]`)
   302  	require.NoError(t, err)
   303  
   304  	err = optimizer.Optimize(&tree.Node, nil)
   305  	require.NoError(t, err)
   306  
   307  	expected := &ast.BuiltinNode{
   308  		Name: "findLast",
   309  		Arguments: []ast.Node{
   310  			&ast.IdentifierNode{Value: "users"},
   311  			&ast.ClosureNode{
   312  				Node: &ast.BinaryNode{
   313  					Operator: "==",
   314  					Left: &ast.MemberNode{
   315  						Node:     &ast.PointerNode{},
   316  						Property: &ast.StringNode{Value: "Name"},
   317  					},
   318  					Right: &ast.StringNode{Value: "Bob"},
   319  				},
   320  			},
   321  		},
   322  		Throws: true,
   323  	}
   324  
   325  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   326  }
   327  
   328  func TestOptimize_filter_last(t *testing.T) {
   329  	tree, err := parser.Parse(`last(filter(users, .Name == "Bob"))`)
   330  	require.NoError(t, err)
   331  
   332  	err = optimizer.Optimize(&tree.Node, nil)
   333  	require.NoError(t, err)
   334  
   335  	expected := &ast.BuiltinNode{
   336  		Name: "findLast",
   337  		Arguments: []ast.Node{
   338  			&ast.IdentifierNode{Value: "users"},
   339  			&ast.ClosureNode{
   340  				Node: &ast.BinaryNode{
   341  					Operator: "==",
   342  					Left: &ast.MemberNode{
   343  						Node:     &ast.PointerNode{},
   344  						Property: &ast.StringNode{Value: "Name"},
   345  					},
   346  					Right: &ast.StringNode{Value: "Bob"},
   347  				},
   348  			},
   349  		},
   350  		Throws: false,
   351  	}
   352  
   353  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   354  }
   355  
   356  func TestOptimize_filter_map(t *testing.T) {
   357  	tree, err := parser.Parse(`map(filter(users, .Name == "Bob"), .Age)`)
   358  	require.NoError(t, err)
   359  
   360  	err = optimizer.Optimize(&tree.Node, nil)
   361  	require.NoError(t, err)
   362  
   363  	expected := &ast.BuiltinNode{
   364  		Name: "filter",
   365  		Arguments: []ast.Node{
   366  			&ast.IdentifierNode{Value: "users"},
   367  			&ast.ClosureNode{
   368  				Node: &ast.BinaryNode{
   369  					Operator: "==",
   370  					Left: &ast.MemberNode{
   371  						Node:     &ast.PointerNode{},
   372  						Property: &ast.StringNode{Value: "Name"},
   373  					},
   374  					Right: &ast.StringNode{Value: "Bob"},
   375  				},
   376  			},
   377  		},
   378  		Map: &ast.MemberNode{
   379  			Node:     &ast.PointerNode{},
   380  			Property: &ast.StringNode{Value: "Age"},
   381  		},
   382  	}
   383  
   384  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   385  }
   386  
   387  func TestOptimize_filter_map_first(t *testing.T) {
   388  	tree, err := parser.Parse(`first(map(filter(users, .Name == "Bob"), .Age))`)
   389  	require.NoError(t, err)
   390  
   391  	err = optimizer.Optimize(&tree.Node, nil)
   392  	require.NoError(t, err)
   393  
   394  	expected := &ast.BuiltinNode{
   395  		Name: "find",
   396  		Arguments: []ast.Node{
   397  			&ast.IdentifierNode{Value: "users"},
   398  			&ast.ClosureNode{
   399  				Node: &ast.BinaryNode{
   400  					Operator: "==",
   401  					Left: &ast.MemberNode{
   402  						Node:     &ast.PointerNode{},
   403  						Property: &ast.StringNode{Value: "Name"},
   404  					},
   405  					Right: &ast.StringNode{Value: "Bob"},
   406  				},
   407  			},
   408  		},
   409  		Map: &ast.MemberNode{
   410  			Node:     &ast.PointerNode{},
   411  			Property: &ast.StringNode{Value: "Age"},
   412  		},
   413  		Throws: false,
   414  	}
   415  
   416  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   417  }
   418  
   419  func TestOptimize_predicate_combination(t *testing.T) {
   420  	tests := []struct {
   421  		op     string
   422  		fn     string
   423  		wantOp string
   424  	}{
   425  		{"and", "all", "and"},
   426  		{"&&", "all", "&&"},
   427  		{"or", "any", "or"},
   428  		{"||", "any", "||"},
   429  		{"and", "none", "or"},
   430  		{"&&", "none", "||"},
   431  	}
   432  
   433  	for _, tt := range tests {
   434  		rule := fmt.Sprintf(`%s(users, .Age > 18 and .Name != "Bob") %s %s(users, .Age < 30)`, tt.fn, tt.op, tt.fn)
   435  		t.Run(rule, func(t *testing.T) {
   436  			tree, err := parser.Parse(rule)
   437  			require.NoError(t, err)
   438  
   439  			err = optimizer.Optimize(&tree.Node, nil)
   440  			require.NoError(t, err)
   441  
   442  			expected := &ast.BuiltinNode{
   443  				Name: tt.fn,
   444  				Arguments: []ast.Node{
   445  					&ast.IdentifierNode{Value: "users"},
   446  					&ast.ClosureNode{
   447  						Node: &ast.BinaryNode{
   448  							Operator: tt.wantOp,
   449  							Left: &ast.BinaryNode{
   450  								Operator: "and",
   451  								Left: &ast.BinaryNode{
   452  									Operator: ">",
   453  									Left: &ast.MemberNode{
   454  										Node:     &ast.PointerNode{},
   455  										Property: &ast.StringNode{Value: "Age"},
   456  									},
   457  									Right: &ast.IntegerNode{Value: 18},
   458  								},
   459  								Right: &ast.BinaryNode{
   460  									Operator: "!=",
   461  									Left: &ast.MemberNode{
   462  										Node:     &ast.PointerNode{},
   463  										Property: &ast.StringNode{Value: "Name"},
   464  									},
   465  									Right: &ast.StringNode{Value: "Bob"},
   466  								},
   467  							},
   468  							Right: &ast.BinaryNode{
   469  								Operator: "<",
   470  								Left: &ast.MemberNode{
   471  									Node:     &ast.PointerNode{},
   472  									Property: &ast.StringNode{Value: "Age"},
   473  								},
   474  								Right: &ast.IntegerNode{Value: 30},
   475  							},
   476  						},
   477  					},
   478  				},
   479  			}
   480  			assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   481  		})
   482  	}
   483  }
   484  
   485  func TestOptimize_predicate_combination_nested(t *testing.T) {
   486  	tree, err := parser.Parse(`all(users, {all(.Friends, {.Age == 18 })}) && all(users, {all(.Friends, {.Name != "Bob" })})`)
   487  	require.NoError(t, err)
   488  
   489  	err = optimizer.Optimize(&tree.Node, nil)
   490  	require.NoError(t, err)
   491  
   492  	expected := &ast.BuiltinNode{
   493  		Name: "all",
   494  		Arguments: []ast.Node{
   495  			&ast.IdentifierNode{Value: "users"},
   496  			&ast.ClosureNode{
   497  				Node: &ast.BuiltinNode{
   498  					Name: "all",
   499  					Arguments: []ast.Node{
   500  						&ast.MemberNode{
   501  							Node:     &ast.PointerNode{},
   502  							Property: &ast.StringNode{Value: "Friends"},
   503  						},
   504  						&ast.ClosureNode{
   505  							Node: &ast.BinaryNode{
   506  								Operator: "&&",
   507  								Left: &ast.BinaryNode{
   508  									Operator: "==",
   509  									Left: &ast.MemberNode{
   510  										Node:     &ast.PointerNode{},
   511  										Property: &ast.StringNode{Value: "Age"},
   512  									},
   513  									Right: &ast.IntegerNode{Value: 18},
   514  								},
   515  								Right: &ast.BinaryNode{
   516  									Operator: "!=",
   517  									Left: &ast.MemberNode{
   518  										Node:     &ast.PointerNode{},
   519  										Property: &ast.StringNode{Value: "Name"},
   520  									},
   521  									Right: &ast.StringNode{Value: "Bob"},
   522  								},
   523  							},
   524  						},
   525  					},
   526  				},
   527  			},
   528  		},
   529  	}
   530  
   531  	assert.Equal(t, ast.Dump(expected), ast.Dump(tree.Node))
   532  }