github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/sem/tree/typing.go (about)

     1  // Copyright 2022 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package tree
    12  
    13  import (
    14  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree/treebin"
    15  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/types"
    16  	"github.com/cockroachdb/cockroachdb-parser/pkg/util/iterutil"
    17  )
    18  
    19  // InferBinaryType infers the return type of a binary expression, given the type
    20  // of its inputs.
    21  func InferBinaryType(bin treebin.BinaryOperatorSymbol, leftType, rightType *types.T) *types.T {
    22  	o, ok := FindBinaryOverload(bin, leftType, rightType)
    23  	if !ok {
    24  		return nil
    25  	}
    26  	return o.ReturnType
    27  }
    28  
    29  // FindBinaryOverload finds the correct type signature overload for the
    30  // specified binary operator, given the types of its inputs. If an overload is
    31  // found, FindBinaryOverload returns true, plus a pointer to the overload.
    32  // If an overload is not found, FindBinaryOverload returns false.
    33  func FindBinaryOverload(
    34  	bin treebin.BinaryOperatorSymbol, leftType, rightType *types.T,
    35  ) (ret *BinOp, ok bool) {
    36  
    37  	// Find the binary op that matches the type of the expression's left and
    38  	// right children. No more than one match should ever be found. The
    39  	// TestTypingBinaryAssumptions test ensures this will be the case even if
    40  	// new operators or overloads are added.
    41  	_ = BinOps[bin].ForEachBinOp(func(o *BinOp) error {
    42  		if leftType.Family() == types.UnknownFamily {
    43  			ok = rightType.Equivalent(o.RightType)
    44  		} else if rightType.Family() == types.UnknownFamily {
    45  			ok = leftType.Equivalent(o.LeftType)
    46  		} else {
    47  			ok = leftType.Equivalent(o.LeftType) && rightType.Equivalent(o.RightType)
    48  		}
    49  		if !ok {
    50  			return nil
    51  		}
    52  		ret = o
    53  		return iterutil.StopIteration()
    54  	})
    55  	return ret, ok
    56  }