vitess.io/vitess@v0.16.2/go/vt/sqlparser/ast_copy_on_rewrite_test.go (about)

     1  /*
     2  Copyright 2023 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package sqlparser
    18  
    19  import (
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  )
    25  
    26  func TestCopyOnRewrite(t *testing.T) {
    27  	// rewrite an expression without changing the original
    28  	expr, err := ParseExpr("a = b")
    29  	require.NoError(t, err)
    30  	out := CopyOnRewrite(expr, nil, func(cursor *CopyOnWriteCursor) {
    31  		col, ok := cursor.Node().(*ColName)
    32  		if !ok {
    33  			return
    34  		}
    35  		if col.Name.EqualString("a") {
    36  			cursor.Replace(NewIntLiteral("1"))
    37  		}
    38  	}, nil)
    39  
    40  	assert.Equal(t, "a = b", String(expr))
    41  	assert.Equal(t, "1 = b", String(out))
    42  }
    43  
    44  func TestCopyOnRewriteDeeper(t *testing.T) {
    45  	// rewrite an expression without changing the original. the changed happens deep in the syntax tree,
    46  	// here we are testing that all ancestors up to the root are cloned correctly
    47  	expr, err := ParseExpr("a + b * c = 12")
    48  	require.NoError(t, err)
    49  	var path []string
    50  	out := CopyOnRewrite(expr, nil, func(cursor *CopyOnWriteCursor) {
    51  		col, ok := cursor.Node().(*ColName)
    52  		if !ok {
    53  			return
    54  		}
    55  		if col.Name.EqualString("c") {
    56  			cursor.Replace(NewIntLiteral("1"))
    57  		}
    58  	}, func(before, _ SQLNode) {
    59  		path = append(path, String(before))
    60  	})
    61  
    62  	assert.Equal(t, "a + b * c = 12", String(expr))
    63  	assert.Equal(t, "a + b * 1 = 12", String(out))
    64  
    65  	expected := []string{ // this are all the nodes that we need to clone when changing the `c` node
    66  		"c",
    67  		"b * c",
    68  		"a + b * c",
    69  		"a + b * c = 12",
    70  	}
    71  	assert.Equal(t, expected, path)
    72  }
    73  
    74  func TestDontCopyWithoutRewrite(t *testing.T) {
    75  	// when no rewriting happens, we want the original back
    76  	expr, err := ParseExpr("a = b")
    77  	require.NoError(t, err)
    78  	out := CopyOnRewrite(expr, nil, func(cursor *CopyOnWriteCursor) {}, nil)
    79  
    80  	assert.Same(t, expr, out)
    81  }
    82  
    83  func TestStopTreeWalk(t *testing.T) {
    84  	// stop walking down part of the AST
    85  	original := "a = b + c"
    86  	expr, err := ParseExpr(original)
    87  	require.NoError(t, err)
    88  	out := CopyOnRewrite(expr, func(node, parent SQLNode) bool {
    89  		_, ok := node.(*BinaryExpr)
    90  		return !ok
    91  	}, func(cursor *CopyOnWriteCursor) {
    92  		col, ok := cursor.Node().(*ColName)
    93  		if !ok {
    94  			return
    95  		}
    96  
    97  		cursor.Replace(NewStrLiteral(col.Name.String()))
    98  	}, nil)
    99  
   100  	assert.Equal(t, original, String(expr))
   101  	assert.Equal(t, "'a' = b + c", String(out)) // b + c are unchanged since they are under the + (*BinaryExpr)
   102  }
   103  
   104  func TestStopTreeWalkButStillVisit(t *testing.T) {
   105  	// here we are asserting that even when we stop at the binary expression, we still visit it in the post visitor
   106  	original := "1337 = b + c"
   107  	expr, err := ParseExpr(original)
   108  	require.NoError(t, err)
   109  	out := CopyOnRewrite(expr, func(node, parent SQLNode) bool {
   110  		_, ok := node.(*BinaryExpr)
   111  		return !ok
   112  	}, func(cursor *CopyOnWriteCursor) {
   113  		switch cursor.Node().(type) {
   114  		case *BinaryExpr:
   115  			cursor.Replace(NewStrLiteral("johnny was here"))
   116  		case *ColName:
   117  			t.Errorf("should not visit ColName in the post")
   118  		}
   119  	}, nil)
   120  
   121  	assert.Equal(t, original, String(expr))
   122  	assert.Equal(t, "1337 = 'johnny was here'", String(out)) // b + c are replaced
   123  }