github.com/wzzhu/tensor@v0.9.24/dense_assign.go (about) 1 package tensor 2 3 import ( 4 "github.com/pkg/errors" 5 ) 6 7 func overlaps(a, b DenseTensor) bool { 8 if a.cap() == 0 || b.cap() == 0 { 9 return false 10 } 11 aarr := a.arr() 12 barr := b.arr() 13 if aarr.Uintptr() == barr.Uintptr() { 14 return true 15 } 16 aptr := aarr.Uintptr() 17 bptr := barr.Uintptr() 18 19 capA := aptr + uintptr(cap(aarr.Header.Raw)) 20 capB := bptr + uintptr(cap(barr.Header.Raw)) 21 22 switch { 23 case aptr < bptr: 24 if bptr < capA { 25 return true 26 } 27 case aptr > bptr: 28 if aptr < capB { 29 return true 30 } 31 } 32 return false 33 } 34 35 func assignArray(dest, src DenseTensor) (err error) { 36 // var copiedSrc bool 37 38 if src.IsScalar() { 39 panic("HELP") 40 } 41 42 dd := dest.Dims() 43 sd := src.Dims() 44 45 dstrides := dest.Strides() 46 sstrides := src.Strides() 47 48 var ds, ss int 49 ds = dstrides[0] 50 if src.IsVector() { 51 ss = sstrides[0] 52 } else { 53 ss = sstrides[sd-1] 54 } 55 56 // when dd == 1, and the strides point in the same direction 57 // we copy to a temporary if there is an overlap of data 58 if ((dd == 1 && sd >= 1 && ds*ss < 0) || dd > 1) && overlaps(dest, src) { 59 // create temp 60 // copiedSrc = true 61 } 62 63 // broadcast src to dest for raw iteration 64 tmpShape := Shape(BorrowInts(sd)) 65 tmpStrides := BorrowInts(len(src.Strides())) 66 copy(tmpShape, src.Shape()) 67 copy(tmpStrides, src.Strides()) 68 defer ReturnInts(tmpShape) 69 defer ReturnInts(tmpStrides) 70 71 if sd > dd { 72 tmpDim := sd 73 for tmpDim > dd && tmpShape[0] == 1 { 74 tmpDim-- 75 76 // this is better than tmpShape = tmpShape[1:] 77 // because we are going to return these ints later 78 copy(tmpShape, tmpShape[1:]) 79 copy(tmpStrides, tmpStrides[1:]) 80 } 81 } 82 83 var newStrides []int 84 if newStrides, err = BroadcastStrides(dest.Shape(), tmpShape, dstrides, tmpStrides); err != nil { 85 err = errors.Wrapf(err, "BroadcastStrides failed") 86 return 87 } 88 dap := dest.Info() 89 sap := MakeAP(tmpShape, newStrides, src.Info().o, src.Info().Δ) 90 91 diter := newFlatIterator(dap) 92 siter := newFlatIterator(&sap) 93 _, err = copyDenseIter(dest, src, diter, siter) 94 sap.zeroOnly() // cleanup, but not entirely because tmpShape and tmpStrides are separately cleaned up. Don't double free 95 return 96 }