gorgonia.org/tensor@v0.9.24/dense.go (about)

     1  package tensor
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"unsafe"
     7  
     8  	"github.com/pkg/errors"
     9  	"gorgonia.org/tensor/internal/storage"
    10  )
    11  
    12  const (
    13  	maskCompEvery int = 8
    14  )
    15  
    16  // Dense represents a dense tensor - this is the most common form of tensors. It can be used to represent vectors, matrices.. etc
    17  type Dense struct {
    18  	AP
    19  	array
    20  
    21  	flag MemoryFlag
    22  	e    Engine         // execution engine for the *Dense
    23  	oe   standardEngine // optimized engine
    24  
    25  	// backup AP. When a transpose is done, the old *AP is backed up here, for easy untransposes
    26  	old           AP
    27  	transposeWith []int
    28  
    29  	// if viewOf != nil, then this *Dense is a view.
    30  	viewOf uintptr
    31  
    32  	mask       []bool // mask slice can be used to identify missing or invalid values. len(mask)<=len(v)
    33  	maskIsSoft bool
    34  }
    35  
    36  // NewDense creates a new *Dense. It tries its best to get from the tensor pool.
    37  func NewDense(dt Dtype, shape Shape, opts ...ConsOpt) *Dense {
    38  	return recycledDense(dt, shape, opts...)
    39  }
    40  
    41  func recycledDense(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) {
    42  	retVal = recycledDenseNoFix(dt, shape, opts...)
    43  	retVal.fix()
    44  	if err := retVal.sanity(); err != nil {
    45  		panic(err)
    46  	}
    47  	return
    48  }
    49  
    50  func recycledDenseNoFix(dt Dtype, shape Shape, opts ...ConsOpt) (retVal *Dense) {
    51  	//	size := shape.TotalSize()
    52  	//if shape.IsScalar() {
    53  	//	size = 1
    54  	//}
    55  	retVal = borrowDense()
    56  	retVal.array.t = dt
    57  	retVal.AP.zeroWithDims(shape.Dims())
    58  
    59  	for _, opt := range opts {
    60  		opt(retVal)
    61  	}
    62  	retVal.setShape(shape...)
    63  	return
    64  }
    65  
    66  func (t *Dense) fromSlice(x interface{}) {
    67  	t.array.Header.Raw = nil // GC anything else
    68  	t.array.fromSlice(x)
    69  }
    70  
    71  func (t *Dense) addMask(mask []bool) {
    72  	l := len(mask)
    73  	if l > 0 && l != t.len() {
    74  		panic("Mask is not same length as data")
    75  	}
    76  	t.mask = mask
    77  }
    78  
    79  func (t *Dense) makeArray(size int) {
    80  	switch te := t.e.(type) {
    81  	case NonStdEngine:
    82  		t.flag = MakeMemoryFlag(t.flag, ManuallyManaged)
    83  	case arrayMaker:
    84  		te.makeArray(&t.array, t.t, size)
    85  		return
    86  	default:
    87  	}
    88  
    89  	memsize := calcMemSize(t.t, size)
    90  	mem, err := t.e.Alloc(memsize)
    91  	if err != nil {
    92  		panic(err)
    93  	}
    94  
    95  	t.array.Raw = storage.FromMemory(mem.Uintptr(), uintptr(memsize))
    96  	return
    97  }
    98  
    99  // Info returns the access pattern which explains how the data in the underlying array is accessed. This is mostly used for debugging.
   100  func (t *Dense) Info() *AP { return &t.AP }
   101  
   102  // Dtype returns the data type of the *Dense tensor.
   103  func (t *Dense) Dtype() Dtype { return t.t }
   104  
   105  // Data returns the underlying array. If the *Dense represents a scalar value, the scalar value is returned instead
   106  func (t *Dense) Data() interface{} {
   107  	if t.IsScalar() {
   108  		return t.Get(0)
   109  	}
   110  
   111  	// build a type of []T
   112  	shdr := reflect.SliceHeader{
   113  		Data: t.array.Uintptr(),
   114  		Len:  t.array.Len(),
   115  		Cap:  t.array.Cap(),
   116  	}
   117  	sliceT := reflect.SliceOf(t.t.Type)
   118  	ptr := unsafe.Pointer(&shdr)
   119  	val := reflect.Indirect(reflect.NewAt(sliceT, ptr))
   120  	return val.Interface()
   121  }
   122  
   123  // DataSize returns the size of the underlying array. Typically t.DataSize() == t.Shape().TotalSize()
   124  func (t *Dense) DataSize() int {
   125  	if t.IsScalar() {
   126  		return 0 // DOUBLE CHECK
   127  	}
   128  	return t.array.Len()
   129  }
   130  
   131  // Engine returns the execution engine associated with this Tensor
   132  func (t *Dense) Engine() Engine { return t.e }
   133  
   134  // Reshape reshapes a *Dense. If the tensors need to be materialized (either it's a view or transpose), it will be materialized before the reshape happens
   135  func (t *Dense) Reshape(dims ...int) error {
   136  	if t.Shape().TotalSize() != Shape(dims).TotalSize() {
   137  		return errors.Errorf("Cannot reshape %v into %v", t.Shape(), dims)
   138  	}
   139  
   140  	if t.viewOf != 0 && t.o.IsNotContiguous() {
   141  		return errors.Errorf(methodNYI, "Reshape", "non-contiguous views")
   142  	}
   143  
   144  	if !t.old.IsZero() {
   145  		t.Transpose()
   146  	}
   147  
   148  	return t.reshape(dims...)
   149  }
   150  
   151  func (t *Dense) reshape(dims ...int) error {
   152  	t.setShape(dims...)
   153  	return t.sanity()
   154  }
   155  
   156  func (t *Dense) unsqueeze(axis int) error {
   157  	if axis > t.shape.Dims()+1 {
   158  		return errors.Errorf("Cannot unsqueeze on axis %d when the tensor has shape %v", axis, t.shape)
   159  	}
   160  	t.shape = append(t.shape, 1)
   161  	copy(t.shape[axis+1:], t.shape[axis:])
   162  	t.shape[axis] = 1
   163  
   164  	t.strides = append(t.strides, 1)
   165  	copy(t.strides[axis+1:], t.strides[axis:])
   166  
   167  	return nil
   168  }
   169  
   170  // ScalarValue returns the scalar value of a *Tensor,
   171  // IF and ONLY IF it's a Tensor representation of a scalar value.
   172  // This is required because operations like a (vec ยท vec) would return a scalar value.
   173  // I didn't want to return interface{} for all the API methods, so the next best solution is to
   174  // wrap the scalar value in a *Tensor
   175  func (t *Dense) ScalarValue() interface{} {
   176  	if !t.IsScalar() {
   177  		panic(fmt.Sprintf("ScalarValue only works when the Tensor is a representation of a scalar value. The value of the tensor is %v", t))
   178  	}
   179  
   180  	return t.Get(0)
   181  }
   182  
   183  // IsView indicates if the Tensor is a view of another (typically from slicing)
   184  func (t *Dense) IsView() bool {
   185  	return t.viewOf != 0
   186  }
   187  
   188  // IsMaterializeable indicates if the Tensor is materializable - if it has either gone through some transforms or slicing
   189  func (t *Dense) IsMaterializable() bool {
   190  	return t.viewOf != 0 || !t.old.IsZero()
   191  }
   192  
   193  // IsManuallyManaged returns true if the memory associated with this *Dense is manually managed (by the user)
   194  func (t *Dense) IsManuallyManaged() bool { return t.flag.manuallyManaged() }
   195  
   196  // IsNativelyAccessible checks if the pointers are accessible by Go
   197  func (t *Dense) IsNativelyAccessible() bool { return t.flag.nativelyAccessible() }
   198  
   199  // Clone clones a *Dense. It creates a copy of the data, and the underlying array will be allocated
   200  func (t *Dense) Clone() interface{} {
   201  	if t.e != nil {
   202  		retVal := new(Dense)
   203  		t.AP.CloneTo(&retVal.AP)
   204  		retVal.t = t.t
   205  		retVal.e = t.e
   206  		retVal.oe = t.oe
   207  		retVal.flag = t.flag
   208  		retVal.makeArray(t.Len())
   209  
   210  		if !t.old.IsZero() {
   211  			retVal.old = t.old.Clone()
   212  			t.old.CloneTo(&retVal.old)
   213  		}
   214  		copyDense(retVal, t)
   215  		retVal.lock()
   216  
   217  		return retVal
   218  	}
   219  	panic("Unreachable: No engine")
   220  }
   221  
   222  // IsMasked indicates whether tensor is masked
   223  func (t *Dense) IsMasked() bool { return len(t.mask) == t.len() }
   224  
   225  // MaskFromDense adds a mask slice to tensor by XORing dense arguments' masks
   226  func (t *Dense) MaskFromDense(tts ...*Dense) {
   227  	hasMask := BorrowBools(len(tts))
   228  	defer ReturnBools(hasMask)
   229  
   230  	numMasked := 0
   231  	var masked = false
   232  
   233  	for i, tt := range tts {
   234  		if tt != nil {
   235  			hasMask[i] = tt.IsMasked()
   236  			masked = masked || hasMask[i]
   237  			if hasMask[i] {
   238  				numMasked++
   239  			}
   240  		}
   241  	}
   242  	if numMasked < 1 {
   243  		return
   244  	}
   245  
   246  	//Only make mask if none already. This way one of the tts can be t itself
   247  
   248  	if len(t.mask) < t.DataSize() {
   249  		t.makeMask()
   250  	}
   251  
   252  	for i, tt := range tts {
   253  		if tt != nil {
   254  			n := len(tt.mask)
   255  			if hasMask[i] {
   256  				for j := range t.mask {
   257  					t.mask[j] = t.mask[j] || tt.mask[j%n]
   258  				}
   259  			}
   260  		}
   261  	}
   262  }
   263  
   264  // Private methods
   265  
   266  func (t *Dense) cap() int       { return t.array.Cap() }
   267  func (t *Dense) len() int       { return t.array.Len() } // exactly the same as DataSize
   268  func (t *Dense) arr() array     { return t.array }
   269  func (t *Dense) arrPtr() *array { return &t.array }
   270  
   271  func (t *Dense) setShape(s ...int) {
   272  	t.unlock()
   273  	t.SetShape(s...)
   274  	t.lock()
   275  	return
   276  }
   277  
   278  func (t *Dense) setAP(ap *AP) { t.AP = *ap }
   279  
   280  func (t *Dense) fix() {
   281  	if t.e == nil {
   282  		t.e = StdEng{}
   283  	}
   284  
   285  	if oe, ok := t.e.(standardEngine); ok {
   286  		t.oe = oe
   287  	}
   288  
   289  	switch {
   290  	case t.IsScalar() && t.array.Header.Raw == nil:
   291  		t.makeArray(1)
   292  	case t.Shape() == nil && t.array.Header.Raw != nil:
   293  		size := t.Len()
   294  		if size == 1 {
   295  			t.SetShape() // scalar
   296  		} else {
   297  			t.SetShape(size) // vector
   298  		}
   299  	case t.array.Header.Raw == nil && t.t != Dtype{}:
   300  		size := t.Shape().TotalSize()
   301  		t.makeArray(size)
   302  
   303  	}
   304  	if len(t.mask) != t.len() {
   305  		t.mask = t.mask[:0]
   306  	}
   307  	t.lock() // don't put this in a defer - if t.array.Ptr == nil and t.Shape() == nil. then leave it unlocked
   308  }
   309  
   310  // makeMask adds a mask slice to tensor if required
   311  func (t *Dense) makeMask() {
   312  	var size int
   313  	size = t.shape.TotalSize()
   314  	if len(t.mask) >= size {
   315  		t.mask = t.mask[:size]
   316  	}
   317  	if cap(t.mask) < size {
   318  		t.mask = make([]bool, size)
   319  	}
   320  	t.mask = t.mask[:size]
   321  	memsetBools(t.mask, false)
   322  }
   323  
   324  // sanity is a function that sanity checks that a tensor is correct.
   325  func (t *Dense) sanity() error {
   326  	if !t.AP.IsZero() && t.Shape() == nil && t.array.Header.Raw == nil {
   327  		return errors.New(emptyTensor)
   328  	}
   329  
   330  	size := t.Len()
   331  	expected := t.Size()
   332  	if t.viewOf == 0 && size != expected && !t.IsScalar() {
   333  		return errors.Wrap(errors.Errorf(shapeMismatch, t.Shape(), size), "sanity check failed")
   334  	}
   335  
   336  	// TODO: sanity check for views
   337  	return nil
   338  }
   339  
   340  // isTransposed returns true if the *Dense holds a transposed array.
   341  func (t *Dense) isTransposed() bool { return t.old.IsZero() }
   342  
   343  // oshape returns the original shape
   344  func (t *Dense) oshape() Shape {
   345  	if !t.old.IsZero() {
   346  		return t.old.Shape()
   347  	}
   348  	return t.Shape()
   349  }
   350  
   351  // ostrides returns the original strides
   352  func (t *Dense) ostrides() []int {
   353  	if !t.old.IsZero() {
   354  		return t.old.Strides()
   355  	}
   356  	return t.Strides()
   357  }
   358  
   359  // ShallowClone clones the *Dense without making a copy of the underlying array
   360  func (t *Dense) ShallowClone() *Dense {
   361  	retVal := borrowDense()
   362  	retVal.e = t.e
   363  	retVal.oe = t.oe
   364  	t.AP.CloneTo(&retVal.AP)
   365  	retVal.flag = t.flag
   366  	retVal.array = t.array
   367  
   368  	retVal.old = t.old
   369  	retVal.transposeWith = t.transposeWith
   370  	retVal.viewOf = t.viewOf
   371  	retVal.mask = t.mask
   372  	retVal.maskIsSoft = t.maskIsSoft
   373  	return retVal
   374  }
   375  
   376  func (t *Dense) oldAP() *AP           { return &t.old }
   377  func (t *Dense) setOldAP(ap *AP)      { t.old = *ap }
   378  func (t *Dense) transposeAxes() []int { return t.transposeWith }
   379  
   380  //go:nocheckptr
   381  func (t *Dense) parentTensor() *Dense {
   382  	if t.viewOf != 0 {
   383  		return (*Dense)(unsafe.Pointer(t.viewOf))
   384  	}
   385  	return nil
   386  }
   387  
   388  func (t *Dense) setParentTensor(d *Dense) {
   389  	if d == nil {
   390  		t.viewOf = 0
   391  		return
   392  	}
   393  	t.viewOf = uintptr(unsafe.Pointer(d))
   394  }
   395  
   396  /* ------ Mask operations */
   397  
   398  //ResetMask fills the mask with either false, or the provided boolean value
   399  func (t *Dense) ResetMask(val ...bool) error {
   400  	if !t.IsMasked() {
   401  		t.makeMask()
   402  	}
   403  	var fillValue = false
   404  	if len(val) > 0 {
   405  		fillValue = val[0]
   406  	}
   407  	memsetBools(t.mask, fillValue)
   408  	return nil
   409  }
   410  
   411  // HardenMask forces the mask to hard. If mask is hard, then true mask values can not be unset
   412  func (t *Dense) HardenMask() bool {
   413  	t.maskIsSoft = false
   414  	return t.maskIsSoft
   415  }
   416  
   417  // SoftenMask forces the mask to soft
   418  func (t *Dense) SoftenMask() bool {
   419  	t.maskIsSoft = true
   420  	return t.maskIsSoft
   421  }
   422  
   423  // MaskFromSlice makes mask from supplied slice
   424  func (t *Dense) MaskFromSlice(x interface{}) {
   425  	t.makeMask()
   426  	n := len(t.mask)
   427  	switch m := x.(type) {
   428  	case []bool:
   429  		copy(t.mask, m)
   430  		return
   431  	case []int:
   432  		for i, v := range m {
   433  			if v != 0 {
   434  				t.mask[i] = true
   435  			}
   436  			if i >= n {
   437  				return
   438  			}
   439  		}
   440  	case []int8:
   441  		for i, v := range m {
   442  			if v != 0 {
   443  				t.mask[i] = true
   444  			}
   445  			if i >= n {
   446  				return
   447  			}
   448  		}
   449  	case []int16:
   450  		for i, v := range m {
   451  			if v != 0 {
   452  				t.mask[i] = true
   453  			}
   454  			if i >= n {
   455  				return
   456  			}
   457  		}
   458  	case []int32:
   459  		for i, v := range m {
   460  			if v != 0 {
   461  				t.mask[i] = true
   462  			}
   463  			if i >= n {
   464  				return
   465  			}
   466  		}
   467  	case []int64:
   468  		for i, v := range m {
   469  			if v != 0 {
   470  				t.mask[i] = true
   471  			}
   472  			if i >= n {
   473  				return
   474  			}
   475  		}
   476  	case []uint:
   477  		for i, v := range m {
   478  			if v != 0 {
   479  				t.mask[i] = true
   480  			}
   481  			if i >= n {
   482  				return
   483  			}
   484  		}
   485  	case []byte:
   486  		for i, v := range m {
   487  			if v != 0 {
   488  				t.mask[i] = true
   489  			}
   490  			if i >= n {
   491  				return
   492  			}
   493  		}
   494  	case []uint16:
   495  		for i, v := range m {
   496  			if v != 0 {
   497  				t.mask[i] = true
   498  			}
   499  			if i >= n {
   500  				return
   501  			}
   502  		}
   503  	case []uint32:
   504  		for i, v := range m {
   505  			if v != 0 {
   506  				t.mask[i] = true
   507  			}
   508  			if i >= n {
   509  				return
   510  			}
   511  		}
   512  	case []uint64:
   513  		for i, v := range m {
   514  			if v != 0 {
   515  				t.mask[i] = true
   516  			}
   517  			if i >= n {
   518  				return
   519  			}
   520  		}
   521  	case []float32:
   522  		for i, v := range m {
   523  			if v != 0 {
   524  				t.mask[i] = true
   525  			}
   526  			if i >= n {
   527  				return
   528  			}
   529  		}
   530  	case []float64:
   531  		for i, v := range m {
   532  			if v != 0 {
   533  				t.mask[i] = true
   534  			}
   535  			if i >= n {
   536  				return
   537  			}
   538  		}
   539  	case []complex64:
   540  		for i, v := range m {
   541  			if v != 0 {
   542  				t.mask[i] = true
   543  			}
   544  			if i >= n {
   545  				return
   546  			}
   547  		}
   548  	case []complex128:
   549  		for i, v := range m {
   550  			if v != 0 {
   551  				t.mask[i] = true
   552  			}
   553  			if i >= n {
   554  				return
   555  			}
   556  		}
   557  	case []string:
   558  		for i, v := range m {
   559  			if v != "" {
   560  				t.mask[i] = true
   561  			}
   562  			if i >= n {
   563  				return
   564  			}
   565  		}
   566  	default:
   567  		return
   568  	}
   569  }
   570  
   571  // Memset sets all the values in the *Dense tensor.
   572  func (t *Dense) Memset(x interface{}) error {
   573  	if !t.IsNativelyAccessible() {
   574  		return errors.Errorf(inaccessibleData, t)
   575  	}
   576  	if t.IsMaterializable() {
   577  		it := newFlatIterator(&t.AP)
   578  		return t.array.memsetIter(x, it)
   579  	}
   580  	return t.array.Memset(x)
   581  }
   582  
   583  // Eq checks that any two things are equal. If the shapes are the same, but the strides are not the same, it's will still be considered the same
   584  func (t *Dense) Eq(other interface{}) bool {
   585  	if ot, ok := other.(*Dense); ok {
   586  		if ot == t {
   587  			return true
   588  		}
   589  		if !t.Shape().Eq(ot.Shape()) {
   590  			return false
   591  		}
   592  
   593  		return t.array.Eq(&ot.array)
   594  	}
   595  	return false
   596  }
   597  
   598  func (t *Dense) Zero() {
   599  	if t.IsMaterializable() {
   600  		it := newFlatIterator(&t.AP)
   601  		if err := t.zeroIter(it); err != nil {
   602  			panic(err)
   603  		}
   604  	}
   605  	if t.IsMasked() {
   606  		t.ResetMask()
   607  	}
   608  	t.array.Zero()
   609  }
   610  
   611  func (t *Dense) Mask() []bool { return t.mask }
   612  
   613  func (t *Dense) SetMask(mask []bool) {
   614  	// if len(mask) != t.len() {
   615  	// 	panic("Cannot set mask")
   616  	// }
   617  	t.mask = mask
   618  }
   619  
   620  func (t *Dense) slice(start, end int) {
   621  	t.array = t.array.slice(start, end)
   622  }
   623  
   624  // RequiresIterator indicates if an iterator is required to read the data in *Dense in the correct fashion
   625  func (t *Dense) RequiresIterator() bool {
   626  	if t.len() == 1 {
   627  		return false
   628  	}
   629  	// non continuous slice, transpose, or masked. If it's a slice and contiguous, then iterator is not required
   630  	if !t.o.IsContiguous() || !t.old.IsZero() || t.IsMasked() {
   631  		return true
   632  	}
   633  	return false
   634  }
   635  
   636  func (t *Dense) Iterator() Iterator { return IteratorFromDense(t) }
   637  
   638  func (t *Dense) standardEngine() standardEngine { return t.oe }