github.com/dolthub/go-mysql-server@v0.18.0/optgen/cmd/support/memo_gen_test.go (about)

     1  package support
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"strings"
     7  	"testing"
     8  )
     9  
    10  func TestMemoGen(t *testing.T) {
    11  	test := struct {
    12  		expected string
    13  	}{
    14  		expected: `
    15          import (
    16            "fmt"
    17            "strings"
    18            "github.com/dolthub/go-mysql-server/sql"
    19            "github.com/dolthub/go-mysql-server/sql/plan"
    20          )
    21          
    22          type HashJoin struct {
    23            *JoinBase
    24            InnerAttrs []sql.Expression
    25            OuterAttrs []sql.Expression
    26          }
    27          
    28          var _ RelExpr = (*hashJoin)(nil)
    29          var _ JoinRel = (*hashJoin)(nil)
    30          
    31          func (r *hashJoin) String() string {
    32            return FormatExpr(r)
    33          }
    34          
    35          func (r *hashJoin) JoinPrivate() *JoinBase {
    36            return r.JoinBase
    37          }
    38          
    39          type TableScan struct {
    40            *sourceBase
    41            Table *plan.TableNode
    42          }
    43          
    44          var _ RelExpr = (*tableScan)(nil)
    45          var _ SourceRel = (*tableScan)(nil)
    46          
    47          func (r *tableScan) String() string {
    48            return FormatExpr(r)
    49          }
    50          
    51          func (r *tableScan) Name() string {
    52            return strings.ToLower(r.Table.Name())
    53          }
    54          
    55          func (r *tableScan) TableId() sql.TableId {
    56            return TableIdForSource(r.g.Id)
    57          }
    58          
    59          func (r *tableScan) TableIdNode() plan.TableIdNode {
    60            return r.Table
    61          }
    62          
    63          func (r *tableScan) OutputCols() sql.Schema {
    64            return r.Table.Schema()
    65          }
    66          
    67          func (r *tableScan) Children() []*ExprGroup {
    68            return nil
    69          }
    70          
    71          func FormatExpr(r exprType) string {
    72            switch r := r.(type) {
    73            case *hashJoin:
    74              return fmt.Sprintf("hashjoin %d %d", r.Left.Id, r.Right.Id)
    75            case *tableScan:
    76              return fmt.Sprintf("tablescan: %s", r.Name())
    77            default:
    78              panic(fmt.Sprintf("unknown RelExpr type: %T", r))
    79            }
    80          }
    81          
    82          func buildRelExpr(b *ExecBuilder, r RelExpr, children ...sql.Node) (sql.Node, error) {
    83            var result sql.Node
    84            var err error
    85          
    86            switch r := r.(type) {
    87            case *hashJoin:
    88            result, err = b.buildHashJoin(r, children...)
    89            case *tableScan:
    90            result, err = b.buildTableScan(r, children...)
    91            default:
    92              panic(fmt.Sprintf("unknown RelExpr type: %T", r))
    93            }
    94          
    95            if err != nil {
    96              return nil, err
    97            }
    98          
    99            if withDescribeStats, ok := result.(sql.WithDescribeStats); ok {
   100              withDescribeStats.SetDescribeStats(*DescribeStats(r))
   101            }
   102            result, err = r.Group().finalize(result)
   103            if err != nil {
   104              return nil, err
   105            }
   106            return result, nil
   107          }
   108  `,
   109  	}
   110  
   111  	defs := MemoExprs{
   112  		Exprs: []ExprDef{
   113  			{
   114  				Name: "hashJoin",
   115  				Join: true,
   116  				Attrs: [][2]string{
   117  					{"innerAttrs", "[]sql.Expression"},
   118  					{"outerAttrs", "[]sql.Expression"},
   119  				},
   120  			},
   121  			{
   122  				Name:       "tableScan",
   123  				SourceType: "*plan.TableNode",
   124  			},
   125  		},
   126  	}
   127  	gen := MemoGen{}
   128  	var buf bytes.Buffer
   129  	gen.Generate(defs, &buf)
   130  
   131  	if testing.Verbose() {
   132  		fmt.Printf("\n=>\n\n%s\n", buf.String())
   133  	}
   134  
   135  	if !strings.Contains(removeWhitespace(buf.String()), removeWhitespace(test.expected)) {
   136  		t.Fatalf("\nexpected:\n%s\nactual:\n%s", test.expected, buf.String())
   137  	}
   138  }