gorgonia.org/tensor@v0.9.24/defaultengine_matop_misc.go (about)

     1  package tensor
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  	"gorgonia.org/tensor/internal/storage"
     6  )
     7  
     8  var (
     9  	_ Diager = StdEng{}
    10  )
    11  
    12  type fastcopier interface {
    13  	fastCopyDenseRepeat(t DenseTensor, d *Dense, outers, size, stride, newStride int, repeats []int) error
    14  }
    15  
    16  // Repeat ...
    17  func (e StdEng) Repeat(t Tensor, axis int, repeats ...int) (Tensor, error) {
    18  	switch tt := t.(type) {
    19  	case DenseTensor:
    20  		newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats)
    21  		if err != nil {
    22  			return nil, err
    23  		}
    24  		rr := recycledDense(t.Dtype(), newShape, WithEngine(StdEng{}))
    25  		return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats)
    26  	default:
    27  		return nil, errors.Errorf("NYI")
    28  	}
    29  }
    30  
    31  // RepeatReuse is like Repeat, but with a provided reuse Tensor. The reuseTensor must be of the same type as the input t.
    32  func (e StdEng) RepeatReuse(t Tensor, reuse Tensor, axis int, repeats ...int) (Tensor, error) {
    33  	switch tt := t.(type) {
    34  	case DenseTensor:
    35  		newShape, newRepeats, newAxis, size, err := e.denseRepeatCheck(t, axis, repeats)
    36  		if err != nil {
    37  			return nil, err
    38  		}
    39  
    40  		rr, ok := reuse.(DenseTensor)
    41  		if !ok {
    42  			return nil, errors.Errorf("t is a DenseTensor but reuse is of %T", reuse)
    43  		}
    44  		if !reuse.Shape().Eq(newShape) {
    45  			return nil, errors.Errorf("Reuse shape is %v. Expected shape is %v", reuse.Shape(), newShape)
    46  		}
    47  		return e.denseRepeat(tt, rr, newShape, newAxis, size, newRepeats)
    48  	default:
    49  		return nil, errors.Errorf("NYI")
    50  	}
    51  }
    52  
    53  func (StdEng) denseRepeatCheck(t Tensor, axis int, repeats []int) (newShape Shape, newRepeats []int, newAxis, size int, err error) {
    54  	if newShape, newRepeats, size, err = t.Shape().Repeat(axis, repeats...); err != nil {
    55  		return nil, nil, -1, -1, errors.Wrap(err, "Unable to get repeated shape")
    56  	}
    57  	newAxis = axis
    58  	if axis == AllAxes {
    59  		newAxis = 0
    60  	}
    61  
    62  	return
    63  }
    64  
    65  func (StdEng) denseRepeat(t, reuse DenseTensor, newShape Shape, axis, size int, repeats []int) (retVal DenseTensor, err error) {
    66  	d, err := assertDense(reuse)
    67  	if err != nil {
    68  		return nil, errors.Wrapf(err, "Repeat reuse is not a *Dense")
    69  	}
    70  	var outers int
    71  	if t.IsScalar() {
    72  		outers = 1
    73  	} else {
    74  		outers = ProdInts(t.Shape()[0:axis])
    75  	}
    76  
    77  	var stride, newStride int
    78  	if newShape.IsVector() || t.IsVector() {
    79  		stride = 1 // special case because CalcStrides() will return []int{1} as the strides for a vector
    80  	} else {
    81  		stride = t.ostrides()[axis]
    82  	}
    83  
    84  	if newShape.IsVector() {
    85  		newStride = 1
    86  	} else {
    87  		newStride = d.ostrides()[axis]
    88  	}
    89  
    90  	var destStart, srcStart int
    91  	// fastCopy is not bypassing the copyDenseSliced method to populate the output tensor
    92  	var fastCopy bool
    93  	var fce fastcopier
    94  	// we need an engine for fastCopying...
    95  	e := t.Engine()
    96  	// e can never be nil. Error would have occurred elsewhere
    97  	var ok bool
    98  	if fce, ok = e.(fastcopier); ok {
    99  		fastCopy = true
   100  	}
   101  
   102  	// In this case, let's not implement the fast copy to keep the code readable
   103  	if ms, ok := t.(MaskedTensor); ok && ms.IsMasked() {
   104  		fastCopy = false
   105  	}
   106  
   107  	// if d is not a fastcopier, then we also cannot use fast copy
   108  	if _, ok := d.Engine().(fastcopier); !ok {
   109  		fastCopy = false
   110  	}
   111  
   112  	if fastCopy {
   113  		if err := fce.fastCopyDenseRepeat(t, d, outers, size, stride, newStride, repeats); err != nil {
   114  			return nil, err
   115  		}
   116  		return d, nil
   117  	}
   118  
   119  	for i := 0; i < outers; i++ {
   120  		for j := 0; j < size; j++ {
   121  			var tmp int
   122  			tmp = repeats[j]
   123  
   124  			for k := 0; k < tmp; k++ {
   125  				if srcStart >= t.len() || destStart+stride > d.len() {
   126  					break
   127  				}
   128  				copyDenseSliced(d, destStart, d.len(), t, srcStart, t.len())
   129  				destStart += newStride
   130  			}
   131  			srcStart += stride
   132  		}
   133  	}
   134  	return d, nil
   135  }
   136  
   137  func (e StdEng) fastCopyDenseRepeat(src DenseTensor, dest *Dense, outers, size, stride, newStride int, repeats []int) error {
   138  	sarr := src.arr()
   139  	darr := dest.arr()
   140  
   141  	var destStart, srcStart int
   142  	for i := 0; i < outers; i++ {
   143  		// faster shortcut for common case.
   144  		//
   145  		// Consider a case where:
   146  		// 	a := ⎡ 1 ⎤
   147  		//	     ⎢ 2 ⎥
   148  		//	     ⎢ 3 ⎥
   149  		//	     ⎣ 4 ⎦
   150  		// a has a shape of (4, 1). it is a *Dense.
   151  		//
   152  		// Now assume we want to repeat it on axis 1, 3 times. We want to repeat it into `b`,
   153  		// which is already allocated and zeroed, as shown below
   154  		//
   155  		// 	b := ⎡ 0 0 0 ⎤
   156  		//	     ⎢ 0 0 0 ⎥
   157  		//	     ⎢ 0 0 0 ⎥
   158  		//	     ⎣ 0 0 0 ⎦
   159  		//
   160  		// Now, both `a` and `b` have a stride of 1.
   161  		//
   162  		// The desired result is:
   163  		// 	b := ⎡ 1 1 1 ⎤
   164  		//	     ⎢ 2 2 2 ⎥
   165  		//	     ⎢ 3 3 3 ⎥
   166  		//	     ⎣ 4 4 4 ⎦
   167  		///
   168  		// Observe that this is simply broadcasting (copying) a[0] (a scalar value) to the row b[0], and so on and so forth.
   169  		// This can be done without knowing the full type - we simply copy the bytes over.
   170  		if stride == 1 && newStride == 1 {
   171  			for sz := 0; sz < size; sz++ {
   172  				tmp := repeats[sz]
   173  
   174  				// first we get the bounds of the src and the dest
   175  				// the srcStart and destStart are the indices assuming a flat array of []T
   176  				// we need to get the byte slice equivalent.
   177  				bSrcStart := srcStart * int(sarr.t.Size())
   178  				bSrcEnd := (srcStart + stride) * int(sarr.t.Size())
   179  				bDestStart := destStart * int(darr.t.Size())
   180  				bDestEnd := (destStart + tmp) * int(darr.t.Size())
   181  
   182  				// then we get the data as a slice of raw bytes
   183  				sBS := sarr.Header.Raw
   184  				dBS := darr.Header.Raw
   185  
   186  				// recall that len(src) < len(dest)
   187  				// it's easier to understand if we define the ranges.
   188  				// Less prone to errors.
   189  				sRange := sBS[bSrcStart:bSrcEnd]
   190  				dRange := dBS[bDestStart:bDestEnd]
   191  
   192  				// finally we copy things.
   193  				for i := 0; i < len(dRange); i += len(sRange) {
   194  					copy(dRange[i:], sRange)
   195  				}
   196  				srcStart += stride
   197  				destStart += tmp
   198  			}
   199  
   200  			// we can straightaway broadcast
   201  
   202  			continue
   203  		}
   204  
   205  		for j := 0; j < size; j++ {
   206  			var tmp int
   207  			tmp = repeats[j]
   208  			var tSlice array
   209  
   210  			tSlice = sarr.slice(srcStart, src.len())
   211  
   212  			for k := 0; k < tmp; k++ {
   213  				if srcStart >= src.len() || destStart+stride > dest.len() {
   214  					break
   215  				}
   216  
   217  				dSlice := darr.slice(destStart, destStart+newStride)
   218  
   219  				// THIS IS AN OPTIMIZATION. REVISIT WHEN NEEDED.
   220  				storage.Copy(dSlice.t.Type, &dSlice.Header, &tSlice.Header)
   221  
   222  				destStart += newStride
   223  			}
   224  			srcStart += stride
   225  		}
   226  	}
   227  	return nil
   228  }
   229  
   230  // Concat tensors
   231  func (e StdEng) Concat(t Tensor, axis int, others ...Tensor) (retVal Tensor, err error) {
   232  	switch tt := t.(type) {
   233  	case DenseTensor:
   234  		var denses []DenseTensor
   235  		if denses, err = tensorsToDenseTensors(others); err != nil {
   236  			return nil, errors.Wrap(err, "Concat failed")
   237  		}
   238  		return e.denseConcat(tt, axis, denses)
   239  	default:
   240  		return nil, errors.Errorf("NYI")
   241  	}
   242  }
   243  
   244  func (e StdEng) denseConcat(a DenseTensor, axis int, Ts []DenseTensor) (DenseTensor, error) {
   245  	ss := make([]Shape, len(Ts))
   246  	var err error
   247  	var isMasked bool
   248  	for i, T := range Ts {
   249  		ss[i] = T.Shape()
   250  		if mt, ok := T.(MaskedTensor); ok {
   251  			isMasked = isMasked || mt.IsMasked()
   252  		}
   253  	}
   254  
   255  	var newShape Shape
   256  	if newShape, err = a.Shape().Concat(axis, ss...); err != nil {
   257  		return nil, errors.Wrap(err, "Unable to find new shape that results from concatenation")
   258  	}
   259  
   260  	retVal := recycledDense(a.Dtype(), newShape, WithEngine(e))
   261  	if isMasked {
   262  		retVal.makeMask()
   263  	}
   264  
   265  	all := make([]DenseTensor, len(Ts)+1)
   266  	all[0] = a
   267  	copy(all[1:], Ts)
   268  
   269  	// TODO: OPIMIZATION
   270  	// When (axis == 0 && a is row major and all others is row major) || (axis == last axis of A && all tensors are colmajor)
   271  	// just flat copy
   272  	//
   273  
   274  	// isOuter  is true when the axis is the outermost axis
   275  	// isInner is true when the axis is the inner most axis
   276  	isOuter := axis == 0
   277  	isInner := axis == (a.Shape().Dims() - 1)
   278  
   279  	// special case
   280  	var start, end int
   281  	for _, T := range all {
   282  		end += T.Shape()[axis]
   283  		slices := make([]Slice, axis+1)
   284  		slices[axis] = makeRS(start, end)
   285  
   286  		var v *Dense
   287  		if v, err = sliceDense(retVal, slices...); err != nil {
   288  			return nil, errors.Wrap(err, "Unable to slice DenseTensor while performing denseConcat")
   289  		}
   290  
   291  		// keep dims after slicing
   292  		switch {
   293  		case v.IsVector() && T.IsMatrix() && axis == 0:
   294  			v.reshape(v.shape[0], 1)
   295  		case T.IsRowVec() && axis == 0:
   296  			T.reshape(T.Shape()[1])
   297  		case v.Shape().IsScalarEquiv() && T.Shape().IsScalarEquiv():
   298  			copyArray(v.arrPtr(), T.arrPtr())
   299  			if mt, ok := T.(MaskedTensor); ok {
   300  				copy(v.mask, mt.Mask())
   301  			}
   302  			start = end
   303  			continue
   304  		default:
   305  			diff := retVal.Shape().Dims() - v.Shape().Dims()
   306  			if diff > 0 && isOuter {
   307  				newShape := make(Shape, v.Shape().Dims()+diff)
   308  				for i := 0; i < diff; i++ {
   309  					newShape[i] = 1
   310  				}
   311  				copy(newShape[diff:], v.Shape())
   312  				v.reshape(newShape...)
   313  			} else if diff > 0 && isInner {
   314  				newShape := v.Shape().Clone()
   315  				newStrides := v.strides
   316  				for i := 0; i < diff; i++ {
   317  					newShape = append(newShape, 1)
   318  					newStrides = append(newStrides, 1)
   319  				}
   320  				v.shape = newShape
   321  				v.strides = newStrides
   322  			} else if T.Shape()[axis] == 1 {
   323  				if err := v.unsqueeze(axis); err != nil {
   324  					return nil, errors.Wrapf(err, "Unable to keep dims after slicing a shape %v on axis %d where the size is 1", T.Shape(), axis)
   325  				}
   326  			}
   327  		}
   328  
   329  		var vmask, Tmask []bool
   330  		vmask = v.mask
   331  		v.mask = nil
   332  		if mt, ok := T.(MaskedTensor); ok && mt.IsMasked() {
   333  			Tmask = mt.Mask()
   334  			mt.SetMask(nil)
   335  
   336  		}
   337  
   338  		if err = assignArray(v, T); err != nil {
   339  			return nil, errors.Wrap(err, "Unable to assignArray in denseConcat")
   340  		}
   341  		// if it's a masked tensor, we copy the mask as well
   342  		if Tmask != nil {
   343  			if vmask != nil {
   344  				if cap(vmask) < len(Tmask) {
   345  					vmask2 := make([]bool, len(Tmask))
   346  					copy(vmask2, vmask)
   347  					vmask = vmask2
   348  				}
   349  				copy(vmask, Tmask)
   350  				v.SetMask(vmask)
   351  			}
   352  			// mt.SetMask(Tmask)
   353  		}
   354  
   355  		start = end
   356  	}
   357  
   358  	return retVal, nil
   359  }
   360  
   361  // Diag ...
   362  func (e StdEng) Diag(t Tensor) (retVal Tensor, err error) {
   363  	a, ok := t.(DenseTensor)
   364  	if !ok {
   365  		return nil, errors.Errorf("StdEng only works with DenseTensor for Diagonal()")
   366  	}
   367  
   368  	if a.Dims() != 2 {
   369  		err = errors.Errorf(dimMismatch, 2, a.Dims())
   370  		return
   371  	}
   372  
   373  	if err = typeclassCheck(a.Dtype(), numberTypes); err != nil {
   374  		return nil, errors.Wrap(err, "Diagonal")
   375  	}
   376  
   377  	rstride := a.Strides()[0]
   378  	cstride := a.Strides()[1]
   379  
   380  	r := a.Shape()[0]
   381  	c := a.Shape()[1]
   382  
   383  	m := MinInt(r, c)
   384  	stride := rstride + cstride
   385  
   386  	b := a.Clone().(DenseTensor)
   387  	b.Zero()
   388  
   389  	switch a.rtype().Size() {
   390  	case 1:
   391  		bdata := b.hdr().Uint8s()
   392  		adata := a.hdr().Uint8s()
   393  		for i := 0; i < m; i++ {
   394  			bdata[i] = adata[i*stride]
   395  		}
   396  	case 2:
   397  		bdata := b.hdr().Uint16s()
   398  		adata := a.hdr().Uint16s()
   399  		for i := 0; i < m; i++ {
   400  			bdata[i] = adata[i*stride]
   401  		}
   402  	case 4:
   403  		bdata := b.hdr().Uint32s()
   404  		adata := a.hdr().Uint32s()
   405  		for i := 0; i < m; i++ {
   406  			bdata[i] = adata[i*stride]
   407  		}
   408  	case 8:
   409  		bdata := b.hdr().Uint64s()
   410  		adata := a.hdr().Uint64s()
   411  		for i := 0; i < m; i++ {
   412  			bdata[i] = adata[i*stride]
   413  		}
   414  	default:
   415  		return nil, errors.Errorf(typeNYI, "Arbitrary sized diag", t)
   416  	}
   417  	return b, nil
   418  }