vitess.io/vitess@v0.16.2/go/vt/vttablet/tabletserver/rules/rules_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package rules 18 19 import ( 20 "bytes" 21 "encoding/json" 22 "reflect" 23 "regexp" 24 "strings" 25 "testing" 26 27 "github.com/stretchr/testify/assert" 28 29 "vitess.io/vitess/go/sqltypes" 30 "vitess.io/vitess/go/vt/sqlparser" 31 "vitess.io/vitess/go/vt/vterrors" 32 "vitess.io/vitess/go/vt/vttablet/tabletserver/planbuilder" 33 34 querypb "vitess.io/vitess/go/vt/proto/query" 35 vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" 36 ) 37 38 func TestQueryRules(t *testing.T) { 39 qrs := New() 40 qr1 := NewQueryRule("rule 1", "r1", QRFail) 41 qr2 := NewQueryRule("rule 2", "r2", QRFail) 42 qrs.Add(qr1) 43 qrs.Add(qr2) 44 45 qrf := qrs.Find("r1") 46 if qrf != qr1 { 47 t.Errorf("want:\n%#v\ngot:\n%#v", qr1, qrf) 48 } 49 50 qrf = qrs.Find("r2") 51 if qrf != qr2 { 52 t.Errorf("want:\n%#v\ngot:\n%#v", qr2, qrf) 53 } 54 55 qrf = qrs.Find("unknown_rule") 56 if qrf != nil { 57 t.Fatalf("rule: unknown_rule does not exist, should get nil") 58 } 59 60 if qrs.rules[0] != qr1 { 61 t.Errorf("want:\n%#v\ngot:\n%#v", qr1, qrs.rules[0]) 62 } 63 64 qrf = qrs.Delete("r1") 65 if qrf != qr1 { 66 t.Errorf("want:\n%#v\ngot:\n%#v", qr1, qrf) 67 } 68 69 if len(qrs.rules) != 1 { 70 t.Errorf("want 1, got %d", len(qrs.rules)) 71 } 72 73 if qrs.rules[0] != qr2 { 74 t.Errorf("want:\n%#v\ngot:\n%#v", qr2, qrf) 75 } 76 77 qrf = qrs.Delete("unknown_rule") 78 if qrf != nil { 79 t.Fatalf("delete an unknown_rule, should return nil") 80 } 81 } 82 83 // TestCopy tests for deep copy 84 func TestCopy(t *testing.T) { 85 qrs1 := New() 86 qr1 := NewQueryRule("rule 1", "r1", QRFail) 87 qr1.AddPlanCond(planbuilder.PlanSelect) 88 qr1.AddTableCond("aa") 89 qr1.AddBindVarCond("a", true, false, QRNoOp, nil) 90 91 qr2 := NewQueryRule("rule 2", "r2", QRFail) 92 qrs1.Add(qr1) 93 qrs1.Add(qr2) 94 95 qrs2 := qrs1.Copy() 96 if !reflect.DeepEqual(qrs2, qrs1) { 97 t.Errorf("qrs1: %+v, not equal to %+v", qrs2, qrs1) 98 } 99 100 qrs1 = New() 101 qrs2 = qrs1.Copy() 102 if !reflect.DeepEqual(qrs2, qrs1) { 103 t.Errorf("qrs1: %+v, not equal to %+v", qrs2, qrs1) 104 } 105 } 106 107 func TestFilterByPlan(t *testing.T) { 108 qrs := New() 109 110 qr1 := NewQueryRule("rule 1", "r1", QRFail) 111 qr1.SetIPCond("123") 112 qr1.SetQueryCond("select") 113 qr1.AddPlanCond(planbuilder.PlanSelect) 114 qr1.AddBindVarCond("a", true, false, QRNoOp, nil) 115 116 qr2 := NewQueryRule("rule 2", "r2", QRFail) 117 qr2.AddPlanCond(planbuilder.PlanSelect) 118 qr2.AddPlanCond(planbuilder.PlanSelect) 119 qr2.AddBindVarCond("a", true, false, QRNoOp, nil) 120 121 qr3 := NewQueryRule("rule 3", "r3", QRFail) 122 qr3.SetQueryCond("sele.*") 123 qr3.AddBindVarCond("a", true, false, QRNoOp, nil) 124 125 qr4 := NewQueryRule("rule 4", "r4", QRFail) 126 qr4.AddTableCond("b") 127 qr4.AddTableCond("c") 128 129 qrs.Add(qr1) 130 qrs.Add(qr2) 131 qrs.Add(qr3) 132 qrs.Add(qr4) 133 134 qrs1 := qrs.FilterByPlan("select", planbuilder.PlanSelect, "a") 135 want := compacted(`[{ 136 "Description":"rule 1", 137 "Name":"r1", 138 "RequestIP":"123", 139 "BindVarConds":[{ 140 "Name":"a", 141 "OnAbsent":true, 142 "Operator":"" 143 }], 144 "Action":"FAIL" 145 },{ 146 "Description":"rule 2", 147 "Name":"r2", 148 "BindVarConds":[{ 149 "Name":"a", 150 "OnAbsent":true, 151 "Operator":"" 152 }], 153 "Action":"FAIL" 154 },{ 155 "Description":"rule 3", 156 "Name":"r3", 157 "BindVarConds":[{ 158 "Name":"a", 159 "OnAbsent":true, 160 "Operator":"" 161 }], 162 "Action":"FAIL" 163 }]`) 164 got := marshalled(qrs1) 165 if got != want { 166 t.Errorf("qrs1:\n%s, want\n%s", got, want) 167 } 168 169 qrs1 = qrs.FilterByPlan("insert", planbuilder.PlanSelect, "a") 170 want = compacted(`[{ 171 "Description":"rule 2", 172 "Name":"r2", 173 "BindVarConds":[{ 174 "Name":"a", 175 "OnAbsent":true, 176 "Operator":"" 177 }], 178 "Action":"FAIL" 179 }]`) 180 got = marshalled(qrs1) 181 if got != want { 182 t.Errorf("qrs1:\n%s, want\n%s", got, want) 183 } 184 { 185 // test multiple tables: 186 qrs1 := qrs.FilterByPlan("insert", planbuilder.PlanSelect, "a", "other_table") 187 want := compacted(`[{ 188 "Description":"rule 2", 189 "Name":"r2", 190 "BindVarConds":[{ 191 "Name":"a", 192 "OnAbsent":true, 193 "Operator":"" 194 }], 195 "Action":"FAIL" 196 }]`) 197 got = marshalled(qrs1) 198 if got != want { 199 t.Errorf("qrs1:\n%s, want\n%s", got, want) 200 } 201 202 } 203 { 204 // test multiple tables: 205 qrs1 := qrs.FilterByPlan("insert", planbuilder.PlanSelect, "other_table", "a") 206 want := compacted(`[{ 207 "Description":"rule 2", 208 "Name":"r2", 209 "BindVarConds":[{ 210 "Name":"a", 211 "OnAbsent":true, 212 "Operator":"" 213 }], 214 "Action":"FAIL" 215 }]`) 216 got = marshalled(qrs1) 217 if got != want { 218 t.Errorf("qrs1:\n%s, want\n%s", got, want) 219 } 220 } 221 222 qrs1 = qrs.FilterByPlan("insert", planbuilder.PlanSelect, "a") 223 got = marshalled(qrs1) 224 if got != want { 225 t.Errorf("qrs1:\n%s, want\n%s", got, want) 226 } 227 228 qrs1 = qrs.FilterByPlan("select", planbuilder.PlanInsert, "a") 229 want = compacted(`[{ 230 "Description":"rule 3", 231 "Name":"r3", 232 "BindVarConds":[{ 233 "Name":"a", 234 "OnAbsent":true, 235 "Operator":"" 236 }], 237 "Action":"FAIL" 238 }]`) 239 got = marshalled(qrs1) 240 if got != want { 241 t.Errorf("qrs1:\n%s, want\n%s", got, want) 242 } 243 244 qrs1 = qrs.FilterByPlan("sel", planbuilder.PlanInsert, "a") 245 if qrs1.rules != nil { 246 t.Errorf("want nil, got non-nil") 247 } 248 249 qrs1 = qrs.FilterByPlan("table", planbuilder.PlanInsert, "b") 250 want = compacted(`[{ 251 "Description":"rule 4", 252 "Name":"r4", 253 "Action":"FAIL" 254 }]`) 255 got = marshalled(qrs1) 256 if got != want { 257 t.Errorf("qrs1:\n%s, want\n%s", got, want) 258 } 259 260 qr5 := NewQueryRule("rule 5", "r5", QRFail) 261 qrs.Add(qr5) 262 263 qrs1 = qrs.FilterByPlan("sel", planbuilder.PlanInsert, "a") 264 want = compacted(`[{ 265 "Description":"rule 5", 266 "Name":"r5", 267 "Action":"FAIL" 268 }]`) 269 got = marshalled(qrs1) 270 if got != want { 271 t.Errorf("qrs1:\n%s, want\n%s", got, want) 272 } 273 274 qrsnil1 := New() 275 if qrsnil2 := qrsnil1.FilterByPlan("", planbuilder.PlanSelect, "a"); qrsnil2.rules != nil { 276 t.Errorf("want nil, got non-nil") 277 } 278 } 279 280 func TestQueryRule(t *testing.T) { 281 qr := NewQueryRule("rule 1", "r1", QRFail) 282 err := qr.SetIPCond("123") 283 if err != nil { 284 t.Errorf("unexpected: %v", err) 285 } 286 if !qr.requestIP.MatchString("123") { 287 t.Errorf("want match") 288 } 289 if qr.requestIP.MatchString("1234") { 290 t.Errorf("want no match") 291 } 292 if qr.requestIP.MatchString("12") { 293 t.Errorf("want no match") 294 } 295 err = qr.SetIPCond("[") 296 if err == nil { 297 t.Errorf("want error") 298 } 299 300 qr.AddPlanCond(planbuilder.PlanSelect) 301 qr.AddPlanCond(planbuilder.PlanInsert) 302 303 if qr.plans[0] != planbuilder.PlanSelect { 304 t.Errorf("want PASS_SELECT, got %s", qr.plans[0].String()) 305 } 306 if qr.plans[1] != planbuilder.PlanInsert { 307 t.Errorf("want INSERT_PK, got %s", qr.plans[1].String()) 308 } 309 310 qr.AddTableCond("a") 311 if qr.tableNames[0] != "a" { 312 t.Errorf("want a, got %s", qr.tableNames[0]) 313 } 314 } 315 316 func TestBindVarStruct(t *testing.T) { 317 qr := NewQueryRule("rule 1", "r1", QRFail) 318 319 err := qr.AddBindVarCond("b", false, true, QRNoOp, nil) 320 if err != nil { 321 t.Errorf("unexpected: %v", err) 322 } 323 err = qr.AddBindVarCond("a", true, false, QRNoOp, nil) 324 if err != nil { 325 t.Errorf("unexpected: %v", err) 326 } 327 if qr.bindVarConds[1].name != "a" { 328 t.Errorf("want a, got %s", qr.bindVarConds[1].name) 329 } 330 if !qr.bindVarConds[1].onAbsent { 331 t.Errorf("want true, got false") 332 } 333 if qr.bindVarConds[1].onMismatch { 334 t.Errorf("want false, got true") 335 } 336 if qr.bindVarConds[1].op != QRNoOp { 337 t.Errorf("exepecting no-op, got %v", qr.bindVarConds[1]) 338 } 339 if qr.bindVarConds[1].value != nil { 340 t.Errorf("want nil, got %#v", qr.bindVarConds[1].value) 341 } 342 } 343 344 type BVCreation struct { 345 name string 346 onAbsent bool 347 onMismatch bool 348 op Operator 349 value any 350 expecterr bool 351 } 352 353 var creationCases = []BVCreation{ 354 {"a", true, true, QREqual, uint64(1), false}, 355 {"a", true, true, QRNotEqual, uint64(1), false}, 356 {"a", true, true, QRLessThan, uint64(1), false}, 357 {"a", true, true, QRGreaterEqual, uint64(1), false}, 358 {"a", true, true, QRGreaterThan, uint64(1), false}, 359 {"a", true, true, QRLessEqual, uint64(1), false}, 360 361 {"a", true, true, QREqual, int64(1), false}, 362 {"a", true, true, QRNotEqual, int64(1), false}, 363 {"a", true, true, QRLessThan, int64(1), false}, 364 {"a", true, true, QRGreaterEqual, int64(1), false}, 365 {"a", true, true, QRGreaterThan, int64(1), false}, 366 {"a", true, true, QRLessEqual, int64(1), false}, 367 368 {"a", true, true, QREqual, "a", false}, 369 {"a", true, true, QRNotEqual, "a", false}, 370 {"a", true, true, QRLessThan, "a", false}, 371 {"a", true, true, QRGreaterEqual, "a", false}, 372 {"a", true, true, QRGreaterThan, "a", false}, 373 {"a", true, true, QRLessEqual, "a", false}, 374 {"a", true, true, QRMatch, "a", false}, 375 {"a", true, true, QRNoMatch, "a", false}, 376 377 {"a", true, true, QRMatch, int64(1), true}, 378 {"a", true, true, QRNoMatch, int64(1), true}, 379 {"a", true, true, QRMatch, "[", true}, 380 {"a", true, true, QRNoMatch, "[", true}, 381 } 382 383 func TestBVCreation(t *testing.T) { 384 qr := NewQueryRule("rule 1", "r1", QRFail) 385 for i, tcase := range creationCases { 386 err := qr.AddBindVarCond(tcase.name, tcase.onAbsent, tcase.onMismatch, tcase.op, tcase.value) 387 haserr := (err != nil) 388 if haserr != tcase.expecterr { 389 t.Errorf("test %d: got %v for %#v", i, haserr, tcase) 390 } 391 } 392 } 393 394 type BindVarTestCase struct { 395 bvc BindVarCond 396 bvval *querypb.BindVariable 397 expected bool 398 } 399 400 var bvtestcases = []BindVarTestCase{ 401 {BindVarCond{"b", true, true, QRNoOp, nil}, sqltypes.Int64BindVariable(1), true}, 402 {BindVarCond{"b", false, true, QRNoOp, nil}, sqltypes.Int64BindVariable(1), false}, 403 {BindVarCond{"a", true, true, QRNoOp, nil}, sqltypes.Int64BindVariable(1), false}, 404 {BindVarCond{"a", false, true, QRNoOp, nil}, sqltypes.Int64BindVariable(1), true}, 405 406 {BindVarCond{"a", true, true, QREqual, bvcuint64(10)}, sqltypes.Int64BindVariable(1), false}, 407 {BindVarCond{"a", true, true, QREqual, bvcuint64(10)}, sqltypes.Int64BindVariable(10), true}, 408 {BindVarCond{"a", true, true, QREqual, bvcuint64(10)}, sqltypes.Uint64BindVariable(1), false}, 409 {BindVarCond{"a", true, true, QREqual, bvcuint64(10)}, sqltypes.Uint64BindVariable(10), true}, 410 {BindVarCond{"a", true, true, QREqual, bvcuint64(10)}, sqltypes.StringBindVariable("abc"), false}, 411 412 {BindVarCond{"a", true, true, QRNotEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(1), true}, 413 {BindVarCond{"a", true, true, QRNotEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(10), false}, 414 {BindVarCond{"a", true, true, QRNotEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(11), true}, 415 {BindVarCond{"a", true, true, QRNotEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(-1), true}, 416 417 {BindVarCond{"a", true, true, QRLessThan, bvcuint64(10)}, sqltypes.Int64BindVariable(1), true}, 418 {BindVarCond{"a", true, true, QRLessThan, bvcuint64(10)}, sqltypes.Int64BindVariable(10), false}, 419 {BindVarCond{"a", true, true, QRLessThan, bvcuint64(10)}, sqltypes.Int64BindVariable(11), false}, 420 {BindVarCond{"a", true, true, QRLessThan, bvcuint64(10)}, sqltypes.Int64BindVariable(-1), true}, 421 422 {BindVarCond{"a", true, true, QRGreaterEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(1), false}, 423 {BindVarCond{"a", true, true, QRGreaterEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(10), true}, 424 {BindVarCond{"a", true, true, QRGreaterEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(11), true}, 425 {BindVarCond{"a", true, true, QRGreaterEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(-1), false}, 426 427 {BindVarCond{"a", true, true, QRGreaterThan, bvcuint64(10)}, sqltypes.Int64BindVariable(1), false}, 428 {BindVarCond{"a", true, true, QRGreaterThan, bvcuint64(10)}, sqltypes.Int64BindVariable(10), false}, 429 {BindVarCond{"a", true, true, QRGreaterThan, bvcuint64(10)}, sqltypes.Int64BindVariable(11), true}, 430 {BindVarCond{"a", true, true, QRGreaterThan, bvcuint64(10)}, sqltypes.Int64BindVariable(-1), false}, 431 432 {BindVarCond{"a", true, true, QRLessEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(1), true}, 433 {BindVarCond{"a", true, true, QRLessEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(10), true}, 434 {BindVarCond{"a", true, true, QRLessEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(11), false}, 435 {BindVarCond{"a", true, true, QRLessEqual, bvcuint64(10)}, sqltypes.Int64BindVariable(-1), true}, 436 437 {BindVarCond{"a", true, true, QREqual, bvcint64(10)}, sqltypes.Int64BindVariable(1), false}, 438 {BindVarCond{"a", true, true, QREqual, bvcint64(10)}, sqltypes.Int64BindVariable(10), true}, 439 {BindVarCond{"a", true, true, QREqual, bvcint64(10)}, sqltypes.Uint64BindVariable(1), false}, 440 {BindVarCond{"a", true, true, QREqual, bvcint64(10)}, sqltypes.Uint64BindVariable(0xFFFFFFFFFFFFFFFF), false}, 441 {BindVarCond{"a", true, true, QREqual, bvcint64(10)}, sqltypes.Uint64BindVariable(10), true}, 442 {BindVarCond{"a", true, true, QREqual, bvcint64(10)}, sqltypes.StringBindVariable("abc"), false}, 443 444 {BindVarCond{"a", true, true, QRNotEqual, bvcint64(10)}, sqltypes.Int64BindVariable(1), true}, 445 {BindVarCond{"a", true, true, QRNotEqual, bvcint64(10)}, sqltypes.Int64BindVariable(10), false}, 446 {BindVarCond{"a", true, true, QRNotEqual, bvcint64(10)}, sqltypes.Int64BindVariable(11), true}, 447 {BindVarCond{"a", true, true, QRNotEqual, bvcint64(10)}, sqltypes.Uint64BindVariable(0xFFFFFFFFFFFFFFFF), true}, 448 449 {BindVarCond{"a", true, true, QRLessThan, bvcint64(10)}, sqltypes.Int64BindVariable(1), true}, 450 {BindVarCond{"a", true, true, QRLessThan, bvcint64(10)}, sqltypes.Int64BindVariable(10), false}, 451 {BindVarCond{"a", true, true, QRLessThan, bvcint64(10)}, sqltypes.Int64BindVariable(11), false}, 452 {BindVarCond{"a", true, true, QRLessThan, bvcint64(10)}, sqltypes.Uint64BindVariable(0xFFFFFFFFFFFFFFFF), false}, 453 454 {BindVarCond{"a", true, true, QRGreaterEqual, bvcint64(10)}, sqltypes.Int64BindVariable(1), false}, 455 {BindVarCond{"a", true, true, QRGreaterEqual, bvcint64(10)}, sqltypes.Int64BindVariable(10), true}, 456 {BindVarCond{"a", true, true, QRGreaterEqual, bvcint64(10)}, sqltypes.Int64BindVariable(11), true}, 457 {BindVarCond{"a", true, true, QRGreaterEqual, bvcint64(10)}, sqltypes.Uint64BindVariable(0xFFFFFFFFFFFFFFFF), true}, 458 459 {BindVarCond{"a", true, true, QRGreaterThan, bvcint64(10)}, sqltypes.Int64BindVariable(1), false}, 460 {BindVarCond{"a", true, true, QRGreaterThan, bvcint64(10)}, sqltypes.Int64BindVariable(10), false}, 461 {BindVarCond{"a", true, true, QRGreaterThan, bvcint64(10)}, sqltypes.Int64BindVariable(11), true}, 462 {BindVarCond{"a", true, true, QRGreaterThan, bvcint64(10)}, sqltypes.Uint64BindVariable(0xFFFFFFFFFFFFFFFF), true}, 463 464 {BindVarCond{"a", true, true, QRLessEqual, bvcint64(10)}, sqltypes.Int64BindVariable(1), true}, 465 {BindVarCond{"a", true, true, QRLessEqual, bvcint64(10)}, sqltypes.Int64BindVariable(10), true}, 466 {BindVarCond{"a", true, true, QRLessEqual, bvcint64(10)}, sqltypes.Int64BindVariable(11), false}, 467 {BindVarCond{"a", true, true, QRLessEqual, bvcint64(10)}, sqltypes.Uint64BindVariable(0xFFFFFFFFFFFFFFFF), false}, 468 469 {BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.StringBindVariable("a"), false}, 470 {BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.StringBindVariable("b"), true}, 471 {BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.StringBindVariable("c"), false}, 472 {BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.BytesBindVariable([]byte("a")), false}, 473 {BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.BytesBindVariable([]byte("b")), true}, 474 {BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.BytesBindVariable([]byte("c")), false}, 475 {BindVarCond{"a", true, true, QREqual, bvcstring("b")}, sqltypes.Int64BindVariable(1), false}, 476 477 {BindVarCond{"a", true, true, QRNotEqual, bvcstring("b")}, sqltypes.StringBindVariable("a"), true}, 478 {BindVarCond{"a", true, true, QRNotEqual, bvcstring("b")}, sqltypes.StringBindVariable("b"), false}, 479 {BindVarCond{"a", true, true, QRNotEqual, bvcstring("b")}, sqltypes.StringBindVariable("c"), true}, 480 481 {BindVarCond{"a", true, true, QRLessThan, bvcstring("b")}, sqltypes.StringBindVariable("a"), true}, 482 {BindVarCond{"a", true, true, QRLessThan, bvcstring("b")}, sqltypes.StringBindVariable("b"), false}, 483 {BindVarCond{"a", true, true, QRLessThan, bvcstring("b")}, sqltypes.StringBindVariable("c"), false}, 484 485 {BindVarCond{"a", true, true, QRGreaterEqual, bvcstring("b")}, sqltypes.StringBindVariable("a"), false}, 486 {BindVarCond{"a", true, true, QRGreaterEqual, bvcstring("b")}, sqltypes.StringBindVariable("b"), true}, 487 {BindVarCond{"a", true, true, QRGreaterEqual, bvcstring("b")}, sqltypes.StringBindVariable("c"), true}, 488 489 {BindVarCond{"a", true, true, QRGreaterThan, bvcstring("b")}, sqltypes.StringBindVariable("a"), false}, 490 {BindVarCond{"a", true, true, QRGreaterThan, bvcstring("b")}, sqltypes.StringBindVariable("b"), false}, 491 {BindVarCond{"a", true, true, QRGreaterThan, bvcstring("b")}, sqltypes.StringBindVariable("c"), true}, 492 493 {BindVarCond{"a", true, true, QRLessEqual, bvcstring("b")}, sqltypes.StringBindVariable("a"), true}, 494 {BindVarCond{"a", true, true, QRLessEqual, bvcstring("b")}, sqltypes.StringBindVariable("b"), true}, 495 {BindVarCond{"a", true, true, QRLessEqual, bvcstring("b")}, sqltypes.StringBindVariable("c"), false}, 496 497 {BindVarCond{"a", true, true, QRMatch, makere("a.*")}, sqltypes.StringBindVariable("c"), false}, 498 {BindVarCond{"a", true, true, QRMatch, makere("a.*")}, sqltypes.StringBindVariable("a"), true}, 499 {BindVarCond{"a", true, true, QRMatch, makere("a.*")}, sqltypes.Int64BindVariable(1), false}, 500 501 {BindVarCond{"a", true, true, QRNoMatch, makere("a.*")}, sqltypes.StringBindVariable("c"), true}, 502 {BindVarCond{"a", true, true, QRNoMatch, makere("a.*")}, sqltypes.StringBindVariable("a"), false}, 503 {BindVarCond{"a", true, true, QRNoMatch, makere("a.*")}, sqltypes.Int64BindVariable(1), true}, 504 } 505 506 func makere(s string) bvcre { 507 re, _ := regexp.Compile(s) 508 return bvcre{re} 509 } 510 511 func TestBVConditions(t *testing.T) { 512 bv := make(map[string]*querypb.BindVariable) 513 for _, tcase := range bvtestcases { 514 bv["a"] = tcase.bvval 515 if bvMatch(tcase.bvc, bv) != tcase.expected { 516 t.Errorf("bvmatch(%+v, %v): %v, want %v", tcase.bvc, tcase.bvval, !tcase.expected, tcase.expected) 517 } 518 } 519 } 520 521 func TestAction(t *testing.T) { 522 qrs := New() 523 524 qr1 := NewQueryRule("rule 1", "r1", QRFail) 525 qr1.SetIPCond("123") 526 527 qr2 := NewQueryRule("rule 2", "r2", QRFailRetry) 528 qr2.SetUserCond("user") 529 530 qr3 := NewQueryRule("rule 3", "r3", QRFail) 531 qr3.AddBindVarCond("a", true, true, QREqual, uint64(1)) 532 533 qrs.Add(qr1) 534 qrs.Add(qr2) 535 qrs.Add(qr3) 536 537 bv := make(map[string]*querypb.BindVariable) 538 bv["a"] = sqltypes.Uint64BindVariable(0) 539 540 mc := sqlparser.MarginComments{ 541 Leading: "some comments leading the query", 542 Trailing: "other trailing comments", 543 } 544 545 action, cancelCtx, desc := qrs.GetAction("123", "user1", bv, mc) 546 assert.Equalf(t, action, QRFail, "expected fail, got %v", action) 547 assert.Equalf(t, desc, "rule 1", "want rule 1, got %s", desc) 548 assert.Nil(t, cancelCtx) 549 550 action, cancelCtx, desc = qrs.GetAction("1234", "user", bv, mc) 551 assert.Equalf(t, action, QRFailRetry, "want fail_retry, got: %s", action) 552 assert.Equalf(t, desc, "rule 2", "want rule 2, got %s", desc) 553 assert.Nil(t, cancelCtx) 554 555 action, _, _ = qrs.GetAction("1234", "user1", bv, mc) 556 assert.Equalf(t, action, QRContinue, "want continue, got %s", action) 557 558 bv["a"] = sqltypes.Uint64BindVariable(1) 559 action, _, desc = qrs.GetAction("1234", "user1", bv, mc) 560 assert.Equalf(t, action, QRFail, "want fail, got %s", action) 561 assert.Equalf(t, desc, "rule 3", "want rule 3, got %s", desc) 562 563 // reset bound variable 'a' to 0 so it doesn't match rule 3 564 bv["a"] = sqltypes.Uint64BindVariable(0) 565 566 qr4 := NewQueryRule("rule 4", "r4", QRFail) 567 qr4.SetTrailingCommentCond(".*trailing.*") 568 569 newQrs := qrs.Copy() 570 newQrs.Add(qr4) 571 572 action, _, desc = newQrs.GetAction("1234", "user1", bv, mc) 573 assert.Equalf(t, action, QRFail, "want fail, got %s", action) 574 assert.Equalf(t, desc, "rule 4", "want rule 4, got %s", desc) 575 576 qr5 := NewQueryRule("rule 5", "r4", QRFail) 577 qr5.SetLeadingCommentCond(".*leading.*") 578 579 newQrs = qrs.Copy() 580 newQrs.Add(qr5) 581 action, _, desc = newQrs.GetAction("1234", "user1", bv, mc) 582 assert.Equalf(t, action, QRFail, "want fail, got %s", action) 583 assert.Equalf(t, desc, "rule 5", "want rule 5, got %s", desc) 584 } 585 586 func TestImport(t *testing.T) { 587 var qrs = New() 588 jsondata := `[{ 589 "Description": "desc1", 590 "Name": "name1", 591 "RequestIP": "123.123.123", 592 "User": "user", 593 "Query": "query", 594 "Plans": ["Select", "Insert"], 595 "TableNames":["a", "b"], 596 "BindVarConds": [{ 597 "Name": "bvname1", 598 "OnAbsent": true, 599 "Operator": "" 600 },{ 601 "Name": "bvname2", 602 "OnAbsent": true, 603 "OnMismatch": true, 604 "Operator": "==", 605 "Value": 123 606 }], 607 "Action": "FAIL_RETRY" 608 },{ 609 "Description": "desc2", 610 "Name": "name2", 611 "Action": "FAIL" 612 }]` 613 err := qrs.UnmarshalJSON([]byte(jsondata)) 614 if err != nil { 615 t.Error(err) 616 return 617 } 618 got := marshalled(qrs) 619 want := compacted(jsondata) 620 if got != want { 621 t.Errorf("qrs:\n%s, want\n%s", got, want) 622 } 623 } 624 625 type ValidJSONCase struct { 626 input string 627 op Operator 628 typ int 629 } 630 631 const ( 632 UINT = iota 633 INT 634 STR 635 REGEXP 636 ) 637 638 var validjsons = []ValidJSONCase{ 639 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "==", "Value": 18446744073709551615}]}]`, QREqual, UINT}, 640 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "!=", "Value": 18446744073709551615}]}]`, QRNotEqual, UINT}, 641 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "<", "Value": 18446744073709551615}]}]`, QRLessThan, UINT}, 642 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": ">=", "Value": 18446744073709551615}]}]`, QRGreaterEqual, UINT}, 643 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": ">", "Value": 18446744073709551615}]}]`, QRGreaterThan, UINT}, 644 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "<=", "Value": 18446744073709551615}]}]`, QRLessEqual, UINT}, 645 646 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "==", "Value": -123}]}]`, QREqual, INT}, 647 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "!=", "Value": -123}]}]`, QRNotEqual, INT}, 648 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "<", "Value": -123}]}]`, QRLessThan, INT}, 649 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": ">=", "Value": -123}]}]`, QRGreaterEqual, INT}, 650 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": ">", "Value": -123}]}]`, QRGreaterThan, INT}, 651 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "<=", "Value": -123}]}]`, QRLessEqual, INT}, 652 653 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "==", "Value": "123"}]}]`, QREqual, STR}, 654 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "!=", "Value": "123"}]}]`, QRNotEqual, STR}, 655 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "<", "Value": "123"}]}]`, QRLessThan, STR}, 656 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": ">=", "Value": "123"}]}]`, QRGreaterEqual, STR}, 657 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": ">", "Value": "123"}]}]`, QRGreaterThan, STR}, 658 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "<=", "Value": "123"}]}]`, QRLessEqual, STR}, 659 660 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "MATCH", "Value": "123"}]}]`, QRMatch, REGEXP}, 661 {`[{"BindVarConds": [{"Name": "bvname1", "OnAbsent": true, "OnMismatch": true, "Operator": "NOMATCH", "Value": "123"}]}]`, QRNoMatch, REGEXP}, 662 } 663 664 func TestValidJSON(t *testing.T) { 665 for i, tcase := range validjsons { 666 qrs := New() 667 err := qrs.UnmarshalJSON([]byte(tcase.input)) 668 if err != nil { 669 t.Fatalf("Unexpected error for case %d: %v", i, err) 670 } 671 bvc := qrs.rules[0].bindVarConds[0] 672 if bvc.op != tcase.op { 673 t.Errorf("want %v, got %v", tcase.op, bvc.op) 674 } 675 switch tcase.typ { 676 case UINT: 677 if bvc.value.(bvcuint64) != bvcuint64(18446744073709551615) { 678 t.Errorf("want %v, got %v", uint64(18446744073709551615), bvc.value.(bvcuint64)) 679 } 680 case INT: 681 if bvc.value.(bvcint64) != -123 { 682 t.Errorf("want %v, got %v", -123, bvc.value.(bvcint64)) 683 } 684 case STR: 685 if bvc.value.(bvcstring) != "123" { 686 t.Errorf("want %v, got %v", "123", bvc.value.(bvcint64)) 687 } 688 case REGEXP: 689 if bvc.value.(bvcre).re == nil { 690 t.Errorf("want non-nil") 691 } 692 } 693 } 694 } 695 696 type InvalidJSONCase struct { 697 input, err string 698 } 699 700 var invalidjsons = []InvalidJSONCase{ 701 {`[{"Name": 1 }]`, "want string for Name"}, 702 {`[{"Description": 1 }]`, "want string for Description"}, 703 {`[{"RequestIP": 1 }]`, "want string for RequestIP"}, 704 {`[{"User": 1 }]`, "want string for User"}, 705 {`[{"Query": 1 }]`, "want string for Query"}, 706 {`[{"Plans": 1 }]`, "want list for Plans"}, 707 {`[{"TableNames": 1 }]`, "want list for TableNames"}, 708 {`[{"BindVarConds": 1 }]`, "want list for BindVarConds"}, 709 {`[{"RequestIP": "[" }]`, "could not set IP condition: ["}, 710 {`[{"User": "[" }]`, "could not set User condition: ["}, 711 {`[{"Query": "[" }]`, "could not set Query condition: ["}, 712 {`[{"Plans": [1] }]`, "want string for Plans"}, 713 {`[{"Plans": ["invalid"] }]`, "invalid plan name: invalid"}, 714 {`[{"TableNames": [1] }]`, "want string for TableNames"}, 715 {`[{"BindVarConds": [1] }]`, "want json object for bind var conditions"}, 716 {`[{"BindVarConds": [{}] }]`, "Name missing in BindVarConds"}, 717 {`[{"BindVarConds": [{"Name": 1}] }]`, "want string for Name in BindVarConds"}, 718 {`[{"BindVarConds": [{"Name": "a"}] }]`, "OnAbsent missing in BindVarConds"}, 719 {`[{"BindVarConds": [{"Name": "a", "OnAbsent": 1}] }]`, "want bool for OnAbsent"}, 720 {`[{"BindVarConds": [{"Name": "a", "OnAbsent": true}]}]`, "Operator missing in BindVarConds"}, 721 {`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "a"}]}]`, "invalid Operator a"}, 722 {`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "=="}]}]`, "Value missing in BindVarConds"}, 723 {`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "==", "Value": 1.2}]}]`, "want int64/uint64: 1.2"}, 724 {`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "==", "Value": {}}]}]`, "want string or number: map[]"}, 725 {`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "MATCH", "Value": 1}]}]`, "want string: 1"}, 726 {`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "NOMATCH", "Value": 1}]}]`, "want string: 1"}, 727 {`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": 123, "Value": "1"}]}]`, "want string for Operator"}, 728 {`[{"Unknown": [{"Name": "a"}]}]`, "unrecognized tag Unknown"}, 729 {`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "Operator": "<=", "Value": "1"}]}]`, "OnMismatch missing in BindVarConds"}, 730 {`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "OnMismatch": true, "Operator": "MATCH", "Value": "["}]}]`, "processing [: error parsing regexp: missing closing ]: `[$`"}, 731 {`[{"BindVarConds": [{"Name": "a", "OnAbsent": true, "OnMismatch": true, "Operator": "NOMATCH", "Value": "["}]}]`, "processing [: error parsing regexp: missing closing ]: `[$`"}, 732 {`[{"Action": 1 }]`, "want string for Action"}, 733 {`[{"Action": "foo" }]`, "invalid Action foo"}, 734 } 735 736 func TestInvalidJSON(t *testing.T) { 737 for _, tcase := range invalidjsons { 738 qrs := New() 739 err := qrs.UnmarshalJSON([]byte(tcase.input)) 740 if err == nil { 741 t.Errorf("want error for case %q", tcase.input) 742 continue 743 } 744 recvd := strings.Replace(err.Error(), "fatal: ", "", 1) 745 if recvd != tcase.err { 746 t.Errorf("invalid json: %s, want '%v', got '%v'", tcase.input, tcase.err, recvd) 747 } 748 } 749 qrs := New() 750 err := qrs.UnmarshalJSON([]byte(`{`)) 751 if code := vterrors.Code(err); code != vtrpcpb.Code_INVALID_ARGUMENT { 752 t.Errorf("qrs.UnmarshalJSON: %v, want %v", code, vtrpcpb.Code_INVALID_ARGUMENT) 753 } 754 } 755 756 func TestBuildQueryRuleActionFail(t *testing.T) { 757 var ruleInfo map[string]any 758 err := json.Unmarshal([]byte(`{"Action": "FAIL" }`), &ruleInfo) 759 if err != nil { 760 t.Fatalf("failed to unmarshal json, got error: %v", err) 761 } 762 qr, err := BuildQueryRule(ruleInfo) 763 if err != nil { 764 t.Fatalf("build query rule should succeed") 765 } 766 if qr.act != QRFail { 767 t.Fatalf("action should fail") 768 } 769 } 770 771 func TestBadAddBindVarCond(t *testing.T) { 772 qr1 := NewQueryRule("rule 1", "r1", QRFail) 773 err := qr1.AddBindVarCond("a", true, false, QRMatch, uint64(1)) 774 if err == nil { 775 t.Fatalf("invalid op: QRMatch for value type: uint64") 776 } 777 } 778 779 func TestOpNames(t *testing.T) { 780 want := []string{ 781 "", 782 "==", 783 "!=", 784 "<", 785 ">=", 786 ">", 787 "<=", 788 "MATCH", 789 "NOMATCH", 790 } 791 if !reflect.DeepEqual(opnames, want) { 792 t.Errorf("opnames: \n%v, want \n%v", opnames, want) 793 } 794 } 795 796 func compacted(in string) string { 797 dst := bytes.NewBuffer(nil) 798 err := json.Compact(dst, []byte(in)) 799 if err != nil { 800 panic(err) 801 } 802 return dst.String() 803 } 804 805 func marshalled(in any) string { 806 b, err := json.Marshal(in) 807 if err != nil { 808 panic(err) 809 } 810 return string(b) 811 }