github.com/pingcap/tidb/parser@v0.0.0-20231013125129-93a834a6bf8d/ast/util_test.go (about)

     1  // Copyright 2017 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package ast_test
    15  
    16  import (
    17  	"fmt"
    18  	"strings"
    19  	"testing"
    20  
    21  	"github.com/pingcap/tidb/parser"
    22  	. "github.com/pingcap/tidb/parser/ast"
    23  	. "github.com/pingcap/tidb/parser/format"
    24  	"github.com/pingcap/tidb/parser/mysql"
    25  	"github.com/pingcap/tidb/parser/test_driver"
    26  	"github.com/stretchr/testify/require"
    27  )
    28  
    29  func TestCacheable(t *testing.T) {
    30  	// test non-SelectStmt
    31  	var stmt Node = &DeleteStmt{}
    32  	require.False(t, IsReadOnly(stmt))
    33  
    34  	stmt = &InsertStmt{}
    35  	require.False(t, IsReadOnly(stmt))
    36  
    37  	stmt = &UpdateStmt{}
    38  	require.False(t, IsReadOnly(stmt))
    39  
    40  	stmt = &ExplainStmt{}
    41  	require.True(t, IsReadOnly(stmt))
    42  
    43  	stmt = &ExplainStmt{}
    44  	require.True(t, IsReadOnly(stmt))
    45  
    46  	stmt = &DoStmt{}
    47  	require.True(t, IsReadOnly(stmt))
    48  
    49  	stmt = &ExplainStmt{
    50  		Stmt: &InsertStmt{},
    51  	}
    52  	require.True(t, IsReadOnly(stmt))
    53  
    54  	stmt = &ExplainStmt{
    55  		Analyze: true,
    56  		Stmt:    &InsertStmt{},
    57  	}
    58  	require.False(t, IsReadOnly(stmt))
    59  
    60  	stmt = &ExplainStmt{
    61  		Stmt: &SelectStmt{},
    62  	}
    63  	require.True(t, IsReadOnly(stmt))
    64  
    65  	stmt = &ExplainStmt{
    66  		Analyze: true,
    67  		Stmt:    &SelectStmt{},
    68  	}
    69  	require.True(t, IsReadOnly(stmt))
    70  
    71  	stmt = &ShowStmt{}
    72  	require.True(t, IsReadOnly(stmt))
    73  
    74  	stmt = &ShowStmt{}
    75  	require.True(t, IsReadOnly(stmt))
    76  }
    77  
    78  func TestUnionReadOnly(t *testing.T) {
    79  	selectReadOnly := &SelectStmt{}
    80  	selectForUpdate := &SelectStmt{
    81  		LockInfo: &SelectLockInfo{LockType: SelectLockForUpdate},
    82  	}
    83  	selectForUpdateNoWait := &SelectStmt{
    84  		LockInfo: &SelectLockInfo{LockType: SelectLockForUpdateNoWait},
    85  	}
    86  
    87  	setOprStmt := &SetOprStmt{
    88  		SelectList: &SetOprSelectList{
    89  			Selects: []Node{selectReadOnly, selectReadOnly},
    90  		},
    91  	}
    92  	require.True(t, IsReadOnly(setOprStmt))
    93  
    94  	setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectReadOnly, selectReadOnly}
    95  	require.True(t, IsReadOnly(setOprStmt))
    96  
    97  	setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectForUpdate}
    98  	require.False(t, IsReadOnly(setOprStmt))
    99  
   100  	setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectForUpdateNoWait}
   101  	require.False(t, IsReadOnly(setOprStmt))
   102  
   103  	setOprStmt.SelectList.Selects = []Node{selectForUpdate, selectForUpdateNoWait}
   104  	require.False(t, IsReadOnly(setOprStmt))
   105  
   106  	setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectForUpdate, selectForUpdateNoWait}
   107  	require.False(t, IsReadOnly(setOprStmt))
   108  }
   109  
   110  // CleanNodeText set the text of node and all child node empty.
   111  // For test only.
   112  func CleanNodeText(node Node) {
   113  	var cleaner nodeTextCleaner
   114  	node.Accept(&cleaner)
   115  }
   116  
   117  // nodeTextCleaner clean the text of a node and it's child node.
   118  // For test only.
   119  type nodeTextCleaner struct {
   120  }
   121  
   122  // Enter implements Visitor interface.
   123  func (checker *nodeTextCleaner) Enter(in Node) (out Node, skipChildren bool) {
   124  	in.SetText(nil, "")
   125  	in.SetOriginTextPosition(0)
   126  	if v, ok := in.(ValueExpr); ok && v != nil {
   127  		tpFlag := v.GetType().GetFlag()
   128  		if tpFlag&mysql.UnderScoreCharsetFlag != 0 {
   129  			// ignore underscore charset flag to let `'abc' = _utf8'abc'` pass
   130  			tpFlag ^= mysql.UnderScoreCharsetFlag
   131  			v.GetType().SetFlag(tpFlag)
   132  		}
   133  	}
   134  
   135  	switch node := in.(type) {
   136  	case *Constraint:
   137  		if node.Option != nil {
   138  			if node.Option.KeyBlockSize == 0x0 && node.Option.Tp == 0 && node.Option.Comment == "" {
   139  				node.Option = nil
   140  			}
   141  		}
   142  	case *FuncCallExpr:
   143  		node.FnName.O = strings.ToLower(node.FnName.O)
   144  		switch node.FnName.L {
   145  		case "convert":
   146  			node.Args[1].(*test_driver.ValueExpr).Datum.SetBytes(nil)
   147  		}
   148  	case *AggregateFuncExpr:
   149  		node.F = strings.ToLower(node.F)
   150  	case *FieldList:
   151  		for _, f := range node.Fields {
   152  			f.Offset = 0
   153  		}
   154  	case *AlterTableSpec:
   155  		for _, opt := range node.Options {
   156  			opt.StrValue = strings.ToLower(opt.StrValue)
   157  		}
   158  	case *Join:
   159  		node.ExplicitParens = false
   160  	case *ColumnDef:
   161  		node.Tp.CleanElemIsBinaryLit()
   162  	}
   163  	return in, false
   164  }
   165  
   166  // Leave implements Visitor interface.
   167  func (checker *nodeTextCleaner) Leave(in Node) (out Node, ok bool) {
   168  	return in, true
   169  }
   170  
   171  type NodeRestoreTestCase struct {
   172  	sourceSQL string
   173  	expectSQL string
   174  }
   175  
   176  func runNodeRestoreTest(t *testing.T, nodeTestCases []NodeRestoreTestCase, template string, extractNodeFunc func(node Node) Node) {
   177  	runNodeRestoreTestWithFlags(t, nodeTestCases, template, extractNodeFunc, DefaultRestoreFlags)
   178  }
   179  
   180  func runNodeRestoreTestWithFlags(t *testing.T, nodeTestCases []NodeRestoreTestCase, template string, extractNodeFunc func(node Node) Node, flags RestoreFlags) {
   181  	p := parser.New()
   182  	p.EnableWindowFunc(true)
   183  	for _, testCase := range nodeTestCases {
   184  		sourceSQL := fmt.Sprintf(template, testCase.sourceSQL)
   185  		expectSQL := fmt.Sprintf(template, testCase.expectSQL)
   186  		stmt, err := p.ParseOneStmt(sourceSQL, "", "")
   187  		comment := fmt.Sprintf("source %#v", testCase)
   188  		require.NoError(t, err, comment)
   189  		var sb strings.Builder
   190  		err = extractNodeFunc(stmt).Restore(NewRestoreCtx(flags, &sb))
   191  		require.NoError(t, err, comment)
   192  		restoreSql := fmt.Sprintf(template, sb.String())
   193  		comment = fmt.Sprintf("source %#v; restore %v", testCase, restoreSql)
   194  		require.Equal(t, expectSQL, restoreSql, comment)
   195  		stmt2, err := p.ParseOneStmt(restoreSql, "", "")
   196  		require.NoError(t, err, comment)
   197  		CleanNodeText(stmt)
   198  		CleanNodeText(stmt2)
   199  		require.Equal(t, stmt, stmt2, comment)
   200  	}
   201  }
   202  
   203  // runNodeRestoreTestWithFlagsStmtChange likes runNodeRestoreTestWithFlags but not check if the ASTs are same.
   204  // Sometimes the AST are different and it's expected.
   205  func runNodeRestoreTestWithFlagsStmtChange(t *testing.T, nodeTestCases []NodeRestoreTestCase, template string, extractNodeFunc func(node Node) Node, flags RestoreFlags) {
   206  	p := parser.New()
   207  	p.EnableWindowFunc(true)
   208  	for _, testCase := range nodeTestCases {
   209  		sourceSQL := fmt.Sprintf(template, testCase.sourceSQL)
   210  		expectSQL := fmt.Sprintf(template, testCase.expectSQL)
   211  		stmt, err := p.ParseOneStmt(sourceSQL, "", "")
   212  		comment := fmt.Sprintf("source %#v", testCase)
   213  		require.NoError(t, err, comment)
   214  		var sb strings.Builder
   215  		err = extractNodeFunc(stmt).Restore(NewRestoreCtx(flags, &sb))
   216  		require.NoError(t, err, comment)
   217  		restoreSql := fmt.Sprintf(template, sb.String())
   218  		comment = fmt.Sprintf("source %#v; restore %v", testCase, restoreSql)
   219  		require.Equal(t, expectSQL, restoreSql, comment)
   220  	}
   221  }