github.com/dolthub/go-mysql-server@v0.18.0/optgen/cmd/support/memo_gen_test.go (about) 1 package support 2 3 import ( 4 "bytes" 5 "fmt" 6 "strings" 7 "testing" 8 ) 9 10 func TestMemoGen(t *testing.T) { 11 test := struct { 12 expected string 13 }{ 14 expected: ` 15 import ( 16 "fmt" 17 "strings" 18 "github.com/dolthub/go-mysql-server/sql" 19 "github.com/dolthub/go-mysql-server/sql/plan" 20 ) 21 22 type HashJoin struct { 23 *JoinBase 24 InnerAttrs []sql.Expression 25 OuterAttrs []sql.Expression 26 } 27 28 var _ RelExpr = (*hashJoin)(nil) 29 var _ JoinRel = (*hashJoin)(nil) 30 31 func (r *hashJoin) String() string { 32 return FormatExpr(r) 33 } 34 35 func (r *hashJoin) JoinPrivate() *JoinBase { 36 return r.JoinBase 37 } 38 39 type TableScan struct { 40 *sourceBase 41 Table *plan.TableNode 42 } 43 44 var _ RelExpr = (*tableScan)(nil) 45 var _ SourceRel = (*tableScan)(nil) 46 47 func (r *tableScan) String() string { 48 return FormatExpr(r) 49 } 50 51 func (r *tableScan) Name() string { 52 return strings.ToLower(r.Table.Name()) 53 } 54 55 func (r *tableScan) TableId() sql.TableId { 56 return TableIdForSource(r.g.Id) 57 } 58 59 func (r *tableScan) TableIdNode() plan.TableIdNode { 60 return r.Table 61 } 62 63 func (r *tableScan) OutputCols() sql.Schema { 64 return r.Table.Schema() 65 } 66 67 func (r *tableScan) Children() []*ExprGroup { 68 return nil 69 } 70 71 func FormatExpr(r exprType) string { 72 switch r := r.(type) { 73 case *hashJoin: 74 return fmt.Sprintf("hashjoin %d %d", r.Left.Id, r.Right.Id) 75 case *tableScan: 76 return fmt.Sprintf("tablescan: %s", r.Name()) 77 default: 78 panic(fmt.Sprintf("unknown RelExpr type: %T", r)) 79 } 80 } 81 82 func buildRelExpr(b *ExecBuilder, r RelExpr, children ...sql.Node) (sql.Node, error) { 83 var result sql.Node 84 var err error 85 86 switch r := r.(type) { 87 case *hashJoin: 88 result, err = b.buildHashJoin(r, children...) 89 case *tableScan: 90 result, err = b.buildTableScan(r, children...) 91 default: 92 panic(fmt.Sprintf("unknown RelExpr type: %T", r)) 93 } 94 95 if err != nil { 96 return nil, err 97 } 98 99 if withDescribeStats, ok := result.(sql.WithDescribeStats); ok { 100 withDescribeStats.SetDescribeStats(*DescribeStats(r)) 101 } 102 result, err = r.Group().finalize(result) 103 if err != nil { 104 return nil, err 105 } 106 return result, nil 107 } 108 `, 109 } 110 111 defs := MemoExprs{ 112 Exprs: []ExprDef{ 113 { 114 Name: "hashJoin", 115 Join: true, 116 Attrs: [][2]string{ 117 {"innerAttrs", "[]sql.Expression"}, 118 {"outerAttrs", "[]sql.Expression"}, 119 }, 120 }, 121 { 122 Name: "tableScan", 123 SourceType: "*plan.TableNode", 124 }, 125 }, 126 } 127 gen := MemoGen{} 128 var buf bytes.Buffer 129 gen.Generate(defs, &buf) 130 131 if testing.Verbose() { 132 fmt.Printf("\n=>\n\n%s\n", buf.String()) 133 } 134 135 if !strings.Contains(removeWhitespace(buf.String()), removeWhitespace(test.expected)) { 136 t.Fatalf("\nexpected:\n%s\nactual:\n%s", test.expected, buf.String()) 137 } 138 }