gorgonia.org/tensor@v0.9.24/api_cmp.go (about)

     1  package tensor
     2  
     3  import "github.com/pkg/errors"
     4  
     5  // public API for comparison ops
     6  
     7  // Lt performs a elementwise less than comparison (a < b). a and b can either be float64 or *Dense.
     8  // It returns the same Tensor type as its input.
     9  //
    10  // If both operands are *Dense, shape is checked first.
    11  // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out.
    12  func Lt(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
    13  	var lter Lter
    14  	var ok bool
    15  	switch at := a.(type) {
    16  	case Tensor:
    17  		lter, ok = at.Engine().(Lter)
    18  		switch bt := b.(type) {
    19  		case Tensor:
    20  			if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison
    21  				if !ok {
    22  					if lter, ok = bt.Engine().(Lter); !ok {
    23  						return nil, errors.Errorf("Neither operands have engines that support Lt")
    24  					}
    25  				}
    26  
    27  				return lter.Lt(at, bt, opts...)
    28  			} else {
    29  				var leftTensor bool
    30  				if !bt.Shape().IsScalar() {
    31  					leftTensor = false // a Scalar-Tensor * b Tensor
    32  					tmp := at
    33  					at = bt
    34  					bt = tmp
    35  				} else {
    36  					leftTensor = true // a Tensor * b Scalar-Tensor
    37  				}
    38  
    39  				if !ok {
    40  					return nil, errors.Errorf("Engine does not support Lt")
    41  				}
    42  				return lter.LtScalar(at, bt, leftTensor, opts...)
    43  			}
    44  		default:
    45  			if !ok {
    46  				return nil, errors.Errorf("Engine does not support Lt")
    47  			}
    48  			return lter.LtScalar(at, bt, true, opts...)
    49  		}
    50  	default:
    51  		switch bt := b.(type) {
    52  		case Tensor:
    53  			if lter, ok = bt.Engine().(Lter); !ok {
    54  				return nil, errors.Errorf("Engine does not support Lt")
    55  			}
    56  			return lter.LtScalar(bt, at, false, opts...)
    57  		default:
    58  			return nil, errors.Errorf("Unable to perform Lt on %T and %T", a, b)
    59  		}
    60  	}
    61  }
    62  
    63  // Gt performs a elementwise greater than comparison (a > b). a and b can either be float64 or *Dense.
    64  // It returns the same Tensor type as its input.
    65  //
    66  // If both operands are *Dense, shape is checked first.
    67  // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out.
    68  func Gt(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
    69  	var gter Gter
    70  	var ok bool
    71  	switch at := a.(type) {
    72  	case Tensor:
    73  		gter, ok = at.Engine().(Gter)
    74  		switch bt := b.(type) {
    75  		case Tensor:
    76  			if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison
    77  				if !ok {
    78  					if gter, ok = bt.Engine().(Gter); !ok {
    79  						return nil, errors.Errorf("Neither operands have engines that support Gt")
    80  					}
    81  				}
    82  				return gter.Gt(at, bt, opts...)
    83  			} else {
    84  				var leftTensor bool
    85  				if !bt.Shape().IsScalar() {
    86  					leftTensor = false // a Scalar-Tensor * b Tensor
    87  					tmp := at
    88  					at = bt
    89  					bt = tmp
    90  				} else {
    91  					leftTensor = true // a Tensor * b Scalar-Tensor
    92  				}
    93  
    94  				if !ok {
    95  					return nil, errors.Errorf("Engine does not support Gt")
    96  				}
    97  				return gter.GtScalar(at, bt, leftTensor, opts...)
    98  			}
    99  		default:
   100  			if !ok {
   101  				return nil, errors.Errorf("Engine does not support Gt")
   102  			}
   103  			return gter.GtScalar(at, bt, true, opts...)
   104  		}
   105  	default:
   106  		switch bt := b.(type) {
   107  		case Tensor:
   108  			if gter, ok = bt.Engine().(Gter); !ok {
   109  				return nil, errors.Errorf("Engine does not support Gt")
   110  			}
   111  			return gter.GtScalar(bt, at, false, opts...)
   112  		default:
   113  			return nil, errors.Errorf("Unable to perform Gt on %T and %T", a, b)
   114  		}
   115  	}
   116  }
   117  
   118  // Lte performs a elementwise less than eq comparison (a <= b). a and b can either be float64 or *Dense.
   119  // It returns the same Tensor type as its input.
   120  //
   121  // If both operands are *Dense, shape is checked first.
   122  // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out.
   123  func Lte(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
   124  	var lteer Lteer
   125  	var ok bool
   126  	switch at := a.(type) {
   127  	case Tensor:
   128  		lteer, ok = at.Engine().(Lteer)
   129  		switch bt := b.(type) {
   130  		case Tensor:
   131  			if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison
   132  				if !ok {
   133  					if lteer, ok = bt.Engine().(Lteer); !ok {
   134  						return nil, errors.Errorf("Neither operands have engines that support Lte")
   135  					}
   136  				}
   137  				return lteer.Lte(at, bt, opts...)
   138  			} else {
   139  				var leftTensor bool
   140  				if !bt.Shape().IsScalar() {
   141  					leftTensor = false // a Scalar-Tensor * b Tensor
   142  					tmp := at
   143  					at = bt
   144  					bt = tmp
   145  				} else {
   146  					leftTensor = true // a Tensor * b Scalar-Tensor
   147  				}
   148  
   149  				if !ok {
   150  					return nil, errors.Errorf("Engine does not support Lte")
   151  				}
   152  				return lteer.LteScalar(at, bt, leftTensor, opts...)
   153  			}
   154  
   155  		default:
   156  			if !ok {
   157  				return nil, errors.Errorf("Engine does not support Lte")
   158  			}
   159  			return lteer.LteScalar(at, bt, true, opts...)
   160  		}
   161  	default:
   162  		switch bt := b.(type) {
   163  		case Tensor:
   164  			if lteer, ok = bt.Engine().(Lteer); !ok {
   165  				return nil, errors.Errorf("Engine does not support Lte")
   166  			}
   167  			return lteer.LteScalar(bt, at, false, opts...)
   168  		default:
   169  			return nil, errors.Errorf("Unable to perform Lte on %T and %T", a, b)
   170  		}
   171  	}
   172  }
   173  
   174  // Gte performs a elementwise greater than eq comparison (a >= b). a and b can either be float64 or *Dense.
   175  // It returns the same Tensor type as its input.
   176  //
   177  // If both operands are *Dense, shape is checked first.
   178  // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out.
   179  func Gte(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
   180  	var gteer Gteer
   181  	var ok bool
   182  	switch at := a.(type) {
   183  	case Tensor:
   184  		gteer, ok = at.Engine().(Gteer)
   185  		switch bt := b.(type) {
   186  		case Tensor:
   187  			if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison
   188  				if !ok {
   189  					if gteer, ok = bt.Engine().(Gteer); !ok {
   190  						return nil, errors.Errorf("Neither operands have engines that support Gte")
   191  					}
   192  				}
   193  				return gteer.Gte(at, bt, opts...)
   194  			} else {
   195  				var leftTensor bool
   196  				if !bt.Shape().IsScalar() {
   197  					leftTensor = false // a Scalar-Tensor * b Tensor
   198  					tmp := at
   199  					at = bt
   200  					bt = tmp
   201  				} else {
   202  					leftTensor = true // a Tensor * b Scalar-Tensor
   203  				}
   204  
   205  				if !ok {
   206  					return nil, errors.Errorf("Engine does not support Gte")
   207  				}
   208  				return gteer.GteScalar(at, bt, leftTensor, opts...)
   209  			}
   210  		default:
   211  			if !ok {
   212  				return nil, errors.Errorf("Engine does not support Gte")
   213  			}
   214  			return gteer.GteScalar(at, bt, true, opts...)
   215  		}
   216  	default:
   217  		switch bt := b.(type) {
   218  		case Tensor:
   219  			if gteer, ok = bt.Engine().(Gteer); !ok {
   220  				return nil, errors.Errorf("Engine does not support Gte")
   221  			}
   222  			return gteer.GteScalar(bt, at, false, opts...)
   223  		default:
   224  			return nil, errors.Errorf("Unable to perform Gte on %T and %T", a, b)
   225  		}
   226  	}
   227  }
   228  
   229  // ElEq performs a elementwise equality comparison (a == b). a and b can either be float64 or *Dense.
   230  // It returns the same Tensor type as its input.
   231  //
   232  // If both operands are *Dense, shape is checked first.
   233  // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out.
   234  func ElEq(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
   235  	var eleqer ElEqer
   236  	var ok bool
   237  	switch at := a.(type) {
   238  	case Tensor:
   239  		eleqer, ok = at.Engine().(ElEqer)
   240  		switch bt := b.(type) {
   241  		case Tensor:
   242  			if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison
   243  				if !ok {
   244  					if eleqer, ok = bt.Engine().(ElEqer); !ok {
   245  						return nil, errors.Errorf("Neither operands have engines that support ElEq")
   246  					}
   247  				}
   248  				return eleqer.ElEq(at, bt, opts...)
   249  			} else {
   250  				var leftTensor bool
   251  				if !bt.Shape().IsScalar() {
   252  					leftTensor = false // a Scalar-Tensor * b Tensor
   253  					tmp := at
   254  					at = bt
   255  					bt = tmp
   256  				} else {
   257  					leftTensor = true // a Tensor * b Scalar-Tensor
   258  				}
   259  
   260  				if !ok {
   261  					return nil, errors.Errorf("Engine does not support ElEq")
   262  				}
   263  				return eleqer.EqScalar(at, bt, leftTensor, opts...)
   264  			}
   265  
   266  		default:
   267  			if !ok {
   268  				return nil, errors.Errorf("Engine does not support ElEq")
   269  			}
   270  			return eleqer.EqScalar(at, bt, true, opts...)
   271  		}
   272  	default:
   273  		switch bt := b.(type) {
   274  		case Tensor:
   275  			if eleqer, ok = bt.Engine().(ElEqer); !ok {
   276  				return nil, errors.Errorf("Engine does not support ElEq")
   277  			}
   278  			return eleqer.EqScalar(bt, at, false, opts...)
   279  		default:
   280  			return nil, errors.Errorf("Unable to perform ElEq on %T and %T", a, b)
   281  		}
   282  	}
   283  }
   284  
   285  // ElNe performs a elementwise equality comparison (a != b). a and b can either be float64 or *Dense.
   286  // It returns the same Tensor type as its input.
   287  //
   288  // If both operands are *Dense, shape is checked first.
   289  // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out.
   290  func ElNe(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) {
   291  	var eleqer ElEqer
   292  	var ok bool
   293  	switch at := a.(type) {
   294  	case Tensor:
   295  		eleqer, ok = at.Engine().(ElEqer)
   296  		switch bt := b.(type) {
   297  		case Tensor:
   298  			if !ok {
   299  				if eleqer, ok = bt.Engine().(ElEqer); !ok {
   300  					return nil, errors.Errorf("Neither operands have engines that support ElEq")
   301  				}
   302  			}
   303  			return eleqer.ElNe(at, bt, opts...)
   304  		default:
   305  			if !ok {
   306  				return nil, errors.Errorf("Engine does not support ElEq")
   307  			}
   308  			return eleqer.NeScalar(at, bt, true, opts...)
   309  		}
   310  	default:
   311  		switch bt := b.(type) {
   312  		case Tensor:
   313  			if eleqer, ok = bt.Engine().(ElEqer); !ok {
   314  				return nil, errors.Errorf("Engine does not support ElEq")
   315  			}
   316  			return eleqer.NeScalar(bt, at, false, opts...)
   317  		default:
   318  			return nil, errors.Errorf("Unable to perform ElEq on %T and %T", a, b)
   319  		}
   320  	}
   321  }