vitess.io/vitess@v0.16.2/go/vt/vttablet/tabletserver/rules/rules_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package rules
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/json"
    22  	"reflect"
    23  	"regexp"
    24  	"strings"
    25  	"testing"
    26  
    27  	"github.com/stretchr/testify/assert"
    28  
    29  	"vitess.io/vitess/go/sqltypes"
    30  	"vitess.io/vitess/go/vt/sqlparser"
    31  	"vitess.io/vitess/go/vt/vterrors"
    32  	"vitess.io/vitess/go/vt/vttablet/tabletserver/planbuilder"
    33  
    34  	querypb "vitess.io/vitess/go/vt/proto/query"
    35  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    36  )
    37  
    38  func TestQueryRules(t *testing.T) {
    39  	qrs := New()
    40  	qr1 := NewQueryRule("rule 1", "r1", QRFail)
    41  	qr2 := NewQueryRule("rule 2", "r2", QRFail)
    42  	qrs.Add(qr1)
    43  	qrs.Add(qr2)
    44  
    45  	qrf := qrs.Find("r1")
    46  	if qrf != qr1 {
    47  		t.Errorf("want:\n%#v\ngot:\n%#v", qr1, qrf)
    48  	}
    49  
    50  	qrf = qrs.Find("r2")
    51  	if qrf != qr2 {
    52  		t.Errorf("want:\n%#v\ngot:\n%#v", qr2, qrf)
    53  	}
    54  
    55  	qrf = qrs.Find("unknown_rule")
    56  	if qrf != nil {
    57  		t.Fatalf("rule: unknown_rule does not exist, should get nil")
    58  	}
    59  
    60  	if qrs.rules[0] != qr1 {
    61  		t.Errorf("want:\n%#v\ngot:\n%#v", qr1, qrs.rules[0])
    62  	}
    63  
    64  	qrf = qrs.Delete("r1")
    65  	if qrf != qr1 {
    66  		t.Errorf("want:\n%#v\ngot:\n%#v", qr1, qrf)
    67  	}
    68  
    69  	if len(qrs.rules) != 1 {
    70  		t.Errorf("want 1, got %d", len(qrs.rules))
    71  	}
    72  
    73  	if qrs.rules[0] != qr2 {
    74  		t.Errorf("want:\n%#v\ngot:\n%#v", qr2, qrf)
    75  	}
    76  
    77  	qrf = qrs.Delete("unknown_rule")
    78  	if qrf != nil {
    79  		t.Fatalf("delete an unknown_rule, should return nil")
    80  	}
    81  }
    82  
    83  // TestCopy tests for deep copy
    84  func TestCopy(t *testing.T) {
    85  	qrs1 := New()
    86  	qr1 := NewQueryRule("rule 1", "r1", QRFail)
    87  	qr1.AddPlanCond(planbuilder.PlanSelect)
    88  	qr1.AddTableCond("aa")
    89  	qr1.AddBindVarCond("a", true, false, QRNoOp, nil)
    90  
    91  	qr2 := NewQueryRule("rule 2", "r2", QRFail)
    92  	qrs1.Add(qr1)
    93  	qrs1.Add(qr2)
    94  
    95  	qrs2 := qrs1.Copy()
    96  	if !reflect.DeepEqual(qrs2, qrs1) {
    97  		t.Errorf("qrs1: %+v, not equal to %+v", qrs2, qrs1)
    98  	}
    99  
   100  	qrs1 = New()
   101  	qrs2 = qrs1.Copy()
   102  	if !reflect.DeepEqual(qrs2, qrs1) {
   103  		t.Errorf("qrs1: %+v, not equal to %+v", qrs2, qrs1)
   104  	}
   105  }
   106  
   107  func TestFilterByPlan(t *testing.T) {
   108  	qrs := New()
   109  
   110  	qr1 := NewQueryRule("rule 1", "r1", QRFail)
   111  	qr1.SetIPCond("123")
   112  	qr1.SetQueryCond("select")
   113  	qr1.AddPlanCond(planbuilder.PlanSelect)
   114  	qr1.AddBindVarCond("a", true, false, QRNoOp, nil)
   115  
   116  	qr2 := NewQueryRule("rule 2", "r2", QRFail)
   117  	qr2.AddPlanCond(planbuilder.PlanSelect)
   118  	qr2.AddPlanCond(planbuilder.PlanSelect)
   119  	qr2.AddBindVarCond("a", true, false, QRNoOp, nil)
   120  
   121  	qr3 := NewQueryRule("rule 3", "r3", QRFail)
   122  	qr3.SetQueryCond("sele.*")
   123  	qr3.AddBindVarCond("a", true, false, QRNoOp, nil)
   124  
   125  	qr4 := NewQueryRule("rule 4", "r4", QRFail)
   126  	qr4.AddTableCond("b")
   127  	qr4.AddTableCond("c")
   128  
   129  	qrs.Add(qr1)
   130  	qrs.Add(qr2)
   131  	qrs.Add(qr3)
   132  	qrs.Add(qr4)
   133  
   134  	qrs1 := qrs.FilterByPlan("select", planbuilder.PlanSelect, "a")
   135  	want := compacted(`[{
   136  		"Description":"rule 1",
   137  		"Name":"r1",
   138  		"RequestIP":"123",
   139  		"BindVarConds":[{
   140  			"Name":"a",
   141  			"OnAbsent":true,
   142  			"Operator":""
   143  		}],
   144  		"Action":"FAIL"
   145  	},{
   146  		"Description":"rule 2",
   147  		"Name":"r2",
   148  		"BindVarConds":[{
   149  			"Name":"a",
   150  			"OnAbsent":true,
   151  			"Operator":""
   152  		}],
   153  		"Action":"FAIL"
   154  	},{
   155  		"Description":"rule 3",
   156  		"Name":"r3",
   157  		"BindVarConds":[{
   158  			"Name":"a",
   159  			"OnAbsent":true,
   160  			"Operator":""
   161  		}],
   162  		"Action":"FAIL"
   163  	}]`)
   164  	got := marshalled(qrs1)
   165  	if got != want {
   166  		t.Errorf("qrs1:\n%s, want\n%s", got, want)
   167  	}
   168  
   169  	qrs1 = qrs.FilterByPlan("insert", planbuilder.PlanSelect, "a")
   170  	want = compacted(`[{
   171  		"Description":"rule 2",
   172  		"Name":"r2",
   173  		"BindVarConds":[{
   174  			"Name":"a",
   175  			"OnAbsent":true,
   176  			"Operator":""
   177  		}],
   178  		"Action":"FAIL"
   179  	}]`)
   180  	got = marshalled(qrs1)
   181  	if got != want {
   182  		t.Errorf("qrs1:\n%s, want\n%s", got, want)
   183  	}
   184  	{
   185  		// test multiple tables:
   186  		qrs1 := qrs.FilterByPlan("insert", planbuilder.PlanSelect, "a", "other_table")
   187  		want := compacted(`[{
   188  			"Description":"rule 2",
   189  			"Name":"r2",
   190  			"BindVarConds":[{
   191  				"Name":"a",
   192  				"OnAbsent":true,
   193  				"Operator":""
   194  			}],
   195  			"Action":"FAIL"
   196  		}]`)
   197  		got = marshalled(qrs1)
   198  		if got != want {
   199  			t.Errorf("qrs1:\n%s, want\n%s", got, want)
   200  		}
   201  
   202  	}
   203  	{
   204  		// test multiple tables:
   205  		qrs1 := qrs.FilterByPlan("insert", planbuilder.PlanSelect, "other_table", "a")
   206  		want := compacted(`[{
   207  			"Description":"rule 2",
   208  			"Name":"r2",
   209  			"BindVarConds":[{
   210  				"Name":"a",
   211  				"OnAbsent":true,
   212  				"Operator":""
   213  			}],
   214  			"Action":"FAIL"
   215  		}]`)
   216  		got = marshalled(qrs1)
   217  		if got != want {
   218  			t.Errorf("qrs1:\n%s, want\n%s", got, want)
   219  		}
   220  	}
   221  
   222  	qrs1 = qrs.FilterByPlan("insert", planbuilder.PlanSelect, "a")
   223  	got = marshalled(qrs1)
   224  	if got != want {
   225  		t.Errorf("qrs1:\n%s, want\n%s", got, want)
   226  	}
   227  
   228  	qrs1 = qrs.FilterByPlan("select", planbuilder.PlanInsert, "a")
   229  	want = compacted(`[{
   230  		"Description":"rule 3",
   231  		"Name":"r3",
   232  		"BindVarConds":[{
   233  			"Name":"a",
   234  			"OnAbsent":true,
   235  			"Operator":""
   236  		}],
   237  		"Action":"FAIL"
   238  	}]`)
   239  	got = marshalled(qrs1)
   240  	if got != want {
   241  		t.Errorf("qrs1:\n%s, want\n%s", got, want)
   242  	}
   243  
   244  	qrs1 = qrs.FilterByPlan("sel", planbuilder.PlanInsert, "a")
   245  	if qrs1.rules != nil {
   246  		t.Errorf("want nil, got non-nil")
   247  	}
   248  
   249  	qrs1 = qrs.FilterByPlan("table", planbuilder.PlanInsert, "b")
   250  	want = compacted(`[{
   251  		"Description":"rule 4",
   252  		"Name":"r4",
   253  		"Action":"FAIL"
   254  	}]`)
   255  	got = marshalled(qrs1)
   256  	if got != want {
   257  		t.Errorf("qrs1:\n%s, want\n%s", got, want)
   258  	}
   259  
   260  	qr5 := NewQueryRule("rule 5", "r5", QRFail)
   261  	qrs.Add(qr5)
   262  
   263  	qrs1 = qrs.FilterByPlan("sel", planbuilder.PlanInsert, "a")
   264  	want = compacted(`[{
   265  		"Description":"rule 5",
   266  		"Name":"r5",
   267  		"Action":"FAIL"
   268  	}]`)
   269  	got = marshalled(qrs1)
   270  	if got != want {
   271  		t.Errorf("qrs1:\n%s, want\n%s", got, want)
   272  	}
   273  
   274  	qrsnil1 := New()
   275  	if qrsnil2 := qrsnil1.FilterByPlan("", planbuilder.PlanSelect, "a"); qrsnil2.rules != nil {
   276  		t.Errorf("want nil, got non-nil")
   277  	}
   278  }
   279  
   280  func TestQueryRule(t *testing.T) {
   281  	qr := NewQueryRule("rule 1", "r1", QRFail)
   282  	err := qr.SetIPCond("123")
   283  	if err != nil {
   284  		t.Errorf("unexpected: %v", err)
   285  	}
   286  	if !qr.requestIP.MatchString("123") {
   287  		t.Errorf("want match")
   288  	}
   289  	if qr.requestIP.MatchString("1234") {
   290  		t.Errorf("want no match")
   291  	}
   292  	if qr.requestIP.MatchString("12") {
   293  		t.Errorf("want no match")
   294  	}
   295  	err = qr.SetIPCond("[")
   296  	if err == nil {
   297  		t.Errorf("want error")
   298  	}
   299  
   300  	qr.AddPlanCond(planbuilder.PlanSelect)
   301  	qr.AddPlanCond(planbuilder.PlanInsert)
   302  
   303  	if qr.plans[0] != planbuilder.PlanSelect {
   304  		t.Errorf("want PASS_SELECT, got %s", qr.plans[0].String())
   305  	}
   306  	if qr.plans[1] != planbuilder.PlanInsert {
   307  		t.Errorf("want INSERT_PK, got %s", qr.plans[1].String())
   308  	}
   309  
   310  	qr.AddTableCond("a")
   311  	if qr.tableNames[0] != "a" {
   312  		t.Errorf("want a, got %s", qr.tableNames[0])
   313  	}
   314  }
   315  
   316  func TestBindVarStruct(t *testing.T) {
   317  	qr := NewQueryRule("rule 1", "r1", QRFail)
   318  
   319  	err := qr.AddBindVarCond("b", false, true, QRNoOp, nil)
   320  	if err != nil {
   321  		t.Errorf("unexpected: %v", err)
   322  	}
   323  	err = qr.AddBindVarCond("a", true, false, QRNoOp, nil)
   324  	if err != nil {
   325  		t.Errorf("unexpected: %v", err)
   326  	}
   327  	if qr.bindVarConds[1].name != "a" {
   328  		t.Errorf("want a, got %s", qr.bindVarConds[1].name)
   329  	}
   330  	if !qr.bindVarConds[1].onAbsent {
   331  		t.Errorf("want true, got false")
   332  	}
   333  	if qr.bindVarConds[1].onMismatch {
   334  		t.Errorf("want false, got true")
   335  	}
   336  	if qr.bindVarConds[1].op != QRNoOp {
   337  		t.Errorf("exepecting no-op, got %v", qr.bindVarConds[1])
   338  	}
   339  	if qr.bindVarConds[1].value != nil {
   340  		t.Errorf("want nil, got %#v", qr.bindVarConds[1].value)
   341  	}
   342  }
   343  
   344  type BVCreation struct {
   345  	name       string
   346  	onAbsent   bool
   347  	onMismatch bool
   348  	op         Operator
   349  	value      any
   350  	expecterr  bool
   351  }
   352  
   353  var creationCases = []BVCreation{
   354  	{"a", true, true, QREqual, uint64(1), false},
   355  	{"a", true, true, QRNotEqual, uint64(1), false},
   356  	{"a", true, true, QRLessThan, uint64(1), false},
   357  	{"a", true, true, QRGreaterEqual, uint64(1), false},
   358  	{"a", true, true, QRGreaterThan, uint64(1), false},
   359  	{"a", true, true, QRLessEqual, uint64(1), false},
   360  
   361  	{"a", true, true, QREqual, int64(1), false},
   362  	{"a", true, true, QRNotEqual, int64(1), false},
   363  	{"a", true, true, QRLessThan, int64(1), false},
   364  	{"a", true, true, QRGreaterEqual, int64(1), false},
   365  	{"a", true, true, QRGreaterThan, int64(1), false},
   366  	{"a", true, true, QRLessEqual, int64(1), false},
   367  
   368  	{"a", true, true, QREqual, "a", false},
   369  	{"a", true, true, QRNotEqual, "a", false},
   370  	{"a", true, true, QRLessThan, "a", false},
   371  	{"a", true, true, QRGreaterEqual, "a", false},
   372  	{"a", true, true, QRGreaterThan, "a", false},
   373  	{"a", true, true, QRLessEqual, "a", false},
   374  	{"a", true, true, QRMatch, "a", false},
   375  	{"a", true, true, QRNoMatch, "a", false},
   376  
   377  	{"a", true, true, QRMatch, int64(1), true},
   378  	{"a", true, true, QRNoMatch, int64(1), true},
   379  	{"a", true, true, QRMatch, "[", true},
   380  	{"a", true, true, QRNoMatch, "[", true},
   381  }
   382  
   383  func TestBVCreation(t *testing.T) {
   384  	qr := NewQueryRule("rule 1", "r1", QRFail)
   385  	for i, tcase := range creationCases {
   386  		err := qr.AddBindVarCond(tcase.name, tcase.onAbsent, tcase.onMismatch, tcase.op, tcase.value)
   387  		haserr := (err != nil)
   388  		if haserr != tcase.expecterr {
   389  			t.Errorf("test %d: got %v for %#v", i, haserr, tcase)
   390  		}
   391  	}
   392  }
   393  
   394  type BindVarTestCase struct {
   395  	bvc      BindVarCond
   396  	bvval    *querypb.BindVariable
   397  	expected bool
   398  }
   399  
   400  var bvtestcases = []BindVarTestCase{
   401  	{BindVarCond{"b", true, true, QRNoOp, nil}, sqltypes.Int64BindVariable(1), true},
   402  	{BindVarCond{"b", false, true, QRNoOp, nil}, sqltypes.Int64BindVariable(1), false},
   403  	{BindVarCond{"a", true, true, QRNoOp, nil}, sqltypes.Int64BindVariable(1), false},
   404  	{BindVarCond{"a", false, true, QRNoOp, nil}, sqltypes.Int64BindVariable(1), true},
   405  
   406  	{BindVarCond{"a", true, true, QREqual, bvcuint64(10)}, sqltypes.Int64BindVariable(1), false},
   407  	{BindVarCond{"a", true, true, QREqual, bvcuint64(10)}, sqltypes.Int64BindVariable(10), true},
   408  	{BindVarCond{"a", true, true, QREqual, bvcuint64(10)}, sqltypes.Uint64BindVariable(1), false},
   409  	{BindVarCond{"a", true, true, QREqual, bvcuint64(10)}, sqltypes.Uint64BindVariable(10), true},
   410  	{BindVarCond{"a", true, true, QREqual, bvcuint64(10)}, sqltypes.StringBindVariable("abc"), false},
   411  
   412  	{BindVarCond{"a", true, true, QRNotEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(1), true},
   413  	{BindVarCond{"a", true, true, QRNotEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(10), false},
   414  	{BindVarCond{"a", true, true, QRNotEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(11), true},
   415  	{BindVarCond{"a", true, true, QRNotEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(-1), true},
   416  
   417  	{BindVarCond{"a", true, true, QRLessThan, bvcuint64(10)}, sqltypes.Int64BindVariable(1), true},
   418  	{BindVarCond{"a", true, true, QRLessThan, bvcuint64(10)}, sqltypes.Int64BindVariable(10), false},
   419  	{BindVarCond{"a", true, true, QRLessThan, bvcuint64(10)}, sqltypes.Int64BindVariable(11), false},
   420  	{BindVarCond{"a", true, true, QRLessThan, bvcuint64(10)}, sqltypes.Int64BindVariable(-1), true},
   421  
   422  	{BindVarCond{"a", true, true, QRGreaterEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(1), false},
   423  	{BindVarCond{"a", true, true, QRGreaterEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(10), true},
   424  	{BindVarCond{"a", true, true, QRGreaterEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(11), true},
   425  	{BindVarCond{"a", true, true, QRGreaterEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(-1), false},
   426  
   427  	{BindVarCond{"a", true, true, QRGreaterThan, bvcuint64(10)}, sqltypes.Int64BindVariable(1), false},
   428  	{BindVarCond{"a", true, true, QRGreaterThan, bvcuint64(10)}, sqltypes.Int64BindVariable(10), false},
   429  	{BindVarCond{"a", true, true, QRGreaterThan, bvcuint64(10)}, sqltypes.Int64BindVariable(11), true},
   430  	{BindVarCond{"a", true, true, QRGreaterThan, bvcuint64(10)}, sqltypes.Int64BindVariable(-1), false},
   431  
   432  	{BindVarCond{"a", true, true, QRLessEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(1), true},
   433  	{BindVarCond{"a", true, true, QRLessEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(10), true},
   434  	{BindVarCond{"a", true, true, QRLessEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(11), false},
   435  	{BindVarCond{"a", true, true, QRLessEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(-1), true},
   436  
   437  	{BindVarCond{"a", true, true, QREqual, bvcint64(10)}, sqltypes.Int64BindVariable(1), false},
   438  	{BindVarCond{"a", true, true, QREqual, bvcint64(10)}, sqltypes.Int64BindVariable(10), true},
   439  	{BindVarCond{"a", true, true, QREqual, bvcint64(10)}, sqltypes.Uint64BindVariable(1), false},
   440  	{BindVarCond{"a", true, true, QREqual, bvcint64(10)}, sqltypes.Uint64BindVariable(0xFFFFFFFFFFFFFFFF), false},
   441  	{BindVarCond{"a", true, true, QREqual, bvcint64(10)}, sqltypes.Uint64BindVariable(10), true},
   442  	{BindVarCond{"a", true, true, QREqual, bvcint64(10)}, sqltypes.StringBindVariable("abc"), false},
   443  
   444  	{BindVarCond{"a", true, true, QRNotEqual, bvcint64(10)}, sqltypes.Int64BindVariable(1), true},
   445  	{BindVarCond{"a", true, true, QRNotEqual, bvcint64(10)}, sqltypes.Int64BindVariable(10), false},
   446  	{BindVarCond{"a", true, true, QRNotEqual, bvcint64(10)}, sqltypes.Int64BindVariable(11), true},
   447  	{BindVarCond{"a", true, true, QRNotEqual, bvcint64(10)}, sqltypes.Uint64BindVariable(0xFFFFFFFFFFFFFFFF), true},
   448  
   449  	{BindVarCond{"a", true, true, QRLessThan, bvcint64(10)}, sqltypes.Int64BindVariable(1), true},
   450  	{BindVarCond{"a", true, true, QRLessThan, bvcint64(10)}, sqltypes.Int64BindVariable(10), false},
   451  	{BindVarCond{"a", true, true, QRLessThan, bvcint64(10)}, sqltypes.Int64BindVariable(11), false},
   452  	{BindVarCond{"a", true, true, QRLessThan, bvcint64(10)}, sqltypes.Uint64BindVariable(0xFFFFFFFFFFFFFFFF), false},
   453  
   454  	{BindVarCond{"a", true, true, QRGreaterEqual, bvcint64(10)}, sqltypes.Int64BindVariable(1), false},
   455  	{BindVarCond{"a", true, true, QRGreaterEqual, bvcint64(10)}, sqltypes.Int64BindVariable(10), true},
   456  	{BindVarCond{"a", true, true, QRGreaterEqual, bvcint64(10)}, sqltypes.Int64BindVariable(11), true},
   457  	{BindVarCond{"a", true, true, QRGreaterEqual, bvcint64(10)}, sqltypes.Uint64BindVariable(0xFFFFFFFFFFFFFFFF), true},
   458  
   459  	{BindVarCond{"a", true, true, QRGreaterThan, bvcint64(10)}, sqltypes.Int64BindVariable(1), false},
   460  	{BindVarCond{"a", true, true, QRGreaterThan, bvcint64(10)}, sqltypes.Int64BindVariable(10), false},
   461  	{BindVarCond{"a", true, true, QRGreaterThan, bvcint64(10)}, sqltypes.Int64BindVariable(11), true},
   462  	{BindVarCond{"a", true, true, QRGreaterThan, bvcint64(10)}, sqltypes.Uint64BindVariable(0xFFFFFFFFFFFFFFFF), true},
   463  
   464  	{BindVarCond{"a", true, true, QRLessEqual, bvcint64(10)}, sqltypes.Int64BindVariable(1), true},
   465  	{BindVarCond{"a", true, true, QRLessEqual, bvcint64(10)}, sqltypes.Int64BindVariable(10), true},
   466  	{BindVarCond{"a", true, true, QRLessEqual, bvcint64(10)}, sqltypes.Int64BindVariable(11), false},
   467  	{BindVarCond{"a", true, true, QRLessEqual, bvcint64(10)}, sqltypes.Uint64BindVariable(0xFFFFFFFFFFFFFFFF), false},
   468  
   469  	{BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.StringBindVariable("a"), false},
   470  	{BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.StringBindVariable("b"), true},
   471  	{BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.StringBindVariable("c"), false},
   472  	{BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.BytesBindVariable([]byte("a")), false},
   473  	{BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.BytesBindVariable([]byte("b")), true},
   474  	{BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.BytesBindVariable([]byte("c")), false},
   475  	{BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.Int64BindVariable(1), false},
   476  
   477  	{BindVarCond{"a", true, true, QRNotEqual, bvcstring("b")}, sqltypes.StringBindVariable("a"), true},
   478  	{BindVarCond{"a", true, true, QRNotEqual, bvcstring("b")}, sqltypes.StringBindVariable("b"), false},
   479  	{BindVarCond{"a", true, true, QRNotEqual, bvcstring("b")}, sqltypes.StringBindVariable("c"), true},
   480  
   481  	{BindVarCond{"a", true, true, QRLessThan, bvcstring("b")}, sqltypes.StringBindVariable("a"), true},
   482  	{BindVarCond{"a", true, true, QRLessThan, bvcstring("b")}, sqltypes.StringBindVariable("b"), false},
   483  	{BindVarCond{"a", true, true, QRLessThan, bvcstring("b")}, sqltypes.StringBindVariable("c"), false},
   484  
   485  	{BindVarCond{"a", true, true, QRGreaterEqual, bvcstring("b")}, sqltypes.StringBindVariable("a"), false},
   486  	{BindVarCond{"a", true, true, QRGreaterEqual, bvcstring("b")}, sqltypes.StringBindVariable("b"), true},
   487  	{BindVarCond{"a", true, true, QRGreaterEqual, bvcstring("b")}, sqltypes.StringBindVariable("c"), true},
   488  
   489  	{BindVarCond{"a", true, true, QRGreaterThan, bvcstring("b")}, sqltypes.StringBindVariable("a"), false},
   490  	{BindVarCond{"a", true, true, QRGreaterThan, bvcstring("b")}, sqltypes.StringBindVariable("b"), false},
   491  	{BindVarCond{"a", true, true, QRGreaterThan, bvcstring("b")}, sqltypes.StringBindVariable("c"), true},
   492  
   493  	{BindVarCond{"a", true, true, QRLessEqual, bvcstring("b")}, sqltypes.StringBindVariable("a"), true},
   494  	{BindVarCond{"a", true, true, QRLessEqual, bvcstring("b")}, sqltypes.StringBindVariable("b"), true},
   495  	{BindVarCond{"a", true, true, QRLessEqual, bvcstring("b")}, sqltypes.StringBindVariable("c"), false},
   496  
   497  	{BindVarCond{"a", true, true, QRMatch, makere("a.*")}, sqltypes.StringBindVariable("c"), false},
   498  	{BindVarCond{"a", true, true, QRMatch, makere("a.*")}, sqltypes.StringBindVariable("a"), true},
   499  	{BindVarCond{"a", true, true, QRMatch, makere("a.*")}, sqltypes.Int64BindVariable(1), false},
   500  
   501  	{BindVarCond{"a", true, true, QRNoMatch, makere("a.*")}, sqltypes.StringBindVariable("c"), true},
   502  	{BindVarCond{"a", true, true, QRNoMatch, makere("a.*")}, sqltypes.StringBindVariable("a"), false},
   503  	{BindVarCond{"a", true, true, QRNoMatch, makere("a.*")}, sqltypes.Int64BindVariable(1), true},
   504  }
   505  
   506  func makere(s string) bvcre {
   507  	re, _ := regexp.Compile(s)
   508  	return bvcre{re}
   509  }
   510  
   511  func TestBVConditions(t *testing.T) {
   512  	bv := make(map[string]*querypb.BindVariable)
   513  	for _, tcase := range bvtestcases {
   514  		bv["a"] = tcase.bvval
   515  		if bvMatch(tcase.bvc, bv) != tcase.expected {
   516  			t.Errorf("bvmatch(%+v, %v): %v, want %v", tcase.bvc, tcase.bvval, !tcase.expected, tcase.expected)
   517  		}
   518  	}
   519  }
   520  
   521  func TestAction(t *testing.T) {
   522  	qrs := New()
   523  
   524  	qr1 := NewQueryRule("rule 1", "r1", QRFail)
   525  	qr1.SetIPCond("123")
   526  
   527  	qr2 := NewQueryRule("rule 2", "r2", QRFailRetry)
   528  	qr2.SetUserCond("user")
   529  
   530  	qr3 := NewQueryRule("rule 3", "r3", QRFail)
   531  	qr3.AddBindVarCond("a", true, true, QREqual, uint64(1))
   532  
   533  	qrs.Add(qr1)
   534  	qrs.Add(qr2)
   535  	qrs.Add(qr3)
   536  
   537  	bv := make(map[string]*querypb.BindVariable)
   538  	bv["a"] = sqltypes.Uint64BindVariable(0)
   539  
   540  	mc := sqlparser.MarginComments{
   541  		Leading:  "some comments leading the query",
   542  		Trailing: "other trailing comments",
   543  	}
   544  
   545  	action, cancelCtx, desc := qrs.GetAction("123", "user1", bv, mc)
   546  	assert.Equalf(t, action, QRFail, "expected fail, got %v", action)
   547  	assert.Equalf(t, desc, "rule 1", "want rule 1, got %s", desc)
   548  	assert.Nil(t, cancelCtx)
   549  
   550  	action, cancelCtx, desc = qrs.GetAction("1234", "user", bv, mc)
   551  	assert.Equalf(t, action, QRFailRetry, "want fail_retry, got: %s", action)
   552  	assert.Equalf(t, desc, "rule 2", "want rule 2, got %s", desc)
   553  	assert.Nil(t, cancelCtx)
   554  
   555  	action, _, _ = qrs.GetAction("1234", "user1", bv, mc)
   556  	assert.Equalf(t, action, QRContinue, "want continue, got %s", action)
   557  
   558  	bv["a"] = sqltypes.Uint64BindVariable(1)
   559  	action, _, desc = qrs.GetAction("1234", "user1", bv, mc)
   560  	assert.Equalf(t, action, QRFail, "want fail, got %s", action)
   561  	assert.Equalf(t, desc, "rule 3", "want rule 3, got %s", desc)
   562  
   563  	// reset bound variable 'a' to 0 so it doesn't match rule 3
   564  	bv["a"] = sqltypes.Uint64BindVariable(0)
   565  
   566  	qr4 := NewQueryRule("rule 4", "r4", QRFail)
   567  	qr4.SetTrailingCommentCond(".*trailing.*")
   568  
   569  	newQrs := qrs.Copy()
   570  	newQrs.Add(qr4)
   571  
   572  	action, _, desc = newQrs.GetAction("1234", "user1", bv, mc)
   573  	assert.Equalf(t, action, QRFail, "want fail, got %s", action)
   574  	assert.Equalf(t, desc, "rule 4", "want rule 4, got %s", desc)
   575  
   576  	qr5 := NewQueryRule("rule 5", "r4", QRFail)
   577  	qr5.SetLeadingCommentCond(".*leading.*")
   578  
   579  	newQrs = qrs.Copy()
   580  	newQrs.Add(qr5)
   581  	action, _, desc = newQrs.GetAction("1234", "user1", bv, mc)
   582  	assert.Equalf(t, action, QRFail, "want fail, got %s", action)
   583  	assert.Equalf(t, desc, "rule 5", "want rule 5, got %s", desc)
   584  }
   585  
   586  func TestImport(t *testing.T) {
   587  	var qrs = New()
   588  	jsondata := `[{
   589  		"Description": "desc1",
   590  		"Name": "name1",
   591  		"RequestIP": "123.123.123",
   592  		"User": "user",
   593  		"Query": "query",
   594  		"Plans": ["Select", "Insert"],
   595  		"TableNames":["a", "b"],
   596  		"BindVarConds": [{
   597  			"Name": "bvname1",
   598  			"OnAbsent": true,
   599  			"Operator": ""
   600  		},{
   601  			"Name": "bvname2",
   602  			"OnAbsent": true,
   603  			"OnMismatch": true,
   604  			"Operator": "==",
   605  			"Value": 123
   606  		}],
   607  		"Action": "FAIL_RETRY"
   608  	},{
   609  		"Description": "desc2",
   610  		"Name": "name2",
   611  		"Action": "FAIL"
   612  	}]`
   613  	err := qrs.UnmarshalJSON([]byte(jsondata))
   614  	if err != nil {
   615  		t.Error(err)
   616  		return
   617  	}
   618  	got := marshalled(qrs)
   619  	want := compacted(jsondata)
   620  	if got != want {
   621  		t.Errorf("qrs:\n%s, want\n%s", got, want)
   622  	}
   623  }
   624  
   625  type ValidJSONCase struct {
   626  	input string
   627  	op    Operator
   628  	typ   int
   629  }
   630  
   631  const (
   632  	UINT = iota
   633  	INT
   634  	STR
   635  	REGEXP
   636  )
   637  
   638  var validjsons = []ValidJSONCase{
   639  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "==", "Value": 18446744073709551615}]}]`, QREqual, UINT},
   640  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "!=", "Value": 18446744073709551615}]}]`, QRNotEqual, UINT},
   641  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "<", "Value": 18446744073709551615}]}]`, QRLessThan, UINT},
   642  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": ">=", "Value": 18446744073709551615}]}]`, QRGreaterEqual, UINT},
   643  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": ">", "Value": 18446744073709551615}]}]`, QRGreaterThan, UINT},
   644  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "<=", "Value": 18446744073709551615}]}]`, QRLessEqual, UINT},
   645  
   646  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "==", "Value": -123}]}]`, QREqual, INT},
   647  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "!=", "Value": -123}]}]`, QRNotEqual, INT},
   648  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "<", "Value": -123}]}]`, QRLessThan, INT},
   649  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": ">=", "Value": -123}]}]`, QRGreaterEqual, INT},
   650  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": ">", "Value": -123}]}]`, QRGreaterThan, INT},
   651  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "<=", "Value": -123}]}]`, QRLessEqual, INT},
   652  
   653  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "==", "Value": "123"}]}]`, QREqual, STR},
   654  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "!=", "Value": "123"}]}]`, QRNotEqual, STR},
   655  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "<", "Value": "123"}]}]`, QRLessThan, STR},
   656  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": ">=", "Value": "123"}]}]`, QRGreaterEqual, STR},
   657  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": ">", "Value": "123"}]}]`, QRGreaterThan, STR},
   658  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "<=", "Value": "123"}]}]`, QRLessEqual, STR},
   659  
   660  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "MATCH", "Value": "123"}]}]`, QRMatch, REGEXP},
   661  	{`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "NOMATCH", "Value": "123"}]}]`, QRNoMatch, REGEXP},
   662  }
   663  
   664  func TestValidJSON(t *testing.T) {
   665  	for i, tcase := range validjsons {
   666  		qrs := New()
   667  		err := qrs.UnmarshalJSON([]byte(tcase.input))
   668  		if err != nil {
   669  			t.Fatalf("Unexpected error for case %d: %v", i, err)
   670  		}
   671  		bvc := qrs.rules[0].bindVarConds[0]
   672  		if bvc.op != tcase.op {
   673  			t.Errorf("want %v, got %v", tcase.op, bvc.op)
   674  		}
   675  		switch tcase.typ {
   676  		case UINT:
   677  			if bvc.value.(bvcuint64) != bvcuint64(18446744073709551615) {
   678  				t.Errorf("want %v, got %v", uint64(18446744073709551615), bvc.value.(bvcuint64))
   679  			}
   680  		case INT:
   681  			if bvc.value.(bvcint64) != -123 {
   682  				t.Errorf("want %v, got %v", -123, bvc.value.(bvcint64))
   683  			}
   684  		case STR:
   685  			if bvc.value.(bvcstring) != "123" {
   686  				t.Errorf("want %v, got %v", "123", bvc.value.(bvcint64))
   687  			}
   688  		case REGEXP:
   689  			if bvc.value.(bvcre).re == nil {
   690  				t.Errorf("want non-nil")
   691  			}
   692  		}
   693  	}
   694  }
   695  
   696  type InvalidJSONCase struct {
   697  	input, err string
   698  }
   699  
   700  var invalidjsons = []InvalidJSONCase{
   701  	{`[{"Name": 1 }]`, "want string for Name"},
   702  	{`[{"Description": 1 }]`, "want string for Description"},
   703  	{`[{"RequestIP": 1 }]`, "want string for RequestIP"},
   704  	{`[{"User": 1 }]`, "want string for User"},
   705  	{`[{"Query": 1 }]`, "want string for Query"},
   706  	{`[{"Plans": 1 }]`, "want list for Plans"},
   707  	{`[{"TableNames": 1 }]`, "want list for TableNames"},
   708  	{`[{"BindVarConds": 1 }]`, "want list for BindVarConds"},
   709  	{`[{"RequestIP": "[" }]`, "could not set IP condition: ["},
   710  	{`[{"User": "[" }]`, "could not set User condition: ["},
   711  	{`[{"Query": "[" }]`, "could not set Query condition: ["},
   712  	{`[{"Plans": [1] }]`, "want string for Plans"},
   713  	{`[{"Plans": ["invalid"] }]`, "invalid plan name: invalid"},
   714  	{`[{"TableNames": [1] }]`, "want string for TableNames"},
   715  	{`[{"BindVarConds": [1] }]`, "want json object for bind var conditions"},
   716  	{`[{"BindVarConds": [{}] }]`, "Name missing in BindVarConds"},
   717  	{`[{"BindVarConds": [{"Name": 1}] }]`, "want string for Name in BindVarConds"},
   718  	{`[{"BindVarConds": [{"Name": "a"}] }]`, "OnAbsent missing in BindVarConds"},
   719  	{`[{"BindVarConds": [{"Name": "a", "OnAbsent": 1}] }]`, "want bool for OnAbsent"},
   720  	{`[{"BindVarConds": [{"Name": "a", "OnAbsent": true}]}]`, "Operator missing in BindVarConds"},
   721  	{`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "a"}]}]`, "invalid Operator a"},
   722  	{`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "=="}]}]`, "Value missing in BindVarConds"},
   723  	{`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "==", "Value": 1.2}]}]`, "want int64/uint64: 1.2"},
   724  	{`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "==", "Value": {}}]}]`, "want string or number: map[]"},
   725  	{`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "MATCH", "Value": 1}]}]`, "want string: 1"},
   726  	{`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "NOMATCH", "Value": 1}]}]`, "want string: 1"},
   727  	{`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": 123, "Value": "1"}]}]`, "want string for Operator"},
   728  	{`[{"Unknown": [{"Name": "a"}]}]`, "unrecognized tag Unknown"},
   729  	{`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "<=", "Value": "1"}]}]`, "OnMismatch missing in BindVarConds"},
   730  	{`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "OnMismatch": true, "Operator": "MATCH", "Value": "["}]}]`, "processing [: error parsing regexp: missing closing ]: `[$`"},
   731  	{`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "OnMismatch": true, "Operator": "NOMATCH", "Value": "["}]}]`, "processing [: error parsing regexp: missing closing ]: `[$`"},
   732  	{`[{"Action": 1 }]`, "want string for Action"},
   733  	{`[{"Action": "foo" }]`, "invalid Action foo"},
   734  }
   735  
   736  func TestInvalidJSON(t *testing.T) {
   737  	for _, tcase := range invalidjsons {
   738  		qrs := New()
   739  		err := qrs.UnmarshalJSON([]byte(tcase.input))
   740  		if err == nil {
   741  			t.Errorf("want error for case %q", tcase.input)
   742  			continue
   743  		}
   744  		recvd := strings.Replace(err.Error(), "fatal: ", "", 1)
   745  		if recvd != tcase.err {
   746  			t.Errorf("invalid json: %s, want '%v', got '%v'", tcase.input, tcase.err, recvd)
   747  		}
   748  	}
   749  	qrs := New()
   750  	err := qrs.UnmarshalJSON([]byte(`{`))
   751  	if code := vterrors.Code(err); code != vtrpcpb.Code_INVALID_ARGUMENT {
   752  		t.Errorf("qrs.UnmarshalJSON: %v, want %v", code, vtrpcpb.Code_INVALID_ARGUMENT)
   753  	}
   754  }
   755  
   756  func TestBuildQueryRuleActionFail(t *testing.T) {
   757  	var ruleInfo map[string]any
   758  	err := json.Unmarshal([]byte(`{"Action": "FAIL" }`), &ruleInfo)
   759  	if err != nil {
   760  		t.Fatalf("failed to unmarshal json, got error: %v", err)
   761  	}
   762  	qr, err := BuildQueryRule(ruleInfo)
   763  	if err != nil {
   764  		t.Fatalf("build query rule should succeed")
   765  	}
   766  	if qr.act != QRFail {
   767  		t.Fatalf("action should fail")
   768  	}
   769  }
   770  
   771  func TestBadAddBindVarCond(t *testing.T) {
   772  	qr1 := NewQueryRule("rule 1", "r1", QRFail)
   773  	err := qr1.AddBindVarCond("a", true, false, QRMatch, uint64(1))
   774  	if err == nil {
   775  		t.Fatalf("invalid op: QRMatch for value type: uint64")
   776  	}
   777  }
   778  
   779  func TestOpNames(t *testing.T) {
   780  	want := []string{
   781  		"",
   782  		"==",
   783  		"!=",
   784  		"<",
   785  		">=",
   786  		">",
   787  		"<=",
   788  		"MATCH",
   789  		"NOMATCH",
   790  	}
   791  	if !reflect.DeepEqual(opnames, want) {
   792  		t.Errorf("opnames: \n%v, want \n%v", opnames, want)
   793  	}
   794  }
   795  
   796  func compacted(in string) string {
   797  	dst := bytes.NewBuffer(nil)
   798  	err := json.Compact(dst, []byte(in))
   799  	if err != nil {
   800  		panic(err)
   801  	}
   802  	return dst.String()
   803  }
   804  
   805  func marshalled(in any) string {
   806  	b, err := json.Marshal(in)
   807  	if err != nil {
   808  		panic(err)
   809  	}
   810  	return string(b)
   811  }