gorgonia.org/tensor@v0.9.24/api_cmp.go (about) 1 package tensor 2 3 import "github.com/pkg/errors" 4 5 // public API for comparison ops 6 7 // Lt performs a elementwise less than comparison (a < b). a and b can either be float64 or *Dense. 8 // It returns the same Tensor type as its input. 9 // 10 // If both operands are *Dense, shape is checked first. 11 // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out. 12 func Lt(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 13 var lter Lter 14 var ok bool 15 switch at := a.(type) { 16 case Tensor: 17 lter, ok = at.Engine().(Lter) 18 switch bt := b.(type) { 19 case Tensor: 20 if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison 21 if !ok { 22 if lter, ok = bt.Engine().(Lter); !ok { 23 return nil, errors.Errorf("Neither operands have engines that support Lt") 24 } 25 } 26 27 return lter.Lt(at, bt, opts...) 28 } else { 29 var leftTensor bool 30 if !bt.Shape().IsScalar() { 31 leftTensor = false // a Scalar-Tensor * b Tensor 32 tmp := at 33 at = bt 34 bt = tmp 35 } else { 36 leftTensor = true // a Tensor * b Scalar-Tensor 37 } 38 39 if !ok { 40 return nil, errors.Errorf("Engine does not support Lt") 41 } 42 return lter.LtScalar(at, bt, leftTensor, opts...) 43 } 44 default: 45 if !ok { 46 return nil, errors.Errorf("Engine does not support Lt") 47 } 48 return lter.LtScalar(at, bt, true, opts...) 49 } 50 default: 51 switch bt := b.(type) { 52 case Tensor: 53 if lter, ok = bt.Engine().(Lter); !ok { 54 return nil, errors.Errorf("Engine does not support Lt") 55 } 56 return lter.LtScalar(bt, at, false, opts...) 57 default: 58 return nil, errors.Errorf("Unable to perform Lt on %T and %T", a, b) 59 } 60 } 61 } 62 63 // Gt performs a elementwise greater than comparison (a > b). a and b can either be float64 or *Dense. 64 // It returns the same Tensor type as its input. 65 // 66 // If both operands are *Dense, shape is checked first. 67 // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out. 68 func Gt(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 69 var gter Gter 70 var ok bool 71 switch at := a.(type) { 72 case Tensor: 73 gter, ok = at.Engine().(Gter) 74 switch bt := b.(type) { 75 case Tensor: 76 if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison 77 if !ok { 78 if gter, ok = bt.Engine().(Gter); !ok { 79 return nil, errors.Errorf("Neither operands have engines that support Gt") 80 } 81 } 82 return gter.Gt(at, bt, opts...) 83 } else { 84 var leftTensor bool 85 if !bt.Shape().IsScalar() { 86 leftTensor = false // a Scalar-Tensor * b Tensor 87 tmp := at 88 at = bt 89 bt = tmp 90 } else { 91 leftTensor = true // a Tensor * b Scalar-Tensor 92 } 93 94 if !ok { 95 return nil, errors.Errorf("Engine does not support Gt") 96 } 97 return gter.GtScalar(at, bt, leftTensor, opts...) 98 } 99 default: 100 if !ok { 101 return nil, errors.Errorf("Engine does not support Gt") 102 } 103 return gter.GtScalar(at, bt, true, opts...) 104 } 105 default: 106 switch bt := b.(type) { 107 case Tensor: 108 if gter, ok = bt.Engine().(Gter); !ok { 109 return nil, errors.Errorf("Engine does not support Gt") 110 } 111 return gter.GtScalar(bt, at, false, opts...) 112 default: 113 return nil, errors.Errorf("Unable to perform Gt on %T and %T", a, b) 114 } 115 } 116 } 117 118 // Lte performs a elementwise less than eq comparison (a <= b). a and b can either be float64 or *Dense. 119 // It returns the same Tensor type as its input. 120 // 121 // If both operands are *Dense, shape is checked first. 122 // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out. 123 func Lte(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 124 var lteer Lteer 125 var ok bool 126 switch at := a.(type) { 127 case Tensor: 128 lteer, ok = at.Engine().(Lteer) 129 switch bt := b.(type) { 130 case Tensor: 131 if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison 132 if !ok { 133 if lteer, ok = bt.Engine().(Lteer); !ok { 134 return nil, errors.Errorf("Neither operands have engines that support Lte") 135 } 136 } 137 return lteer.Lte(at, bt, opts...) 138 } else { 139 var leftTensor bool 140 if !bt.Shape().IsScalar() { 141 leftTensor = false // a Scalar-Tensor * b Tensor 142 tmp := at 143 at = bt 144 bt = tmp 145 } else { 146 leftTensor = true // a Tensor * b Scalar-Tensor 147 } 148 149 if !ok { 150 return nil, errors.Errorf("Engine does not support Lte") 151 } 152 return lteer.LteScalar(at, bt, leftTensor, opts...) 153 } 154 155 default: 156 if !ok { 157 return nil, errors.Errorf("Engine does not support Lte") 158 } 159 return lteer.LteScalar(at, bt, true, opts...) 160 } 161 default: 162 switch bt := b.(type) { 163 case Tensor: 164 if lteer, ok = bt.Engine().(Lteer); !ok { 165 return nil, errors.Errorf("Engine does not support Lte") 166 } 167 return lteer.LteScalar(bt, at, false, opts...) 168 default: 169 return nil, errors.Errorf("Unable to perform Lte on %T and %T", a, b) 170 } 171 } 172 } 173 174 // Gte performs a elementwise greater than eq comparison (a >= b). a and b can either be float64 or *Dense. 175 // It returns the same Tensor type as its input. 176 // 177 // If both operands are *Dense, shape is checked first. 178 // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out. 179 func Gte(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 180 var gteer Gteer 181 var ok bool 182 switch at := a.(type) { 183 case Tensor: 184 gteer, ok = at.Engine().(Gteer) 185 switch bt := b.(type) { 186 case Tensor: 187 if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison 188 if !ok { 189 if gteer, ok = bt.Engine().(Gteer); !ok { 190 return nil, errors.Errorf("Neither operands have engines that support Gte") 191 } 192 } 193 return gteer.Gte(at, bt, opts...) 194 } else { 195 var leftTensor bool 196 if !bt.Shape().IsScalar() { 197 leftTensor = false // a Scalar-Tensor * b Tensor 198 tmp := at 199 at = bt 200 bt = tmp 201 } else { 202 leftTensor = true // a Tensor * b Scalar-Tensor 203 } 204 205 if !ok { 206 return nil, errors.Errorf("Engine does not support Gte") 207 } 208 return gteer.GteScalar(at, bt, leftTensor, opts...) 209 } 210 default: 211 if !ok { 212 return nil, errors.Errorf("Engine does not support Gte") 213 } 214 return gteer.GteScalar(at, bt, true, opts...) 215 } 216 default: 217 switch bt := b.(type) { 218 case Tensor: 219 if gteer, ok = bt.Engine().(Gteer); !ok { 220 return nil, errors.Errorf("Engine does not support Gte") 221 } 222 return gteer.GteScalar(bt, at, false, opts...) 223 default: 224 return nil, errors.Errorf("Unable to perform Gte on %T and %T", a, b) 225 } 226 } 227 } 228 229 // ElEq performs a elementwise equality comparison (a == b). a and b can either be float64 or *Dense. 230 // It returns the same Tensor type as its input. 231 // 232 // If both operands are *Dense, shape is checked first. 233 // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out. 234 func ElEq(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 235 var eleqer ElEqer 236 var ok bool 237 switch at := a.(type) { 238 case Tensor: 239 eleqer, ok = at.Engine().(ElEqer) 240 switch bt := b.(type) { 241 case Tensor: 242 if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor comparison 243 if !ok { 244 if eleqer, ok = bt.Engine().(ElEqer); !ok { 245 return nil, errors.Errorf("Neither operands have engines that support ElEq") 246 } 247 } 248 return eleqer.ElEq(at, bt, opts...) 249 } else { 250 var leftTensor bool 251 if !bt.Shape().IsScalar() { 252 leftTensor = false // a Scalar-Tensor * b Tensor 253 tmp := at 254 at = bt 255 bt = tmp 256 } else { 257 leftTensor = true // a Tensor * b Scalar-Tensor 258 } 259 260 if !ok { 261 return nil, errors.Errorf("Engine does not support ElEq") 262 } 263 return eleqer.EqScalar(at, bt, leftTensor, opts...) 264 } 265 266 default: 267 if !ok { 268 return nil, errors.Errorf("Engine does not support ElEq") 269 } 270 return eleqer.EqScalar(at, bt, true, opts...) 271 } 272 default: 273 switch bt := b.(type) { 274 case Tensor: 275 if eleqer, ok = bt.Engine().(ElEqer); !ok { 276 return nil, errors.Errorf("Engine does not support ElEq") 277 } 278 return eleqer.EqScalar(bt, at, false, opts...) 279 default: 280 return nil, errors.Errorf("Unable to perform ElEq on %T and %T", a, b) 281 } 282 } 283 } 284 285 // ElNe performs a elementwise equality comparison (a != b). a and b can either be float64 or *Dense. 286 // It returns the same Tensor type as its input. 287 // 288 // If both operands are *Dense, shape is checked first. 289 // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out. 290 func ElNe(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 291 var eleqer ElEqer 292 var ok bool 293 switch at := a.(type) { 294 case Tensor: 295 eleqer, ok = at.Engine().(ElEqer) 296 switch bt := b.(type) { 297 case Tensor: 298 if !ok { 299 if eleqer, ok = bt.Engine().(ElEqer); !ok { 300 return nil, errors.Errorf("Neither operands have engines that support ElEq") 301 } 302 } 303 return eleqer.ElNe(at, bt, opts...) 304 default: 305 if !ok { 306 return nil, errors.Errorf("Engine does not support ElEq") 307 } 308 return eleqer.NeScalar(at, bt, true, opts...) 309 } 310 default: 311 switch bt := b.(type) { 312 case Tensor: 313 if eleqer, ok = bt.Engine().(ElEqer); !ok { 314 return nil, errors.Errorf("Engine does not support ElEq") 315 } 316 return eleqer.NeScalar(bt, at, false, opts...) 317 default: 318 return nil, errors.Errorf("Unable to perform ElEq on %T and %T", a, b) 319 } 320 } 321 }