github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/sem/tree/overload_test.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package tree
    12  
    13  import (
    14  	"context"
    15  	"fmt"
    16  	"go/constant"
    17  	"go/token"
    18  	"strings"
    19  	"testing"
    20  
    21  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    22  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    23  )
    24  
    25  type variadicTestCase struct {
    26  	args    []*types.T
    27  	matches bool
    28  }
    29  
    30  type variadicTestData struct {
    31  	name  string
    32  	cases []variadicTestCase
    33  }
    34  
    35  func TestVariadicFunctions(t *testing.T) {
    36  	defer leaktest.AfterTest(t)()
    37  	testData := map[*VariadicType]variadicTestData{
    38  		{VarType: types.String}: {
    39  			"string...", []variadicTestCase{
    40  				{[]*types.T{types.String}, true},
    41  				{[]*types.T{types.String, types.String}, true},
    42  				{[]*types.T{types.String, types.Unknown}, true},
    43  				{[]*types.T{types.String, types.Unknown, types.String}, true},
    44  				{[]*types.T{types.Int}, false},
    45  			}},
    46  		{FixedTypes: []*types.T{types.Int}, VarType: types.String}: {
    47  			"int, string...", []variadicTestCase{
    48  				{[]*types.T{types.Int}, true},
    49  				{[]*types.T{types.Int, types.String}, true},
    50  				{[]*types.T{types.Int, types.String, types.String}, true},
    51  				{[]*types.T{types.Int, types.Unknown, types.String}, true},
    52  				{[]*types.T{types.String}, false},
    53  			}},
    54  		{FixedTypes: []*types.T{types.Int, types.Bool}, VarType: types.String}: {
    55  			"int, bool, string...", []variadicTestCase{
    56  				{[]*types.T{types.Int}, false},
    57  				{[]*types.T{types.Int, types.Bool}, true},
    58  				{[]*types.T{types.Int, types.Bool, types.String}, true},
    59  				{[]*types.T{types.Int, types.Unknown, types.String}, true},
    60  				{[]*types.T{types.Int, types.Bool, types.String, types.Bool}, false},
    61  				{[]*types.T{types.Int, types.String}, false},
    62  				{[]*types.T{types.Int, types.String, types.String}, false},
    63  				{[]*types.T{types.String}, false},
    64  			}},
    65  	}
    66  
    67  	for fn, data := range testData {
    68  		t.Run(fmt.Sprintf("%v", fn), func(t *testing.T) {
    69  			if data.name != fn.String() {
    70  				t.Fatalf("expected name %v, got %v", data.name, fn.String())
    71  			}
    72  		})
    73  
    74  		for _, v := range data.cases {
    75  			t.Run(fmt.Sprintf("%v/%v", fn, v), func(t *testing.T) {
    76  				if v.matches {
    77  					if !fn.MatchLen(len(v.args)) {
    78  						t.Fatalf("expected fn %v to matchLen %v", fn, v.args)
    79  					}
    80  
    81  					if !fn.Match(v.args) {
    82  						t.Fatalf("expected fn %v to match %v", fn, v.args)
    83  					}
    84  				} else if fn.MatchLen(len(v.args)) && fn.Match(v.args) {
    85  					t.Fatalf("expected fn %v to not match %v", fn, v.args)
    86  				}
    87  			})
    88  		}
    89  	}
    90  }
    91  
    92  type testOverload struct {
    93  	paramTypes ArgTypes
    94  	retType    *types.T
    95  	pref       bool
    96  }
    97  
    98  func (to *testOverload) params() TypeList {
    99  	return to.paramTypes
   100  }
   101  
   102  func (to *testOverload) returnType() ReturnTyper {
   103  	return FixedReturnType(to.retType)
   104  }
   105  
   106  func (to testOverload) preferred() bool {
   107  	return to.pref
   108  }
   109  
   110  func (to *testOverload) String() string {
   111  	typeNames := make([]string, len(to.paramTypes))
   112  	for i, param := range to.paramTypes {
   113  		typeNames[i] = param.Typ.String()
   114  	}
   115  	return fmt.Sprintf("func(%s) %s", strings.Join(typeNames, ","), to.retType)
   116  }
   117  
   118  func makeTestOverload(retType *types.T, params ...*types.T) overloadImpl {
   119  	t := make(ArgTypes, len(params))
   120  	for i := range params {
   121  		t[i].Typ = params[i]
   122  	}
   123  	return &testOverload{
   124  		paramTypes: t,
   125  		retType:    retType,
   126  	}
   127  }
   128  
   129  func TestTypeCheckOverloadedExprs(t *testing.T) {
   130  	defer leaktest.AfterTest(t)()
   131  	intConst := func(s string) Expr {
   132  		return NewNumVal(constant.MakeFromLiteral(s, token.INT, 0), s, false /* negative */)
   133  	}
   134  	decConst := func(s string) Expr {
   135  		return NewNumVal(constant.MakeFromLiteral(s, token.FLOAT, 0), s, false /* negative */)
   136  	}
   137  	strConst := func(s string) Expr {
   138  		return &StrVal{s: s}
   139  	}
   140  	plus := func(left, right Expr) Expr {
   141  		return &BinaryExpr{Operator: Plus, Left: left, Right: right}
   142  	}
   143  	placeholder := func(id int) *Placeholder {
   144  		return &Placeholder{Idx: PlaceholderIdx(id)}
   145  	}
   146  
   147  	unaryIntFn := makeTestOverload(types.Int, types.Int)
   148  	unaryIntFnPref := &testOverload{retType: types.Int, paramTypes: ArgTypes{}, pref: true}
   149  	unaryFloatFn := makeTestOverload(types.Float, types.Float)
   150  	unaryDecimalFn := makeTestOverload(types.Decimal, types.Decimal)
   151  	unaryStringFn := makeTestOverload(types.String, types.String)
   152  	unaryIntervalFn := makeTestOverload(types.Interval, types.Interval)
   153  	unaryTimestampFn := makeTestOverload(types.Timestamp, types.Timestamp)
   154  	binaryIntFn := makeTestOverload(types.Int, types.Int, types.Int)
   155  	binaryFloatFn := makeTestOverload(types.Float, types.Float, types.Float)
   156  	binaryDecimalFn := makeTestOverload(types.Decimal, types.Decimal, types.Decimal)
   157  	binaryStringFn := makeTestOverload(types.String, types.String, types.String)
   158  	binaryTimestampFn := makeTestOverload(types.Timestamp, types.Timestamp, types.Timestamp)
   159  	binaryStringFloatFn1 := makeTestOverload(types.Int, types.String, types.Float)
   160  	binaryStringFloatFn2 := makeTestOverload(types.Float, types.String, types.Float)
   161  	binaryIntDateFn := makeTestOverload(types.Date, types.Int, types.Date)
   162  	binaryArrayIntFn := makeTestOverload(types.Int, types.AnyArray, types.Int)
   163  
   164  	// Out-of-band values used below to distinguish error cases.
   165  	unsupported := &testOverload{}
   166  	ambiguous := &testOverload{}
   167  	shouldError := &testOverload{}
   168  
   169  	testData := []struct {
   170  		desired          *types.T
   171  		exprs            []Expr
   172  		overloads        []overloadImpl
   173  		expectedOverload overloadImpl
   174  		inBinOp          bool
   175  	}{
   176  		// Unary constants.
   177  		{nil, []Expr{intConst("1")}, []overloadImpl{unaryIntFn, unaryFloatFn}, unaryIntFn, false},
   178  		{nil, []Expr{decConst("1.0")}, []overloadImpl{unaryIntFn, unaryDecimalFn}, unaryDecimalFn, false},
   179  		{nil, []Expr{decConst("1.0")}, []overloadImpl{unaryIntFn, unaryFloatFn}, unsupported, false},
   180  		{nil, []Expr{intConst("1")}, []overloadImpl{unaryIntFn, binaryIntFn}, unaryIntFn, false},
   181  		{nil, []Expr{intConst("1")}, []overloadImpl{unaryFloatFn, unaryStringFn}, unaryFloatFn, false},
   182  		{nil, []Expr{intConst("1")}, []overloadImpl{unaryStringFn, binaryIntFn}, unsupported, false},
   183  		{nil, []Expr{strConst("PT12H2M")}, []overloadImpl{unaryIntervalFn}, unaryIntervalFn, false},
   184  		{nil, []Expr{strConst("PT12H2M")}, []overloadImpl{unaryIntervalFn, unaryStringFn}, unaryStringFn, false},
   185  		{nil, []Expr{strConst("PT12H2M")}, []overloadImpl{unaryIntervalFn, unaryTimestampFn}, unaryIntervalFn, false},
   186  		{nil, []Expr{}, []overloadImpl{unaryIntFn, unaryIntFnPref}, unaryIntFnPref, false},
   187  		{nil, []Expr{}, []overloadImpl{unaryIntFnPref, unaryIntFnPref}, ambiguous, false},
   188  		{nil, []Expr{strConst("PT12H2M")}, []overloadImpl{unaryIntervalFn, unaryIntFn}, unaryIntervalFn, false},
   189  		// Unary unresolved Placeholders.
   190  		{nil, []Expr{placeholder(0)}, []overloadImpl{unaryStringFn, unaryIntFn}, shouldError, false},
   191  		{nil, []Expr{placeholder(0)}, []overloadImpl{unaryStringFn, binaryIntFn}, unaryStringFn, false},
   192  		// Unary values (not constants).
   193  		{nil, []Expr{NewDInt(1)}, []overloadImpl{unaryIntFn, unaryFloatFn}, unaryIntFn, false},
   194  		{nil, []Expr{NewDFloat(1)}, []overloadImpl{unaryIntFn, unaryFloatFn}, unaryFloatFn, false},
   195  		{nil, []Expr{NewDInt(1)}, []overloadImpl{unaryIntFn, binaryIntFn}, unaryIntFn, false},
   196  		{nil, []Expr{NewDInt(1)}, []overloadImpl{unaryFloatFn, unaryStringFn}, unsupported, false},
   197  		{nil, []Expr{NewDString("a")}, []overloadImpl{unaryIntFn, unaryFloatFn}, unsupported, false},
   198  		{nil, []Expr{NewDString("a")}, []overloadImpl{unaryIntFn, unaryStringFn}, unaryStringFn, false},
   199  		// Binary constants.
   200  		{nil, []Expr{intConst("1"), intConst("1")}, []overloadImpl{binaryIntFn, binaryFloatFn, unaryIntFn}, binaryIntFn, false},
   201  		{nil, []Expr{intConst("1"), decConst("1.0")}, []overloadImpl{binaryIntFn, binaryDecimalFn, unaryDecimalFn}, binaryDecimalFn, false},
   202  		{nil, []Expr{strConst("2010-09-28"), strConst("2010-09-29")}, []overloadImpl{binaryTimestampFn}, binaryTimestampFn, false},
   203  		{nil, []Expr{strConst("2010-09-28"), strConst("2010-09-29")}, []overloadImpl{binaryTimestampFn, binaryStringFn}, binaryStringFn, false},
   204  		{nil, []Expr{strConst("2010-09-28"), strConst("2010-09-29")}, []overloadImpl{binaryTimestampFn, binaryIntFn}, binaryTimestampFn, false},
   205  		// Binary unresolved Placeholders.
   206  		{nil, []Expr{placeholder(0), placeholder(1)}, []overloadImpl{binaryIntFn, binaryFloatFn}, shouldError, false},
   207  		{nil, []Expr{placeholder(0), placeholder(1)}, []overloadImpl{binaryIntFn, unaryStringFn}, binaryIntFn, false},
   208  		{nil, []Expr{placeholder(0), NewDString("a")}, []overloadImpl{binaryIntFn, binaryStringFn}, binaryStringFn, false},
   209  		{nil, []Expr{placeholder(0), intConst("1")}, []overloadImpl{binaryIntFn, binaryFloatFn}, binaryIntFn, false},
   210  		{nil, []Expr{placeholder(0), intConst("1")}, []overloadImpl{binaryStringFn, binaryFloatFn}, binaryFloatFn, false},
   211  		// Binary values.
   212  		{nil, []Expr{NewDString("a"), NewDString("b")}, []overloadImpl{binaryStringFn, binaryFloatFn, unaryFloatFn}, binaryStringFn, false},
   213  		{nil, []Expr{NewDString("a"), intConst("1")}, []overloadImpl{binaryStringFn, binaryFloatFn, binaryStringFloatFn1}, binaryStringFloatFn1, false},
   214  		{nil, []Expr{NewDString("a"), NewDInt(1)}, []overloadImpl{binaryStringFn, binaryFloatFn, binaryStringFloatFn1}, unsupported, false},
   215  		{nil, []Expr{NewDString("a"), NewDFloat(1)}, []overloadImpl{binaryStringFn, binaryFloatFn, binaryStringFloatFn1}, binaryStringFloatFn1, false},
   216  		{nil, []Expr{NewDString("a"), NewDFloat(1)}, []overloadImpl{binaryStringFn, binaryFloatFn, binaryStringFloatFn2}, binaryStringFloatFn2, false},
   217  		{nil, []Expr{NewDFloat(1), NewDString("a")}, []overloadImpl{binaryStringFn, binaryFloatFn, binaryStringFloatFn1}, unsupported, false},
   218  		{nil, []Expr{NewDString("a"), NewDFloat(1)}, []overloadImpl{binaryStringFn, binaryFloatFn, binaryStringFloatFn1, binaryStringFloatFn2}, ambiguous, false},
   219  		// Desired type with ambiguity.
   220  		{types.Int, []Expr{intConst("1"), decConst("1.0")}, []overloadImpl{binaryIntFn, binaryDecimalFn, unaryDecimalFn}, binaryIntFn, false},
   221  		{types.Int, []Expr{intConst("1"), NewDFloat(1)}, []overloadImpl{binaryIntFn, binaryFloatFn, unaryFloatFn}, binaryFloatFn, false},
   222  		{types.Int, []Expr{NewDString("a"), NewDFloat(1)}, []overloadImpl{binaryStringFn, binaryFloatFn, binaryStringFloatFn1, binaryStringFloatFn2}, binaryStringFloatFn1, false},
   223  		{types.Float, []Expr{NewDString("a"), NewDFloat(1)}, []overloadImpl{binaryStringFn, binaryFloatFn, binaryStringFloatFn1, binaryStringFloatFn2}, binaryStringFloatFn2, false},
   224  		{types.Float, []Expr{placeholder(0), placeholder(1)}, []overloadImpl{binaryIntFn, binaryFloatFn}, binaryFloatFn, false},
   225  		// Sub-expressions.
   226  		{nil, []Expr{decConst("1.0"), plus(intConst("1"), intConst("2"))}, []overloadImpl{binaryIntFn, binaryDecimalFn}, binaryIntFn, false},
   227  		{nil, []Expr{decConst("1.1"), plus(intConst("1"), intConst("2"))}, []overloadImpl{binaryIntFn, binaryDecimalFn}, shouldError, false},
   228  		{nil, []Expr{NewDFloat(1.1), plus(intConst("1"), intConst("2"))}, []overloadImpl{binaryIntFn, binaryDecimalFn, binaryFloatFn}, binaryFloatFn, false},
   229  		{types.Decimal, []Expr{decConst("1.0"), plus(intConst("1"), intConst("2"))}, []overloadImpl{binaryIntFn, binaryDecimalFn}, binaryIntFn, false},              // Limitation.
   230  		{nil, []Expr{plus(intConst("1"), intConst("2")), plus(decConst("1.1"), decConst("2.2"))}, []overloadImpl{binaryIntFn, binaryDecimalFn}, shouldError, false}, // Limitation.
   231  		{nil, []Expr{plus(decConst("1.1"), decConst("2.2")), plus(intConst("1"), intConst("2"))}, []overloadImpl{binaryIntFn, binaryDecimalFn}, shouldError, false},
   232  		{nil, []Expr{plus(NewDFloat(1.1), NewDFloat(2.2)), plus(intConst("1"), intConst("2"))}, []overloadImpl{binaryIntFn, binaryFloatFn}, binaryFloatFn, false},
   233  		// Homogenous preference.
   234  		{nil, []Expr{NewDInt(1), placeholder(1)}, []overloadImpl{binaryIntFn, binaryIntDateFn}, binaryIntFn, false},
   235  		{nil, []Expr{NewDFloat(1), placeholder(1)}, []overloadImpl{binaryIntFn, binaryIntDateFn}, unsupported, false},
   236  		{nil, []Expr{intConst("1"), placeholder(1)}, []overloadImpl{binaryIntFn, binaryIntDateFn}, binaryIntFn, false},
   237  		{nil, []Expr{decConst("1.0"), placeholder(1)}, []overloadImpl{binaryIntFn, binaryIntDateFn}, unsupported, false}, // Limitation.
   238  		{types.Date, []Expr{NewDInt(1), placeholder(1)}, []overloadImpl{binaryIntFn, binaryIntDateFn}, binaryIntDateFn, false},
   239  		{types.Date, []Expr{NewDFloat(1), placeholder(1)}, []overloadImpl{binaryIntFn, binaryIntDateFn}, unsupported, false},
   240  		{types.Date, []Expr{intConst("1"), placeholder(1)}, []overloadImpl{binaryIntFn, binaryIntDateFn}, binaryIntDateFn, false},
   241  		{types.Date, []Expr{decConst("1.0"), placeholder(1)}, []overloadImpl{binaryIntFn, binaryIntDateFn}, binaryIntDateFn, false},
   242  		// BinOps
   243  		{nil, []Expr{NewDInt(1), DNull}, []overloadImpl{binaryIntFn, binaryIntDateFn}, ambiguous, false},
   244  		{nil, []Expr{NewDInt(1), DNull}, []overloadImpl{binaryIntFn, binaryIntDateFn}, binaryIntFn, true},
   245  		// Verify that we don't return uninitialized typedExprs for a function like
   246  		// array_length where the array argument is a placeholder (#36153).
   247  		{nil, []Expr{placeholder(0), intConst("1")}, []overloadImpl{binaryArrayIntFn}, unsupported, false},
   248  		{nil, []Expr{placeholder(0), intConst("1")}, []overloadImpl{binaryArrayIntFn}, unsupported, true},
   249  	}
   250  	ctx := context.Background()
   251  	for i, d := range testData {
   252  		t.Run(fmt.Sprintf("%v/%v", d.exprs, d.overloads), func(t *testing.T) {
   253  			semaCtx := MakeSemaContext()
   254  			if err := semaCtx.Placeholders.Init(2 /* numPlaceholders */, nil /* typeHints */); err != nil {
   255  				t.Fatal(err)
   256  			}
   257  			desired := types.Any
   258  			if d.desired != nil {
   259  				desired = d.desired
   260  			}
   261  			typedExprs, fns, err := typeCheckOverloadedExprs(
   262  				ctx, &semaCtx, desired, d.overloads, d.inBinOp, d.exprs...,
   263  			)
   264  			assertNoErr := func() {
   265  				if err != nil {
   266  					t.Fatalf("%d: unexpected error returned from overload resolution for exprs %s: %v",
   267  						i, d.exprs, err)
   268  				}
   269  			}
   270  			for _, e := range typedExprs {
   271  				if e == nil {
   272  					t.Errorf("%d: returned uninitialized TypedExpr", i)
   273  				}
   274  			}
   275  			switch d.expectedOverload {
   276  			case shouldError:
   277  				if err == nil {
   278  					t.Errorf("%d: expecting error to be returned from overload resolution for exprs %s",
   279  						i, d.exprs)
   280  				}
   281  			case unsupported:
   282  				assertNoErr()
   283  				if len(fns) > 0 {
   284  					t.Errorf("%d: expected unsupported overload resolution for exprs %s, found %v",
   285  						i, d.exprs, fns)
   286  				}
   287  			case ambiguous:
   288  				assertNoErr()
   289  				if len(fns) < 2 {
   290  					t.Errorf("%d: expected ambiguous overload resolution for exprs %s, found %v",
   291  						i, d.exprs, fns)
   292  				}
   293  			default:
   294  				assertNoErr()
   295  				if len(fns) != 1 || fns[0] != d.expectedOverload {
   296  					t.Errorf("%d: expected overload %s to be chosen when type checking %s, found %v",
   297  						i, d.expectedOverload, d.exprs, fns)
   298  				}
   299  			}
   300  		})
   301  	}
   302  }