gorgonia.org/tensor@v0.9.24/defaultengine_argmethods.go (about) 1 package tensor 2 3 import "github.com/pkg/errors" 4 5 func (e StdEng) Argmax(t Tensor, axis int) (retVal Tensor, err error) { 6 7 switch tt := t.(type) { 8 case DenseTensor: 9 return e.argmaxDenseTensor(tt, axis) 10 default: 11 return nil, errors.Errorf(typeNYI, "StdEng.Argmax", t) 12 } 13 } 14 15 func (e StdEng) argmaxDenseTensor(t DenseTensor, axis int) (retVal *Dense, err error) { 16 if err = unaryCheck(t, ordTypes); err != nil { 17 return nil, errors.Wrapf(err, opFail, "Argmax") 18 } 19 20 if axis >= len(t.Shape()) { 21 return nil, errors.Errorf(dimMismatch, len(t.Shape()), axis) 22 } 23 24 dataA := t.hdr() 25 typ := t.rtype() 26 27 // SPECIAL CASE: FLAT ARGMAX 28 if axis == AllAxes { 29 var index int 30 if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() { 31 if index = e.E.ArgmaxFlatMasked(typ, dataA, mt.Mask()); index == -1 { 32 return nil, errors.Errorf("t is not supported - %T of %v", t, t.Dtype()) 33 } 34 } else { 35 if index = e.E.ArgmaxFlat(typ, dataA); index == -1 { 36 return nil, errors.Errorf("t is not supported - %T of %v", t, t.Dtype()) 37 } 38 } 39 return New(FromScalar(index)), nil 40 } 41 42 // ARGMAX ALONG AXIS 43 44 var indices []int 45 axes := make([]int, len(t.Shape())) 46 for i := range t.Shape() { 47 switch { 48 case i < axis: 49 axes[i] = i 50 case i == axis: 51 axes[len(axes)-1] = i 52 case i > axis: 53 axes[i-1] = i 54 } 55 } 56 57 // be a good citizen - borrow and return, since we're only using this AP to figure out the moves 58 newAP, _, err := t.Info().T(axes...) 59 if _, ok := err.(NoOpError); !ok && err != nil { 60 return 61 } else if ok { 62 t.Info().CloneTo(&newAP) 63 } 64 65 it := IteratorFromDense(t) 66 iteratorLoadAP(it, &newAP) 67 68 lastSize := it.Shape()[len(it.Shape())-1] 69 newShape := it.Shape().Clone() 70 newShape = newShape[:len(newShape)-1] 71 72 // cleanup 73 defer func() { 74 newAP.zero() 75 ReturnInts(newShape) 76 }() 77 78 if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() { 79 mask := mt.Mask() 80 if indices, err = e.E.ArgmaxIterMasked(typ, dataA, mask, it, lastSize); err != nil { 81 return 82 } 83 } else { 84 if indices, err = e.E.ArgmaxIter(typ, dataA, it, lastSize); err != nil { 85 return 86 } 87 } 88 89 return New(WithShape(newShape...), WithBacking(indices)), nil 90 } 91 92 func (e StdEng) Argmin(t Tensor, axis int) (retVal Tensor, err error) { 93 94 switch tt := t.(type) { 95 case DenseTensor: 96 return e.argminDenseTensor(tt, axis) 97 default: 98 return nil, errors.Errorf(typeNYI, "StdEng.Argmin", t) 99 } 100 } 101 102 func (e StdEng) argminDenseTensor(t DenseTensor, axis int) (retVal *Dense, err error) { 103 if err = unaryCheck(t, ordTypes); err != nil { 104 return nil, errors.Wrapf(err, opFail, "Argmin") 105 } 106 107 if axis >= len(t.Shape()) { 108 return nil, errors.Errorf(dimMismatch, len(t.Shape()), axis) 109 } 110 111 dataA := t.hdr() 112 typ := t.rtype() 113 114 // SPECIAL CASE: FLAT ARGMAX 115 if axis == AllAxes { 116 var index int 117 if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() { 118 if index = e.E.ArgminFlatMasked(typ, dataA, mt.Mask()); index == -1 { 119 return nil, errors.Errorf("t is not supported - %T of %v", t, t.Dtype()) 120 } 121 } else { 122 if index = e.E.ArgminFlat(typ, dataA); index == -1 { 123 return nil, errors.Errorf("t is not supported - %T of %v", t, t.Dtype()) 124 } 125 } 126 return New(FromScalar(index)), nil 127 } 128 129 // ARGMAX ALONG AXIS 130 131 var indices []int 132 axes := make([]int, len(t.Shape())) 133 for i := range t.Shape() { 134 switch { 135 case i < axis: 136 axes[i] = i 137 case i == axis: 138 axes[len(axes)-1] = i 139 case i > axis: 140 axes[i-1] = i 141 } 142 } 143 144 // be a good citizen - borrow and return, since we're only using this AP to figure out the moves 145 newAP, _, err := t.Info().T(axes...) 146 if _, ok := err.(NoOpError); !ok && err != nil { 147 return 148 } else if ok { 149 newAP = t.Info().Clone() 150 } 151 152 it := IteratorFromDense(t) 153 iteratorLoadAP(it, &newAP) 154 155 lastSize := it.Shape()[len(it.Shape())-1] 156 newShape := it.Shape().Clone() 157 newShape = newShape[:len(newShape)-1] 158 159 // cleanup 160 defer func() { 161 newAP.zero() 162 ReturnInts(newShape) 163 }() 164 165 if mt, ok := t.(MaskedTensor); ok && mt.IsMasked() { 166 mask := mt.Mask() 167 if indices, err = e.E.ArgminIterMasked(typ, dataA, mask, it, lastSize); err != nil { 168 return 169 } 170 } else { 171 if indices, err = e.E.ArgminIter(typ, dataA, it, lastSize); err != nil { 172 return 173 } 174 } 175 176 return New(WithShape(newShape...), WithBacking(indices)), nil 177 }