github.com/pingcap/tidb/parser@v0.0.0-20231013125129-93a834a6bf8d/ast/util_test.go (about) 1 // Copyright 2017 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package ast_test 15 16 import ( 17 "fmt" 18 "strings" 19 "testing" 20 21 "github.com/pingcap/tidb/parser" 22 . "github.com/pingcap/tidb/parser/ast" 23 . "github.com/pingcap/tidb/parser/format" 24 "github.com/pingcap/tidb/parser/mysql" 25 "github.com/pingcap/tidb/parser/test_driver" 26 "github.com/stretchr/testify/require" 27 ) 28 29 func TestCacheable(t *testing.T) { 30 // test non-SelectStmt 31 var stmt Node = &DeleteStmt{} 32 require.False(t, IsReadOnly(stmt)) 33 34 stmt = &InsertStmt{} 35 require.False(t, IsReadOnly(stmt)) 36 37 stmt = &UpdateStmt{} 38 require.False(t, IsReadOnly(stmt)) 39 40 stmt = &ExplainStmt{} 41 require.True(t, IsReadOnly(stmt)) 42 43 stmt = &ExplainStmt{} 44 require.True(t, IsReadOnly(stmt)) 45 46 stmt = &DoStmt{} 47 require.True(t, IsReadOnly(stmt)) 48 49 stmt = &ExplainStmt{ 50 Stmt: &InsertStmt{}, 51 } 52 require.True(t, IsReadOnly(stmt)) 53 54 stmt = &ExplainStmt{ 55 Analyze: true, 56 Stmt: &InsertStmt{}, 57 } 58 require.False(t, IsReadOnly(stmt)) 59 60 stmt = &ExplainStmt{ 61 Stmt: &SelectStmt{}, 62 } 63 require.True(t, IsReadOnly(stmt)) 64 65 stmt = &ExplainStmt{ 66 Analyze: true, 67 Stmt: &SelectStmt{}, 68 } 69 require.True(t, IsReadOnly(stmt)) 70 71 stmt = &ShowStmt{} 72 require.True(t, IsReadOnly(stmt)) 73 74 stmt = &ShowStmt{} 75 require.True(t, IsReadOnly(stmt)) 76 } 77 78 func TestUnionReadOnly(t *testing.T) { 79 selectReadOnly := &SelectStmt{} 80 selectForUpdate := &SelectStmt{ 81 LockInfo: &SelectLockInfo{LockType: SelectLockForUpdate}, 82 } 83 selectForUpdateNoWait := &SelectStmt{ 84 LockInfo: &SelectLockInfo{LockType: SelectLockForUpdateNoWait}, 85 } 86 87 setOprStmt := &SetOprStmt{ 88 SelectList: &SetOprSelectList{ 89 Selects: []Node{selectReadOnly, selectReadOnly}, 90 }, 91 } 92 require.True(t, IsReadOnly(setOprStmt)) 93 94 setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectReadOnly, selectReadOnly} 95 require.True(t, IsReadOnly(setOprStmt)) 96 97 setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectForUpdate} 98 require.False(t, IsReadOnly(setOprStmt)) 99 100 setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectForUpdateNoWait} 101 require.False(t, IsReadOnly(setOprStmt)) 102 103 setOprStmt.SelectList.Selects = []Node{selectForUpdate, selectForUpdateNoWait} 104 require.False(t, IsReadOnly(setOprStmt)) 105 106 setOprStmt.SelectList.Selects = []Node{selectReadOnly, selectForUpdate, selectForUpdateNoWait} 107 require.False(t, IsReadOnly(setOprStmt)) 108 } 109 110 // CleanNodeText set the text of node and all child node empty. 111 // For test only. 112 func CleanNodeText(node Node) { 113 var cleaner nodeTextCleaner 114 node.Accept(&cleaner) 115 } 116 117 // nodeTextCleaner clean the text of a node and it's child node. 118 // For test only. 119 type nodeTextCleaner struct { 120 } 121 122 // Enter implements Visitor interface. 123 func (checker *nodeTextCleaner) Enter(in Node) (out Node, skipChildren bool) { 124 in.SetText(nil, "") 125 in.SetOriginTextPosition(0) 126 if v, ok := in.(ValueExpr); ok && v != nil { 127 tpFlag := v.GetType().GetFlag() 128 if tpFlag&mysql.UnderScoreCharsetFlag != 0 { 129 // ignore underscore charset flag to let `'abc' = _utf8'abc'` pass 130 tpFlag ^= mysql.UnderScoreCharsetFlag 131 v.GetType().SetFlag(tpFlag) 132 } 133 } 134 135 switch node := in.(type) { 136 case *Constraint: 137 if node.Option != nil { 138 if node.Option.KeyBlockSize == 0x0 && node.Option.Tp == 0 && node.Option.Comment == "" { 139 node.Option = nil 140 } 141 } 142 case *FuncCallExpr: 143 node.FnName.O = strings.ToLower(node.FnName.O) 144 switch node.FnName.L { 145 case "convert": 146 node.Args[1].(*test_driver.ValueExpr).Datum.SetBytes(nil) 147 } 148 case *AggregateFuncExpr: 149 node.F = strings.ToLower(node.F) 150 case *FieldList: 151 for _, f := range node.Fields { 152 f.Offset = 0 153 } 154 case *AlterTableSpec: 155 for _, opt := range node.Options { 156 opt.StrValue = strings.ToLower(opt.StrValue) 157 } 158 case *Join: 159 node.ExplicitParens = false 160 case *ColumnDef: 161 node.Tp.CleanElemIsBinaryLit() 162 } 163 return in, false 164 } 165 166 // Leave implements Visitor interface. 167 func (checker *nodeTextCleaner) Leave(in Node) (out Node, ok bool) { 168 return in, true 169 } 170 171 type NodeRestoreTestCase struct { 172 sourceSQL string 173 expectSQL string 174 } 175 176 func runNodeRestoreTest(t *testing.T, nodeTestCases []NodeRestoreTestCase, template string, extractNodeFunc func(node Node) Node) { 177 runNodeRestoreTestWithFlags(t, nodeTestCases, template, extractNodeFunc, DefaultRestoreFlags) 178 } 179 180 func runNodeRestoreTestWithFlags(t *testing.T, nodeTestCases []NodeRestoreTestCase, template string, extractNodeFunc func(node Node) Node, flags RestoreFlags) { 181 p := parser.New() 182 p.EnableWindowFunc(true) 183 for _, testCase := range nodeTestCases { 184 sourceSQL := fmt.Sprintf(template, testCase.sourceSQL) 185 expectSQL := fmt.Sprintf(template, testCase.expectSQL) 186 stmt, err := p.ParseOneStmt(sourceSQL, "", "") 187 comment := fmt.Sprintf("source %#v", testCase) 188 require.NoError(t, err, comment) 189 var sb strings.Builder 190 err = extractNodeFunc(stmt).Restore(NewRestoreCtx(flags, &sb)) 191 require.NoError(t, err, comment) 192 restoreSql := fmt.Sprintf(template, sb.String()) 193 comment = fmt.Sprintf("source %#v; restore %v", testCase, restoreSql) 194 require.Equal(t, expectSQL, restoreSql, comment) 195 stmt2, err := p.ParseOneStmt(restoreSql, "", "") 196 require.NoError(t, err, comment) 197 CleanNodeText(stmt) 198 CleanNodeText(stmt2) 199 require.Equal(t, stmt, stmt2, comment) 200 } 201 } 202 203 // runNodeRestoreTestWithFlagsStmtChange likes runNodeRestoreTestWithFlags but not check if the ASTs are same. 204 // Sometimes the AST are different and it's expected. 205 func runNodeRestoreTestWithFlagsStmtChange(t *testing.T, nodeTestCases []NodeRestoreTestCase, template string, extractNodeFunc func(node Node) Node, flags RestoreFlags) { 206 p := parser.New() 207 p.EnableWindowFunc(true) 208 for _, testCase := range nodeTestCases { 209 sourceSQL := fmt.Sprintf(template, testCase.sourceSQL) 210 expectSQL := fmt.Sprintf(template, testCase.expectSQL) 211 stmt, err := p.ParseOneStmt(sourceSQL, "", "") 212 comment := fmt.Sprintf("source %#v", testCase) 213 require.NoError(t, err, comment) 214 var sb strings.Builder 215 err = extractNodeFunc(stmt).Restore(NewRestoreCtx(flags, &sb)) 216 require.NoError(t, err, comment) 217 restoreSql := fmt.Sprintf(template, sb.String()) 218 comment = fmt.Sprintf("source %#v; restore %v", testCase, restoreSql) 219 require.Equal(t, expectSQL, restoreSql, comment) 220 } 221 }