github.com/wzzhu/tensor@v0.9.24/defaultengine_selbyidx.go (about) 1 package tensor 2 3 import ( 4 "github.com/pkg/errors" 5 "github.com/wzzhu/tensor/internal/storage" 6 7 "reflect" 8 ) 9 10 // SelectByIndices selects the values given the in `indices` tensor. 11 // 12 // Currently SelectByIndices only supports Dense tensors that do not require the use of iterators. 13 // Please make a pull request to support tensors that require the use of an iterator to traverse data. 14 func (e StdEng) SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { 15 if !indices.Shape().IsVectorLike() { 16 return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", indices.Shape()) 17 } 18 if indices.Dtype() != Int { 19 return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", indices.Dtype()) 20 } 21 22 // if b is a scalar, then use Slice 23 if a.Shape().IsScalarEquiv() { 24 slices := make([]Slice, a.Shape().Dims()) 25 slices[axis] = ss(getInts(indices)[0]) 26 return a.Slice(slices...) 27 } 28 29 expectedShape := a.Shape().Clone() 30 expectedShape[axis] = indices.Shape().TotalSize() 31 32 var reuse DenseTensor 33 var safe, toReuse, _ bool 34 if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, a.Dtype(), a.DataOrder(), true, opts...); err != nil { 35 return nil, errors.Wrap(err, "Unable to handle funcOpts") 36 } 37 if safe || !toReuse && reuse == nil && safe { 38 // create reuse 39 reuse = New(WithShape(expectedShape...), Of(a.Dtype())) 40 } 41 42 if !safe { 43 if a.Shape()[axis] != indices.Shape().TotalSize() { 44 expected := a.Shape().Clone() 45 expected[axis] = indices.Shape().TotalSize() 46 return nil, errors.Errorf("Expected a safe resuse to have the same shape as the expected shape of the result: %v. The input a has %v ", expected, a.Shape()) 47 } 48 49 reuse = a.(DenseTensor) 50 } 51 52 typ := a.Dtype().Type 53 var dataA, dataB, dataReuse *storage.Header 54 var ait, bit, iit Iterator 55 var useIter bool 56 if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, indices, reuse); err != nil { 57 return nil, errors.Wrapf(err, "StdEng.Add") 58 } 59 60 if useIter { 61 e.iterSelectByIdx(axis, dataA, dataB, dataReuse, ait, bit, iit) 62 //TODO 63 return 64 } 65 66 e.selectByIdx(axis, dataB.Ints(), typ, dataA, dataReuse, a.(*Dense).AP, reuse.(*Dense).AP) 67 return reuse, nil 68 } 69 70 func (e StdEng) iterSelectByIdx(axis int, dataA, dataB, dataReuse *storage.Header, ait, bit, iit Iterator) { 71 panic("iterSelectByIdx is not yet implemented") 72 } 73 74 func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, dataRetVal *storage.Header, apA, apRet AP) { 75 isInnermost := axis == apA.shape.Dims()-1 76 77 outer := ProdInts(apA.shape[:axis]) 78 79 axStride := apA.strides[axis] 80 retStride := apRet.strides[axis] 81 var outerRetStride int 82 if axis == 0 { 83 // then it's the outermost 84 outerRetStride = apRet.strides[axis] * 2 85 } else { 86 outerRetStride = apRet.strides[axis-1] 87 } 88 89 srcCoord := make([]int, apA.shape.Dims()) 90 dstCoord := make([]int, apRet.shape.Dims()) 91 92 if isInnermost { 93 prevAxis := axis - 1 94 if prevAxis < 0 { 95 // this may be the case if input is a vector 96 prevAxis = 0 97 } 98 prevStride := apA.strides[prevAxis] 99 retPrevStride := apRet.strides[prevAxis] 100 for i, idx := range indices { 101 srcCoord[axis] = idx 102 dstCoord[axis] = i 103 start, _ := Ltoi(apA.shape, apA.strides, srcCoord...) 104 dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...) 105 for o := 0; o < outer; o++ { 106 end := start + axStride 107 dstEnd := dstStart + retStride 108 109 storage.CopySliced(typ, dataRetVal, dstStart, dstEnd, dataA, start, end) 110 111 start += prevStride 112 dstStart += retPrevStride 113 114 } 115 } 116 return 117 } 118 119 for i, idx := range indices { 120 srcCoord[axis] = idx 121 dstCoord[axis] = i 122 start, _ := Ltoi(apA.shape, apA.strides, srcCoord...) 123 dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...) 124 125 for o := 0; o < outer; o++ { 126 end := start + axStride 127 dstEnd := dstStart + retStride 128 129 storage.CopySliced(typ, dataRetVal, dstStart, dstEnd, dataA, start, end) 130 131 start = end + axStride 132 dstStart = dstEnd + (outerRetStride - retStride) 133 } 134 } 135 } 136 137 // SelectByIndicesB computes the gradient of the result of `SelectByIndices`. 138 // 139 // Currently SelectByIndicesB only supports Dense tensors that do not require the use of iterators. 140 // Please make a pull request to support tensors that require the use of an iterator to traverse data. 141 func (e StdEng) SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { 142 if !indices.Shape().IsVectorLike() { 143 return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", outGrad.Shape()) 144 } 145 if indices.Dtype() != Int { 146 return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", outGrad.Dtype()) 147 } 148 149 // if b is a scalar, then use Slice 150 if input.Shape().IsScalarEquiv() { 151 slices := make([]Slice, input.Shape().Dims()) 152 slices[axis] = ss(outGrad.Data().([]int)[0]) 153 return input.Slice(slices...) 154 } 155 156 expectedShape := input.Shape().Clone() 157 158 var reuse DenseTensor 159 var _, toReuse, _ bool 160 if reuse, _, toReuse, _, _, err = handleFuncOpts(input.Shape(), input.Dtype(), input.DataOrder(), true, opts...); err != nil { 161 return nil, errors.Wrap(err, "Unable to handle funcOpts") 162 } 163 if !toReuse && reuse == nil { 164 // create reuse 165 reuse = New(WithShape(expectedShape...), Of(input.Dtype())) 166 } 167 168 typ := input.Dtype().Type 169 var _, dataB, dataReuse *storage.Header 170 var _, bit, iit Iterator 171 var useIter bool 172 if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(input, outGrad, reuse); err != nil { 173 return nil, errors.Wrapf(err, "StdEng.SelectByIndicesB") 174 } 175 176 if useIter { 177 e.iterSelectByIndicesB(axis, dataB, dataReuse, bit, iit) 178 //TODO 179 return 180 } 181 182 e.selectByIndicesB(axis, getInts(indices), typ, dataB, dataReuse, outGrad.(*Dense).AP, reuse.(*Dense).AP) 183 184 return reuse, nil 185 } 186 187 func (e StdEng) iterSelectByIndicesB(axis int, dataB, dataGradA *storage.Header, bit, iit Iterator) { 188 panic("iterSelectByIndicesB not implemented yet") 189 } 190 191 func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, dataB, dataGradA *storage.Header, apB, apRet AP) { 192 isInnermost := axis == apB.shape.Dims()-1 193 194 outer := ProdInts(apB.shape[:axis]) 195 196 axStride := apB.strides[axis] 197 retStride := apRet.strides[axis] 198 var outerRetStride int 199 if axis == 0 { 200 outerRetStride = apRet.strides[axis] * 2 201 } else { 202 outerRetStride = apRet.strides[axis-1] 203 } 204 205 dstCoord := make([]int, apB.shape.Dims()) 206 srcCoord := make([]int, apRet.shape.Dims()) 207 208 if isInnermost { 209 prevAxis := axis - 1 210 if prevAxis < 0 { 211 // this may be the case if input is a vector 212 prevAxis = 0 213 } 214 retPrevStride := apB.strides[prevAxis] 215 prevStride := apRet.strides[prevAxis] 216 for i, idx := range indices { 217 dstCoord[axis] = idx 218 srcCoord[axis] = i 219 dstStart, _ := Ltoi(apB.shape, apB.strides, dstCoord...) 220 start, _ := Ltoi(apRet.shape, apRet.strides, srcCoord...) 221 for o := 0; o < outer; o++ { 222 dstEnd := dstStart + axStride 223 end := start + retStride 224 225 e.E.AddSliced(typ, dataGradA, dstStart, dstEnd, dataB, start, end) 226 227 dstStart += prevStride 228 start += retPrevStride 229 230 } 231 } 232 return 233 } 234 235 for i, idx := range indices { 236 dstCoord[axis] = idx 237 srcCoord[axis] = i 238 dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...) 239 start, _ := Ltoi(apB.shape, apB.strides, srcCoord...) 240 241 for o := 0; o < outer; o++ { 242 dstEnd := dstStart + axStride 243 end := start + retStride 244 245 e.E.AddSliced(typ, dataGradA, dstStart, dstEnd, dataB, start, end) 246 247 dstStart = dstEnd + axStride 248 start = end + (outerRetStride - retStride) 249 } 250 } 251 }