github.com/wzzhu/tensor@v0.9.24/dense_linalg.go (about) 1 package tensor 2 3 import ( 4 "github.com/pkg/errors" 5 ) 6 7 // Trace returns the trace of the matrix (i.e. the sum of the diagonal elements). It only works for matrices 8 func (t *Dense) Trace() (retVal interface{}, err error) { 9 e := t.e 10 11 if tracer, ok := e.(Tracer); ok { 12 return tracer.Trace(t) 13 } 14 return nil, errors.Errorf("Engine %T does not support Trace", e) 15 } 16 17 // Inner performs a dot product on two vectors. If t or other are not vectors, it will return an error. 18 func (t *Dense) Inner(other Tensor) (retVal interface{}, err error) { 19 // check that the data is a float 20 if err = typeclassCheck(t.t, floatcmplxTypes); err != nil { 21 return nil, errors.Wrapf(err, unsupportedDtype, t.t, "Inner") 22 } 23 24 // check both are vectors 25 if !t.Shape().IsVector() || !other.Shape().IsVector() { 26 return nil, errors.Errorf("Inner only works when there are two vectors. t's Shape: %v; other's Shape %v", t.Shape(), other.Shape()) 27 } 28 29 // we do this check instead of the more common t.Shape()[1] != other.Shape()[0], 30 // basically to ensure a similarity with numpy's dot and vectors. 31 if t.len() != other.DataSize() { 32 return nil, errors.Errorf(shapeMismatch, t.Shape(), other.Shape()) 33 } 34 35 e := t.e 36 switch ip := e.(type) { 37 case InnerProderF32: 38 return ip.Inner(t, other) 39 case InnerProderF64: 40 return ip.Inner(t, other) 41 case InnerProder: 42 return ip.Inner(t, other) 43 } 44 45 return nil, errors.Errorf("Engine does not support Inner()") 46 } 47 48 // MatVecMul performs a matrix-vector multiplication. 49 func (t *Dense) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) { 50 // check that it's a matrix x vector 51 if t.Dims() != 2 || !other.Shape().IsVector() { 52 err = errors.Errorf("MatVecMul requires t be a matrix and other to be a vector. Got t's shape: %v, other's shape: %v", t.Shape(), other.Shape()) 53 return 54 } 55 56 // checks that t is mxn matrix 57 m := t.Shape()[0] 58 n := t.Shape()[1] 59 60 // check shape 61 var odim int 62 oshape := other.Shape() 63 switch { 64 case oshape.IsColVec(): 65 odim = oshape[0] 66 case oshape.IsRowVec(): 67 odim = oshape[1] 68 case oshape.IsVector(): 69 odim = oshape[0] 70 default: 71 err = errors.Errorf(shapeMismatch, t.Shape(), other.Shape()) // should be unreachable 72 return 73 } 74 75 if odim != n { 76 err = errors.Errorf(shapeMismatch, n, other.Shape()) 77 return 78 } 79 80 expectedShape := Shape{m} 81 82 // check whether retVal has the same size as the resulting matrix would be: mx1 83 fo := ParseFuncOpts(opts...) 84 defer returnOpOpt(fo) 85 if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil { 86 err = errors.Wrapf(err, opFail, "MatVecMul") 87 return 88 } 89 90 if retVal == nil { 91 retVal = recycledDense(t.t, expectedShape, WithEngine(t.e)) 92 if t.o.IsColMajor() { 93 AsFortran(nil)(retVal) 94 } 95 } 96 97 e := t.e 98 99 if mvm, ok := e.(MatVecMuler); ok { 100 if err = mvm.MatVecMul(t, other, retVal); err != nil { 101 return nil, errors.Wrapf(err, opFail, "MatVecMul") 102 } 103 return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape) 104 } 105 return nil, errors.New("engine does not support MatVecMul") 106 } 107 108 // MatMul is the basic matrix multiplication that you learned in high school. It takes an optional reuse ndarray, where the ndarray is reused as the result. 109 // If that isn't passed in, a new ndarray will be created instead. 110 func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) { 111 // check that both are matrices 112 if !t.Shape().IsMatrix() || !other.Shape().IsMatrix() { 113 err = errors.Errorf("MatMul requires both operands to be matrices. Got t's shape: %v, other's shape: %v", t.Shape(), other.Shape()) 114 return 115 } 116 117 // checks that t is mxk matrix 118 var m, n, k int 119 m = t.Shape()[0] 120 k = t.Shape()[1] 121 n = other.Shape()[1] 122 123 // check shape 124 if k != other.Shape()[0] { 125 err = errors.Errorf(shapeMismatch, t.Shape(), other.Shape()) 126 return 127 } 128 129 // check whether retVal has the same size as the resulting matrix would be: mxn 130 expectedShape := Shape{m, n} 131 132 fo := ParseFuncOpts(opts...) 133 defer returnOpOpt(fo) 134 if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil { 135 err = errors.Wrapf(err, opFail, "MatMul") 136 return 137 } 138 139 if retVal == nil { 140 retVal = recycledDense(t.t, expectedShape, WithEngine(t.e)) 141 if t.o.IsColMajor() { 142 AsFortran(nil)(retVal) 143 } 144 } 145 146 e := t.e 147 if mm, ok := e.(MatMuler); ok { 148 if err = mm.MatMul(t, other, retVal); err != nil { 149 return 150 } 151 return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape) 152 } 153 154 return nil, errors.New("engine does not support MatMul") 155 } 156 157 // Outer finds the outer product of two vectors 158 func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error) { 159 // check both are vectors 160 if !t.Shape().IsVector() || !other.Shape().IsVector() { 161 err = errors.Errorf("Outer only works when there are two vectors. t's shape: %v. other's shape: %v", t.Shape(), other.Shape()) 162 return 163 } 164 165 m := t.Size() 166 n := other.Size() 167 168 // check whether retVal has the same size as the resulting matrix would be: mxn 169 expectedShape := Shape{m, n} 170 171 fo := ParseFuncOpts(opts...) 172 defer returnOpOpt(fo) 173 if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil { 174 err = errors.Wrapf(err, opFail, "Outer") 175 return 176 } 177 178 if retVal == nil { 179 retVal = recycledDense(t.t, expectedShape, WithEngine(t.e)) 180 if t.o.IsColMajor() { 181 AsFortran(nil)(retVal) 182 } 183 } 184 185 e := t.e 186 187 // DGER does not have any beta. So the values have to be zeroed first if the tensor is to be reused 188 retVal.Zero() 189 if op, ok := e.(OuterProder); ok { 190 if err = op.Outer(t, other, retVal); err != nil { 191 return nil, errors.Wrapf(err, opFail, "engine.uter") 192 } 193 return handleIncr(retVal, fo.Reuse(), fo.Incr(), expectedShape) 194 } 195 return nil, errors.New("engine does not support Outer") 196 } 197 198 // TensorMul is for multiplying Tensors with more than 2 dimensions. 199 // 200 // The algorithm is conceptually simple (but tricky to get right): 201 // 1. Transpose and reshape the Tensors in such a way that both t and other are 2D matrices 202 // 2. Use DGEMM to multiply them 203 // 3. Reshape the results to be the new expected result 204 // 205 // This function is a Go implementation of Numpy's tensordot method. It simplifies a lot of what Numpy does. 206 func (t *Dense) TensorMul(other Tensor, axesA, axesB []int) (retVal *Dense, err error) { 207 ts := t.Shape() 208 td := len(ts) 209 210 os := other.Shape() 211 od := len(os) 212 213 na := len(axesA) 214 nb := len(axesB) 215 sameLength := na == nb 216 if sameLength { 217 for i := 0; i < na; i++ { 218 if ts[axesA[i]] != os[axesB[i]] { 219 sameLength = false 220 break 221 } 222 if axesA[i] < 0 { 223 axesA[i] += td 224 } 225 226 if axesB[i] < 0 { 227 axesB[i] += od 228 } 229 } 230 } 231 232 if !sameLength { 233 err = errors.Errorf(shapeMismatch, ts, os) 234 return 235 } 236 237 // handle shapes 238 var notins []int 239 for i := 0; i < td; i++ { 240 notin := true 241 for _, a := range axesA { 242 if i == a { 243 notin = false 244 break 245 } 246 } 247 if notin { 248 notins = append(notins, i) 249 } 250 } 251 252 newAxesA := BorrowInts(len(notins) + len(axesA)) 253 defer ReturnInts(newAxesA) 254 newAxesA = newAxesA[:0] 255 newAxesA = append(notins, axesA...) 256 n2 := 1 257 for _, a := range axesA { 258 n2 *= ts[a] 259 } 260 261 newShapeT := Shape(BorrowInts(2)) 262 defer ReturnInts(newShapeT) 263 newShapeT[0] = ts.TotalSize() / n2 264 newShapeT[1] = n2 265 266 retShape1 := BorrowInts(len(ts)) 267 defer ReturnInts(retShape1) 268 retShape1 = retShape1[:0] 269 for _, ni := range notins { 270 retShape1 = append(retShape1, ts[ni]) 271 } 272 273 // work on other now 274 notins = notins[:0] 275 for i := 0; i < od; i++ { 276 notin := true 277 for _, a := range axesB { 278 if i == a { 279 notin = false 280 break 281 } 282 } 283 if notin { 284 notins = append(notins, i) 285 } 286 } 287 288 newAxesB := BorrowInts(len(notins) + len(axesB)) 289 defer ReturnInts(newAxesB) 290 newAxesB = newAxesB[:0] 291 newAxesB = append(axesB, notins...) 292 293 newShapeO := Shape(BorrowInts(2)) 294 defer ReturnInts(newShapeO) 295 newShapeO[0] = n2 296 newShapeO[1] = os.TotalSize() / n2 297 298 retShape2 := BorrowInts(len(ts)) 299 retShape2 = retShape2[:0] 300 for _, ni := range notins { 301 retShape2 = append(retShape2, os[ni]) 302 } 303 304 // we borrowClone because we don't want to touch the original Tensors 305 doT := t.Clone().(*Dense) 306 doOther := other.Clone().(*Dense) 307 defer ReturnTensor(doT) 308 defer ReturnTensor(doOther) 309 310 if err = doT.T(newAxesA...); err != nil { 311 return 312 } 313 doT.Transpose() // we have to materialize the transpose first or the underlying data won't be changed and the reshape that follows would be meaningless 314 315 if err = doT.Reshape(newShapeT...); err != nil { 316 return 317 } 318 319 if err = doOther.T(newAxesB...); err != nil { 320 return 321 } 322 doOther.Transpose() 323 if err = doOther.Reshape(newShapeO...); err != nil { 324 return 325 } 326 327 // the magic happens here 328 var rt Tensor 329 if rt, err = Dot(doT, doOther); err != nil { 330 return 331 } 332 retVal = rt.(*Dense) 333 334 retShape := BorrowInts(len(retShape1) + len(retShape2)) 335 defer ReturnInts(retShape) 336 337 retShape = retShape[:0] 338 retShape = append(retShape, retShape1...) 339 retShape = append(retShape, retShape2...) 340 341 if len(retShape) == 0 { // In case a scalar is returned, it should be returned as shape = {1} 342 retShape = append(retShape, 1) 343 } 344 345 if err = retVal.Reshape(retShape...); err != nil { 346 return 347 } 348 349 return 350 } 351 352 // SVD does the Single Value Decomposition for the *Dense. 353 // 354 // How it works is it temporarily converts the *Dense into a gonum/mat64 matrix, and uses Gonum's SVD function to perform the SVD. 355 // In the future, when gonum/lapack fully supports float32, we'll look into rewriting this 356 func (t *Dense) SVD(uv, full bool) (s, u, v *Dense, err error) { 357 e := t.Engine() 358 359 if svder, ok := e.(SVDer); ok { 360 var sT, uT, vT Tensor 361 if sT, uT, vT, err = svder.SVD(t, uv, full); err != nil { 362 return nil, nil, nil, errors.Wrap(err, "Error while performing *Dense.SVD") 363 } 364 if s, err = assertDense(sT); err != nil { 365 return nil, nil, nil, errors.Wrapf(err, "sT is not *Dense (uv %t full %t). Got %T instead", uv, full, sT) 366 } 367 // if not uv and not full, u can be nil 368 if u, err = assertDense(uT); err != nil && !(!uv && !full) { 369 return nil, nil, nil, errors.Wrapf(err, "uT is not *Dense (uv %t full %t). Got %T instead", uv, full, uT) 370 } 371 // if not uv and not full, v can be nil 372 if v, err = assertDense(vT); err != nil && !(!uv && !full) { 373 return nil, nil, nil, errors.Wrapf(err, "vT is not *Dense (uv %t full %t). Got %T instead", uv, full, vT) 374 } 375 return s, u, v, nil 376 } 377 return nil, nil, nil, errors.New("Engine does not support SVD") 378 } 379 380 /* UTILITY FUNCTIONS */ 381 382 // handleReuse extracts a *Dense from Tensor, and checks the shape of the reuse Tensor 383 func handleReuse(reuse Tensor, expectedShape Shape, safe bool) (retVal *Dense, err error) { 384 if reuse != nil { 385 if retVal, err = assertDense(reuse); err != nil { 386 err = errors.Wrapf(err, opFail, "handling reuse") 387 return 388 } 389 if !safe { 390 return 391 } 392 if err = reuseCheckShape(retVal, expectedShape); err != nil { 393 err = errors.Wrapf(err, "Unable to process reuse *Dense Tensor. Shape error.") 394 return 395 } 396 return 397 } 398 return 399 } 400 401 // handleIncr is the cleanup step for when there is an Tensor to increment. If the result tensor is the same as the reuse Tensor, the result tensor gets returned to the pool 402 func handleIncr(res *Dense, reuse, incr Tensor, expectedShape Shape) (retVal *Dense, err error) { 403 // handle increments 404 if incr != nil { 405 if !expectedShape.Eq(incr.Shape()) { 406 err = errors.Errorf(shapeMismatch, expectedShape, incr.Shape()) 407 return 408 } 409 var incrD *Dense 410 var ok bool 411 if incrD, ok = incr.(*Dense); !ok { 412 err = errors.Errorf(extractionFail, "*Dense", incr) 413 return 414 } 415 416 if err = typeclassCheck(incrD.t, numberTypes); err != nil { 417 err = errors.Wrapf(err, "handleIncr only handles Number types. Got %v instead", incrD.t) 418 return 419 } 420 421 if _, err = incrD.Add(res, UseUnsafe()); err != nil { 422 return 423 } 424 // vecAdd(incr.data, retVal.data) 425 426 // return retVal to pool - if and only if retVal is not reuse 427 // reuse indicates that someone else also has the reference to the *Dense 428 if res != reuse { 429 ReturnTensor(res) 430 } 431 432 // then 433 retVal = incrD 434 return 435 } 436 437 return res, nil 438 }