github.com/wzzhu/tensor@v0.9.24/defaultengine_linalg.go (about) 1 package tensor 2 3 import ( 4 "reflect" 5 6 "github.com/pkg/errors" 7 "gonum.org/v1/gonum/blas" 8 "gonum.org/v1/gonum/mat" 9 ) 10 11 // Trace returns the trace of a matrix (i.e. the sum of the diagonal elements). If the Tensor provided is not a matrix, it will return an error 12 func (e StdEng) Trace(t Tensor) (retVal interface{}, err error) { 13 if t.Dims() != 2 { 14 err = errors.Errorf(dimMismatch, 2, t.Dims()) 15 return 16 } 17 18 if err = typeclassCheck(t.Dtype(), numberTypes); err != nil { 19 return nil, errors.Wrap(err, "Trace") 20 } 21 22 rstride := t.Strides()[0] 23 cstride := t.Strides()[1] 24 25 r := t.Shape()[0] 26 c := t.Shape()[1] 27 28 m := MinInt(r, c) 29 stride := rstride + cstride 30 31 switch data := t.Data().(type) { 32 case []int: 33 var trace int 34 for i := 0; i < m; i++ { 35 trace += data[i*stride] 36 } 37 retVal = trace 38 case []int8: 39 var trace int8 40 for i := 0; i < m; i++ { 41 trace += data[i*stride] 42 } 43 retVal = trace 44 case []int16: 45 var trace int16 46 for i := 0; i < m; i++ { 47 trace += data[i*stride] 48 } 49 retVal = trace 50 case []int32: 51 var trace int32 52 for i := 0; i < m; i++ { 53 trace += data[i*stride] 54 } 55 retVal = trace 56 case []int64: 57 var trace int64 58 for i := 0; i < m; i++ { 59 trace += data[i*stride] 60 } 61 retVal = trace 62 case []uint: 63 var trace uint 64 for i := 0; i < m; i++ { 65 trace += data[i*stride] 66 } 67 retVal = trace 68 case []uint8: 69 var trace uint8 70 for i := 0; i < m; i++ { 71 trace += data[i*stride] 72 } 73 retVal = trace 74 case []uint16: 75 var trace uint16 76 for i := 0; i < m; i++ { 77 trace += data[i*stride] 78 } 79 retVal = trace 80 case []uint32: 81 var trace uint32 82 for i := 0; i < m; i++ { 83 trace += data[i*stride] 84 } 85 retVal = trace 86 case []uint64: 87 var trace uint64 88 for i := 0; i < m; i++ { 89 trace += data[i*stride] 90 } 91 retVal = trace 92 case []float32: 93 var trace float32 94 for i := 0; i < m; i++ { 95 trace += data[i*stride] 96 } 97 retVal = trace 98 case []float64: 99 var trace float64 100 for i := 0; i < m; i++ { 101 trace += data[i*stride] 102 } 103 retVal = trace 104 case []complex64: 105 var trace complex64 106 for i := 0; i < m; i++ { 107 trace += data[i*stride] 108 } 109 retVal = trace 110 case []complex128: 111 var trace complex128 112 for i := 0; i < m; i++ { 113 trace += data[i*stride] 114 } 115 retVal = trace 116 } 117 return 118 } 119 120 func (e StdEng) Dot(x, y Tensor, opts ...FuncOpt) (retVal Tensor, err error) { 121 if _, ok := x.(DenseTensor); !ok { 122 err = errors.Errorf("Engine only supports working on x that is a DenseTensor. Got %T instead", x) 123 return 124 } 125 126 if _, ok := y.(DenseTensor); !ok { 127 err = errors.Errorf("Engine only supports working on y that is a DenseTensor. Got %T instead", y) 128 return 129 } 130 131 var a, b DenseTensor 132 if a, err = getFloatDenseTensor(x); err != nil { 133 err = errors.Wrapf(err, opFail, "Dot") 134 return 135 } 136 if b, err = getFloatDenseTensor(y); err != nil { 137 err = errors.Wrapf(err, opFail, "Dot") 138 return 139 } 140 141 fo := ParseFuncOpts(opts...) 142 143 var reuse, incr DenseTensor 144 if reuse, err = getFloatDenseTensor(fo.reuse); err != nil { 145 err = errors.Wrapf(err, opFail, "Dot - reuse") 146 return 147 148 } 149 150 if incr, err = getFloatDenseTensor(fo.incr); err != nil { 151 err = errors.Wrapf(err, opFail, "Dot - incr") 152 return 153 } 154 155 switch { 156 case a.IsScalar() && b.IsScalar(): 157 var res interface{} 158 switch a.Dtype().Kind() { 159 case reflect.Float64: 160 res = a.GetF64(0) * b.GetF64(0) 161 case reflect.Float32: 162 res = a.GetF32(0) * b.GetF32(0) 163 } 164 165 switch { 166 case incr != nil: 167 if !incr.IsScalar() { 168 err = errors.Errorf(shapeMismatch, ScalarShape(), incr.Shape()) 169 return 170 } 171 if err = e.E.MulIncr(a.Dtype().Type, a.hdr(), b.hdr(), incr.hdr()); err != nil { 172 err = errors.Wrapf(err, opFail, "Dot scalar incr") 173 return 174 175 } 176 retVal = incr 177 case reuse != nil: 178 reuse.Set(0, res) 179 reuse.reshape() 180 retVal = reuse 181 default: 182 retVal = New(FromScalar(res)) 183 } 184 return 185 case a.IsScalar(): 186 switch { 187 case incr != nil: 188 return Mul(a.ScalarValue(), b, WithIncr(incr)) 189 case reuse != nil: 190 return Mul(a.ScalarValue(), b, WithReuse(reuse)) 191 } 192 // default moved out 193 return Mul(a.ScalarValue(), b) 194 case b.IsScalar(): 195 switch { 196 case incr != nil: 197 return Mul(a, b.ScalarValue(), WithIncr(incr)) 198 case reuse != nil: 199 return Mul(a, b.ScalarValue(), WithReuse(reuse)) 200 } 201 return Mul(a, b.ScalarValue()) 202 } 203 204 switch { 205 case a.IsVector(): 206 switch { 207 case b.IsVector(): 208 // check size 209 if a.len() != b.len() { 210 err = errors.Errorf(shapeMismatch, a.Shape(), b.Shape()) 211 return 212 } 213 var ret interface{} 214 if ret, err = e.Inner(a, b); err != nil { 215 return nil, errors.Wrapf(err, opFail, "Dot") 216 } 217 return New(FromScalar(ret)), nil 218 case b.IsMatrix(): 219 b.T() 220 defer b.UT() 221 switch { 222 case reuse != nil && incr != nil: 223 return b.MatVecMul(a, WithReuse(reuse), WithIncr(incr)) 224 case reuse != nil: 225 return b.MatVecMul(a, WithReuse(reuse)) 226 case incr != nil: 227 return b.MatVecMul(a, WithIncr(incr)) 228 default: 229 } 230 return b.MatVecMul(a) 231 default: 232 233 } 234 case a.IsMatrix(): 235 switch { 236 case b.IsVector(): 237 switch { 238 case reuse != nil && incr != nil: 239 return a.MatVecMul(b, WithReuse(reuse), WithIncr(incr)) 240 case reuse != nil: 241 return a.MatVecMul(b, WithReuse(reuse)) 242 case incr != nil: 243 return a.MatVecMul(b, WithIncr(incr)) 244 default: 245 } 246 return a.MatVecMul(b) 247 248 case b.IsMatrix(): 249 switch { 250 case reuse != nil && incr != nil: 251 return a.MatMul(b, WithReuse(reuse), WithIncr(incr)) 252 case reuse != nil: 253 return a.MatMul(b, WithReuse(reuse)) 254 case incr != nil: 255 return a.MatMul(b, WithIncr(incr)) 256 default: 257 } 258 return a.MatMul(b) 259 default: 260 } 261 default: 262 } 263 264 as := a.Shape() 265 bs := b.Shape() 266 axesA := BorrowInts(1) 267 axesB := BorrowInts(1) 268 defer ReturnInts(axesA) 269 defer ReturnInts(axesB) 270 271 var lastA, secondLastB int 272 273 lastA = len(as) - 1 274 axesA[0] = lastA 275 if len(bs) >= 2 { 276 secondLastB = len(bs) - 2 277 } else { 278 secondLastB = 0 279 } 280 axesB[0] = secondLastB 281 282 if as[lastA] != bs[secondLastB] { 283 err = errors.Errorf(shapeMismatch, as, bs) 284 return 285 } 286 287 var rd *Dense 288 if rd, err = a.TensorMul(b, axesA, axesB); err != nil { 289 panic(err) 290 } 291 292 if reuse != nil { 293 copyDense(reuse, rd) 294 ap := rd.Info().Clone() 295 reuse.setAP(&ap) 296 defer ReturnTensor(rd) 297 // swap out the underlying data and metadata 298 // reuse.data, rd.data = rd.data, reuse.data 299 // reuse.AP, rd.AP = rd.AP, reuse.AP 300 // defer ReturnTensor(rd) 301 302 retVal = reuse 303 } else { 304 retVal = rd 305 } 306 307 return 308 } 309 310 // TODO: make it take DenseTensor 311 func (e StdEng) SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error) { 312 var t *Dense 313 var ok bool 314 if err = e.checkAccessible(a); err != nil { 315 return nil, nil, nil, errors.Wrapf(err, "opFail %v", "SVD") 316 } 317 if t, ok = a.(*Dense); !ok { 318 return nil, nil, nil, errors.Errorf("StdEng only performs SVDs for DenseTensors. Got %T instead", a) 319 } 320 if err = typeclassCheck(a.Dtype(), floatTypes); err != nil { 321 return nil, nil, nil, errors.Errorf("StdEng can only perform SVDs for float64 and float32 type. Got tensor of %v instead", t.Dtype()) 322 } 323 324 if !t.IsMatrix() { 325 return nil, nil, nil, errors.Errorf(dimMismatch, 2, t.Dims()) 326 } 327 328 var m *mat.Dense 329 var svd mat.SVD 330 331 if m, err = ToMat64(t, UseUnsafe()); err != nil { 332 return 333 } 334 335 switch { 336 case full && uv: 337 ok = svd.Factorize(m, mat.SVDFull) 338 case !full && uv: 339 ok = svd.Factorize(m, mat.SVDThin) 340 case full && !uv: 341 // illogical state - if you specify "full", you WANT the UV matrices 342 // error 343 err = errors.Errorf("SVD requires computation of `u` and `v` matrices if `full` was specified.") 344 return 345 default: 346 // by default, we return only the singular values 347 ok = svd.Factorize(m, mat.SVDNone) 348 } 349 350 if !ok { 351 // error 352 err = errors.Errorf("Unable to compute SVD") 353 return 354 } 355 356 // extract values 357 var um, vm mat.Dense 358 s = recycledDense(Float64, Shape{MinInt(t.Shape()[0], t.Shape()[1])}, WithEngine(e)) 359 svd.Values(s.Data().([]float64)) 360 if uv { 361 svd.UTo(&um) 362 svd.VTo(&vm) 363 // vm.VFromSVD(&svd) 364 365 u = FromMat64(&um, UseUnsafe(), As(t.t)) 366 v = FromMat64(&vm, UseUnsafe(), As(t.t)) 367 } 368 369 return 370 } 371 372 // Inner is a thin layer over BLAS's D/Sdot. 373 // It returns a scalar value, wrapped in an interface{}, which is not quite nice. 374 func (e StdEng) Inner(a, b Tensor) (retVal interface{}, err error) { 375 var ad, bd DenseTensor 376 if ad, bd, err = e.checkTwoFloatComplexTensors(a, b); err != nil { 377 return nil, errors.Wrapf(err, opFail, "StdEng.Inner") 378 } 379 380 switch A := ad.Data().(type) { 381 case []float32: 382 B := bd.Float32s() 383 retVal = whichblas.Sdot(len(A), A, 1, B, 1) 384 case []float64: 385 B := bd.Float64s() 386 retVal = whichblas.Ddot(len(A), A, 1, B, 1) 387 case []complex64: 388 B := bd.Complex64s() 389 retVal = whichblas.Cdotu(len(A), A, 1, B, 1) 390 case []complex128: 391 B := bd.Complex128s() 392 retVal = whichblas.Zdotu(len(A), A, 1, B, 1) 393 } 394 return 395 } 396 397 // MatVecMul is a thin layer over BLAS' DGEMV 398 // Because DGEMV computes: 399 // y = αA * x + βy 400 // we set beta to 0, so we don't have to manually zero out the reused/retval tensor data 401 func (e StdEng) MatVecMul(a, b, prealloc Tensor) (err error) { 402 // check all are DenseTensors 403 var ad, bd, pd DenseTensor 404 if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { 405 return errors.Wrapf(err, opFail, "StdEng.MatVecMul") 406 } 407 408 m := ad.oshape()[0] 409 n := ad.oshape()[1] 410 411 tA := blas.NoTrans 412 do := a.DataOrder() 413 z := ad.oldAP().IsZero() 414 415 var lda int 416 switch { 417 case do.IsRowMajor() && z: 418 lda = n 419 case do.IsRowMajor() && !z: 420 tA = blas.Trans 421 lda = n 422 case do.IsColMajor() && z: 423 tA = blas.Trans 424 lda = m 425 m, n = n, m 426 case do.IsColMajor() && !z: 427 lda = m 428 m, n = n, m 429 } 430 431 incX, incY := 1, 1 // step size 432 433 // ASPIRATIONAL TODO: different incX and incY 434 // TECHNICAL DEBT. TECHDEBT. TECH DEBT 435 // Example use case: 436 // log.Printf("a %v %v", ad.Strides(), ad.ostrides()) 437 // log.Printf("b %v", b.Strides()) 438 // incX := a.Strides()[0] 439 // incY = b.Strides()[0] 440 441 switch A := ad.Data().(type) { 442 case []float64: 443 x := bd.Float64s() 444 y := pd.Float64s() 445 alpha, beta := float64(1), float64(0) 446 whichblas.Dgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY) 447 case []float32: 448 x := bd.Float32s() 449 y := pd.Float32s() 450 alpha, beta := float32(1), float32(0) 451 whichblas.Sgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY) 452 case []complex64: 453 x := bd.Complex64s() 454 y := pd.Complex64s() 455 var alpha, beta complex64 = complex(1, 0), complex(0, 0) 456 whichblas.Cgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY) 457 case []complex128: 458 x := bd.Complex128s() 459 y := pd.Complex128s() 460 var alpha, beta complex128 = complex(1, 0), complex(0, 0) 461 whichblas.Zgemv(tA, m, n, alpha, A, lda, x, incX, beta, y, incY) 462 default: 463 return errors.Errorf(typeNYI, "matVecMul", bd.Data()) 464 } 465 466 return nil 467 } 468 469 // MatMul is a thin layer over DGEMM. 470 // DGEMM computes: 471 // C = αA * B + βC 472 // To prevent needless zeroing out of the slice, we just set β to 0 473 func (e StdEng) MatMul(a, b, prealloc Tensor) (err error) { 474 // check all are DenseTensors 475 var ad, bd, pd DenseTensor 476 if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { 477 return errors.Wrapf(err, opFail, "StdEng.MatMul") 478 } 479 480 ado := a.DataOrder() 481 bdo := b.DataOrder() 482 cdo := prealloc.DataOrder() 483 484 // get result shapes. k is the shared dimension 485 // a is (m, k) 486 // b is (k, n) 487 // c is (m, n) 488 var m, n, k int 489 m = ad.Shape()[0] 490 k = ad.Shape()[1] 491 n = bd.Shape()[1] 492 493 // wrt the strides, we use the original strides, because that's what BLAS needs, instead of calling .Strides() 494 // lda in colmajor = number of rows; 495 // lda in row major = number of cols 496 var lda, ldb, ldc int 497 switch { 498 case ado.IsColMajor(): 499 lda = m 500 case ado.IsRowMajor(): 501 lda = k 502 } 503 504 switch { 505 case bdo.IsColMajor(): 506 ldb = bd.Shape()[0] 507 case bdo.IsRowMajor(): 508 ldb = n 509 } 510 511 switch { 512 case cdo.IsColMajor(): 513 ldc = prealloc.Shape()[0] 514 case cdo.IsRowMajor(): 515 ldc = prealloc.Shape()[1] 516 } 517 518 // check for trans 519 tA, tB := blas.NoTrans, blas.NoTrans 520 if !ad.oldAP().IsZero() { 521 tA = blas.Trans 522 if ado.IsRowMajor() { 523 lda = m 524 } else { 525 lda = k 526 } 527 } 528 if !bd.oldAP().IsZero() { 529 tB = blas.Trans 530 if bdo.IsRowMajor() { 531 ldb = bd.Shape()[0] 532 } else { 533 ldb = bd.Shape()[1] 534 } 535 } 536 537 switch A := ad.Data().(type) { 538 case []float64: 539 B := bd.Float64s() 540 C := pd.Float64s() 541 alpha, beta := float64(1), float64(0) 542 if ado.IsColMajor() && bdo.IsColMajor() { 543 whichblas.Dgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) 544 } else { 545 whichblas.Dgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) 546 } 547 case []float32: 548 B := bd.Float32s() 549 C := pd.Float32s() 550 alpha, beta := float32(1), float32(0) 551 if ado.IsColMajor() && bdo.IsColMajor() { 552 whichblas.Sgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) 553 } else { 554 whichblas.Sgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) 555 } 556 case []complex64: 557 B := bd.Complex64s() 558 C := pd.Complex64s() 559 var alpha, beta complex64 = complex(1, 0), complex(0, 0) 560 if ado.IsColMajor() && bdo.IsColMajor() { 561 whichblas.Cgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) 562 } else { 563 whichblas.Cgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) 564 } 565 case []complex128: 566 B := bd.Complex128s() 567 C := pd.Complex128s() 568 var alpha, beta complex128 = complex(1, 0), complex(0, 0) 569 if ado.IsColMajor() && bdo.IsColMajor() { 570 whichblas.Zgemm(tA, tB, n, m, k, alpha, B, ldb, A, lda, beta, C, ldc) 571 } else { 572 whichblas.Zgemm(tA, tB, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc) 573 } 574 default: 575 return errors.Errorf(typeNYI, "matMul", ad.Data()) 576 } 577 return 578 } 579 580 // Outer is a thin wrapper over S/Dger 581 func (e StdEng) Outer(a, b, prealloc Tensor) (err error) { 582 // check all are DenseTensors 583 var ad, bd, pd DenseTensor 584 if ad, bd, pd, err = e.checkThreeFloatComplexTensors(a, b, prealloc); err != nil { 585 return errors.Wrapf(err, opFail, "StdEng.Outer") 586 } 587 588 m := ad.Size() 589 n := bd.Size() 590 pdo := pd.DataOrder() 591 592 // the stride of a Vector is always going to be [1], 593 // incX := t.Strides()[0] 594 // incY := other.Strides()[0] 595 incX, incY := 1, 1 596 // lda := pd.Strides()[0] 597 var lda int 598 switch { 599 case pdo.IsColMajor(): 600 aShape := a.Shape().Clone() 601 bShape := b.Shape().Clone() 602 if err = a.Reshape(aShape[0], 1); err != nil { 603 return err 604 } 605 if err = b.Reshape(1, bShape[0]); err != nil { 606 return err 607 } 608 609 if err = e.MatMul(a, b, prealloc); err != nil { 610 return err 611 } 612 613 if err = b.Reshape(bShape...); err != nil { 614 return 615 } 616 if err = a.Reshape(aShape...); err != nil { 617 return 618 } 619 return nil 620 621 case pdo.IsRowMajor(): 622 lda = pd.Shape()[1] 623 } 624 625 switch x := ad.Data().(type) { 626 case []float64: 627 y := bd.Float64s() 628 A := pd.Float64s() 629 alpha := float64(1) 630 whichblas.Dger(m, n, alpha, x, incX, y, incY, A, lda) 631 case []float32: 632 y := bd.Float32s() 633 A := pd.Float32s() 634 alpha := float32(1) 635 whichblas.Sger(m, n, alpha, x, incX, y, incY, A, lda) 636 case []complex64: 637 y := bd.Complex64s() 638 A := pd.Complex64s() 639 var alpha complex64 = complex(1, 0) 640 whichblas.Cgeru(m, n, alpha, x, incX, y, incY, A, lda) 641 case []complex128: 642 y := bd.Complex128s() 643 A := pd.Complex128s() 644 var alpha complex128 = complex(1, 0) 645 whichblas.Zgeru(m, n, alpha, x, incX, y, incY, A, lda) 646 default: 647 return errors.Errorf(typeNYI, "outer", b.Data()) 648 } 649 return nil 650 } 651 652 /* UNEXPORTED UTILITY FUNCTIONS */ 653 654 func (e StdEng) checkTwoFloatTensors(a, b Tensor) (ad, bd DenseTensor, err error) { 655 if err = e.checkAccessible(a); err != nil { 656 return nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") 657 } 658 if err = e.checkAccessible(b); err != nil { 659 return nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") 660 } 661 662 if a.Dtype() != b.Dtype() { 663 return nil, nil, errors.New("Expected a and b to have the same Dtype") 664 } 665 666 if ad, err = getFloatDenseTensor(a); err != nil { 667 return nil, nil, errors.Wrap(err, "checkTwoTensors expects a to be be a DenseTensor") 668 } 669 if bd, err = getFloatDenseTensor(b); err != nil { 670 return nil, nil, errors.Wrap(err, "checkTwoTensors expects b to be be a DenseTensor") 671 } 672 return 673 } 674 675 func (e StdEng) checkThreeFloatTensors(a, b, ret Tensor) (ad, bd, retVal DenseTensor, err error) { 676 if err = e.checkAccessible(a); err != nil { 677 return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") 678 } 679 if err = e.checkAccessible(b); err != nil { 680 return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") 681 } 682 if err = e.checkAccessible(ret); err != nil { 683 return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: ret is not accessible") 684 } 685 686 if a.Dtype() != b.Dtype() || b.Dtype() != ret.Dtype() { 687 return nil, nil, nil, errors.New("Expected a and b and retVal all to have the same Dtype") 688 } 689 690 if ad, err = getFloatDenseTensor(a); err != nil { 691 return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects a to be be a DenseTensor") 692 } 693 if bd, err = getFloatDenseTensor(b); err != nil { 694 return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects b to be be a DenseTensor") 695 } 696 if retVal, err = getFloatDenseTensor(ret); err != nil { 697 return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects retVal to be be a DenseTensor") 698 } 699 return 700 } 701 702 func (e StdEng) checkTwoFloatComplexTensors(a, b Tensor) (ad, bd DenseTensor, err error) { 703 if err = e.checkAccessible(a); err != nil { 704 return nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") 705 } 706 if err = e.checkAccessible(b); err != nil { 707 return nil, nil, errors.Wrap(err, "checkTwoTensors: a is not accessible") 708 } 709 710 if a.Dtype() != b.Dtype() { 711 return nil, nil, errors.New("Expected a and b to have the same Dtype") 712 } 713 714 if ad, err = getFloatComplexDenseTensor(a); err != nil { 715 return nil, nil, errors.Wrap(err, "checkTwoTensors expects a to be be a DenseTensor") 716 } 717 if bd, err = getFloatComplexDenseTensor(b); err != nil { 718 return nil, nil, errors.Wrap(err, "checkTwoTensors expects b to be be a DenseTensor") 719 } 720 return 721 } 722 723 func (e StdEng) checkThreeFloatComplexTensors(a, b, ret Tensor) (ad, bd, retVal DenseTensor, err error) { 724 if err = e.checkAccessible(a); err != nil { 725 return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") 726 } 727 if err = e.checkAccessible(b); err != nil { 728 return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: a is not accessible") 729 } 730 if err = e.checkAccessible(ret); err != nil { 731 return nil, nil, nil, errors.Wrap(err, "checkThreeTensors: ret is not accessible") 732 } 733 734 if a.Dtype() != b.Dtype() || b.Dtype() != ret.Dtype() { 735 return nil, nil, nil, errors.New("Expected a and b and retVal all to have the same Dtype") 736 } 737 738 if ad, err = getFloatComplexDenseTensor(a); err != nil { 739 return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects a to be be a DenseTensor") 740 } 741 if bd, err = getFloatComplexDenseTensor(b); err != nil { 742 return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects b to be be a DenseTensor") 743 } 744 if retVal, err = getFloatComplexDenseTensor(ret); err != nil { 745 return nil, nil, nil, errors.Wrap(err, "checkTwoTensors expects retVal to be be a DenseTensor") 746 } 747 return 748 }