github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/walk_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "testing" 19 20 "github.com/stretchr/testify/require" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/types" 24 ) 25 26 func TestWalk(t *testing.T) { 27 lit1 := NewLiteral(1, types.Int64) 28 lit2 := NewLiteral(2, types.Int64) 29 col := NewUnresolvedColumn("foo") 30 fn := NewUnresolvedFunction( 31 "bar", 32 false, 33 nil, 34 lit1, 35 lit2, 36 ) 37 and := NewAnd(col, fn) 38 e := NewNot(and) 39 40 var f visitor 41 var visited []sql.Expression 42 f = func(node sql.Expression) sql.Visitor { 43 visited = append(visited, node) 44 return f 45 } 46 47 sql.Walk(f, e) 48 49 require.Equal(t, 50 []sql.Expression{e, and, col, fn, lit1, lit2}, 51 visited, 52 ) 53 54 visited = nil 55 f = func(node sql.Expression) sql.Visitor { 56 visited = append(visited, node) 57 if _, ok := node.(*UnresolvedFunction); ok { 58 return nil 59 } 60 return f 61 } 62 63 sql.Walk(f, e) 64 65 require.Equal(t, 66 []sql.Expression{e, and, col, fn}, 67 visited, 68 ) 69 } 70 71 type visitor func(sql.Expression) sql.Visitor 72 73 func (f visitor) Visit(n sql.Expression) sql.Visitor { 74 return f(n) 75 } 76 77 func TestInspect(t *testing.T) { 78 lit1 := NewLiteral(1, types.Int64) 79 lit2 := NewLiteral(2, types.Int64) 80 col := NewUnresolvedColumn("foo") 81 fn := NewUnresolvedFunction( 82 "bar", 83 false, 84 nil, 85 lit1, 86 lit2, 87 ) 88 and := NewAnd(col, fn) 89 e := NewNot(and) 90 91 var f func(sql.Expression) bool 92 var visited []sql.Expression 93 f = func(node sql.Expression) bool { 94 visited = append(visited, node) 95 return true 96 } 97 98 sql.Inspect(e, f) 99 100 require.Equal(t, 101 []sql.Expression{e, and, col, fn, lit1, lit2}, 102 visited, 103 ) 104 105 visited = nil 106 f = func(node sql.Expression) bool { 107 visited = append(visited, node) 108 if _, ok := node.(*UnresolvedFunction); ok { 109 return false 110 } 111 return true 112 } 113 114 sql.Inspect(e, f) 115 116 require.Equal(t, 117 []sql.Expression{e, and, col, fn}, 118 visited, 119 ) 120 }