github.com/wzzhu/tensor@v0.9.24/dense.go (about) 1 package tensor 2 3 import ( 4 "fmt" 5 "reflect" 6 "unsafe" 7 8 "github.com/pkg/errors" 9 "github.com/wzzhu/tensor/internal/storage" 10 ) 11 12 const ( 13 maskCompEvery int = 8 14 ) 15 16 // Dense represents a dense tensor - this is the most common form of tensors. It can be used to represent vectors, matrices.. etc 17 type Dense struct { 18 AP 19 array 20 21 flag MemoryFlag 22 e Engine // execution engine for the *Dense 23 oe standardEngine // optimized engine 24 25 // backup AP. When a transpose is done, the old *AP is backed up here, for easy untransposes 26 old AP 27 transposeWith []int 28 29 // if viewOf != nil, then this *Dense is a view. 30 viewOf uintptr 31 32 mask []bool // mask slice can be used to identify missing or invalid values. len(mask)<=len(v) 33 maskIsSoft bool 34 } 35 36 // NewDense creates a new *Dense. It tries its best to get from the tensor pool. 37 func NewDense(dt Dtype, shape Shape, opts ...ConsOpt) *Dense { 38 return recycledDense(dt, shape, opts...) 39 } 40 41 func recycledDense(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { 42 retVal = recycledDenseNoFix(dt, shape, opts...) 43 retVal.fix() 44 if err := retVal.sanity(); err != nil { 45 panic(err) 46 } 47 return 48 } 49 50 func recycledDenseNoFix(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) { 51 // size := shape.TotalSize() 52 //if shape.IsScalar() { 53 // size = 1 54 //} 55 retVal = borrowDense() 56 retVal.array.t = dt 57 retVal.AP.zeroWithDims(shape.Dims()) 58 59 for _, opt := range opts { 60 opt(retVal) 61 } 62 retVal.setShape(shape...) 63 return 64 } 65 66 func (t *Dense) fromSlice(x interface{}) { 67 t.array.Header.Raw = nil // GC anything else 68 t.array.fromSlice(x) 69 } 70 71 func (t *Dense) addMask(mask []bool) { 72 l := len(mask) 73 if l > 0 && l != t.len() { 74 panic("Mask is not same length as data") 75 } 76 t.mask = mask 77 } 78 79 func (t *Dense) makeArray(size int) { 80 switch te := t.e.(type) { 81 case NonStdEngine: 82 t.flag = MakeMemoryFlag(t.flag, ManuallyManaged) 83 case arrayMaker: 84 te.makeArray(&t.array, t.t, size) 85 return 86 default: 87 } 88 89 memsize := calcMemSize(t.t, size) 90 mem, err := t.e.Alloc(memsize) 91 if err != nil { 92 panic(err) 93 } 94 95 t.array.Raw = storage.FromMemory(mem.Uintptr(), uintptr(memsize)) 96 return 97 } 98 99 // Info returns the access pattern which explains how the data in the underlying array is accessed. This is mostly used for debugging. 100 func (t *Dense) Info() *AP { return &t.AP } 101 102 // Dtype returns the data type of the *Dense tensor. 103 func (t *Dense) Dtype() Dtype { return t.t } 104 105 // Data returns the underlying array. If the *Dense represents a scalar value, the scalar value is returned instead 106 func (t *Dense) Data() interface{} { 107 if t.IsScalar() { 108 return t.Get(0) 109 } 110 111 // build a type of []T 112 shdr := reflect.SliceHeader{ 113 Data: t.array.Uintptr(), 114 Len: t.array.Len(), 115 Cap: t.array.Cap(), 116 } 117 sliceT := reflect.SliceOf(t.t.Type) 118 ptr := unsafe.Pointer(&shdr) 119 val := reflect.Indirect(reflect.NewAt(sliceT, ptr)) 120 return val.Interface() 121 } 122 123 // DataSize returns the size of the underlying array. Typically t.DataSize() == t.Shape().TotalSize() 124 func (t *Dense) DataSize() int { 125 if t.IsScalar() { 126 return 0 // DOUBLE CHECK 127 } 128 return t.array.Len() 129 } 130 131 // Engine returns the execution engine associated with this Tensor 132 func (t *Dense) Engine() Engine { return t.e } 133 134 // Reshape reshapes a *Dense. If the tensors need to be materialized (either it's a view or transpose), it will be materialized before the reshape happens 135 func (t *Dense) Reshape(dims ...int) error { 136 if t.Shape().TotalSize() != Shape(dims).TotalSize() { 137 return errors.Errorf("Cannot reshape %v into %v", t.Shape(), dims) 138 } 139 140 if t.viewOf != 0 && t.o.IsNotContiguous() { 141 return errors.Errorf(methodNYI, "Reshape", "non-contiguous views") 142 } 143 144 if !t.old.IsZero() { 145 t.Transpose() 146 } 147 148 return t.reshape(dims...) 149 } 150 151 func (t *Dense) reshape(dims ...int) error { 152 t.setShape(dims...) 153 return t.sanity() 154 } 155 156 func (t *Dense) unsqueeze(axis int) error { 157 if axis > t.shape.Dims()+1 { 158 return errors.Errorf("Cannot unsqueeze on axis %d when the tensor has shape %v", axis, t.shape) 159 } 160 t.shape = append(t.shape, 1) 161 copy(t.shape[axis+1:], t.shape[axis:]) 162 t.shape[axis] = 1 163 164 t.strides = append(t.strides, 1) 165 copy(t.strides[axis+1:], t.strides[axis:]) 166 167 return nil 168 } 169 170 // ScalarValue returns the scalar value of a *Tensor, 171 // IF and ONLY IF it's a Tensor representation of a scalar value. 172 // This is required because operations like a (vec ยท vec) would return a scalar value. 173 // I didn't want to return interface{} for all the API methods, so the next best solution is to 174 // wrap the scalar value in a *Tensor 175 func (t *Dense) ScalarValue() interface{} { 176 if !t.IsScalar() { 177 panic(fmt.Sprintf("ScalarValue only works when the Tensor is a representation of a scalar value. The value of the tensor is %v", t)) 178 } 179 180 return t.Get(0) 181 } 182 183 // IsView indicates if the Tensor is a view of another (typically from slicing) 184 func (t *Dense) IsView() bool { 185 return t.viewOf != 0 186 } 187 188 // IsMaterializeable indicates if the Tensor is materializable - if it has either gone through some transforms or slicing 189 func (t *Dense) IsMaterializable() bool { 190 return t.viewOf != 0 || !t.old.IsZero() 191 } 192 193 // IsManuallyManaged returns true if the memory associated with this *Dense is manually managed (by the user) 194 func (t *Dense) IsManuallyManaged() bool { return t.flag.manuallyManaged() } 195 196 // IsNativelyAccessible checks if the pointers are accessible by Go 197 func (t *Dense) IsNativelyAccessible() bool { return t.flag.nativelyAccessible() } 198 199 // Clone clones a *Dense. It creates a copy of the data, and the underlying array will be allocated 200 func (t *Dense) Clone() interface{} { 201 if t.e != nil { 202 retVal := new(Dense) 203 t.AP.CloneTo(&retVal.AP) 204 retVal.t = t.t 205 retVal.e = t.e 206 retVal.oe = t.oe 207 retVal.flag = t.flag 208 retVal.makeArray(t.Len()) 209 210 if !t.old.IsZero() { 211 retVal.old = t.old.Clone() 212 t.old.CloneTo(&retVal.old) 213 } 214 copyDense(retVal, t) 215 retVal.lock() 216 217 return retVal 218 } 219 panic("Unreachable: No engine") 220 } 221 222 // IsMasked indicates whether tensor is masked 223 func (t *Dense) IsMasked() bool { return len(t.mask) == t.len() } 224 225 // MaskFromDense adds a mask slice to tensor by XORing dense arguments' masks 226 func (t *Dense) MaskFromDense(tts ...*Dense) { 227 hasMask := BorrowBools(len(tts)) 228 defer ReturnBools(hasMask) 229 230 numMasked := 0 231 var masked = false 232 233 for i, tt := range tts { 234 if tt != nil { 235 hasMask[i] = tt.IsMasked() 236 masked = masked || hasMask[i] 237 if hasMask[i] { 238 numMasked++ 239 } 240 } 241 } 242 if numMasked < 1 { 243 return 244 } 245 246 //Only make mask if none already. This way one of the tts can be t itself 247 248 if len(t.mask) < t.DataSize() { 249 t.makeMask() 250 } 251 252 for i, tt := range tts { 253 if tt != nil { 254 n := len(tt.mask) 255 if hasMask[i] { 256 for j := range t.mask { 257 t.mask[j] = t.mask[j] || tt.mask[j%n] 258 } 259 } 260 } 261 } 262 } 263 264 // Private methods 265 266 func (t *Dense) cap() int { return t.array.Cap() } 267 func (t *Dense) len() int { return t.array.Len() } // exactly the same as DataSize 268 func (t *Dense) arr() array { return t.array } 269 func (t *Dense) arrPtr() *array { return &t.array } 270 271 func (t *Dense) setShape(s ...int) { 272 t.unlock() 273 t.SetShape(s...) 274 t.lock() 275 return 276 } 277 278 func (t *Dense) setAP(ap *AP) { t.AP = *ap } 279 280 func (t *Dense) fix() { 281 if t.e == nil { 282 t.e = StdEng{} 283 } 284 285 if oe, ok := t.e.(standardEngine); ok { 286 t.oe = oe 287 } 288 289 switch { 290 case t.IsScalar() && t.array.Header.Raw == nil: 291 t.makeArray(1) 292 case t.Shape() == nil && t.array.Header.Raw != nil: 293 size := t.Len() 294 if size == 1 { 295 t.SetShape() // scalar 296 } else { 297 t.SetShape(size) // vector 298 } 299 case t.array.Header.Raw == nil && t.t != Dtype{}: 300 size := t.Shape().TotalSize() 301 t.makeArray(size) 302 303 } 304 if len(t.mask) != t.len() { 305 t.mask = t.mask[:0] 306 } 307 t.lock() // don't put this in a defer - if t.array.Ptr == nil and t.Shape() == nil. then leave it unlocked 308 } 309 310 // makeMask adds a mask slice to tensor if required 311 func (t *Dense) makeMask() { 312 var size int 313 size = t.shape.TotalSize() 314 if len(t.mask) >= size { 315 t.mask = t.mask[:size] 316 } 317 if cap(t.mask) < size { 318 t.mask = make([]bool, size) 319 } 320 t.mask = t.mask[:size] 321 memsetBools(t.mask, false) 322 } 323 324 // sanity is a function that sanity checks that a tensor is correct. 325 func (t *Dense) sanity() error { 326 if !t.AP.IsZero() && t.Shape() == nil && t.array.Header.Raw == nil { 327 return errors.New(emptyTensor) 328 } 329 330 size := t.Len() 331 expected := t.Size() 332 if t.viewOf == 0 && size != expected && !t.IsScalar() { 333 return errors.Wrap(errors.Errorf(shapeMismatch, t.Shape(), size), "sanity check failed") 334 } 335 336 // TODO: sanity check for views 337 return nil 338 } 339 340 // isTransposed returns true if the *Dense holds a transposed array. 341 func (t *Dense) isTransposed() bool { return t.old.IsZero() } 342 343 // oshape returns the original shape 344 func (t *Dense) oshape() Shape { 345 if !t.old.IsZero() { 346 return t.old.Shape() 347 } 348 return t.Shape() 349 } 350 351 // ostrides returns the original strides 352 func (t *Dense) ostrides() []int { 353 if !t.old.IsZero() { 354 return t.old.Strides() 355 } 356 return t.Strides() 357 } 358 359 // ShallowClone clones the *Dense without making a copy of the underlying array 360 func (t *Dense) ShallowClone() *Dense { 361 retVal := borrowDense() 362 retVal.e = t.e 363 retVal.oe = t.oe 364 t.AP.CloneTo(&retVal.AP) 365 retVal.flag = t.flag 366 retVal.array = t.array 367 368 retVal.old = t.old 369 retVal.transposeWith = t.transposeWith 370 retVal.viewOf = t.viewOf 371 retVal.mask = t.mask 372 retVal.maskIsSoft = t.maskIsSoft 373 return retVal 374 } 375 376 func (t *Dense) oldAP() *AP { return &t.old } 377 func (t *Dense) setOldAP(ap *AP) { t.old = *ap } 378 func (t *Dense) transposeAxes() []int { return t.transposeWith } 379 380 //go:nocheckptr 381 func (t *Dense) parentTensor() *Dense { 382 if t.viewOf != 0 { 383 return (*Dense)(unsafe.Pointer(t.viewOf)) 384 } 385 return nil 386 } 387 388 func (t *Dense) setParentTensor(d *Dense) { 389 if d == nil { 390 t.viewOf = 0 391 return 392 } 393 t.viewOf = uintptr(unsafe.Pointer(d)) 394 } 395 396 /* ------ Mask operations */ 397 398 //ResetMask fills the mask with either false, or the provided boolean value 399 func (t *Dense) ResetMask(val ...bool) error { 400 if !t.IsMasked() { 401 t.makeMask() 402 } 403 var fillValue = false 404 if len(val) > 0 { 405 fillValue = val[0] 406 } 407 memsetBools(t.mask, fillValue) 408 return nil 409 } 410 411 // HardenMask forces the mask to hard. If mask is hard, then true mask values can not be unset 412 func (t *Dense) HardenMask() bool { 413 t.maskIsSoft = false 414 return t.maskIsSoft 415 } 416 417 // SoftenMask forces the mask to soft 418 func (t *Dense) SoftenMask() bool { 419 t.maskIsSoft = true 420 return t.maskIsSoft 421 } 422 423 // MaskFromSlice makes mask from supplied slice 424 func (t *Dense) MaskFromSlice(x interface{}) { 425 t.makeMask() 426 n := len(t.mask) 427 switch m := x.(type) { 428 case []bool: 429 copy(t.mask, m) 430 return 431 case []int: 432 for i, v := range m { 433 if v != 0 { 434 t.mask[i] = true 435 } 436 if i >= n { 437 return 438 } 439 } 440 case []int8: 441 for i, v := range m { 442 if v != 0 { 443 t.mask[i] = true 444 } 445 if i >= n { 446 return 447 } 448 } 449 case []int16: 450 for i, v := range m { 451 if v != 0 { 452 t.mask[i] = true 453 } 454 if i >= n { 455 return 456 } 457 } 458 case []int32: 459 for i, v := range m { 460 if v != 0 { 461 t.mask[i] = true 462 } 463 if i >= n { 464 return 465 } 466 } 467 case []int64: 468 for i, v := range m { 469 if v != 0 { 470 t.mask[i] = true 471 } 472 if i >= n { 473 return 474 } 475 } 476 case []uint: 477 for i, v := range m { 478 if v != 0 { 479 t.mask[i] = true 480 } 481 if i >= n { 482 return 483 } 484 } 485 case []byte: 486 for i, v := range m { 487 if v != 0 { 488 t.mask[i] = true 489 } 490 if i >= n { 491 return 492 } 493 } 494 case []uint16: 495 for i, v := range m { 496 if v != 0 { 497 t.mask[i] = true 498 } 499 if i >= n { 500 return 501 } 502 } 503 case []uint32: 504 for i, v := range m { 505 if v != 0 { 506 t.mask[i] = true 507 } 508 if i >= n { 509 return 510 } 511 } 512 case []uint64: 513 for i, v := range m { 514 if v != 0 { 515 t.mask[i] = true 516 } 517 if i >= n { 518 return 519 } 520 } 521 case []float32: 522 for i, v := range m { 523 if v != 0 { 524 t.mask[i] = true 525 } 526 if i >= n { 527 return 528 } 529 } 530 case []float64: 531 for i, v := range m { 532 if v != 0 { 533 t.mask[i] = true 534 } 535 if i >= n { 536 return 537 } 538 } 539 case []complex64: 540 for i, v := range m { 541 if v != 0 { 542 t.mask[i] = true 543 } 544 if i >= n { 545 return 546 } 547 } 548 case []complex128: 549 for i, v := range m { 550 if v != 0 { 551 t.mask[i] = true 552 } 553 if i >= n { 554 return 555 } 556 } 557 case []string: 558 for i, v := range m { 559 if v != "" { 560 t.mask[i] = true 561 } 562 if i >= n { 563 return 564 } 565 } 566 default: 567 return 568 } 569 } 570 571 // Memset sets all the values in the *Dense tensor. 572 func (t *Dense) Memset(x interface{}) error { 573 if !t.IsNativelyAccessible() { 574 return errors.Errorf(inaccessibleData, t) 575 } 576 if t.IsMaterializable() { 577 it := newFlatIterator(&t.AP) 578 return t.array.memsetIter(x, it) 579 } 580 return t.array.Memset(x) 581 } 582 583 // Eq checks that any two things are equal. If the shapes are the same, but the strides are not the same, it's will still be considered the same 584 func (t *Dense) Eq(other interface{}) bool { 585 if ot, ok := other.(*Dense); ok { 586 if ot == t { 587 return true 588 } 589 if !t.Shape().Eq(ot.Shape()) { 590 return false 591 } 592 593 return t.array.Eq(&ot.array) 594 } 595 return false 596 } 597 598 func (t *Dense) Zero() { 599 if t.IsMaterializable() { 600 it := newFlatIterator(&t.AP) 601 if err := t.zeroIter(it); err != nil { 602 panic(err) 603 } 604 } 605 if t.IsMasked() { 606 t.ResetMask() 607 } 608 t.array.Zero() 609 } 610 611 func (t *Dense) Mask() []bool { return t.mask } 612 613 func (t *Dense) SetMask(mask []bool) { 614 // if len(mask) != t.len() { 615 // panic("Cannot set mask") 616 // } 617 t.mask = mask 618 } 619 620 func (t *Dense) slice(start, end int) { 621 t.array = t.array.slice(start, end) 622 } 623 624 // RequiresIterator indicates if an iterator is required to read the data in *Dense in the correct fashion 625 func (t *Dense) RequiresIterator() bool { 626 if t.len() == 1 { 627 return false 628 } 629 // non continuous slice, transpose, or masked. If it's a slice and contiguous, then iterator is not required 630 if !t.o.IsContiguous() || !t.old.IsZero() || t.IsMasked() { 631 return true 632 } 633 return false 634 } 635 636 func (t *Dense) Iterator() Iterator { return IteratorFromDense(t) } 637 638 func (t *Dense) standardEngine() standardEngine { return t.oe }