gorgonia.org/tensor@v0.9.24/defaultengine_matop_stack.go (about)

     1  package tensor
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  )
     6  
     7  // This file contains code for the execution engine to stack tensors
     8  
     9  func (e StdEng) StackDense(t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error) {
    10  	opdims := t.Dims()
    11  	if axis >= opdims+1 {
    12  		err = errors.Errorf(dimMismatch, opdims+1, axis)
    13  		return
    14  	}
    15  
    16  	newShape := Shape(BorrowInts(opdims + 1))
    17  	newShape[axis] = len(others) + 1
    18  	shape := t.Shape()
    19  	var cur int
    20  	for i, s := range shape {
    21  		if i == axis {
    22  			cur++
    23  		}
    24  		newShape[cur] = s
    25  		cur++
    26  	}
    27  
    28  	info := t.Info()
    29  	var newStrides []int
    30  	if info.o.IsColMajor() {
    31  		newStrides = newShape.CalcStridesColMajor()
    32  	} else {
    33  		newStrides = newShape.CalcStrides()
    34  
    35  	}
    36  	ap := MakeAP(newShape, newStrides, info.o, info.Δ)
    37  
    38  	allNoMat := !t.RequiresIterator()
    39  	for _, ot := range others {
    40  		if allNoMat && ot.RequiresIterator() {
    41  			allNoMat = false
    42  		}
    43  	}
    44  
    45  	retVal = recycledDense(t.Dtype(), ap.Shape(), WithEngine(e))
    46  	retVal.setAP(&ap)
    47  
    48  	// the "viewStack" method is the more generalized method
    49  	// and will work for all Tensors, regardless of whether it's a view
    50  	// But the simpleStack is faster, and is an optimization
    51  
    52  	if allNoMat {
    53  		retVal = e.denseSimpleStack(t, retVal, axis, others)
    54  	} else {
    55  		retVal, err = e.denseViewStack(t, retVal, axis, others)
    56  	}
    57  	return
    58  }
    59  
    60  func (e StdEng) denseSimpleStack(t, retVal DenseTensor, axis int, others []DenseTensor) DenseTensor {
    61  	switch axis {
    62  	case 0:
    63  		copyDense(retVal, t)
    64  		next := t.len()
    65  		for _, ot := range others {
    66  			copyDenseSliced(retVal, next, retVal.len(), ot, 0, ot.len())
    67  			next += ot.len()
    68  		}
    69  	default:
    70  		axisStride := retVal.Info().Strides()[axis]
    71  		batches := retVal.len() / axisStride
    72  
    73  		destStart := 0
    74  		start := 0
    75  		end := start + axisStride
    76  
    77  		for i := 0; i < batches; i++ {
    78  			copyDenseSliced(retVal, destStart, retVal.len(), t, start, end)
    79  			for _, ot := range others {
    80  				destStart += axisStride
    81  				copyDenseSliced(retVal, destStart, retVal.len(), ot, start, end)
    82  				i++
    83  			}
    84  			destStart += axisStride
    85  			start += axisStride
    86  			end += axisStride
    87  		}
    88  	}
    89  	return retVal
    90  }
    91  
    92  func (e StdEng) denseViewStack(t, retVal DenseTensor, axis int, others []DenseTensor) (DenseTensor, error) {
    93  	axisStride := retVal.Info().Strides()[axis]
    94  	batches := retVal.len() / axisStride
    95  
    96  	it := IteratorFromDense(t)
    97  	its := make([]Iterator, 0, len(others))
    98  	for _, ot := range others {
    99  		oter := IteratorFromDense(ot)
   100  		its = append(its, oter)
   101  	}
   102  
   103  	err := e.doViewStack(t, retVal, axisStride, batches, it, others, its)
   104  	return retVal, err
   105  }
   106  
   107  func (e StdEng) doViewStack(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) error {
   108  	switch int(t.Dtype().Size()) {
   109  	case 1:
   110  		return e.doViewStack1(t, retVal, axisStride, batches, it, others, its)
   111  	case 2:
   112  		return e.doViewStack2(t, retVal, axisStride, batches, it, others, its)
   113  	case 4:
   114  		return e.doViewStack4(t, retVal, axisStride, batches, it, others, its)
   115  	case 8:
   116  		return e.doViewStack8(t, retVal, axisStride, batches, it, others, its)
   117  	default:
   118  		return e.doViewStackArbitrary(t, retVal, axisStride, batches, it, others, its)
   119  	}
   120  }
   121  
   122  func (e StdEng) doViewStack1(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) (err error) {
   123  	data := retVal.hdr().Uint8s()[:0]
   124  	var mask []bool
   125  	var retIsMasked bool
   126  	if mt, ok := t.(MaskedTensor); ok {
   127  		retIsMasked = mt.IsMasked()
   128  	}
   129  	for _, ot := range others {
   130  		if mt, ok := ot.(MaskedTensor); ok {
   131  			retIsMasked = retIsMasked || mt.IsMasked()
   132  		}
   133  	}
   134  
   135  	f := func(t DenseTensor, it Iterator) (last int, isMasked bool, err error) {
   136  		var tmask []bool
   137  		if mt, ok := t.(MaskedTensor); ok {
   138  			tmask = mt.Mask()
   139  			isMasked = mt.IsMasked()
   140  		}
   141  
   142  		for last = 0; last < axisStride; last++ {
   143  			id, err := it.Next()
   144  			if handleNoOp(err) != nil {
   145  				return -1, isMasked, errors.Wrap(err, "doviewStackfailed")
   146  			}
   147  			if err != nil {
   148  				break
   149  			}
   150  			data = append(data, t.hdr().Uint8s()[id])
   151  			if isMasked {
   152  				mask = append(mask, tmask[id])
   153  			}
   154  		}
   155  		return
   156  	}
   157  
   158  	for i := 0; i < batches; i++ {
   159  		var last int
   160  		var isMasked bool
   161  		if last, isMasked, err = f(t, it); err != nil {
   162  			return
   163  		}
   164  		if retIsMasked && (!isMasked) {
   165  			mask = append(mask, make([]bool, last)...)
   166  		}
   167  		for j, ot := range others {
   168  			if last, isMasked, err = f(ot, its[j]); err != nil {
   169  				return
   170  			}
   171  			if retIsMasked && (!isMasked) {
   172  				mask = append(mask, make([]bool, last)...)
   173  			}
   174  		}
   175  	}
   176  
   177  	if mt, ok := retVal.(MaskedTensor); ok {
   178  		mt.SetMask(mask)
   179  	}
   180  	return nil
   181  }
   182  
   183  func (e StdEng) doViewStack2(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) (err error) {
   184  	data := retVal.hdr().Uint16s()[:0]
   185  	var mask []bool
   186  	var retIsMasked bool
   187  	if mt, ok := t.(MaskedTensor); ok {
   188  		retIsMasked = mt.IsMasked()
   189  	}
   190  	for _, ot := range others {
   191  		if mt, ok := ot.(MaskedTensor); ok {
   192  			retIsMasked = retIsMasked || mt.IsMasked()
   193  		}
   194  	}
   195  
   196  	f := func(t DenseTensor, it Iterator) (last int, isMasked bool, err error) {
   197  		var tmask []bool
   198  		if mt, ok := t.(MaskedTensor); ok {
   199  			tmask = mt.Mask()
   200  			isMasked = mt.IsMasked()
   201  		}
   202  
   203  		for last = 0; last < axisStride; last++ {
   204  			id, err := it.Next()
   205  			if handleNoOp(err) != nil {
   206  				return -1, isMasked, errors.Wrap(err, "doviewStackfailed")
   207  			}
   208  			if err != nil {
   209  				break
   210  			}
   211  			data = append(data, t.hdr().Uint16s()[id])
   212  			if isMasked {
   213  				mask = append(mask, tmask[id])
   214  			}
   215  		}
   216  		return
   217  	}
   218  
   219  	for i := 0; i < batches; i++ {
   220  		var last int
   221  		var isMasked bool
   222  		if last, isMasked, err = f(t, it); err != nil {
   223  			return
   224  		}
   225  		if retIsMasked && (!isMasked) {
   226  			mask = append(mask, make([]bool, last)...)
   227  		}
   228  		for j, ot := range others {
   229  			if last, isMasked, err = f(ot, its[j]); err != nil {
   230  				return
   231  			}
   232  			if retIsMasked && (!isMasked) {
   233  				mask = append(mask, make([]bool, last)...)
   234  			}
   235  		}
   236  	}
   237  
   238  	if mt, ok := retVal.(MaskedTensor); ok {
   239  		mt.SetMask(mask)
   240  	}
   241  	return nil
   242  }
   243  
   244  func (e StdEng) doViewStack4(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) (err error) {
   245  	data := retVal.hdr().Uint32s()[:0]
   246  	var mask []bool
   247  	var retIsMasked bool
   248  	if mt, ok := t.(MaskedTensor); ok {
   249  		retIsMasked = mt.IsMasked()
   250  	}
   251  	for _, ot := range others {
   252  		if mt, ok := ot.(MaskedTensor); ok {
   253  			retIsMasked = retIsMasked || mt.IsMasked()
   254  		}
   255  	}
   256  
   257  	f := func(t DenseTensor, it Iterator) (last int, isMasked bool, err error) {
   258  		var tmask []bool
   259  		if mt, ok := t.(MaskedTensor); ok {
   260  			tmask = mt.Mask()
   261  			isMasked = mt.IsMasked()
   262  		}
   263  
   264  		for last = 0; last < axisStride; last++ {
   265  			id, err := it.Next()
   266  			if handleNoOp(err) != nil {
   267  				return -1, isMasked, errors.Wrap(err, "doviewStackfailed")
   268  			}
   269  			if err != nil {
   270  				break
   271  			}
   272  			data = append(data, t.hdr().Uint32s()[id])
   273  			if isMasked {
   274  				mask = append(mask, tmask[id])
   275  			}
   276  		}
   277  		return
   278  	}
   279  
   280  	for i := 0; i < batches; i++ {
   281  		var last int
   282  		var isMasked bool
   283  		if last, isMasked, err = f(t, it); err != nil {
   284  			return
   285  		}
   286  		if retIsMasked && (!isMasked) {
   287  			mask = append(mask, make([]bool, last)...)
   288  		}
   289  		for j, ot := range others {
   290  			if last, isMasked, err = f(ot, its[j]); err != nil {
   291  				return
   292  			}
   293  			if retIsMasked && (!isMasked) {
   294  				mask = append(mask, make([]bool, last)...)
   295  			}
   296  		}
   297  	}
   298  
   299  	if mt, ok := retVal.(MaskedTensor); ok {
   300  		mt.SetMask(mask)
   301  	}
   302  	return nil
   303  }
   304  
   305  func (e StdEng) doViewStack8(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) (err error) {
   306  	data := retVal.hdr().Uint64s()[:0]
   307  	var mask []bool
   308  	var retIsMasked bool
   309  	if mt, ok := t.(MaskedTensor); ok {
   310  		retIsMasked = mt.IsMasked()
   311  	}
   312  	for _, ot := range others {
   313  		if mt, ok := ot.(MaskedTensor); ok {
   314  			retIsMasked = retIsMasked || mt.IsMasked()
   315  		}
   316  	}
   317  
   318  	f := func(t DenseTensor, it Iterator) (last int, isMasked bool, err error) {
   319  		var tmask []bool
   320  		if mt, ok := t.(MaskedTensor); ok {
   321  			tmask = mt.Mask()
   322  			isMasked = mt.IsMasked()
   323  		}
   324  
   325  		for last = 0; last < axisStride; last++ {
   326  			id, err := it.Next()
   327  			if handleNoOp(err) != nil {
   328  				return -1, isMasked, errors.Wrap(err, "doviewStackfailed")
   329  			}
   330  			if err != nil {
   331  				break
   332  			}
   333  			data = append(data, t.hdr().Uint64s()[id])
   334  			if isMasked {
   335  				mask = append(mask, tmask[id])
   336  			}
   337  		}
   338  		return
   339  	}
   340  
   341  	for i := 0; i < batches; i++ {
   342  		var last int
   343  		var isMasked bool
   344  		if last, isMasked, err = f(t, it); err != nil {
   345  			return
   346  		}
   347  		if retIsMasked && (!isMasked) {
   348  			mask = append(mask, make([]bool, last)...)
   349  		}
   350  		for j, ot := range others {
   351  			if last, isMasked, err = f(ot, its[j]); err != nil {
   352  				return
   353  			}
   354  			if retIsMasked && (!isMasked) {
   355  				mask = append(mask, make([]bool, last)...)
   356  			}
   357  		}
   358  	}
   359  
   360  	if mt, ok := retVal.(MaskedTensor); ok {
   361  		mt.SetMask(mask)
   362  	}
   363  	return nil
   364  }
   365  
   366  func (e StdEng) doViewStackArbitrary(t, retVal DenseTensor, axisStride, batches int, it Iterator, others []DenseTensor, its []Iterator) (err error) {
   367  	dt := t.Dtype()
   368  	data := retVal.hdr().Raw[:0] // truncate to 0
   369  	size := int(dt.Size())
   370  	var mask []bool
   371  	var retIsMasked bool
   372  	if mt, ok := t.(MaskedTensor); ok {
   373  		retIsMasked = mt.IsMasked()
   374  	}
   375  	for _, ot := range others {
   376  		if mt, ok := ot.(MaskedTensor); ok {
   377  			retIsMasked = retIsMasked || mt.IsMasked()
   378  		}
   379  	}
   380  
   381  	f := func(t DenseTensor, it Iterator) (last int, isMasked bool, err error) {
   382  		var tmask []bool
   383  		if mt, ok := t.(MaskedTensor); ok {
   384  			tmask = mt.Mask()
   385  			isMasked = mt.IsMasked()
   386  		}
   387  		bs := t.hdr().Raw
   388  
   389  		for last = 0; last < axisStride; last++ {
   390  			id, err := it.Next()
   391  			if handleNoOp(err) != nil {
   392  				return -1, isMasked, errors.Wrap(err, "doviewStackfailed")
   393  			}
   394  			if err != nil {
   395  				break
   396  			}
   397  			v := bs[id*size : id*size+size]
   398  			data = append(data, v...)
   399  			if isMasked {
   400  				mask = append(mask, tmask[id])
   401  			}
   402  		}
   403  		return
   404  	}
   405  
   406  	for i := 0; i < batches; i++ {
   407  		var last int
   408  		var isMasked bool
   409  		if last, isMasked, err = f(t, it); err != nil {
   410  			return
   411  		}
   412  		if retIsMasked && (!isMasked) {
   413  			mask = append(mask, make([]bool, last)...)
   414  		}
   415  		for j, ot := range others {
   416  			if last, isMasked, err = f(ot, its[j]); err != nil {
   417  				return
   418  			}
   419  			if retIsMasked && (!isMasked) {
   420  				mask = append(mask, make([]bool, last)...)
   421  			}
   422  		}
   423  	}
   424  
   425  	if mt, ok := retVal.(MaskedTensor); ok {
   426  		mt.SetMask(mask)
   427  	}
   428  	return nil
   429  }