github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/soliton/testkit/testkit.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  //go:build !codes
    15  // +build !codes
    16  
    17  package testkit
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"fmt"
    23  	"sort"
    24  	"strings"
    25  	"sync/atomic"
    26  
    27  	"github.com/whtcorpsinc/BerolinaSQL/perceptron"
    28  	"github.com/whtcorpsinc/BerolinaSQL/terror"
    29  	"github.com/whtcorpsinc/check"
    30  	"github.com/whtcorpsinc/errors"
    31  	"github.com/whtcorpsinc/milevadb/ekv"
    32  	"github.com/whtcorpsinc/milevadb/petri"
    33  	"github.com/whtcorpsinc/milevadb/soliton/solitonutil"
    34  	"github.com/whtcorpsinc/milevadb/soliton/sqlexec"
    35  	"github.com/whtcorpsinc/milevadb/stochastik"
    36  	"github.com/whtcorpsinc/milevadb/types"
    37  )
    38  
    39  // TestKit is a utility to run allegrosql test.
    40  type TestKit struct {
    41  	c           *check.C
    42  	causetstore ekv.CausetStorage
    43  	Se          stochastik.Stochastik
    44  }
    45  
    46  // Result is the result returned by MustQuery.
    47  type Result struct {
    48  	rows    [][]string
    49  	comment check.CommentInterface
    50  	c       *check.C
    51  }
    52  
    53  // Check asserts the result equals the expected results.
    54  func (res *Result) Check(expected [][]interface{}) {
    55  	resBuff := bytes.NewBufferString("")
    56  	for _, event := range res.rows {
    57  		fmt.Fprintf(resBuff, "%s\n", event)
    58  	}
    59  	needBuff := bytes.NewBufferString("")
    60  	for _, event := range expected {
    61  		fmt.Fprintf(needBuff, "%s\n", event)
    62  	}
    63  	res.c.Assert(resBuff.String(), check.Equals, needBuff.String(), res.comment)
    64  }
    65  
    66  // CheckAt asserts the result of selected defCausumns equals the expected results.
    67  func (res *Result) CheckAt(defcaus []int, expected [][]interface{}) {
    68  	for _, e := range expected {
    69  		res.c.Assert(len(defcaus), check.Equals, len(e))
    70  	}
    71  
    72  	rows := make([][]string, 0, len(expected))
    73  	for i := range res.rows {
    74  		event := make([]string, 0, len(defcaus))
    75  		for _, r := range defcaus {
    76  			event = append(event, res.rows[i][r])
    77  		}
    78  		rows = append(rows, event)
    79  	}
    80  	got := fmt.Sprintf("%s", rows)
    81  	need := fmt.Sprintf("%s", expected)
    82  	res.c.Assert(got, check.Equals, need, res.comment)
    83  }
    84  
    85  // Rows returns the result data.
    86  func (res *Result) Rows() [][]interface{} {
    87  	ifacesSlice := make([][]interface{}, len(res.rows))
    88  	for i := range res.rows {
    89  		ifaces := make([]interface{}, len(res.rows[i]))
    90  		for j := range res.rows[i] {
    91  			ifaces[j] = res.rows[i][j]
    92  		}
    93  		ifacesSlice[i] = ifaces
    94  	}
    95  	return ifacesSlice
    96  }
    97  
    98  // Sort sorts and return the result.
    99  func (res *Result) Sort() *Result {
   100  	sort.Slice(res.rows, func(i, j int) bool {
   101  		a := res.rows[i]
   102  		b := res.rows[j]
   103  		for i := range a {
   104  			if a[i] < b[i] {
   105  				return true
   106  			} else if a[i] > b[i] {
   107  				return false
   108  			}
   109  		}
   110  		return false
   111  	})
   112  	return res
   113  }
   114  
   115  // NewTestKit returns a new *TestKit.
   116  func NewTestKit(c *check.C, causetstore ekv.CausetStorage) *TestKit {
   117  	return &TestKit{
   118  		c:           c,
   119  		causetstore: causetstore,
   120  	}
   121  }
   122  
   123  // NewTestKitWithInit returns a new *TestKit and creates a stochastik.
   124  func NewTestKitWithInit(c *check.C, causetstore ekv.CausetStorage) *TestKit {
   125  	tk := NewTestKit(c, causetstore)
   126  	// Use test and prepare a stochastik.
   127  	tk.MustInterDirc("use test")
   128  	return tk
   129  }
   130  
   131  var connectionID uint64
   132  
   133  // GetConnectionID get the connection ID for tk.Se
   134  func (tk *TestKit) GetConnectionID() {
   135  	if tk.Se != nil {
   136  		id := atomic.AddUint64(&connectionID, 1)
   137  		tk.Se.SetConnectionID(id)
   138  	}
   139  }
   140  
   141  // InterDirc executes a allegrosql memex.
   142  func (tk *TestKit) InterDirc(allegrosql string, args ...interface{}) (sqlexec.RecordSet, error) {
   143  	var err error
   144  	if tk.Se == nil {
   145  		tk.Se, err = stochastik.CreateStochastik4Test(tk.causetstore)
   146  		tk.c.Assert(err, check.IsNil)
   147  		tk.GetConnectionID()
   148  	}
   149  	ctx := context.Background()
   150  	if len(args) == 0 {
   151  		sc := tk.Se.GetStochastikVars().StmtCtx
   152  		prevWarns := sc.GetWarnings()
   153  		stmts, err := tk.Se.Parse(ctx, allegrosql)
   154  		if err != nil {
   155  			return nil, errors.Trace(err)
   156  		}
   157  		warns := sc.GetWarnings()
   158  		BerolinaSQLWarns := warns[len(prevWarns):]
   159  		var rs0 sqlexec.RecordSet
   160  		for i, stmt := range stmts {
   161  			rs, err := tk.Se.InterDircuteStmt(ctx, stmt)
   162  			if i == 0 {
   163  				rs0 = rs
   164  			}
   165  			if err != nil {
   166  				tk.Se.GetStochastikVars().StmtCtx.AppendError(err)
   167  				return nil, errors.Trace(err)
   168  			}
   169  		}
   170  		if len(BerolinaSQLWarns) > 0 {
   171  			tk.Se.GetStochastikVars().StmtCtx.AppendWarnings(BerolinaSQLWarns)
   172  		}
   173  		return rs0, nil
   174  	}
   175  	stmtID, _, _, err := tk.Se.PrepareStmt(allegrosql)
   176  	if err != nil {
   177  		return nil, errors.Trace(err)
   178  	}
   179  	params := make([]types.Causet, len(args))
   180  	for i := 0; i < len(params); i++ {
   181  		params[i] = types.NewCauset(args[i])
   182  	}
   183  	rs, err := tk.Se.InterDircutePreparedStmt(ctx, stmtID, params)
   184  	if err != nil {
   185  		return nil, errors.Trace(err)
   186  	}
   187  	err = tk.Se.DropPreparedStmt(stmtID)
   188  	if err != nil {
   189  		return nil, errors.Trace(err)
   190  	}
   191  	return rs, nil
   192  }
   193  
   194  // CheckInterDircResult checks the affected rows and the insert id after executing MustInterDirc.
   195  func (tk *TestKit) CheckInterDircResult(affectedRows, insertID int64) {
   196  	tk.c.Assert(affectedRows, check.Equals, int64(tk.Se.AffectedRows()))
   197  	tk.c.Assert(insertID, check.Equals, int64(tk.Se.LastInsertID()))
   198  }
   199  
   200  // CheckLastMessage checks last message after executing MustInterDirc
   201  func (tk *TestKit) CheckLastMessage(msg string) {
   202  	tk.c.Assert(tk.Se.LastMessage(), check.Equals, msg)
   203  }
   204  
   205  // MustInterDirc executes a allegrosql memex and asserts nil error.
   206  func (tk *TestKit) MustInterDirc(allegrosql string, args ...interface{}) {
   207  	res, err := tk.InterDirc(allegrosql, args...)
   208  	tk.c.Assert(err, check.IsNil, check.Commentf("allegrosql:%s, %v, error stack %v", allegrosql, args, errors.ErrorStack(err)))
   209  	if res != nil {
   210  		tk.c.Assert(res.Close(), check.IsNil)
   211  	}
   212  }
   213  
   214  // HasCauset checks if the result execution plan contains specific plan.
   215  func (tk *TestKit) HasCauset(allegrosql string, plan string, args ...interface{}) bool {
   216  	rs := tk.MustQuery("explain "+allegrosql, args...)
   217  	for i := range rs.rows {
   218  		if strings.Contains(rs.rows[i][0], plan) {
   219  			return true
   220  		}
   221  	}
   222  	return false
   223  }
   224  
   225  // MustUseIndex checks if the result execution plan contains specific index(es).
   226  func (tk *TestKit) MustUseIndex(allegrosql string, index string, args ...interface{}) bool {
   227  	rs := tk.MustQuery("explain "+allegrosql, args...)
   228  	for i := range rs.rows {
   229  		if strings.Contains(rs.rows[i][3], "index:"+index) {
   230  			return true
   231  		}
   232  	}
   233  	return false
   234  }
   235  
   236  // MustIndexLookup checks whether the plan for the allegrosql is IndexLookUp.
   237  func (tk *TestKit) MustIndexLookup(allegrosql string, args ...interface{}) *Result {
   238  	tk.c.Assert(tk.HasCauset(allegrosql, "IndexLookUp", args...), check.IsTrue)
   239  	return tk.MustQuery(allegrosql, args...)
   240  }
   241  
   242  // MustBlockDual checks whether the plan for the allegrosql is BlockDual.
   243  func (tk *TestKit) MustBlockDual(allegrosql string, args ...interface{}) *Result {
   244  	tk.c.Assert(tk.HasCauset(allegrosql, "BlockDual", args...), check.IsTrue)
   245  	return tk.MustQuery(allegrosql, args...)
   246  }
   247  
   248  // MustPointGet checks whether the plan for the allegrosql is Point_Get.
   249  func (tk *TestKit) MustPointGet(allegrosql string, args ...interface{}) *Result {
   250  	rs := tk.MustQuery("explain "+allegrosql, args...)
   251  	tk.c.Assert(len(rs.rows), check.Equals, 1)
   252  	tk.c.Assert(strings.Contains(rs.rows[0][0], "Point_Get"), check.IsTrue, check.Commentf("plan %v", rs.rows[0][0]))
   253  	return tk.MustQuery(allegrosql, args...)
   254  }
   255  
   256  // MustQuery query the memexs and returns result rows.
   257  // If expected result is set it asserts the query result equals expected result.
   258  func (tk *TestKit) MustQuery(allegrosql string, args ...interface{}) *Result {
   259  	comment := check.Commentf("allegrosql:%s, args:%v", allegrosql, args)
   260  	rs, err := tk.InterDirc(allegrosql, args...)
   261  	tk.c.Assert(errors.ErrorStack(err), check.Equals, "", comment)
   262  	tk.c.Assert(rs, check.NotNil, comment)
   263  	return tk.ResultSetToResult(rs, comment)
   264  }
   265  
   266  // QueryToErr executes a allegrosql memex and discard results.
   267  func (tk *TestKit) QueryToErr(allegrosql string, args ...interface{}) error {
   268  	comment := check.Commentf("allegrosql:%s, args:%v", allegrosql, args)
   269  	res, err := tk.InterDirc(allegrosql, args...)
   270  	tk.c.Assert(errors.ErrorStack(err), check.Equals, "", comment)
   271  	tk.c.Assert(res, check.NotNil, comment)
   272  	_, resErr := stochastik.GetRows4Test(context.Background(), tk.Se, res)
   273  	tk.c.Assert(res.Close(), check.IsNil)
   274  	return resErr
   275  }
   276  
   277  // InterDircToErr executes a allegrosql memex and discard results.
   278  func (tk *TestKit) InterDircToErr(allegrosql string, args ...interface{}) error {
   279  	res, err := tk.InterDirc(allegrosql, args...)
   280  	if res != nil {
   281  		tk.c.Assert(res.Close(), check.IsNil)
   282  	}
   283  	return err
   284  }
   285  
   286  // MustGetErrMsg executes a allegrosql memex and assert it's error message.
   287  func (tk *TestKit) MustGetErrMsg(allegrosql string, errStr string) {
   288  	err := tk.InterDircToErr(allegrosql)
   289  	tk.c.Assert(err, check.NotNil)
   290  	tk.c.Assert(err.Error(), check.Equals, errStr)
   291  }
   292  
   293  // MustGetErrCode executes a allegrosql memex and assert it's error code.
   294  func (tk *TestKit) MustGetErrCode(allegrosql string, errCode int) {
   295  	_, err := tk.InterDirc(allegrosql)
   296  	tk.c.Assert(err, check.NotNil)
   297  	originErr := errors.Cause(err)
   298  	tErr, ok := originErr.(*terror.Error)
   299  	tk.c.Assert(ok, check.IsTrue, check.Commentf("expect type 'terror.Error', but obtain '%T'", originErr))
   300  	sqlErr := terror.ToALLEGROSQLError(tErr)
   301  	tk.c.Assert(int(sqlErr.Code), check.Equals, errCode, check.Commentf("Assertion failed, origin err:\n  %v", sqlErr))
   302  }
   303  
   304  // ResultSetToResult converts sqlexec.RecordSet to testkit.Result.
   305  // It is used to check results of execute memex in binary mode.
   306  func (tk *TestKit) ResultSetToResult(rs sqlexec.RecordSet, comment check.CommentInterface) *Result {
   307  	return tk.ResultSetToResultWithCtx(context.Background(), rs, comment)
   308  }
   309  
   310  // ResultSetToResultWithCtx converts sqlexec.RecordSet to testkit.Result.
   311  func (tk *TestKit) ResultSetToResultWithCtx(ctx context.Context, rs sqlexec.RecordSet, comment check.CommentInterface) *Result {
   312  	sRows, err := stochastik.ResultSetToStringSlice(ctx, tk.Se, rs)
   313  	tk.c.Check(err, check.IsNil, comment)
   314  	return &Result{rows: sRows, c: tk.c, comment: comment}
   315  }
   316  
   317  // Rows is similar to RowsWithSep, use white space as separator string.
   318  func Rows(args ...string) [][]interface{} {
   319  	return solitonutil.RowsWithSep(" ", args...)
   320  }
   321  
   322  // GetBlockID gets causet ID by name.
   323  func (tk *TestKit) GetBlockID(blockName string) int64 {
   324  	dom := petri.GetPetri(tk.Se)
   325  	is := dom.SchemaReplicant()
   326  	tbl, err := is.BlockByName(perceptron.NewCIStr("test"), perceptron.NewCIStr(blockName))
   327  	tk.c.Assert(err, check.IsNil)
   328  	return tbl.Meta().ID
   329  }