github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/walk_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package expression
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/types"
    24  )
    25  
    26  func TestWalk(t *testing.T) {
    27  	lit1 := NewLiteral(1, types.Int64)
    28  	lit2 := NewLiteral(2, types.Int64)
    29  	col := NewUnresolvedColumn("foo")
    30  	fn := NewUnresolvedFunction(
    31  		"bar",
    32  		false,
    33  		nil,
    34  		lit1,
    35  		lit2,
    36  	)
    37  	and := NewAnd(col, fn)
    38  	e := NewNot(and)
    39  
    40  	var f visitor
    41  	var visited []sql.Expression
    42  	f = func(node sql.Expression) sql.Visitor {
    43  		visited = append(visited, node)
    44  		return f
    45  	}
    46  
    47  	sql.Walk(f, e)
    48  
    49  	require.Equal(t,
    50  		[]sql.Expression{e, and, col, fn, lit1, lit2},
    51  		visited,
    52  	)
    53  
    54  	visited = nil
    55  	f = func(node sql.Expression) sql.Visitor {
    56  		visited = append(visited, node)
    57  		if _, ok := node.(*UnresolvedFunction); ok {
    58  			return nil
    59  		}
    60  		return f
    61  	}
    62  
    63  	sql.Walk(f, e)
    64  
    65  	require.Equal(t,
    66  		[]sql.Expression{e, and, col, fn},
    67  		visited,
    68  	)
    69  }
    70  
    71  type visitor func(sql.Expression) sql.Visitor
    72  
    73  func (f visitor) Visit(n sql.Expression) sql.Visitor {
    74  	return f(n)
    75  }
    76  
    77  func TestInspect(t *testing.T) {
    78  	lit1 := NewLiteral(1, types.Int64)
    79  	lit2 := NewLiteral(2, types.Int64)
    80  	col := NewUnresolvedColumn("foo")
    81  	fn := NewUnresolvedFunction(
    82  		"bar",
    83  		false,
    84  		nil,
    85  		lit1,
    86  		lit2,
    87  	)
    88  	and := NewAnd(col, fn)
    89  	e := NewNot(and)
    90  
    91  	var f func(sql.Expression) bool
    92  	var visited []sql.Expression
    93  	f = func(node sql.Expression) bool {
    94  		visited = append(visited, node)
    95  		return true
    96  	}
    97  
    98  	sql.Inspect(e, f)
    99  
   100  	require.Equal(t,
   101  		[]sql.Expression{e, and, col, fn, lit1, lit2},
   102  		visited,
   103  	)
   104  
   105  	visited = nil
   106  	f = func(node sql.Expression) bool {
   107  		visited = append(visited, node)
   108  		if _, ok := node.(*UnresolvedFunction); ok {
   109  			return false
   110  		}
   111  		return true
   112  	}
   113  
   114  	sql.Inspect(e, f)
   115  
   116  	require.Equal(t,
   117  		[]sql.Expression{e, and, col, fn},
   118  		visited,
   119  	)
   120  }