github.com/dolthub/go-mysql-server@v0.18.0/optgen/cmd/support/agg_gen_test.go (about) 1 package support 2 3 import ( 4 "bytes" 5 "fmt" 6 "strings" 7 "testing" 8 ) 9 10 func TestAggGen(t *testing.T) { 11 test := struct { 12 defines AggDefs 13 expected string 14 }{ 15 defines: AggDefs{ 16 []AggDef{ 17 { 18 Name: "Test", 19 Desc: "Test description", 20 RetType: "sql.Float64", 21 }, 22 }, 23 }, 24 expected: ` 25 import ( 26 "fmt" 27 "github.com/dolthub/go-mysql-server/sql/types" 28 "github.com/dolthub/go-mysql-server/sql" 29 "github.com/dolthub/go-mysql-server/sql/expression" 30 "github.com/dolthub/go-mysql-server/sql/transform" 31 ) 32 33 type Test struct{ 34 unaryAggBase 35 } 36 37 var _ sql.FunctionExpression = (*Test)(nil) 38 var _ sql.Aggregation = (*Test)(nil) 39 var _ sql.WindowAdaptableExpression = (*Test)(nil) 40 41 func NewTest(e sql.Expression) *Test { 42 return &Test{ 43 unaryAggBase{ 44 UnaryExpression: expression.UnaryExpression{Child: e}, 45 functionName: "Test", 46 description: "Test description", 47 }, 48 } 49 } 50 51 func (a *Test) Type() sql.Type { 52 return sql.Float64 53 } 54 55 func (a *Test) IsNullable() bool { 56 return false 57 } 58 59 func (a *Test) String() string { 60 if a.window != nil { 61 pr := sql.NewTreePrinter() 62 _ = pr.WriteNode("TEST") 63 children := []string{a.window.String(), a.Child.String()} 64 pr.WriteChildren(children...) 65 return pr.String() 66 } 67 return fmt.Sprintf("TEST(%s)", a.Child) 68 } 69 70 func (a *Test) DebugString() string { 71 if a.window != nil { 72 pr := sql.NewTreePrinter() 73 _ = pr.WriteNode("TEST") 74 children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)} 75 pr.WriteChildren(children...) 76 return pr.String() 77 } 78 return fmt.Sprintf("TEST(%s)", sql.DebugString(a.Child)) 79 } 80 81 func (a *Test) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { 82 res := a.unaryAggBase.WithWindow(window) 83 return &Test{unaryAggBase: *res.(*unaryAggBase)} 84 } 85 86 func (a *Test) WithChildren(children ...sql.Expression) (sql.Expression, error) { 87 res, err := a.unaryAggBase.WithChildren(children...) 88 return &Test{unaryAggBase: *res.(*unaryAggBase)}, err 89 } 90 91 func (a *Test) WithId(id sql.ColumnId) sql.IdExpression { 92 res := a.unaryAggBase.WithId(id) 93 return &Test{unaryAggBase: *res.(*unaryAggBase)} 94 } 95 96 func (a *Test) NewBuffer() (sql.AggregationBuffer, error) { 97 child, err := transform.Clone(a.Child) 98 if err != nil { 99 return nil, err 100 } 101 return NewTestBuffer(child), nil 102 } 103 104 func (a *Test) NewWindowFunction() (sql.WindowFunction, error) { 105 child, err := transform.Clone(a.Child) 106 if err != nil { 107 return nil, err 108 } 109 return NewTestAgg(child).WithWindow(a.Window()) 110 } 111 `, 112 } 113 114 var gen AggGen 115 var buf bytes.Buffer 116 gen.Generate(test.defines, &buf) 117 118 if testing.Verbose() { 119 fmt.Printf("%+v\n=>\n\n%s\n", test.defines, buf.String()) 120 } 121 122 if !strings.Contains(removeWhitespace(buf.String()), removeWhitespace(test.expected)) { 123 t.Fatalf("\nexpected:\n%s\nactual:\n%s", test.expected, buf.String()) 124 } 125 } 126 127 func removeWhitespace(s string) string { 128 return strings.Trim(strings.Replace(strings.Replace(s, " ", "", -1), "\t", "", -1), " \t\r\n") 129 }