github.com/wzzhu/tensor@v0.9.24/defaultengine_selbyidx.go (about)

     1  package tensor
     2  
     3  import (
     4  	"github.com/pkg/errors"
     5  	"github.com/wzzhu/tensor/internal/storage"
     6  
     7  	"reflect"
     8  )
     9  
    10  // SelectByIndices selects the values given the in `indices` tensor.
    11  //
    12  // Currently SelectByIndices only supports Dense tensors that do not require the use of iterators.
    13  // Please make a pull request to support tensors that require the use of an iterator to traverse data.
    14  func (e StdEng) SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
    15  	if !indices.Shape().IsVectorLike() {
    16  		return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", indices.Shape())
    17  	}
    18  	if indices.Dtype() != Int {
    19  		return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", indices.Dtype())
    20  	}
    21  
    22  	// if b is a scalar, then use Slice
    23  	if a.Shape().IsScalarEquiv() {
    24  		slices := make([]Slice, a.Shape().Dims())
    25  		slices[axis] = ss(getInts(indices)[0])
    26  		return a.Slice(slices...)
    27  	}
    28  
    29  	expectedShape := a.Shape().Clone()
    30  	expectedShape[axis] = indices.Shape().TotalSize()
    31  
    32  	var reuse DenseTensor
    33  	var safe, toReuse, _ bool
    34  	if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, a.Dtype(), a.DataOrder(), true, opts...); err != nil {
    35  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
    36  	}
    37  	if safe || !toReuse && reuse == nil && safe {
    38  		// create reuse
    39  		reuse = New(WithShape(expectedShape...), Of(a.Dtype()))
    40  	}
    41  
    42  	if !safe {
    43  		if a.Shape()[axis] != indices.Shape().TotalSize() {
    44  			expected := a.Shape().Clone()
    45  			expected[axis] = indices.Shape().TotalSize()
    46  			return nil, errors.Errorf("Expected a safe resuse to have the same shape as the expected shape of the result: %v. The input a has %v ", expected, a.Shape())
    47  		}
    48  
    49  		reuse = a.(DenseTensor)
    50  	}
    51  
    52  	typ := a.Dtype().Type
    53  	var dataA, dataB, dataReuse *storage.Header
    54  	var ait, bit, iit Iterator
    55  	var useIter bool
    56  	if dataA, dataB, dataReuse, ait, bit, iit, useIter, _, err = prepDataVV(a, indices, reuse); err != nil {
    57  		return nil, errors.Wrapf(err, "StdEng.Add")
    58  	}
    59  
    60  	if useIter {
    61  		e.iterSelectByIdx(axis, dataA, dataB, dataReuse, ait, bit, iit)
    62  		//TODO
    63  		return
    64  	}
    65  
    66  	e.selectByIdx(axis, dataB.Ints(), typ, dataA, dataReuse, a.(*Dense).AP, reuse.(*Dense).AP)
    67  	return reuse, nil
    68  }
    69  
    70  func (e StdEng) iterSelectByIdx(axis int, dataA, dataB, dataReuse *storage.Header, ait, bit, iit Iterator) {
    71  	panic("iterSelectByIdx is not yet implemented")
    72  }
    73  
    74  func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, dataRetVal *storage.Header, apA, apRet AP) {
    75  	isInnermost := axis == apA.shape.Dims()-1
    76  
    77  	outer := ProdInts(apA.shape[:axis])
    78  
    79  	axStride := apA.strides[axis]
    80  	retStride := apRet.strides[axis]
    81  	var outerRetStride int
    82  	if axis == 0 {
    83  		// then it's the outermost
    84  		outerRetStride = apRet.strides[axis] * 2
    85  	} else {
    86  		outerRetStride = apRet.strides[axis-1]
    87  	}
    88  
    89  	srcCoord := make([]int, apA.shape.Dims())
    90  	dstCoord := make([]int, apRet.shape.Dims())
    91  
    92  	if isInnermost {
    93  		prevAxis := axis - 1
    94  		if prevAxis < 0 {
    95  			// this may be the case if input is a vector
    96  			prevAxis = 0
    97  		}
    98  		prevStride := apA.strides[prevAxis]
    99  		retPrevStride := apRet.strides[prevAxis]
   100  		for i, idx := range indices {
   101  			srcCoord[axis] = idx
   102  			dstCoord[axis] = i
   103  			start, _ := Ltoi(apA.shape, apA.strides, srcCoord...)
   104  			dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...)
   105  			for o := 0; o < outer; o++ {
   106  				end := start + axStride
   107  				dstEnd := dstStart + retStride
   108  
   109  				storage.CopySliced(typ, dataRetVal, dstStart, dstEnd, dataA, start, end)
   110  
   111  				start += prevStride
   112  				dstStart += retPrevStride
   113  
   114  			}
   115  		}
   116  		return
   117  	}
   118  
   119  	for i, idx := range indices {
   120  		srcCoord[axis] = idx
   121  		dstCoord[axis] = i
   122  		start, _ := Ltoi(apA.shape, apA.strides, srcCoord...)
   123  		dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...)
   124  
   125  		for o := 0; o < outer; o++ {
   126  			end := start + axStride
   127  			dstEnd := dstStart + retStride
   128  
   129  			storage.CopySliced(typ, dataRetVal, dstStart, dstEnd, dataA, start, end)
   130  
   131  			start = end + axStride
   132  			dstStart = dstEnd + (outerRetStride - retStride)
   133  		}
   134  	}
   135  }
   136  
   137  // SelectByIndicesB computes the gradient of the result of `SelectByIndices`.
   138  //
   139  // Currently SelectByIndicesB only supports Dense tensors that do not require the use of iterators.
   140  // Please make a pull request to support tensors that require the use of an iterator to traverse data.
   141  func (e StdEng) SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
   142  	if !indices.Shape().IsVectorLike() {
   143  		return nil, errors.Errorf("Expected indices to be a vector. Got %v instead", outGrad.Shape())
   144  	}
   145  	if indices.Dtype() != Int {
   146  		return nil, errors.Errorf("Expected indices to be a vector of ints. Got %v instead", outGrad.Dtype())
   147  	}
   148  
   149  	// if b is a scalar, then use Slice
   150  	if input.Shape().IsScalarEquiv() {
   151  		slices := make([]Slice, input.Shape().Dims())
   152  		slices[axis] = ss(outGrad.Data().([]int)[0])
   153  		return input.Slice(slices...)
   154  	}
   155  
   156  	expectedShape := input.Shape().Clone()
   157  
   158  	var reuse DenseTensor
   159  	var _, toReuse, _ bool
   160  	if reuse, _, toReuse, _, _, err = handleFuncOpts(input.Shape(), input.Dtype(), input.DataOrder(), true, opts...); err != nil {
   161  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
   162  	}
   163  	if !toReuse && reuse == nil {
   164  		// create reuse
   165  		reuse = New(WithShape(expectedShape...), Of(input.Dtype()))
   166  	}
   167  
   168  	typ := input.Dtype().Type
   169  	var _, dataB, dataReuse *storage.Header
   170  	var _, bit, iit Iterator
   171  	var useIter bool
   172  	if _, dataB, dataReuse, _, bit, iit, useIter, _, err = prepDataVV(input, outGrad, reuse); err != nil {
   173  		return nil, errors.Wrapf(err, "StdEng.SelectByIndicesB")
   174  	}
   175  
   176  	if useIter {
   177  		e.iterSelectByIndicesB(axis, dataB, dataReuse, bit, iit)
   178  		//TODO
   179  		return
   180  	}
   181  
   182  	e.selectByIndicesB(axis, getInts(indices), typ, dataB, dataReuse, outGrad.(*Dense).AP, reuse.(*Dense).AP)
   183  
   184  	return reuse, nil
   185  }
   186  
   187  func (e StdEng) iterSelectByIndicesB(axis int, dataB, dataGradA *storage.Header, bit, iit Iterator) {
   188  	panic("iterSelectByIndicesB not implemented yet")
   189  }
   190  
   191  func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, dataB, dataGradA *storage.Header, apB, apRet AP) {
   192  	isInnermost := axis == apB.shape.Dims()-1
   193  
   194  	outer := ProdInts(apB.shape[:axis])
   195  
   196  	axStride := apB.strides[axis]
   197  	retStride := apRet.strides[axis]
   198  	var outerRetStride int
   199  	if axis == 0 {
   200  		outerRetStride = apRet.strides[axis] * 2
   201  	} else {
   202  		outerRetStride = apRet.strides[axis-1]
   203  	}
   204  
   205  	dstCoord := make([]int, apB.shape.Dims())
   206  	srcCoord := make([]int, apRet.shape.Dims())
   207  
   208  	if isInnermost {
   209  		prevAxis := axis - 1
   210  		if prevAxis < 0 {
   211  			// this may be the case if input is a vector
   212  			prevAxis = 0
   213  		}
   214  		retPrevStride := apB.strides[prevAxis]
   215  		prevStride := apRet.strides[prevAxis]
   216  		for i, idx := range indices {
   217  			dstCoord[axis] = idx
   218  			srcCoord[axis] = i
   219  			dstStart, _ := Ltoi(apB.shape, apB.strides, dstCoord...)
   220  			start, _ := Ltoi(apRet.shape, apRet.strides, srcCoord...)
   221  			for o := 0; o < outer; o++ {
   222  				dstEnd := dstStart + axStride
   223  				end := start + retStride
   224  
   225  				e.E.AddSliced(typ, dataGradA, dstStart, dstEnd, dataB, start, end)
   226  
   227  				dstStart += prevStride
   228  				start += retPrevStride
   229  
   230  			}
   231  		}
   232  		return
   233  	}
   234  
   235  	for i, idx := range indices {
   236  		dstCoord[axis] = idx
   237  		srcCoord[axis] = i
   238  		dstStart, _ := Ltoi(apRet.shape, apRet.strides, dstCoord...)
   239  		start, _ := Ltoi(apB.shape, apB.strides, srcCoord...)
   240  
   241  		for o := 0; o < outer; o++ {
   242  			dstEnd := dstStart + axStride
   243  			end := start + retStride
   244  
   245  			e.E.AddSliced(typ, dataGradA, dstStart, dstEnd, dataB, start, end)
   246  
   247  			dstStart = dstEnd + axStride
   248  			start = end + (outerRetStride - retStride)
   249  		}
   250  	}
   251  }