gorgonia.org/gorgonia@v0.9.17/operatorPointwise_binary_const.go (about)

     1  package gorgonia
     2  
     3  import "gorgonia.org/tensor"
     4  
     5  var (
     6  	/* scalar-tensor float64 and vice versa */
     7  
     8  	// arith
     9  	tadd = denseBinOp(tensor.Add)
    10  	tsub = denseBinOp(tensor.Sub)
    11  	tmul = denseBinOp(tensor.Mul)
    12  	tdiv = denseBinOp(tensor.Div)
    13  	tpow = denseBinOp(tensor.Pow)
    14  
    15  	// cmp
    16  	tlt  = denseCmpOp(tensor.Lt)
    17  	tgt  = denseCmpOp(tensor.Gt)
    18  	tlte = denseCmpOp(tensor.Lte)
    19  	tgte = denseCmpOp(tensor.Gte)
    20  	teq  = denseCmpOp(tensor.ElEq)
    21  	tne  = denseCmpOp(tensor.ElNe)
    22  )
    23  
    24  type denseBinOp func(a, b interface{}, opts ...tensor.FuncOpt) (tensor.Tensor, error)
    25  type denseCmpOp func(a, b interface{}, opts ...tensor.FuncOpt) (tensor.Tensor, error)
    26  
    27  type ʘBinaryOperatorType byte
    28  
    29  const (
    30  	// arith
    31  	addOpType ʘBinaryOperatorType = iota
    32  	subOpType
    33  	mulOpType
    34  	divOpType
    35  	powOpType
    36  
    37  	// cmp
    38  	ltOpType
    39  	gtOpType
    40  	lteOpType
    41  	gteOpType
    42  	eqOpType
    43  	neOpType
    44  
    45  	maxʘBinaryOpType // delimits the end of all possible binOpType
    46  )
    47  
    48  func (op ʘBinaryOperatorType) String() string {
    49  	return ʘBinOpStrs[op]
    50  }
    51  
    52  // ʘBinOpStrs is the string representation for a binOpType
    53  // It should be held constant.
    54  var ʘBinOpStrs = [maxʘBinaryOpType]string{
    55  	// arith ops
    56  	"+",
    57  	"-",
    58  	"⊙",
    59  	"÷",
    60  	"^",
    61  
    62  	// cmp ops
    63  	"<",
    64  	">",
    65  	"<=",
    66  	">=",
    67  	"==",
    68  	"!=",
    69  }
    70  
    71  // ʘBinOpNames is the string representation for a binOpType
    72  // It should be held constant.
    73  var ʘBinOpNames = [maxʘBinaryOpType]string{
    74  	// arith ops
    75  	"add",
    76  	"sub",
    77  	"mul",
    78  	"div",
    79  	"pow",
    80  
    81  	// cmp ops
    82  	"lt",
    83  	"gt",
    84  	"lte",
    85  	"gte",
    86  	"eq",
    87  	"ne",
    88  }
    89  
    90  // ʘBinOpCommutative is the array that stores whether a binary operator is commutative
    91  // It should be held constant.
    92  var ʘBinOpCommutative = [maxʘBinaryOpType]bool{
    93  	true, false, true, false, false,
    94  	false, false, false, false, true, true,
    95  }
    96  
    97  var ʘBinOpDiffExprs = [maxʘBinaryOpType]func(x, y, z, gradZ *Node) (Nodes, error){
    98  	addDiffExpr, subDiffExpr, hadamardProdDiffExpr, hadamardDivDiffExpr, hadamardPowDiffExpr,
    99  	nondiffBinOpExpr, nondiffBinOpExpr, nondiffBinOpExpr, nondiffBinOpExpr, nondiffBinOpExpr, nondiffBinOpExpr,
   100  }
   101  
   102  var ʘBinOpDiffFns = [maxʘBinaryOpType]func(ctx ExecutionContext, x, y, z *Node) error{
   103  	addDiff, subDiff, hadamardProdDiff, hadamardDivDiff, hadamardPowDiff,
   104  	nondiffBinOp, nondiffBinOp, nondiffBinOp, nondiffBinOp, nondiffBinOp, nondiffBinOp,
   105  }
   106  
   107  // isCommutative gives info about whether the operator is commutative
   108  // For example:
   109  //		a + b == b + a
   110  // will ALWAYS evaluate to true. The same cannot be said about subtraction:
   111  // 		a - b != b - a
   112  // While a-b *may* be equal to b-a, it is not guaranteed. Therefore subtraction
   113  // is not commutative
   114  func (op ʘBinaryOperatorType) isCommutative() bool {
   115  	if op >= maxʘBinaryOpType {
   116  		panic("isCommutative() for unsupported BinOp undefined")
   117  	}
   118  	return ʘBinOpCommutative[op]
   119  }
   120  
   121  func (op ʘBinaryOperatorType) diffWRT(inputs int) []bool {
   122  	if inputs != 2 {
   123  		panic("binary operator only supports 2 inputs")
   124  	}
   125  
   126  	if op.isArith() {
   127  		return []bool{true, true}
   128  	}
   129  	return []bool{false, false}
   130  }
   131  
   132  // isArith indicates if the binary operator is an arithmetic type
   133  func (op ʘBinaryOperatorType) isArith() bool {
   134  	switch op {
   135  	case addOpType, subOpType, mulOpType, divOpType, powOpType:
   136  		return true
   137  	default:
   138  		return false
   139  	}
   140  }
   141  
   142  var binOps = [maxʘBinaryOpType]*denseBinOp{
   143  	&tadd,
   144  	&tsub,
   145  	&tmul,
   146  	&tdiv,
   147  	&tpow,
   148  	nil, // lt
   149  	nil, // gt
   150  	nil, // lte
   151  	nil, // gte
   152  	nil, // eq
   153  	nil, // ne
   154  }
   155  
   156  var cmpOps = [maxʘBinaryOpType]*denseCmpOp{
   157  	nil, // add
   158  	nil, // sub
   159  	nil, // mul
   160  	nil, // div
   161  	nil, // pow
   162  	&tlt,
   163  	&tgt,
   164  	&tlte,
   165  	&tgte,
   166  	&teq,
   167  	&tne,
   168  }