github.com/dolthub/go-mysql-server@v0.18.0/sql/transform/walk_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transform
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/stretchr/testify/require"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  )
    24  
    25  func TestWalk(t *testing.T) {
    26  	a1 := a()
    27  	b1 := b()
    28  	c1 := c(a1, b1)
    29  	a2 := a(c1)
    30  	a3 := a(a2)
    31  
    32  	var f visitor
    33  	var visited []sql.Node
    34  	f = func(node sql.Node) Visitor {
    35  		visited = append(visited, node)
    36  		return f
    37  	}
    38  
    39  	Walk(f, a3)
    40  
    41  	require.Equal(t,
    42  		[]sql.Node{a3, a2, c1, a1, nil, b1, nil, nil, nil, nil},
    43  		visited,
    44  	)
    45  
    46  	visited = nil
    47  	f = func(node sql.Node) Visitor {
    48  		visited = append(visited, node)
    49  		if _, ok := node.(*nodeC); ok {
    50  			return nil
    51  		}
    52  		return f
    53  	}
    54  
    55  	Walk(f, a3)
    56  
    57  	require.Equal(t,
    58  		[]sql.Node{a3, a2, c1, nil, nil},
    59  		visited,
    60  	)
    61  }
    62  
    63  type visitor func(sql.Node) Visitor
    64  
    65  func (f visitor) Visit(n sql.Node) Visitor {
    66  	return f(n)
    67  }
    68  
    69  func TestInspect(t *testing.T) {
    70  	a1 := a()
    71  	b1 := b()
    72  	c1 := c(a1, b1)
    73  	a2 := a(c1)
    74  	a3 := a(a2)
    75  
    76  	var f func(sql.Node) bool
    77  	var visited []sql.Node
    78  	f = func(node sql.Node) bool {
    79  		visited = append(visited, node)
    80  		return true
    81  	}
    82  
    83  	Inspect(a3, f)
    84  
    85  	require.Equal(t,
    86  		[]sql.Node{a3, a2, c1, a1, nil, b1, nil, nil, nil, nil},
    87  		visited,
    88  	)
    89  
    90  	visited = nil
    91  	f = func(node sql.Node) bool {
    92  		visited = append(visited, node)
    93  		if _, ok := node.(*nodeC); ok {
    94  			return false
    95  		}
    96  		return true
    97  	}
    98  
    99  	Inspect(a3, f)
   100  
   101  	require.Equal(t,
   102  		[]sql.Node{a3, a2, c1, nil, nil},
   103  		visited,
   104  	)
   105  }