github.com/wzzhu/tensor@v0.9.24/api_arith.go (about) 1 package tensor 2 3 import ( 4 "github.com/pkg/errors" 5 ) 6 7 // exported API for arithmetics and the stupidly crazy amount of overloaded semantics 8 9 // Add performs a pointwise a+b. a and b can either be float64 or Tensor 10 // 11 // If both operands are Tensor, shape is checked first. 12 // Even though the underlying data may have the same size (say (2,2) vs (4,1)), if they have different shapes, it will error out. 13 // 14 15 // Add performs elementwise addition on the Tensor(s). These operations are supported: 16 // Add(*Dense, scalar) 17 // Add(scalar, *Dense) 18 // Add(*Dense, *Dense) 19 // If the Unsafe flag is passed in, the data of the first tensor will be overwritten 20 func Add(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 21 var adder Adder 22 var oe standardEngine 23 var ok bool 24 switch at := a.(type) { 25 case Tensor: 26 oe = at.standardEngine() 27 switch bt := b.(type) { 28 case Tensor: 29 if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor addition 30 if oe != nil { 31 return oe.Add(at, bt, opts...) 32 } 33 if oe = bt.standardEngine(); oe != nil { 34 return oe.Add(at, bt, opts...) 35 } 36 if adder, ok = at.Engine().(Adder); ok { 37 return adder.Add(at, bt, opts...) 38 } 39 if adder, ok = bt.Engine().(Adder); ok { 40 return adder.Add(at, bt, opts...) 41 } 42 return nil, errors.New("Neither engines of either operand support Add") 43 44 } else { // at least one of the operands is a scalar 45 var leftTensor bool 46 if !bt.Shape().IsScalar() { 47 leftTensor = false // a Scalar-Tensor * b Tensor 48 tmp := at 49 at = bt 50 bt = tmp 51 } else { 52 leftTensor = true // a Tensor * b Scalar-Tensor 53 } 54 55 if oe != nil { 56 return oe.AddScalar(at, bt, leftTensor, opts...) 57 } 58 if oe = bt.standardEngine(); oe != nil { 59 return oe.AddScalar(at, bt, leftTensor, opts...) 60 } 61 if adder, ok = at.Engine().(Adder); ok { 62 return adder.AddScalar(at, bt, leftTensor, opts...) 63 } 64 if adder, ok = bt.Engine().(Adder); ok { 65 return adder.AddScalar(at, bt, leftTensor, opts...) 66 } 67 return nil, errors.New("Neither engines of either operand support Add") 68 } 69 70 default: 71 if oe != nil { 72 return oe.AddScalar(at, bt, true, opts...) 73 } 74 if adder, ok = at.Engine().(Adder); ok { 75 return adder.AddScalar(at, bt, true, opts...) 76 } 77 return nil, errors.New("Operand A's engine does not support Add") 78 } 79 default: 80 switch bt := b.(type) { 81 case Tensor: 82 if oe = bt.standardEngine(); oe != nil { 83 return oe.AddScalar(bt, at, false, opts...) 84 } 85 if adder, ok = bt.Engine().(Adder); ok { 86 return adder.AddScalar(bt, at, false, opts...) 87 } 88 return nil, errors.New("Operand B's engine does not support Add") 89 default: 90 return nil, errors.Errorf("Cannot perform Add of %T and %T", a, b) 91 } 92 } 93 panic("Unreachable") 94 } 95 96 // Sub performs elementwise subtraction on the Tensor(s). These operations are supported: 97 // Sub(*Dense, scalar) 98 // Sub(scalar, *Dense) 99 // Sub(*Dense, *Dense) 100 // If the Unsafe flag is passed in, the data of the first tensor will be overwritten 101 func Sub(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 102 var suber Suber 103 var oe standardEngine 104 var ok bool 105 switch at := a.(type) { 106 case Tensor: 107 oe = at.standardEngine() 108 switch bt := b.(type) { 109 case Tensor: 110 if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor substraction 111 if oe != nil { 112 return oe.Sub(at, bt, opts...) 113 } 114 if oe = bt.standardEngine(); oe != nil { 115 return oe.Sub(at, bt, opts...) 116 } 117 if suber, ok = at.Engine().(Suber); ok { 118 return suber.Sub(at, bt, opts...) 119 } 120 if suber, ok = bt.Engine().(Suber); ok { 121 return suber.Sub(at, bt, opts...) 122 } 123 return nil, errors.New("Neither engines of either operand support Sub") 124 125 } else { // at least one of the operands is a scalar 126 var leftTensor bool 127 if !bt.Shape().IsScalar() { 128 leftTensor = false // a Scalar-Tensor * b Tensor 129 tmp := at 130 at = bt 131 bt = tmp 132 } else { 133 leftTensor = true // a Tensor * b Scalar-Tensor 134 } 135 136 if oe != nil { 137 return oe.SubScalar(at, bt, leftTensor, opts...) 138 } 139 if oe = bt.standardEngine(); oe != nil { 140 return oe.SubScalar(at, bt, leftTensor, opts...) 141 } 142 if suber, ok = at.Engine().(Suber); ok { 143 return suber.SubScalar(at, bt, leftTensor, opts...) 144 } 145 if suber, ok = bt.Engine().(Suber); ok { 146 return suber.SubScalar(at, bt, leftTensor, opts...) 147 } 148 return nil, errors.New("Neither engines of either operand support Sub") 149 } 150 151 default: 152 if oe != nil { 153 return oe.SubScalar(at, bt, true, opts...) 154 } 155 if suber, ok = at.Engine().(Suber); ok { 156 return suber.SubScalar(at, bt, true, opts...) 157 } 158 return nil, errors.New("Operand A's engine does not support Sub") 159 } 160 default: 161 switch bt := b.(type) { 162 case Tensor: 163 if oe = bt.standardEngine(); oe != nil { 164 return oe.SubScalar(bt, at, false, opts...) 165 } 166 if suber, ok = bt.Engine().(Suber); ok { 167 return suber.SubScalar(bt, at, false, opts...) 168 } 169 return nil, errors.New("Operand B's engine does not support Sub") 170 default: 171 return nil, errors.Errorf("Cannot perform Sub of %T and %T", a, b) 172 } 173 } 174 panic("Unreachable") 175 } 176 177 // Mul performs elementwise multiplication on the Tensor(s). These operations are supported: 178 // Mul(*Dense, scalar) 179 // Mul(scalar, *Dense) 180 // Mul(*Dense, *Dense) 181 // If the Unsafe flag is passed in, the data of the first tensor will be overwritten 182 func Mul(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 183 var muler Muler 184 var oe standardEngine 185 var ok bool 186 switch at := a.(type) { 187 case Tensor: 188 oe = at.standardEngine() 189 switch bt := b.(type) { 190 case Tensor: 191 if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor multiplication 192 if oe != nil { 193 return oe.Mul(at, bt, opts...) 194 } 195 if oe = bt.standardEngine(); oe != nil { 196 return oe.Mul(at, bt, opts...) 197 } 198 if muler, ok = at.Engine().(Muler); ok { 199 return muler.Mul(at, bt, opts...) 200 } 201 if muler, ok = bt.Engine().(Muler); ok { 202 return muler.Mul(at, bt, opts...) 203 } 204 return nil, errors.New("Neither engines of either operand support Mul") 205 206 } else { // at least one of the operands is a scalar 207 var leftTensor bool 208 if !bt.Shape().IsScalar() { 209 leftTensor = false // a Scalar-Tensor * b Tensor 210 tmp := at 211 at = bt 212 bt = tmp 213 } else { 214 leftTensor = true // a Tensor * b Scalar-Tensor 215 } 216 217 if oe != nil { 218 return oe.MulScalar(at, bt, leftTensor, opts...) 219 } 220 if oe = bt.standardEngine(); oe != nil { 221 return oe.MulScalar(at, bt, leftTensor, opts...) 222 } 223 if muler, ok = at.Engine().(Muler); ok { 224 return muler.MulScalar(at, bt, leftTensor, opts...) 225 } 226 if muler, ok = bt.Engine().(Muler); ok { 227 return muler.MulScalar(at, bt, leftTensor, opts...) 228 } 229 return nil, errors.New("Neither engines of either operand support Mul") 230 } 231 232 default: // a Tensor * b interface 233 if oe != nil { 234 return oe.MulScalar(at, bt, true, opts...) 235 } 236 if muler, ok = at.Engine().(Muler); ok { 237 return muler.MulScalar(at, bt, true, opts...) 238 } 239 return nil, errors.New("Operand A's engine does not support Mul") 240 } 241 242 default: 243 switch bt := b.(type) { 244 case Tensor: // b Tensor * a interface 245 if oe = bt.standardEngine(); oe != nil { 246 return oe.MulScalar(bt, at, false, opts...) 247 } 248 if muler, ok = bt.Engine().(Muler); ok { 249 return muler.MulScalar(bt, at, false, opts...) 250 } 251 return nil, errors.New("Operand B's engine does not support Mul") 252 253 default: // b interface * a interface 254 return nil, errors.Errorf("Cannot perform Mul of %T and %T", a, b) 255 } 256 } 257 panic("Unreachable") 258 } 259 260 // Div performs elementwise division on the Tensor(s). These operations are supported: 261 // Div(*Dense, scalar) 262 // Div(scalar, *Dense) 263 // Div(*Dense, *Dense) 264 // If the Unsafe flag is passed in, the data of the first tensor will be overwritten 265 func Div(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 266 var diver Diver 267 var oe standardEngine 268 var ok bool 269 switch at := a.(type) { 270 case Tensor: 271 oe = at.standardEngine() 272 switch bt := b.(type) { 273 case Tensor: 274 if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor division 275 if oe != nil { 276 return oe.Div(at, bt, opts...) 277 } 278 if oe = bt.standardEngine(); oe != nil { 279 return oe.Div(at, bt, opts...) 280 } 281 if diver, ok = at.Engine().(Diver); ok { 282 return diver.Div(at, bt, opts...) 283 } 284 if diver, ok = bt.Engine().(Diver); ok { 285 return diver.Div(at, bt, opts...) 286 } 287 return nil, errors.New("Neither engines of either operand support Div") 288 289 } else { // at least one of the operands is a scalar 290 var leftTensor bool 291 if !bt.Shape().IsScalar() { 292 leftTensor = false // a Scalar-Tensor * b Tensor 293 tmp := at 294 at = bt 295 bt = tmp 296 } else { 297 leftTensor = true // a Tensor * b Scalar-Tensor 298 } 299 300 if oe != nil { 301 return oe.DivScalar(at, bt, leftTensor, opts...) 302 } 303 if oe = bt.standardEngine(); oe != nil { 304 return oe.DivScalar(at, bt, leftTensor, opts...) 305 } 306 if diver, ok = at.Engine().(Diver); ok { 307 return diver.DivScalar(at, bt, leftTensor, opts...) 308 } 309 if diver, ok = bt.Engine().(Diver); ok { 310 return diver.DivScalar(at, bt, leftTensor, opts...) 311 } 312 return nil, errors.New("Neither engines of either operand support Div") 313 } 314 315 default: 316 if oe != nil { 317 return oe.DivScalar(at, bt, true, opts...) 318 } 319 if diver, ok = at.Engine().(Diver); ok { 320 return diver.DivScalar(at, bt, true, opts...) 321 } 322 return nil, errors.New("Operand A's engine does not support Div") 323 } 324 default: 325 switch bt := b.(type) { 326 case Tensor: 327 if oe = bt.standardEngine(); oe != nil { 328 return oe.DivScalar(bt, at, false, opts...) 329 } 330 if diver, ok = bt.Engine().(Diver); ok { 331 return diver.DivScalar(bt, at, false, opts...) 332 } 333 return nil, errors.New("Operand B's engine does not support Div") 334 default: 335 return nil, errors.Errorf("Cannot perform Div of %T and %T", a, b) 336 } 337 } 338 panic("Unreachable") 339 } 340 341 // Pow performs elementwise exponentiation on the Tensor(s). These operations are supported: 342 // Pow(*Dense, scalar) 343 // Pow(scalar, *Dense) 344 // Pow(*Dense, *Dense) 345 // If the Unsafe flag is passed in, the data of the first tensor will be overwritten 346 func Pow(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 347 var power Power 348 var oe standardEngine 349 var ok bool 350 switch at := a.(type) { 351 case Tensor: 352 oe = at.standardEngine() 353 switch bt := b.(type) { 354 case Tensor: 355 if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor exponentiation 356 if oe != nil { 357 return oe.Pow(at, bt, opts...) 358 } 359 if oe = bt.standardEngine(); oe != nil { 360 return oe.Pow(at, bt, opts...) 361 } 362 if power, ok = at.Engine().(Power); ok { 363 return power.Pow(at, bt, opts...) 364 } 365 if power, ok = bt.Engine().(Power); ok { 366 return power.Pow(at, bt, opts...) 367 } 368 return nil, errors.New("Neither engines of either operand support Pow") 369 370 } else { // at least one of the operands is a scalar 371 var leftTensor bool 372 if !bt.Shape().IsScalar() { 373 leftTensor = false // a Scalar-Tensor * b Tensor 374 tmp := at 375 at = bt 376 bt = tmp 377 } else { 378 leftTensor = true // a Tensor * b Scalar-Tensor 379 } 380 381 if oe != nil { 382 return oe.PowScalar(at, bt, leftTensor, opts...) 383 } 384 if oe = bt.standardEngine(); oe != nil { 385 return oe.PowScalar(at, bt, leftTensor, opts...) 386 } 387 if power, ok = at.Engine().(Power); ok { 388 return power.PowScalar(at, bt, leftTensor, opts...) 389 } 390 if power, ok = bt.Engine().(Power); ok { 391 return power.PowScalar(at, bt, leftTensor, opts...) 392 } 393 return nil, errors.New("Neither engines of either operand support Pow") 394 } 395 396 default: 397 if oe != nil { 398 return oe.PowScalar(at, bt, true, opts...) 399 } 400 if power, ok = at.Engine().(Power); ok { 401 return power.PowScalar(at, bt, true, opts...) 402 } 403 return nil, errors.New("Operand A's engine does not support Pow") 404 } 405 default: 406 switch bt := b.(type) { 407 case Tensor: 408 if oe = bt.standardEngine(); oe != nil { 409 return oe.PowScalar(bt, at, false, opts...) 410 } 411 if power, ok = bt.Engine().(Power); ok { 412 return power.PowScalar(bt, at, false, opts...) 413 } 414 return nil, errors.New("Operand B's engine does not support Pow") 415 default: 416 return nil, errors.Errorf("Cannot perform Pow of %T and %T", a, b) 417 } 418 } 419 panic("Unreachable") 420 } 421 422 // Mod performs elementwise modulo on the Tensor(s). These operations are supported: 423 // Mod(*Dense, scalar) 424 // Mod(scalar, *Dense) 425 // Mod(*Dense, *Dense) 426 // If the Unsafe flag is passed in, the data of the first tensor will be overwritten 427 func Mod(a, b interface{}, opts ...FuncOpt) (retVal Tensor, err error) { 428 var moder Moder 429 var oe standardEngine 430 var ok bool 431 switch at := a.(type) { 432 case Tensor: 433 oe = at.standardEngine() 434 switch bt := b.(type) { 435 case Tensor: 436 if !bt.Shape().IsScalar() && !at.Shape().IsScalar() { // non-scalar Tensor modulo 437 if oe != nil { 438 return oe.Mod(at, bt, opts...) 439 } 440 if oe = bt.standardEngine(); oe != nil { 441 return oe.Mod(at, bt, opts...) 442 } 443 if moder, ok = at.Engine().(Moder); ok { 444 return moder.Mod(at, bt, opts...) 445 } 446 if moder, ok = bt.Engine().(Moder); ok { 447 return moder.Mod(at, bt, opts...) 448 } 449 return nil, errors.New("Neither engines of either operand support Mod") 450 451 } else { // at least one of the operands is a scalar 452 var leftTensor bool 453 if !bt.Shape().IsScalar() { 454 leftTensor = false // a Scalar-Tensor * b Tensor 455 tmp := at 456 at = bt 457 bt = tmp 458 } else { 459 leftTensor = true // a Tensor * b Scalar-Tensor 460 } 461 462 if oe != nil { 463 return oe.ModScalar(at, bt, leftTensor, opts...) 464 } 465 if oe = bt.standardEngine(); oe != nil { 466 return oe.ModScalar(at, bt, leftTensor, opts...) 467 } 468 if moder, ok = at.Engine().(Moder); ok { 469 return moder.ModScalar(at, bt, leftTensor, opts...) 470 } 471 if moder, ok = bt.Engine().(Moder); ok { 472 return moder.ModScalar(at, bt, leftTensor, opts...) 473 } 474 return nil, errors.New("Neither engines of either operand support Mod") 475 } 476 477 default: 478 if oe != nil { 479 return oe.ModScalar(at, bt, true, opts...) 480 } 481 if moder, ok = at.Engine().(Moder); ok { 482 return moder.ModScalar(at, bt, true, opts...) 483 } 484 return nil, errors.New("Operand A's engine does not support Mod") 485 } 486 default: 487 switch bt := b.(type) { 488 case Tensor: 489 if oe = bt.standardEngine(); oe != nil { 490 return oe.ModScalar(bt, at, false, opts...) 491 } 492 if moder, ok = bt.Engine().(Moder); ok { 493 return moder.ModScalar(bt, at, false, opts...) 494 } 495 return nil, errors.New("Operand B's engine does not support Mod") 496 default: 497 return nil, errors.Errorf("Cannot perform Mod of %T and %T", a, b) 498 } 499 } 500 panic("Unreachable") 501 } 502 503 // Dot is a highly opinionated API for performing dot product operations on two *Denses, a and b. 504 // This function is opinionated with regard to the vector operations because of how it treats operations with vectors. 505 // Vectors in this package comes in two flavours - column or row vectors. Column vectors have shape (x, 1), while row vectors have shape (1, x). 506 // 507 // As such, it is easy to assume that performing a linalg operation on vectors would follow the same rules (i.e shapes have to be aligned for things to work). 508 // For the most part in this package, this is true. This function is one of the few notable exceptions. 509 // 510 // Here I give three specific examples of how the expectations of vector operations will differ. 511 // Given two vectors, a, b with shapes (4, 1) and (4, 1), Dot() will perform an inner product as if the shapes were (1, 4) and (4, 1). This will result in a scalar value 512 // Given matrix A and vector b with shapes (2, 4) and (1, 4), Dot() will perform a matrix-vector multiplication as if the shapes were (2,4) and (4,1). This will result in a column vector with shape (2,1) 513 // Given vector a and matrix B with shapes (3, 1) and (3, 2), Dot() will perform a matrix-vector multiplication as if it were Báµ€ * a 514 // 515 // The main reason why this opinionated route was taken was due to the author's familiarity with NumPy, and general laziness in translating existing machine learning algorithms 516 // to fit the API of the package. 517 func Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { 518 if xdottir, ok := x.Engine().(Dotter); ok { 519 return xdottir.Dot(x, y, opts...) 520 } 521 if ydottir, ok := y.Engine().(Dotter); ok { 522 return ydottir.Dot(x, y, opts...) 523 } 524 return nil, errors.New("Neither x's nor y's engines support Dot") 525 } 526 527 // FMA performs Y = A * X + Y. 528 func FMA(a Tensor, x interface{}, y Tensor) (retVal Tensor, err error) { 529 if xTensor, ok := x.(Tensor); ok { 530 if oe := a.standardEngine(); oe != nil { 531 return oe.FMA(a, xTensor, y) 532 } 533 if oe := xTensor.standardEngine(); oe != nil { 534 return oe.FMA(a, xTensor, y) 535 } 536 if oe := y.standardEngine(); oe != nil { 537 return oe.FMA(a, xTensor, y) 538 } 539 540 if e, ok := a.Engine().(FMAer); ok { 541 return e.FMA(a, xTensor, y) 542 } 543 if e, ok := xTensor.Engine().(FMAer); ok { 544 return e.FMA(a, xTensor, y) 545 } 546 if e, ok := y.Engine().(FMAer); ok { 547 return e.FMA(a, xTensor, y) 548 } 549 } else { 550 if oe := a.standardEngine(); oe != nil { 551 return oe.FMAScalar(a, x, y) 552 } 553 if oe := y.standardEngine(); oe != nil { 554 return oe.FMAScalar(a, x, y) 555 } 556 557 if e, ok := a.Engine().(FMAer); ok { 558 return e.FMAScalar(a, x, y) 559 } 560 if e, ok := y.Engine().(FMAer); ok { 561 return e.FMAScalar(a, x, y) 562 } 563 } 564 return Mul(a, x, WithIncr(y)) 565 } 566 567 // MatMul performs matrix-matrix multiplication between two Tensors 568 func MatMul(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { 569 if a.Dtype() != b.Dtype() { 570 err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype()) 571 return 572 } 573 574 switch at := a.(type) { 575 case *Dense: 576 bt := b.(*Dense) 577 return at.MatMul(bt, opts...) 578 } 579 panic("Unreachable") 580 } 581 582 // MatVecMul performs matrix-vector multiplication between two Tensors. `a` is expected to be a matrix, and `b` is expected to be a vector 583 func MatVecMul(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { 584 if a.Dtype() != b.Dtype() { 585 err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype()) 586 return 587 } 588 589 switch at := a.(type) { 590 case *Dense: 591 bt := b.(*Dense) 592 return at.MatVecMul(bt, opts...) 593 } 594 panic("Unreachable") 595 } 596 597 // Inner finds the inner products of two vector Tensors. Both arguments to the functions are eexpected to be vectors. 598 func Inner(a, b Tensor) (retVal interface{}, err error) { 599 if a.Dtype() != b.Dtype() { 600 err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype()) 601 return 602 } 603 604 switch at := a.(type) { 605 case *Dense: 606 bt := b.(*Dense) 607 return at.Inner(bt) 608 } 609 panic("Unreachable") 610 } 611 612 // Outer performs the outer product of two vector Tensors. Both arguments to the functions are expected to be vectors. 613 func Outer(a, b Tensor, opts ...FuncOpt) (retVal Tensor, err error) { 614 if a.Dtype() != b.Dtype() { 615 err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype()) 616 return 617 } 618 619 switch at := a.(type) { 620 case *Dense: 621 bt := b.(*Dense) 622 return at.Outer(bt, opts...) 623 } 624 panic("Unreachable") 625 } 626 627 // Contract performs a contraction of given tensors along given axes 628 func Contract(a, b Tensor, aAxes, bAxes []int) (retVal Tensor, err error) { 629 if a.Dtype() != b.Dtype() { 630 err = errors.Errorf(dtypeMismatch, a.Dtype(), b.Dtype()) 631 return 632 } 633 634 switch at := a.(type) { 635 case *Dense: 636 bt := b.(*Dense) 637 return at.TensorMul(bt, aAxes, bAxes) 638 639 default: 640 panic("Unreachable") 641 } 642 }