gorgonia.org/tensor@v0.9.24/defaultengine_argmethods.go (about)

     1  package tensor
     2  
     3  import "github.com/pkg/errors"
     4  
     5  func (e StdEng) Argmax(t Tensor, axis int) (retVal Tensor, err error) {
     6  
     7  	switch tt := t.(type) {
     8  	case DenseTensor:
     9  		return e.argmaxDenseTensor(tt, axis)
    10  	default:
    11  		return nil, errors.Errorf(typeNYI, "StdEng.Argmax", t)
    12  	}
    13  }
    14  
    15  func (e StdEng) argmaxDenseTensor(t DenseTensor, axis int) (retVal *Dense, err error) {
    16  	if err = unaryCheck(t, ordTypes); err != nil {
    17  		return nil, errors.Wrapf(err, opFail, "Argmax")
    18  	}
    19  
    20  	if axis >= len(t.Shape()) {
    21  		return nil, errors.Errorf(dimMismatch, len(t.Shape()), axis)
    22  	}
    23  
    24  	dataA := t.hdr()
    25  	typ := t.rtype()
    26  
    27  	// SPECIAL CASE: FLAT ARGMAX
    28  	if axis == AllAxes {
    29  		var index int
    30  		if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() {
    31  			if index = e.E.ArgmaxFlatMasked(typ, dataA, mt.Mask()); index == -1 {
    32  				return nil, errors.Errorf("t is not supported - %T of %v", t, t.Dtype())
    33  			}
    34  		} else {
    35  			if index = e.E.ArgmaxFlat(typ, dataA); index == -1 {
    36  				return nil, errors.Errorf("t is not supported -  %T of %v", t, t.Dtype())
    37  			}
    38  		}
    39  		return New(FromScalar(index)), nil
    40  	}
    41  
    42  	// ARGMAX ALONG AXIS
    43  
    44  	var indices []int
    45  	axes := make([]int, len(t.Shape()))
    46  	for i := range t.Shape() {
    47  		switch {
    48  		case i < axis:
    49  			axes[i] = i
    50  		case i == axis:
    51  			axes[len(axes)-1] = i
    52  		case i > axis:
    53  			axes[i-1] = i
    54  		}
    55  	}
    56  
    57  	// be a good citizen - borrow and return, since we're only using this AP to figure out the moves
    58  	newAP, _, err := t.Info().T(axes...)
    59  	if _, ok := err.(NoOpError); !ok && err != nil {
    60  		return
    61  	} else if ok {
    62  		t.Info().CloneTo(&newAP)
    63  	}
    64  
    65  	it := IteratorFromDense(t)
    66  	iteratorLoadAP(it, &newAP)
    67  
    68  	lastSize := it.Shape()[len(it.Shape())-1]
    69  	newShape := it.Shape().Clone()
    70  	newShape = newShape[:len(newShape)-1]
    71  
    72  	// cleanup
    73  	defer func() {
    74  		newAP.zero()
    75  		ReturnInts(newShape)
    76  	}()
    77  
    78  	if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() {
    79  		mask := mt.Mask()
    80  		if indices, err = e.E.ArgmaxIterMasked(typ, dataA, mask, it, lastSize); err != nil {
    81  			return
    82  		}
    83  	} else {
    84  		if indices, err = e.E.ArgmaxIter(typ, dataA, it, lastSize); err != nil {
    85  			return
    86  		}
    87  	}
    88  
    89  	return New(WithShape(newShape...), WithBacking(indices)), nil
    90  }
    91  
    92  func (e StdEng) Argmin(t Tensor, axis int) (retVal Tensor, err error) {
    93  
    94  	switch tt := t.(type) {
    95  	case DenseTensor:
    96  		return e.argminDenseTensor(tt, axis)
    97  	default:
    98  		return nil, errors.Errorf(typeNYI, "StdEng.Argmin", t)
    99  	}
   100  }
   101  
   102  func (e StdEng) argminDenseTensor(t DenseTensor, axis int) (retVal *Dense, err error) {
   103  	if err = unaryCheck(t, ordTypes); err != nil {
   104  		return nil, errors.Wrapf(err, opFail, "Argmin")
   105  	}
   106  
   107  	if axis >= len(t.Shape()) {
   108  		return nil, errors.Errorf(dimMismatch, len(t.Shape()), axis)
   109  	}
   110  
   111  	dataA := t.hdr()
   112  	typ := t.rtype()
   113  
   114  	// SPECIAL CASE: FLAT ARGMAX
   115  	if axis == AllAxes {
   116  		var index int
   117  		if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() {
   118  			if index = e.E.ArgminFlatMasked(typ, dataA, mt.Mask()); index == -1 {
   119  				return nil, errors.Errorf("t is not supported - %T of %v", t, t.Dtype())
   120  			}
   121  		} else {
   122  			if index = e.E.ArgminFlat(typ, dataA); index == -1 {
   123  				return nil, errors.Errorf("t is not supported -  %T of %v", t, t.Dtype())
   124  			}
   125  		}
   126  		return New(FromScalar(index)), nil
   127  	}
   128  
   129  	// ARGMAX ALONG AXIS
   130  
   131  	var indices []int
   132  	axes := make([]int, len(t.Shape()))
   133  	for i := range t.Shape() {
   134  		switch {
   135  		case i < axis:
   136  			axes[i] = i
   137  		case i == axis:
   138  			axes[len(axes)-1] = i
   139  		case i > axis:
   140  			axes[i-1] = i
   141  		}
   142  	}
   143  
   144  	// be a good citizen - borrow and return, since we're only using this AP to figure out the moves
   145  	newAP, _, err := t.Info().T(axes...)
   146  	if _, ok := err.(NoOpError); !ok && err != nil {
   147  		return
   148  	} else if ok {
   149  		newAP = t.Info().Clone()
   150  	}
   151  
   152  	it := IteratorFromDense(t)
   153  	iteratorLoadAP(it, &newAP)
   154  
   155  	lastSize := it.Shape()[len(it.Shape())-1]
   156  	newShape := it.Shape().Clone()
   157  	newShape = newShape[:len(newShape)-1]
   158  
   159  	// cleanup
   160  	defer func() {
   161  		newAP.zero()
   162  		ReturnInts(newShape)
   163  	}()
   164  
   165  	if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() {
   166  		mask := mt.Mask()
   167  		if indices, err = e.E.ArgminIterMasked(typ, dataA, mask, it, lastSize); err != nil {
   168  			return
   169  		}
   170  	} else {
   171  		if indices, err = e.E.ArgminIter(typ, dataA, it, lastSize); err != nil {
   172  			return
   173  		}
   174  	}
   175  
   176  	return New(WithShape(newShape...), WithBacking(indices)), nil
   177  }