github.com/wzzhu/tensor@v0.9.24/dense_assign.go (about)

     1  package tensor
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  )
     6  
     7  func overlaps(a, b DenseTensor) bool {
     8  	if a.cap() == 0 || b.cap() == 0 {
     9  		return false
    10  	}
    11  	aarr := a.arr()
    12  	barr := b.arr()
    13  	if aarr.Uintptr() == barr.Uintptr() {
    14  		return true
    15  	}
    16  	aptr := aarr.Uintptr()
    17  	bptr := barr.Uintptr()
    18  
    19  	capA := aptr + uintptr(cap(aarr.Header.Raw))
    20  	capB := bptr + uintptr(cap(barr.Header.Raw))
    21  
    22  	switch {
    23  	case aptr < bptr:
    24  		if bptr < capA {
    25  			return true
    26  		}
    27  	case aptr > bptr:
    28  		if aptr < capB {
    29  			return true
    30  		}
    31  	}
    32  	return false
    33  }
    34  
    35  func assignArray(dest, src DenseTensor) (err error) {
    36  	// var copiedSrc bool
    37  
    38  	if src.IsScalar() {
    39  		panic("HELP")
    40  	}
    41  
    42  	dd := dest.Dims()
    43  	sd := src.Dims()
    44  
    45  	dstrides := dest.Strides()
    46  	sstrides := src.Strides()
    47  
    48  	var ds, ss int
    49  	ds = dstrides[0]
    50  	if src.IsVector() {
    51  		ss = sstrides[0]
    52  	} else {
    53  		ss = sstrides[sd-1]
    54  	}
    55  
    56  	// when dd == 1, and the strides point in the same direction
    57  	// we copy to a temporary if there is an overlap of data
    58  	if ((dd == 1 && sd >= 1 && ds*ss < 0) || dd > 1) && overlaps(dest, src) {
    59  		// create temp
    60  		// copiedSrc = true
    61  	}
    62  
    63  	// broadcast src to dest for raw iteration
    64  	tmpShape := Shape(BorrowInts(sd))
    65  	tmpStrides := BorrowInts(len(src.Strides()))
    66  	copy(tmpShape, src.Shape())
    67  	copy(tmpStrides, src.Strides())
    68  	defer ReturnInts(tmpShape)
    69  	defer ReturnInts(tmpStrides)
    70  
    71  	if sd > dd {
    72  		tmpDim := sd
    73  		for tmpDim > dd && tmpShape[0] == 1 {
    74  			tmpDim--
    75  
    76  			// this is better than tmpShape = tmpShape[1:]
    77  			// because we are going to return these ints later
    78  			copy(tmpShape, tmpShape[1:])
    79  			copy(tmpStrides, tmpStrides[1:])
    80  		}
    81  	}
    82  
    83  	var newStrides []int
    84  	if newStrides, err = BroadcastStrides(dest.Shape(), tmpShape, dstrides, tmpStrides); err != nil {
    85  		err = errors.Wrapf(err, "BroadcastStrides failed")
    86  		return
    87  	}
    88  	dap := dest.Info()
    89  	sap := MakeAP(tmpShape, newStrides, src.Info().o, src.Info().Δ)
    90  
    91  	diter := newFlatIterator(dap)
    92  	siter := newFlatIterator(&sap)
    93  	_, err = copyDenseIter(dest, src, diter, siter)
    94  	sap.zeroOnly() // cleanup, but not entirely because tmpShape and tmpStrides are separately cleaned up.  Don't double free
    95  	return
    96  }