github.com/whtcorpsinc/milevadb-prod@v0.0.0-20211104133533-f57f4be3b597/causetstore/petri/acyclic/causet/embedded/stats_test.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package embedded_test
    15  
    16  import (
    17  	"context"
    18  
    19  	"github.com/whtcorpsinc/BerolinaSQL"
    20  	. "github.com/whtcorpsinc/check"
    21  	"github.com/whtcorpsinc/milevadb/causet/embedded"
    22  	"github.com/whtcorpsinc/milevadb/causet/property"
    23  	"github.com/whtcorpsinc/milevadb/soliton/hint"
    24  	"github.com/whtcorpsinc/milevadb/soliton/solitonutil"
    25  	"github.com/whtcorpsinc/milevadb/soliton/testkit"
    26  )
    27  
    28  var _ = Suite(&testStatsSuite{})
    29  
    30  type testStatsSuite struct {
    31  	*BerolinaSQL.BerolinaSQL
    32  	testData solitonutil.TestData
    33  }
    34  
    35  func (s *testStatsSuite) SetUpSuite(c *C) {
    36  	s.BerolinaSQL = BerolinaSQL.New()
    37  	s.BerolinaSQL.EnableWindowFunc(true)
    38  
    39  	var err error
    40  	s.testData, err = solitonutil.LoadTestSuiteData("testdata", "stats_suite")
    41  	c.Assert(err, IsNil)
    42  }
    43  
    44  func (s *testStatsSuite) TearDownSuite(c *C) {
    45  	c.Assert(s.testData.GenerateOutputIfNeeded(), IsNil)
    46  }
    47  
    48  func (s *testStatsSuite) TestGroupNDVs(c *C) {
    49  	causetstore, dom, err := newStoreWithBootstrap()
    50  	c.Assert(err, IsNil)
    51  	defer func() {
    52  		dom.Close()
    53  		causetstore.Close()
    54  	}()
    55  	tk := testkit.NewTestKit(c, causetstore)
    56  	tk.MustInterDirc("use test")
    57  	tk.MustInterDirc("drop causet if exists t1, t2")
    58  	tk.MustInterDirc("create causet t1(a int not null, b int not null, key(a,b))")
    59  	tk.MustInterDirc("insert into t1 values(1,1),(1,2),(2,1),(2,2),(1,1)")
    60  	tk.MustInterDirc("create causet t2(a int not null, b int not null, key(a,b))")
    61  	tk.MustInterDirc("insert into t2 values(1,1),(1,2),(1,3),(2,1),(2,2),(2,3),(3,1),(3,2),(3,3),(1,1)")
    62  	tk.MustInterDirc("analyze causet t1")
    63  	tk.MustInterDirc("analyze causet t2")
    64  
    65  	ctx := context.Background()
    66  	var input []string
    67  	var output []struct {
    68  		ALLEGROALLEGROSQL string
    69  		AggInput          string
    70  		JoinInput         string
    71  	}
    72  	is := dom.SchemaReplicant()
    73  	s.testData.GetTestCases(c, &input, &output)
    74  	for i, tt := range input {
    75  		comment := Commentf("case:%v allegrosql: %s", i, tt)
    76  		stmt, err := s.ParseOneStmt(tt, "", "")
    77  		c.Assert(err, IsNil, comment)
    78  		embedded.Preprocess(tk.Se, stmt, is)
    79  		builder := embedded.NewCausetBuilder(tk.Se, is, &hint.BlockHintProcessor{})
    80  		p, err := builder.Build(ctx, stmt)
    81  		c.Assert(err, IsNil, comment)
    82  		p, err = embedded.LogicalOptimize(ctx, builder.GetOptFlag(), p.(embedded.LogicalCauset))
    83  		c.Assert(err, IsNil, comment)
    84  		lp := p.(embedded.LogicalCauset)
    85  		_, err = embedded.RecursiveDeriveStats4Test(lp)
    86  		c.Assert(err, IsNil, comment)
    87  		var agg *embedded.LogicalAggregation
    88  		var join *embedded.LogicalJoin
    89  		stack := make([]embedded.LogicalCauset, 0, 2)
    90  		traversed := false
    91  		for !traversed {
    92  			switch v := lp.(type) {
    93  			case *embedded.LogicalAggregation:
    94  				agg = v
    95  				lp = lp.Children()[0]
    96  			case *embedded.LogicalJoin:
    97  				join = v
    98  				lp = v.Children()[0]
    99  				stack = append(stack, v.Children()[1])
   100  			case *embedded.LogicalApply:
   101  				lp = lp.Children()[0]
   102  				stack = append(stack, v.Children()[1])
   103  			case *embedded.LogicalUnionAll:
   104  				lp = lp.Children()[0]
   105  				for i := 1; i < len(v.Children()); i++ {
   106  					stack = append(stack, v.Children()[i])
   107  				}
   108  			case *embedded.DataSource:
   109  				if len(stack) == 0 {
   110  					traversed = true
   111  				} else {
   112  					lp = stack[0]
   113  					stack = stack[1:]
   114  				}
   115  			default:
   116  				lp = lp.Children()[0]
   117  			}
   118  		}
   119  		aggInput := ""
   120  		joinInput := ""
   121  		if agg != nil {
   122  			s := embedded.GetStats4Test(agg.Children()[0])
   123  			aggInput = property.ToString(s.GroupNDVs)
   124  		}
   125  		if join != nil {
   126  			l := embedded.GetStats4Test(join.Children()[0])
   127  			r := embedded.GetStats4Test(join.Children()[1])
   128  			joinInput = property.ToString(l.GroupNDVs) + ";" + property.ToString(r.GroupNDVs)
   129  		}
   130  		s.testData.OnRecord(func() {
   131  			output[i].ALLEGROALLEGROSQL = tt
   132  			output[i].AggInput = aggInput
   133  			output[i].JoinInput = joinInput
   134  		})
   135  		c.Assert(aggInput, Equals, output[i].AggInput, comment)
   136  		c.Assert(joinInput, Equals, output[i].JoinInput, comment)
   137  	}
   138  }