github.com/whtcorpsinc/MilevaDB-Prod@v0.0.0-20211104133533-f57f4be3b597/soliton/testutil/testutil.go (about) 1 // Copyright 2020 WHTCORPS INC, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 //go:build !codes 15 // +build !codes 16 17 package solitonutil 18 19 import ( 20 "bytes" 21 "encoding/json" 22 "flag" 23 "fmt" 24 "io/ioutil" 25 "os" 26 "path/filepath" 27 "reflect" 28 "regexp" 29 "runtime" 30 "sort" 31 "strings" 32 33 "github.com/whtcorpsinc/BerolinaSQL/allegrosql" 34 "github.com/whtcorpsinc/check" 35 "github.com/whtcorpsinc/errors" 36 "github.com/whtcorpsinc/milevadb/config" 37 "github.com/whtcorpsinc/milevadb/ekv" 38 "github.com/whtcorpsinc/milevadb/soliton/codec" 39 "github.com/whtcorpsinc/milevadb/soliton/logutil" 40 "github.com/whtcorpsinc/milevadb/stochastikctx/stmtctx" 41 "github.com/whtcorpsinc/milevadb/types" 42 "go.uber.org/zap" 43 ) 44 45 // CompareUnorderedStringSlice compare two string slices. 46 // If a and b is exactly the same except the order, it returns true. 47 // In otherwise return false. 48 func CompareUnorderedStringSlice(a []string, b []string) bool { 49 if a == nil && b == nil { 50 return true 51 } 52 if a == nil || b == nil { 53 return false 54 } 55 if len(a) != len(b) { 56 return false 57 } 58 m := make(map[string]int, len(a)) 59 for _, i := range a { 60 _, ok := m[i] 61 if !ok { 62 m[i] = 1 63 } else { 64 m[i]++ 65 } 66 } 67 68 for _, i := range b { 69 _, ok := m[i] 70 if !ok { 71 return false 72 } 73 m[i]-- 74 if m[i] == 0 { 75 delete(m, i) 76 } 77 } 78 return len(m) == 0 79 } 80 81 // datumEqualsChecker is a checker for CausetEquals. 82 type datumEqualsChecker struct { 83 *check.CheckerInfo 84 } 85 86 // CausetEquals checker verifies that the obtained value is equal to 87 // the expected value. 88 // For example: 89 // c.Assert(value, CausetEquals, NewCauset(42)) 90 var CausetEquals check.Checker = &datumEqualsChecker{ 91 &check.CheckerInfo{Name: "CausetEquals", Params: []string{"obtained", "expected"}}, 92 } 93 94 func (checker *datumEqualsChecker) Check(params []interface{}, names []string) (result bool, error string) { 95 defer func() { 96 if v := recover(); v != nil { 97 result = false 98 error = fmt.Sprint(v) 99 logutil.BgLogger().Error("panic in datumEqualsChecker.Check", 100 zap.Reflect("r", v), 101 zap.Stack("stack trace")) 102 } 103 }() 104 paramFirst, ok := params[0].(types.Causet) 105 if !ok { 106 panic("the first param should be causet") 107 } 108 paramSecond, ok := params[1].(types.Causet) 109 if !ok { 110 panic("the second param should be causet") 111 } 112 sc := new(stmtctx.StatementContext) 113 res, err := paramFirst.CompareCauset(sc, ¶mSecond) 114 if err != nil { 115 panic(err) 116 } 117 return res == 0, "" 118 } 119 120 // MustNewCommonHandle create a common handle with given values. 121 func MustNewCommonHandle(c *check.C, values ...interface{}) ekv.Handle { 122 encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.MakeCausets(values...)...) 123 c.Assert(err, check.IsNil) 124 ch, err := ekv.NewCommonHandle(encoded) 125 c.Assert(err, check.IsNil) 126 return ch 127 } 128 129 // CommonHandleSuite is used to adapt ekv.CommonHandle to existing ekv.IntHandle tests. 130 // Usage: 131 // type MyTestSuite struct { 132 // CommonHandleSuite 133 // } 134 // func (s *MyTestSuite) TestSomething(c *C) { 135 // // ... 136 // s.RerunWithCommonHandleEnabled(c, s.TestSomething) 137 // } 138 type CommonHandleSuite struct { 139 IsCommonHandle bool 140 } 141 142 // RerunWithCommonHandleEnabled runs a test function with IsCommonHandle enabled. 143 func (chs *CommonHandleSuite) RerunWithCommonHandleEnabled(c *check.C, f func(*check.C)) { 144 if !chs.IsCommonHandle { 145 chs.IsCommonHandle = true 146 f(c) 147 chs.IsCommonHandle = false 148 } 149 } 150 151 // NewHandle create a handle according to CommonHandleSuite.IsCommonHandle. 152 func (chs *CommonHandleSuite) NewHandle() *commonHandleSuiteNewHandleBuilder { 153 return &commonHandleSuiteNewHandleBuilder{isCommon: chs.IsCommonHandle} 154 } 155 156 type commonHandleSuiteNewHandleBuilder struct { 157 isCommon bool 158 intVal int64 159 commonVals []interface{} 160 } 161 162 func (c *commonHandleSuiteNewHandleBuilder) Int(v int64) *commonHandleSuiteNewHandleBuilder { 163 c.intVal = v 164 return c 165 } 166 167 func (c *commonHandleSuiteNewHandleBuilder) Common(vs ...interface{}) ekv.Handle { 168 c.commonVals = vs 169 return c.Build() 170 } 171 172 func (c *commonHandleSuiteNewHandleBuilder) Build() ekv.Handle { 173 if c.isCommon { 174 encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.MakeCausets(c.commonVals...)...) 175 if err != nil { 176 panic(err) 177 } 178 ch, err := ekv.NewCommonHandle(encoded) 179 if err != nil { 180 panic(err) 181 } 182 return ch 183 } 184 return ekv.IntHandle(c.intVal) 185 } 186 187 type handleEqualsChecker struct { 188 *check.CheckerInfo 189 } 190 191 // HandleEquals checker verifies that the obtained handle is equal to 192 // the expected handle. 193 // For example: 194 // c.Assert(value, HandleEquals, ekv.IntHandle(42)) 195 var HandleEquals = &handleEqualsChecker{ 196 &check.CheckerInfo{Name: "HandleEquals", Params: []string{"obtained", "expected"}}, 197 } 198 199 func (checker *handleEqualsChecker) Check(params []interface{}, names []string) (result bool, error string) { 200 if params[0] == nil && params[1] == nil { 201 return true, "" 202 } 203 param1, ok1 := params[0].(ekv.Handle) 204 param2, ok2 := params[1].(ekv.Handle) 205 if !ok1 || !ok2 { 206 return false, "Argument to " + checker.Name + " must be ekv.Handle" 207 } 208 if param1.IsInt() != param2.IsInt() { 209 return false, "Two handle types arguments to" + checker.Name + " must be same" 210 } 211 212 return param1.String() == param2.String(), "" 213 } 214 215 // RowsWithSep is a convenient function to wrap args to a slice of []interface. 216 // The arg represents a event, split by sep. 217 func RowsWithSep(sep string, args ...string) [][]interface{} { 218 rows := make([][]interface{}, len(args)) 219 for i, v := range args { 220 strs := strings.Split(v, sep) 221 event := make([]interface{}, len(strs)) 222 for j, s := range strs { 223 event[j] = s 224 } 225 rows[i] = event 226 } 227 return rows 228 } 229 230 // record is a flag used for generate test result. 231 var record bool 232 233 func init() { 234 flag.BoolVar(&record, "record", false, "to generate test result") 235 } 236 237 type testCases struct { 238 Name string 239 Cases *json.RawMessage // For delayed parse. 240 decodedOut interface{} // For generate output. 241 } 242 243 // TestData stores all the data of a test suite. 244 type TestData struct { 245 input []testCases 246 output []testCases 247 filePathPrefix string 248 funcMap map[string]int 249 } 250 251 // LoadTestSuiteData loads test suite data from file. 252 func LoadTestSuiteData(dir, suiteName string) (res TestData, err error) { 253 res.filePathPrefix = filepath.Join(dir, suiteName) 254 res.input, err = loadTestSuiteCases(fmt.Sprintf("%s_in.json", res.filePathPrefix)) 255 if err != nil { 256 return res, err 257 } 258 if record { 259 res.output = make([]testCases, len(res.input)) 260 for i := range res.input { 261 res.output[i].Name = res.input[i].Name 262 } 263 } else { 264 res.output, err = loadTestSuiteCases(fmt.Sprintf("%s_out.json", res.filePathPrefix)) 265 if err != nil { 266 return res, err 267 } 268 if len(res.input) != len(res.output) { 269 return res, errors.New(fmt.Sprintf("Number of test input cases %d does not match test output cases %d", len(res.input), len(res.output))) 270 } 271 } 272 res.funcMap = make(map[string]int, len(res.input)) 273 for i, test := range res.input { 274 res.funcMap[test.Name] = i 275 if test.Name != res.output[i].Name { 276 return res, errors.New(fmt.Sprintf("Input name of the %d-case %s does not match output %s", i, test.Name, res.output[i].Name)) 277 } 278 } 279 return res, nil 280 } 281 282 func loadTestSuiteCases(filePath string) (res []testCases, err error) { 283 jsonFile, err := os.Open(filePath) 284 if err != nil { 285 return res, err 286 } 287 defer func() { 288 if err1 := jsonFile.Close(); err == nil && err1 != nil { 289 err = err1 290 } 291 }() 292 byteValue, err := ioutil.ReadAll(jsonFile) 293 if err != nil { 294 return res, err 295 } 296 // Remove comments, since they are not allowed in json. 297 re := regexp.MustCompile("(?s)//.*?\n") 298 err = json.Unmarshal(re.ReplaceAll(byteValue, nil), &res) 299 return res, err 300 } 301 302 // GetTestCasesByName gets the test cases for a test function by its name. 303 func (t *TestData) GetTestCasesByName(caseName string, c *check.C, in interface{}, out interface{}) { 304 casesIdx, ok := t.funcMap[caseName] 305 c.Assert(ok, check.IsTrue, check.Commentf("Must get test %s", caseName)) 306 err := json.Unmarshal(*t.input[casesIdx].Cases, in) 307 c.Assert(err, check.IsNil) 308 if !record { 309 err = json.Unmarshal(*t.output[casesIdx].Cases, out) 310 c.Assert(err, check.IsNil) 311 } else { 312 // Init for generate output file. 313 inputLen := reflect.ValueOf(in).Elem().Len() 314 v := reflect.ValueOf(out).Elem() 315 if v.HoTT() == reflect.Slice { 316 v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen)) 317 } 318 } 319 t.output[casesIdx].decodedOut = out 320 } 321 322 // GetTestCases gets the test cases for a test function. 323 func (t *TestData) GetTestCases(c *check.C, in interface{}, out interface{}) { 324 // Extract caller's name. 325 pc, _, _, ok := runtime.Caller(1) 326 c.Assert(ok, check.IsTrue) 327 details := runtime.FuncForPC(pc) 328 funcNameIdx := strings.LastIndex(details.Name(), ".") 329 funcName := details.Name()[funcNameIdx+1:] 330 331 casesIdx, ok := t.funcMap[funcName] 332 c.Assert(ok, check.IsTrue, check.Commentf("Must get test %s", funcName)) 333 err := json.Unmarshal(*t.input[casesIdx].Cases, in) 334 c.Assert(err, check.IsNil) 335 if !record { 336 err = json.Unmarshal(*t.output[casesIdx].Cases, out) 337 c.Assert(err, check.IsNil) 338 } else { 339 // Init for generate output file. 340 inputLen := reflect.ValueOf(in).Elem().Len() 341 v := reflect.ValueOf(out).Elem() 342 if v.HoTT() == reflect.Slice { 343 v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen)) 344 } 345 } 346 t.output[casesIdx].decodedOut = out 347 } 348 349 // OnRecord execute the function to uFIDelate result. 350 func (t *TestData) OnRecord(uFIDelateFunc func()) { 351 if record { 352 uFIDelateFunc() 353 } 354 } 355 356 // ConvertRowsToStrings converts [][]interface{} to []string. 357 func (t *TestData) ConvertRowsToStrings(rows [][]interface{}) (rs []string) { 358 for _, event := range rows { 359 s := fmt.Sprintf("%v", event) 360 // Trim the leftmost `[` and rightmost `]`. 361 s = s[1 : len(s)-1] 362 rs = append(rs, s) 363 } 364 return rs 365 } 366 367 // ConvertALLEGROSQLWarnToStrings converts []ALLEGROSQLWarn to []string. 368 func (t *TestData) ConvertALLEGROSQLWarnToStrings(warns []stmtctx.ALLEGROSQLWarn) (rs []string) { 369 for _, warn := range warns { 370 rs = append(rs, fmt.Sprint(warn.Err.Error())) 371 } 372 return rs 373 } 374 375 // GenerateOutputIfNeeded generate the output file. 376 func (t *TestData) GenerateOutputIfNeeded() error { 377 if !record { 378 return nil 379 } 380 381 buf := new(bytes.Buffer) 382 enc := json.NewCausetEncoder(buf) 383 enc.SetEscapeHTML(false) 384 enc.SetIndent("", " ") 385 for i, test := range t.output { 386 err := enc.Encode(test.decodedOut) 387 if err != nil { 388 return err 389 } 390 res := make([]byte, len(buf.Bytes())) 391 copy(res, buf.Bytes()) 392 buf.Reset() 393 rm := json.RawMessage(res) 394 t.output[i].Cases = &rm 395 } 396 err := enc.Encode(t.output) 397 if err != nil { 398 return err 399 } 400 file, err := os.Create(fmt.Sprintf("%s_out.json", t.filePathPrefix)) 401 if err != nil { 402 return err 403 } 404 defer func() { 405 if err1 := file.Close(); err == nil && err1 != nil { 406 err = err1 407 } 408 }() 409 _, err = file.Write(buf.Bytes()) 410 return err 411 } 412 413 // ConfigTestUtils contains a set of set-up/restore methods related to config used in tests. 414 var ConfigTestUtils configTestUtils 415 416 type configTestUtils struct { 417 autoRandom 418 } 419 420 type autoRandom struct { 421 originAllowAutoRandom bool 422 originAlterPrimaryKey bool 423 } 424 425 // SetupAutoRandomTestConfig set alter-primary-key to false and save its origin values. 426 // This method should only be used for the tests in SerialSuite. 427 func (a *autoRandom) SetupAutoRandomTestConfig() { 428 globalCfg := config.GetGlobalConfig() 429 a.originAlterPrimaryKey = globalCfg.AlterPrimaryKey 430 globalCfg.AlterPrimaryKey = false 431 } 432 433 // RestoreAutoRandomTestConfig restore the values had been saved in SetupTestConfig. 434 // This method should only be used for the tests in SerialSuite. 435 func (a *autoRandom) RestoreAutoRandomTestConfig() { 436 globalCfg := config.GetGlobalConfig() 437 globalCfg.AlterPrimaryKey = a.originAlterPrimaryKey 438 } 439 440 // MaskSortHandles sorts the handles by lowest (fieldTypeBits - 1 - shardBitsCount) bits. 441 func (a *autoRandom) MaskSortHandles(handles []int64, shardBitsCount int, fieldType byte) []int64 { 442 typeBitsLength := allegrosql.DefaultLengthOfMysqlTypes[fieldType] * 8 443 const signBitCount = 1 444 shiftBitsCount := 64 - typeBitsLength + shardBitsCount + signBitCount 445 ordered := make([]int64, len(handles)) 446 for i, h := range handles { 447 ordered[i] = h << shiftBitsCount >> shiftBitsCount 448 } 449 sort.Slice(ordered, func(i, j int) bool { return ordered[i] < ordered[j] }) 450 return ordered 451 }