github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/soliton/testutil/testutil.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  //go:build !codes
    15  // +build !codes
    16  
    17  package solitonutil
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/json"
    22  	"flag"
    23  	"fmt"
    24  	"io/ioutil"
    25  	"os"
    26  	"path/filepath"
    27  	"reflect"
    28  	"regexp"
    29  	"runtime"
    30  	"sort"
    31  	"strings"
    32  
    33  	"github.com/whtcorpsinc/BerolinaSQL/allegrosql"
    34  	"github.com/whtcorpsinc/check"
    35  	"github.com/whtcorpsinc/errors"
    36  	"github.com/whtcorpsinc/milevadb/config"
    37  	"github.com/whtcorpsinc/milevadb/ekv"
    38  	"github.com/whtcorpsinc/milevadb/soliton/codec"
    39  	"github.com/whtcorpsinc/milevadb/soliton/logutil"
    40  	"github.com/whtcorpsinc/milevadb/stochastikctx/stmtctx"
    41  	"github.com/whtcorpsinc/milevadb/types"
    42  	"go.uber.org/zap"
    43  )
    44  
    45  // CompareUnorderedStringSlice compare two string slices.
    46  // If a and b is exactly the same except the order, it returns true.
    47  // In otherwise return false.
    48  func CompareUnorderedStringSlice(a []string, b []string) bool {
    49  	if a == nil && b == nil {
    50  		return true
    51  	}
    52  	if a == nil || b == nil {
    53  		return false
    54  	}
    55  	if len(a) != len(b) {
    56  		return false
    57  	}
    58  	m := make(map[string]int, len(a))
    59  	for _, i := range a {
    60  		_, ok := m[i]
    61  		if !ok {
    62  			m[i] = 1
    63  		} else {
    64  			m[i]++
    65  		}
    66  	}
    67  
    68  	for _, i := range b {
    69  		_, ok := m[i]
    70  		if !ok {
    71  			return false
    72  		}
    73  		m[i]--
    74  		if m[i] == 0 {
    75  			delete(m, i)
    76  		}
    77  	}
    78  	return len(m) == 0
    79  }
    80  
    81  // datumEqualsChecker is a checker for CausetEquals.
    82  type datumEqualsChecker struct {
    83  	*check.CheckerInfo
    84  }
    85  
    86  // CausetEquals checker verifies that the obtained value is equal to
    87  // the expected value.
    88  // For example:
    89  //     c.Assert(value, CausetEquals, NewCauset(42))
    90  var CausetEquals check.Checker = &datumEqualsChecker{
    91  	&check.CheckerInfo{Name: "CausetEquals", Params: []string{"obtained", "expected"}},
    92  }
    93  
    94  func (checker *datumEqualsChecker) Check(params []interface{}, names []string) (result bool, error string) {
    95  	defer func() {
    96  		if v := recover(); v != nil {
    97  			result = false
    98  			error = fmt.Sprint(v)
    99  			logutil.BgLogger().Error("panic in datumEqualsChecker.Check",
   100  				zap.Reflect("r", v),
   101  				zap.Stack("stack trace"))
   102  		}
   103  	}()
   104  	paramFirst, ok := params[0].(types.Causet)
   105  	if !ok {
   106  		panic("the first param should be causet")
   107  	}
   108  	paramSecond, ok := params[1].(types.Causet)
   109  	if !ok {
   110  		panic("the second param should be causet")
   111  	}
   112  	sc := new(stmtctx.StatementContext)
   113  	res, err := paramFirst.CompareCauset(sc, &paramSecond)
   114  	if err != nil {
   115  		panic(err)
   116  	}
   117  	return res == 0, ""
   118  }
   119  
   120  // MustNewCommonHandle create a common handle with given values.
   121  func MustNewCommonHandle(c *check.C, values ...interface{}) ekv.Handle {
   122  	encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.MakeCausets(values...)...)
   123  	c.Assert(err, check.IsNil)
   124  	ch, err := ekv.NewCommonHandle(encoded)
   125  	c.Assert(err, check.IsNil)
   126  	return ch
   127  }
   128  
   129  // CommonHandleSuite is used to adapt ekv.CommonHandle to existing ekv.IntHandle tests.
   130  //  Usage:
   131  //   type MyTestSuite struct {
   132  //       CommonHandleSuite
   133  //   }
   134  //   func (s *MyTestSuite) TestSomething(c *C) {
   135  //       // ...
   136  //       s.RerunWithCommonHandleEnabled(c, s.TestSomething)
   137  //   }
   138  type CommonHandleSuite struct {
   139  	IsCommonHandle bool
   140  }
   141  
   142  // RerunWithCommonHandleEnabled runs a test function with IsCommonHandle enabled.
   143  func (chs *CommonHandleSuite) RerunWithCommonHandleEnabled(c *check.C, f func(*check.C)) {
   144  	if !chs.IsCommonHandle {
   145  		chs.IsCommonHandle = true
   146  		f(c)
   147  		chs.IsCommonHandle = false
   148  	}
   149  }
   150  
   151  // NewHandle create a handle according to CommonHandleSuite.IsCommonHandle.
   152  func (chs *CommonHandleSuite) NewHandle() *commonHandleSuiteNewHandleBuilder {
   153  	return &commonHandleSuiteNewHandleBuilder{isCommon: chs.IsCommonHandle}
   154  }
   155  
   156  type commonHandleSuiteNewHandleBuilder struct {
   157  	isCommon   bool
   158  	intVal     int64
   159  	commonVals []interface{}
   160  }
   161  
   162  func (c *commonHandleSuiteNewHandleBuilder) Int(v int64) *commonHandleSuiteNewHandleBuilder {
   163  	c.intVal = v
   164  	return c
   165  }
   166  
   167  func (c *commonHandleSuiteNewHandleBuilder) Common(vs ...interface{}) ekv.Handle {
   168  	c.commonVals = vs
   169  	return c.Build()
   170  }
   171  
   172  func (c *commonHandleSuiteNewHandleBuilder) Build() ekv.Handle {
   173  	if c.isCommon {
   174  		encoded, err := codec.EncodeKey(new(stmtctx.StatementContext), nil, types.MakeCausets(c.commonVals...)...)
   175  		if err != nil {
   176  			panic(err)
   177  		}
   178  		ch, err := ekv.NewCommonHandle(encoded)
   179  		if err != nil {
   180  			panic(err)
   181  		}
   182  		return ch
   183  	}
   184  	return ekv.IntHandle(c.intVal)
   185  }
   186  
   187  type handleEqualsChecker struct {
   188  	*check.CheckerInfo
   189  }
   190  
   191  // HandleEquals checker verifies that the obtained handle is equal to
   192  // the expected handle.
   193  // For example:
   194  //     c.Assert(value, HandleEquals, ekv.IntHandle(42))
   195  var HandleEquals = &handleEqualsChecker{
   196  	&check.CheckerInfo{Name: "HandleEquals", Params: []string{"obtained", "expected"}},
   197  }
   198  
   199  func (checker *handleEqualsChecker) Check(params []interface{}, names []string) (result bool, error string) {
   200  	if params[0] == nil && params[1] == nil {
   201  		return true, ""
   202  	}
   203  	param1, ok1 := params[0].(ekv.Handle)
   204  	param2, ok2 := params[1].(ekv.Handle)
   205  	if !ok1 || !ok2 {
   206  		return false, "Argument to " + checker.Name + " must be ekv.Handle"
   207  	}
   208  	if param1.IsInt() != param2.IsInt() {
   209  		return false, "Two handle types arguments to" + checker.Name + " must be same"
   210  	}
   211  
   212  	return param1.String() == param2.String(), ""
   213  }
   214  
   215  // RowsWithSep is a convenient function to wrap args to a slice of []interface.
   216  // The arg represents a event, split by sep.
   217  func RowsWithSep(sep string, args ...string) [][]interface{} {
   218  	rows := make([][]interface{}, len(args))
   219  	for i, v := range args {
   220  		strs := strings.Split(v, sep)
   221  		event := make([]interface{}, len(strs))
   222  		for j, s := range strs {
   223  			event[j] = s
   224  		}
   225  		rows[i] = event
   226  	}
   227  	return rows
   228  }
   229  
   230  // record is a flag used for generate test result.
   231  var record bool
   232  
   233  func init() {
   234  	flag.BoolVar(&record, "record", false, "to generate test result")
   235  }
   236  
   237  type testCases struct {
   238  	Name       string
   239  	Cases      *json.RawMessage // For delayed parse.
   240  	decodedOut interface{}      // For generate output.
   241  }
   242  
   243  // TestData stores all the data of a test suite.
   244  type TestData struct {
   245  	input          []testCases
   246  	output         []testCases
   247  	filePathPrefix string
   248  	funcMap        map[string]int
   249  }
   250  
   251  // LoadTestSuiteData loads test suite data from file.
   252  func LoadTestSuiteData(dir, suiteName string) (res TestData, err error) {
   253  	res.filePathPrefix = filepath.Join(dir, suiteName)
   254  	res.input, err = loadTestSuiteCases(fmt.Sprintf("%s_in.json", res.filePathPrefix))
   255  	if err != nil {
   256  		return res, err
   257  	}
   258  	if record {
   259  		res.output = make([]testCases, len(res.input))
   260  		for i := range res.input {
   261  			res.output[i].Name = res.input[i].Name
   262  		}
   263  	} else {
   264  		res.output, err = loadTestSuiteCases(fmt.Sprintf("%s_out.json", res.filePathPrefix))
   265  		if err != nil {
   266  			return res, err
   267  		}
   268  		if len(res.input) != len(res.output) {
   269  			return res, errors.New(fmt.Sprintf("Number of test input cases %d does not match test output cases %d", len(res.input), len(res.output)))
   270  		}
   271  	}
   272  	res.funcMap = make(map[string]int, len(res.input))
   273  	for i, test := range res.input {
   274  		res.funcMap[test.Name] = i
   275  		if test.Name != res.output[i].Name {
   276  			return res, errors.New(fmt.Sprintf("Input name of the %d-case %s does not match output %s", i, test.Name, res.output[i].Name))
   277  		}
   278  	}
   279  	return res, nil
   280  }
   281  
   282  func loadTestSuiteCases(filePath string) (res []testCases, err error) {
   283  	jsonFile, err := os.Open(filePath)
   284  	if err != nil {
   285  		return res, err
   286  	}
   287  	defer func() {
   288  		if err1 := jsonFile.Close(); err == nil && err1 != nil {
   289  			err = err1
   290  		}
   291  	}()
   292  	byteValue, err := ioutil.ReadAll(jsonFile)
   293  	if err != nil {
   294  		return res, err
   295  	}
   296  	// Remove comments, since they are not allowed in json.
   297  	re := regexp.MustCompile("(?s)//.*?\n")
   298  	err = json.Unmarshal(re.ReplaceAll(byteValue, nil), &res)
   299  	return res, err
   300  }
   301  
   302  // GetTestCasesByName gets the test cases for a test function by its name.
   303  func (t *TestData) GetTestCasesByName(caseName string, c *check.C, in interface{}, out interface{}) {
   304  	casesIdx, ok := t.funcMap[caseName]
   305  	c.Assert(ok, check.IsTrue, check.Commentf("Must get test %s", caseName))
   306  	err := json.Unmarshal(*t.input[casesIdx].Cases, in)
   307  	c.Assert(err, check.IsNil)
   308  	if !record {
   309  		err = json.Unmarshal(*t.output[casesIdx].Cases, out)
   310  		c.Assert(err, check.IsNil)
   311  	} else {
   312  		// Init for generate output file.
   313  		inputLen := reflect.ValueOf(in).Elem().Len()
   314  		v := reflect.ValueOf(out).Elem()
   315  		if v.HoTT() == reflect.Slice {
   316  			v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen))
   317  		}
   318  	}
   319  	t.output[casesIdx].decodedOut = out
   320  }
   321  
   322  // GetTestCases gets the test cases for a test function.
   323  func (t *TestData) GetTestCases(c *check.C, in interface{}, out interface{}) {
   324  	// Extract caller's name.
   325  	pc, _, _, ok := runtime.Caller(1)
   326  	c.Assert(ok, check.IsTrue)
   327  	details := runtime.FuncForPC(pc)
   328  	funcNameIdx := strings.LastIndex(details.Name(), ".")
   329  	funcName := details.Name()[funcNameIdx+1:]
   330  
   331  	casesIdx, ok := t.funcMap[funcName]
   332  	c.Assert(ok, check.IsTrue, check.Commentf("Must get test %s", funcName))
   333  	err := json.Unmarshal(*t.input[casesIdx].Cases, in)
   334  	c.Assert(err, check.IsNil)
   335  	if !record {
   336  		err = json.Unmarshal(*t.output[casesIdx].Cases, out)
   337  		c.Assert(err, check.IsNil)
   338  	} else {
   339  		// Init for generate output file.
   340  		inputLen := reflect.ValueOf(in).Elem().Len()
   341  		v := reflect.ValueOf(out).Elem()
   342  		if v.HoTT() == reflect.Slice {
   343  			v.Set(reflect.MakeSlice(v.Type(), inputLen, inputLen))
   344  		}
   345  	}
   346  	t.output[casesIdx].decodedOut = out
   347  }
   348  
   349  // OnRecord execute the function to uFIDelate result.
   350  func (t *TestData) OnRecord(uFIDelateFunc func()) {
   351  	if record {
   352  		uFIDelateFunc()
   353  	}
   354  }
   355  
   356  // ConvertRowsToStrings converts [][]interface{} to []string.
   357  func (t *TestData) ConvertRowsToStrings(rows [][]interface{}) (rs []string) {
   358  	for _, event := range rows {
   359  		s := fmt.Sprintf("%v", event)
   360  		// Trim the leftmost `[` and rightmost `]`.
   361  		s = s[1 : len(s)-1]
   362  		rs = append(rs, s)
   363  	}
   364  	return rs
   365  }
   366  
   367  // ConvertALLEGROSQLWarnToStrings converts []ALLEGROSQLWarn to []string.
   368  func (t *TestData) ConvertALLEGROSQLWarnToStrings(warns []stmtctx.ALLEGROSQLWarn) (rs []string) {
   369  	for _, warn := range warns {
   370  		rs = append(rs, fmt.Sprint(warn.Err.Error()))
   371  	}
   372  	return rs
   373  }
   374  
   375  // GenerateOutputIfNeeded generate the output file.
   376  func (t *TestData) GenerateOutputIfNeeded() error {
   377  	if !record {
   378  		return nil
   379  	}
   380  
   381  	buf := new(bytes.Buffer)
   382  	enc := json.NewCausetEncoder(buf)
   383  	enc.SetEscapeHTML(false)
   384  	enc.SetIndent("", "  ")
   385  	for i, test := range t.output {
   386  		err := enc.Encode(test.decodedOut)
   387  		if err != nil {
   388  			return err
   389  		}
   390  		res := make([]byte, len(buf.Bytes()))
   391  		copy(res, buf.Bytes())
   392  		buf.Reset()
   393  		rm := json.RawMessage(res)
   394  		t.output[i].Cases = &rm
   395  	}
   396  	err := enc.Encode(t.output)
   397  	if err != nil {
   398  		return err
   399  	}
   400  	file, err := os.Create(fmt.Sprintf("%s_out.json", t.filePathPrefix))
   401  	if err != nil {
   402  		return err
   403  	}
   404  	defer func() {
   405  		if err1 := file.Close(); err == nil && err1 != nil {
   406  			err = err1
   407  		}
   408  	}()
   409  	_, err = file.Write(buf.Bytes())
   410  	return err
   411  }
   412  
   413  // ConfigTestUtils contains a set of set-up/restore methods related to config used in tests.
   414  var ConfigTestUtils configTestUtils
   415  
   416  type configTestUtils struct {
   417  	autoRandom
   418  }
   419  
   420  type autoRandom struct {
   421  	originAllowAutoRandom bool
   422  	originAlterPrimaryKey bool
   423  }
   424  
   425  // SetupAutoRandomTestConfig set alter-primary-key to false and save its origin values.
   426  // This method should only be used for the tests in SerialSuite.
   427  func (a *autoRandom) SetupAutoRandomTestConfig() {
   428  	globalCfg := config.GetGlobalConfig()
   429  	a.originAlterPrimaryKey = globalCfg.AlterPrimaryKey
   430  	globalCfg.AlterPrimaryKey = false
   431  }
   432  
   433  // RestoreAutoRandomTestConfig restore the values had been saved in SetupTestConfig.
   434  // This method should only be used for the tests in SerialSuite.
   435  func (a *autoRandom) RestoreAutoRandomTestConfig() {
   436  	globalCfg := config.GetGlobalConfig()
   437  	globalCfg.AlterPrimaryKey = a.originAlterPrimaryKey
   438  }
   439  
   440  // MaskSortHandles sorts the handles by lowest (fieldTypeBits - 1 - shardBitsCount) bits.
   441  func (a *autoRandom) MaskSortHandles(handles []int64, shardBitsCount int, fieldType byte) []int64 {
   442  	typeBitsLength := allegrosql.DefaultLengthOfMysqlTypes[fieldType] * 8
   443  	const signBitCount = 1
   444  	shiftBitsCount := 64 - typeBitsLength + shardBitsCount + signBitCount
   445  	ordered := make([]int64, len(handles))
   446  	for i, h := range handles {
   447  		ordered[i] = h << shiftBitsCount >> shiftBitsCount
   448  	}
   449  	sort.Slice(ordered, func(i, j int) bool { return ordered[i] < ordered[j] })
   450  	return ordered
   451  }