github.com/XiaoMi/Gaea@v1.2.5/proxy/plan/plan_tidb_test.go (about)

     1  // Copyright 2019 The Gaea Authors. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"strings"
    19  	"testing"
    20  
    21  	"github.com/XiaoMi/Gaea/parser"
    22  	"github.com/XiaoMi/Gaea/parser/ast"
    23  	"github.com/XiaoMi/Gaea/parser/format"
    24  )
    25  
    26  func BenchmarkSelectStmtCheckShard(b *testing.B) {
    27  	r, err := prepareShardKingshardRouter()
    28  	if err != nil {
    29  		b.Fatal(err)
    30  	}
    31  	benchmarks := []struct {
    32  		sql     string
    33  		isShard bool
    34  	}{
    35  		{"select * from test_hash_0, test_hash_1 where test_hash_0.id = test_hash_1.id", true},
    36  		{"select * from test_a, test_b left join test_c on test_a.id=test_c.id where test_a.id in (1,2,3) or test_b.k = 0 order by test_a.id desc limit 10", false},
    37  	}
    38  	for _, bm := range benchmarks {
    39  		b.Run(bm.sql, func(b *testing.B) {
    40  			for i := 0; i < b.N; i++ {
    41  				//for i := 0; i < 1; i++ {
    42  				stmt, err := parser.ParseSQL(bm.sql)
    43  				if err != nil {
    44  					b.Fatal(err)
    45  				}
    46  				selectStmt, ok := stmt.(*ast.SelectStmt)
    47  				if !ok {
    48  					b.Fatal("not a select stmt")
    49  				}
    50  				visitor := NewChecker("test", r)
    51  				selectStmt.Accept(visitor)
    52  				if visitor.IsShard() != bm.isShard {
    53  					b.Errorf("isShard not equal, expect: %v, actual: %v", bm.isShard, visitor.IsShard())
    54  				}
    55  			}
    56  		})
    57  	}
    58  }
    59  
    60  func TestSelectStmtCheckShard(t *testing.T) {
    61  	r, err := prepareShardKingshardRouter()
    62  	if err != nil {
    63  		t.Fatal(err)
    64  	}
    65  	tests := []struct {
    66  		sql     string
    67  		isShard bool
    68  	}{
    69  		{"select * from test_hash_0, test_hash_1 where test_hash_0.id = test_hash_1.id", true},
    70  		{"select * from test_a, test_b left join test_c on test_a.id=test_c.id where test_a.id in (1,2,3) or test_b.k = 0 order by test_a.id desc limit 10", false},
    71  	}
    72  	for _, test := range tests {
    73  		t.Run(test.sql, func(t *testing.T) {
    74  			//for i := 0; i < 1; i++ {
    75  			stmt, err := parser.ParseSQL(test.sql)
    76  			if err != nil {
    77  				t.Fatal(err)
    78  			}
    79  			selectStmt, ok := stmt.(*ast.SelectStmt)
    80  			if !ok {
    81  				t.Fatal("not a select stmt")
    82  			}
    83  			visitor := NewChecker("test", r)
    84  			selectStmt.Accept(visitor)
    85  			if visitor.IsShard() != test.isShard {
    86  				t.Errorf("isShard not equal, expect: %v, actual: %v", test.isShard, visitor.IsShard())
    87  			}
    88  		})
    89  	}
    90  }
    91  
    92  // TODO: no router, panic, change to table test function
    93  func _TestGroupByRewriting(t *testing.T) {
    94  	tests := []struct {
    95  		sql        string
    96  		rewrite    string
    97  		groupByCol []int
    98  		count      int
    99  	}{
   100  		{"select * from tbl1 group by a, b", "SELECT *,a,b FROM tbl1",
   101  			[]int{1, 2}, 2},
   102  	}
   103  	for _, test := range tests {
   104  		t.Run(test.sql, func(t *testing.T) {
   105  			stmt, err := parser.ParseSQL(test.sql)
   106  			if err != nil {
   107  				t.Fatal(err)
   108  			}
   109  			selectStmt, ok := stmt.(*ast.SelectStmt)
   110  			if !ok {
   111  				t.Fatal("not a select stmt")
   112  			}
   113  
   114  			info := NewSelectPlan("test", test.sql, nil)
   115  			if err := HandleSelectStmt(info, selectStmt); err != nil {
   116  				t.Fatal(err)
   117  			}
   118  			s := &strings.Builder{}
   119  			selectStmt.Restore(format.NewRestoreCtx(0, s))
   120  			rewriteSQL := s.String()
   121  			if rewriteSQL != test.rewrite {
   122  				t.Errorf("rewrite sql not equal, expect: %v, actual: %v", test.rewrite, rewriteSQL)
   123  			}
   124  			if len(info.GetGroupByColumnInfo()) != test.count {
   125  				t.Errorf("rewrite sql not equal, expect: %v, actual: %v", test.rewrite, rewriteSQL)
   126  			}
   127  			for i, columnsIndex := range info.GetGroupByColumnInfo() {
   128  				if test.groupByCol[i] != columnsIndex {
   129  					t.Errorf("groupByColumnStart not equal, expect: %v, actual: %v", test.groupByCol[i], columnsIndex)
   130  				}
   131  			}
   132  		})
   133  	}
   134  }