github.com/dolthub/go-mysql-server@v0.18.0/optgen/cmd/support/agg_gen_test.go (about)

     1  package support
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"strings"
     7  	"testing"
     8  )
     9  
    10  func TestAggGen(t *testing.T) {
    11  	test := struct {
    12  		defines  AggDefs
    13  		expected string
    14  	}{
    15  		defines: AggDefs{
    16  			[]AggDef{
    17  				{
    18  					Name:    "Test",
    19  					Desc:    "Test description",
    20  					RetType: "sql.Float64",
    21  				},
    22  			},
    23  		},
    24  		expected: `
    25           import (
    26              "fmt"
    27              "github.com/dolthub/go-mysql-server/sql/types"
    28              "github.com/dolthub/go-mysql-server/sql"
    29              "github.com/dolthub/go-mysql-server/sql/expression"
    30              "github.com/dolthub/go-mysql-server/sql/transform"
    31          )
    32  
    33          type Test struct{
    34              unaryAggBase
    35          }
    36  
    37          var _ sql.FunctionExpression = (*Test)(nil)
    38          var _ sql.Aggregation = (*Test)(nil)
    39          var _ sql.WindowAdaptableExpression = (*Test)(nil)
    40  
    41          func NewTest(e sql.Expression) *Test {
    42              return &Test{
    43                  unaryAggBase{
    44                      UnaryExpression: expression.UnaryExpression{Child: e},
    45                      functionName: "Test",
    46                      description: "Test description",
    47                  },
    48              }
    49          }
    50  
    51          func (a *Test) Type() sql.Type {
    52              return sql.Float64
    53          }
    54  
    55          func (a *Test) IsNullable() bool {
    56              return false
    57          }
    58  
    59          func (a *Test) String() string {
    60            if a.window != nil {
    61              pr := sql.NewTreePrinter()
    62              _ = pr.WriteNode("TEST")
    63          	    children := []string{a.window.String(), a.Child.String()}
    64              pr.WriteChildren(children...)
    65              return pr.String()
    66            }
    67            return fmt.Sprintf("TEST(%s)", a.Child)
    68          }
    69  
    70          func (a *Test) DebugString() string {
    71            if a.window != nil {
    72              pr := sql.NewTreePrinter()
    73              _ = pr.WriteNode("TEST")
    74          	    children := []string{sql.DebugString(a.window), sql.DebugString(a.Child)}
    75              pr.WriteChildren(children...)
    76              return pr.String()
    77            }
    78            return fmt.Sprintf("TEST(%s)", sql.DebugString(a.Child))
    79          }
    80  
    81          func (a *Test) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression {
    82              res := a.unaryAggBase.WithWindow(window)
    83              return &Test{unaryAggBase: *res.(*unaryAggBase)}
    84          }
    85  
    86          func (a *Test) WithChildren(children ...sql.Expression) (sql.Expression, error) {
    87              res, err := a.unaryAggBase.WithChildren(children...)
    88              return &Test{unaryAggBase: *res.(*unaryAggBase)}, err
    89          }
    90  
    91          func (a *Test) WithId(id sql.ColumnId) sql.IdExpression {
    92              res := a.unaryAggBase.WithId(id)
    93              return &Test{unaryAggBase: *res.(*unaryAggBase)}
    94          }
    95  
    96          func (a *Test) NewBuffer() (sql.AggregationBuffer, error) {
    97              child, err := transform.Clone(a.Child)
    98              if err != nil {
    99                  return nil, err
   100              }
   101              return NewTestBuffer(child), nil
   102          }
   103  
   104          func (a *Test) NewWindowFunction() (sql.WindowFunction, error) {
   105              child, err := transform.Clone(a.Child)
   106              if err != nil {
   107                  return nil, err
   108              }
   109              return NewTestAgg(child).WithWindow(a.Window())
   110          }
   111  		`,
   112  	}
   113  
   114  	var gen AggGen
   115  	var buf bytes.Buffer
   116  	gen.Generate(test.defines, &buf)
   117  
   118  	if testing.Verbose() {
   119  		fmt.Printf("%+v\n=>\n\n%s\n", test.defines, buf.String())
   120  	}
   121  
   122  	if !strings.Contains(removeWhitespace(buf.String()), removeWhitespace(test.expected)) {
   123  		t.Fatalf("\nexpected:\n%s\nactual:\n%s", test.expected, buf.String())
   124  	}
   125  }
   126  
   127  func removeWhitespace(s string) string {
   128  	return strings.Trim(strings.Replace(strings.Replace(s, " ", "", -1), "\t", "", -1), " \t\r\n")
   129  }