vitess.io/vitess@v0.16.2/go/vt/sqlparser/ast_copy_on_rewrite_test.go (about) 1 /* 2 Copyright 2023 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package sqlparser 18 19 import ( 20 "testing" 21 22 "github.com/stretchr/testify/assert" 23 "github.com/stretchr/testify/require" 24 ) 25 26 func TestCopyOnRewrite(t *testing.T) { 27 // rewrite an expression without changing the original 28 expr, err := ParseExpr("a = b") 29 require.NoError(t, err) 30 out := CopyOnRewrite(expr, nil, func(cursor *CopyOnWriteCursor) { 31 col, ok := cursor.Node().(*ColName) 32 if !ok { 33 return 34 } 35 if col.Name.EqualString("a") { 36 cursor.Replace(NewIntLiteral("1")) 37 } 38 }, nil) 39 40 assert.Equal(t, "a = b", String(expr)) 41 assert.Equal(t, "1 = b", String(out)) 42 } 43 44 func TestCopyOnRewriteDeeper(t *testing.T) { 45 // rewrite an expression without changing the original. the changed happens deep in the syntax tree, 46 // here we are testing that all ancestors up to the root are cloned correctly 47 expr, err := ParseExpr("a + b * c = 12") 48 require.NoError(t, err) 49 var path []string 50 out := CopyOnRewrite(expr, nil, func(cursor *CopyOnWriteCursor) { 51 col, ok := cursor.Node().(*ColName) 52 if !ok { 53 return 54 } 55 if col.Name.EqualString("c") { 56 cursor.Replace(NewIntLiteral("1")) 57 } 58 }, func(before, _ SQLNode) { 59 path = append(path, String(before)) 60 }) 61 62 assert.Equal(t, "a + b * c = 12", String(expr)) 63 assert.Equal(t, "a + b * 1 = 12", String(out)) 64 65 expected := []string{ // this are all the nodes that we need to clone when changing the `c` node 66 "c", 67 "b * c", 68 "a + b * c", 69 "a + b * c = 12", 70 } 71 assert.Equal(t, expected, path) 72 } 73 74 func TestDontCopyWithoutRewrite(t *testing.T) { 75 // when no rewriting happens, we want the original back 76 expr, err := ParseExpr("a = b") 77 require.NoError(t, err) 78 out := CopyOnRewrite(expr, nil, func(cursor *CopyOnWriteCursor) {}, nil) 79 80 assert.Same(t, expr, out) 81 } 82 83 func TestStopTreeWalk(t *testing.T) { 84 // stop walking down part of the AST 85 original := "a = b + c" 86 expr, err := ParseExpr(original) 87 require.NoError(t, err) 88 out := CopyOnRewrite(expr, func(node, parent SQLNode) bool { 89 _, ok := node.(*BinaryExpr) 90 return !ok 91 }, func(cursor *CopyOnWriteCursor) { 92 col, ok := cursor.Node().(*ColName) 93 if !ok { 94 return 95 } 96 97 cursor.Replace(NewStrLiteral(col.Name.String())) 98 }, nil) 99 100 assert.Equal(t, original, String(expr)) 101 assert.Equal(t, "'a' = b + c", String(out)) // b + c are unchanged since they are under the + (*BinaryExpr) 102 } 103 104 func TestStopTreeWalkButStillVisit(t *testing.T) { 105 // here we are asserting that even when we stop at the binary expression, we still visit it in the post visitor 106 original := "1337 = b + c" 107 expr, err := ParseExpr(original) 108 require.NoError(t, err) 109 out := CopyOnRewrite(expr, func(node, parent SQLNode) bool { 110 _, ok := node.(*BinaryExpr) 111 return !ok 112 }, func(cursor *CopyOnWriteCursor) { 113 switch cursor.Node().(type) { 114 case *BinaryExpr: 115 cursor.Replace(NewStrLiteral("johnny was here")) 116 case *ColName: 117 t.Errorf("should not visit ColName in the post") 118 } 119 }, nil) 120 121 assert.Equal(t, original, String(expr)) 122 assert.Equal(t, "1337 = 'johnny was here'", String(out)) // b + c are replaced 123 }