gorgonia.org/tensor@v0.9.24/api_minmax.go (about)

     1  package tensor
     2  
     3  import "github.com/pkg/errors"
     4  
     5  func MinBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
     6  	var minbetweener MinBetweener
     7  	var oe standardEngine
     8  	var ok bool
     9  	switch at := a.(type) {
    10  	case Tensor:
    11  		oe = at.standardEngine()
    12  		switch bt := b.(type) {
    13  		case Tensor:
    14  			if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition
    15  				if oe != nil {
    16  					return oe.MinBetween(at, bt, opts...)
    17  				}
    18  				if oe = bt.standardEngine(); oe != nil {
    19  					return oe.MinBetween(at, bt, opts...)
    20  				}
    21  				if minbetweener, ok = at.Engine().(MinBetweener); ok {
    22  					return minbetweener.MinBetween(at, bt, opts...)
    23  				}
    24  				if minbetweener, ok = bt.Engine().(MinBetweener); ok {
    25  					return minbetweener.MinBetween(at, bt, opts...)
    26  				}
    27  				return nil, errors.New("Neither engines of either operand support MinBetween")
    28  
    29  			} else { // at least one of the operands is a scalar
    30  				var leftTensor bool
    31  				if !bt.Shape().IsScalar() {
    32  					leftTensor = false // a Scalar-Tensor * b Tensor
    33  					tmp := at
    34  					at = bt
    35  					bt = tmp
    36  				} else {
    37  					leftTensor = true // a Tensor * b Scalar-Tensor
    38  				}
    39  
    40  				if oe != nil {
    41  					return oe.MinBetweenScalar(at, bt, leftTensor, opts...)
    42  				}
    43  				if oe = bt.standardEngine(); oe != nil {
    44  					return oe.MinBetweenScalar(at, bt, leftTensor, opts...)
    45  				}
    46  				if minbetweener, ok = at.Engine().(MinBetweener); ok {
    47  					return minbetweener.MinBetweenScalar(at, bt, leftTensor, opts...)
    48  				}
    49  				if minbetweener, ok = bt.Engine().(MinBetweener); ok {
    50  					return minbetweener.MinBetweenScalar(at, bt, leftTensor, opts...)
    51  				}
    52  				return nil, errors.New("Neither engines of either operand support MinBetween")
    53  			}
    54  
    55  		default:
    56  			if oe != nil {
    57  				return oe.MinBetweenScalar(at, bt, true, opts...)
    58  			}
    59  			if minbetweener, ok = at.Engine().(MinBetweener); ok {
    60  				return minbetweener.MinBetweenScalar(at, bt, true, opts...)
    61  			}
    62  			return nil, errors.New("Operand A's engine does not support MinBetween")
    63  		}
    64  	default:
    65  		switch bt := b.(type) {
    66  		case Tensor:
    67  			if oe = bt.standardEngine(); oe != nil {
    68  				return oe.MinBetweenScalar(bt, at, false, opts...)
    69  			}
    70  			if minbetweener, ok = bt.Engine().(MinBetweener); ok {
    71  				return minbetweener.MinBetweenScalar(bt, at, false, opts...)
    72  			}
    73  			return nil, errors.New("Operand B's engine does not support MinBetween")
    74  		default:
    75  			return nil, errors.Errorf("Cannot perform MinBetween of %T and %T", a, b)
    76  		}
    77  	}
    78  	panic("Unreachable")
    79  }
    80  
    81  func MaxBetween(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
    82  	var maxbetweener MaxBetweener
    83  	var oe standardEngine
    84  	var ok bool
    85  	switch at := a.(type) {
    86  	case Tensor:
    87  		oe = at.standardEngine()
    88  		switch bt := b.(type) {
    89  		case Tensor:
    90  			if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition
    91  				if oe != nil {
    92  					return oe.MaxBetween(at, bt, opts...)
    93  				}
    94  				if oe = bt.standardEngine(); oe != nil {
    95  					return oe.MaxBetween(at, bt, opts...)
    96  				}
    97  				if maxbetweener, ok = at.Engine().(MaxBetweener); ok {
    98  					return maxbetweener.MaxBetween(at, bt, opts...)
    99  				}
   100  				if maxbetweener, ok = bt.Engine().(MaxBetweener); ok {
   101  					return maxbetweener.MaxBetween(at, bt, opts...)
   102  				}
   103  				return nil, errors.New("Neither engines of either operand support MaxBetween")
   104  
   105  			} else { // at least one of the operands is a scalar
   106  				var leftTensor bool
   107  				if !bt.Shape().IsScalar() {
   108  					leftTensor = false // a Scalar-Tensor * b Tensor
   109  					tmp := at
   110  					at = bt
   111  					bt = tmp
   112  				} else {
   113  					leftTensor = true // a Tensor * b Scalar-Tensor
   114  				}
   115  
   116  				if oe != nil {
   117  					return oe.MaxBetweenScalar(at, bt, leftTensor, opts...)
   118  				}
   119  				if oe = bt.standardEngine(); oe != nil {
   120  					return oe.MaxBetweenScalar(at, bt, leftTensor, opts...)
   121  				}
   122  				if maxbetweener, ok = at.Engine().(MaxBetweener); ok {
   123  					return maxbetweener.MaxBetweenScalar(at, bt, leftTensor, opts...)
   124  				}
   125  				if maxbetweener, ok = bt.Engine().(MaxBetweener); ok {
   126  					return maxbetweener.MaxBetweenScalar(at, bt, leftTensor, opts...)
   127  				}
   128  				return nil, errors.New("Neither engines of either operand support MaxBetween")
   129  			}
   130  
   131  		default:
   132  			if oe != nil {
   133  				return oe.MaxBetweenScalar(at, bt, true, opts...)
   134  			}
   135  			if maxbetweener, ok = at.Engine().(MaxBetweener); ok {
   136  				return maxbetweener.MaxBetweenScalar(at, bt, true, opts...)
   137  			}
   138  			return nil, errors.New("Operand A's engine does not support MaxBetween")
   139  		}
   140  	default:
   141  		switch bt := b.(type) {
   142  		case Tensor:
   143  			if oe = bt.standardEngine(); oe != nil {
   144  				return oe.MaxBetweenScalar(bt, at, false, opts...)
   145  			}
   146  			if maxbetweener, ok = bt.Engine().(MaxBetweener); ok {
   147  				return maxbetweener.MaxBetweenScalar(bt, at, false, opts...)
   148  			}
   149  			return nil, errors.New("Operand B's engine does not support MaxBetween")
   150  		default:
   151  			return nil, errors.Errorf("Cannot perform MaxBetween of %T and %T", a, b)
   152  		}
   153  	}
   154  	panic("Unreachable")
   155  }