gorgonia.org/gorgonia@v0.9.17/op_by_indices.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "hash" 6 7 "github.com/chewxy/hm" 8 "github.com/pkg/errors" 9 "gorgonia.org/tensor" 10 ) 11 12 type byIndicesOp struct { 13 axis int 14 } 15 16 func newByIndicesOp(axis int) *byIndicesOp { 17 if axis < 0 { 18 axis = 0 19 } 20 21 return &byIndicesOp{ 22 axis: axis, 23 } 24 } 25 26 // ByIndices is an operation that takes the indices as input and return the selected values from those indices. 27 // The default axis in 0 28 func ByIndices(x *Node, indices *Node, axis int) (*Node, error) { 29 op := newByIndicesOp(axis) 30 31 return ApplyOp(op, x, indices) 32 } 33 34 func (op *byIndicesOp) Arity() int { return 2 } 35 36 func (op *byIndicesOp) ReturnsPtr() bool { return false } 37 38 func (op *byIndicesOp) CallsExtern() bool { return false } 39 40 func (op *byIndicesOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, op.String()) } 41 42 func (op *byIndicesOp) Hashcode() uint32 { return simpleHash(op) } 43 44 func (op *byIndicesOp) String() string { 45 return fmt.Sprintf("ByIndicesOp{axis=%d}", op.axis) 46 } 47 48 func (op *byIndicesOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { 49 s := inputs[0].(tensor.Shape).Clone() 50 i := inputs[1].(tensor.Shape).Clone() 51 if !i.IsVectorLike() { 52 return nil, errors.Errorf("Expected indices to be a vector-like. Got %v instead", i) 53 } 54 55 s[op.axis] = i.TotalSize() 56 57 return s, nil 58 } 59 60 func (op *byIndicesOp) Type() hm.Type { 61 a := hm.TypeVariable('a') 62 b := makeTensorType(1, tensor.Int) 63 64 return hm.NewFnType(a, b, a) 65 } 66 67 func (op *byIndicesOp) OverwritesInput() int { return -1 } 68 69 func (op *byIndicesOp) checkInput(inputs ...Value) (x, indices tensor.Tensor, err error) { 70 if err := checkArity(op, len(inputs)); err != nil { 71 return nil, nil, err 72 } 73 74 var ok bool 75 if x, ok = inputs[0].(tensor.Tensor); !ok { 76 return nil, nil, errors.Errorf("Expected input to be a tensor, got %T", inputs[0]) 77 } 78 if indices, ok = inputs[1].(tensor.Tensor); !ok { 79 return nil, nil, errors.Errorf("Expected indices to be a tensor. Got %T instead", inputs[1]) 80 } 81 82 if indices.Dtype() != tensor.Int { 83 return nil, nil, errors.Errorf("Expected indices to have tensor.Int as a Dtype. Got %T instead", indices.Dtype()) 84 } 85 86 return x, indices, nil 87 } 88 89 func (op *byIndicesOp) Do(inputs ...Value) (Value, error) { 90 inputTensor, indices, err := op.checkInput(inputs...) 91 if err != nil { 92 return nil, fmt.Errorf("Can't check ByIndicesOp input: %w", err) 93 } 94 95 return tensor.ByIndices(inputTensor, indices, op.axis) 96 } 97 98 // DoDiff calculates the diff and sets its value to the output node. Implementation for ADOp interface. 99 func (op *byIndicesOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error { 100 if len(inputs) != 2 { 101 return fmt.Errorf("byIndicesOp.DoDiff needs 2 arguments") 102 } 103 104 odv := output.boundTo.(*dualValue) 105 odvd := odv.Value.(tensor.Tensor) 106 107 diffOp := &byIndicesOpDiffOp{op} 108 109 result, err := diffOp.Do(inputs[0].boundTo, inputs[1].boundTo) 110 if err != nil { 111 return err 112 } 113 114 err = result.(*tensor.Dense).Reshape(odvd.Shape()...) 115 if err != nil { 116 return err 117 } 118 119 sum, err := odvd.(*tensor.Dense).Add(result.(*tensor.Dense), tensor.UseUnsafe()) 120 if err != nil { 121 return err 122 } 123 124 odv.d = sum 125 126 return nil 127 } 128 129 // SymDiff applies the diff op. Implementation for SDOp interface. 130 func (op *byIndicesOp) SymDiff(inputs Nodes, output, grad *Node) (Nodes, error) { 131 err := checkArity(op, len(inputs)) 132 if err != nil { 133 return nil, err 134 } 135 136 x := inputs[0] 137 indices := inputs[1] 138 139 diffOp := &byIndicesOpDiffOp{op} 140 nodes := make(Nodes, op.Arity()) 141 142 nodes[0], err = ApplyOp(diffOp, x, grad, indices) 143 144 return nodes, err 145 } 146 147 // DiffWRT is an implementation for the SDOp interface 148 func (op *byIndicesOp) DiffWRT(inputs int) []bool { 149 if inputs != op.Arity() { 150 panic(fmt.Sprintf("ByIndicesOp operator needs %d inputs, got %d instead", op.Arity(), inputs)) 151 } 152 153 return []bool{true, false} 154 } 155 156 type byIndicesOpDiffOp struct { 157 *byIndicesOp 158 } 159 160 func (op *byIndicesOpDiffOp) Arity() int { return 3 } 161 162 func (op *byIndicesOpDiffOp) ReturnsPtr() bool { return false } 163 164 func (op *byIndicesOpDiffOp) CallsExtern() bool { return false } 165 166 func (op *byIndicesOpDiffOp) WriteHash(h hash.Hash) { 167 fmt.Fprintf(h, op.String()) 168 } 169 170 func (op *byIndicesOpDiffOp) Hashcode() uint32 { return simpleHash(op) } 171 172 func (op *byIndicesOpDiffOp) String() string { 173 return fmt.Sprintf("ByIndicesOpDiff{}(%d)", op.axis) 174 } 175 176 func (op *byIndicesOpDiffOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { 177 s := inputs[0].(tensor.Shape).Clone() 178 179 return s, nil 180 } 181 182 func (op *byIndicesOpDiffOp) Type() hm.Type { 183 a := hm.TypeVariable('a') 184 b := makeTensorType(1, tensor.Int) 185 186 return hm.NewFnType(a, a, b, a) 187 } 188 189 func (op *byIndicesOpDiffOp) OverwritesInput() int { return -1 } 190 191 func (op *byIndicesOpDiffOp) checkInput(inputs ...Value) (in, indices, gradient *tensor.Dense, err error) { 192 if err := checkArity(op, len(inputs)); err != nil { 193 return nil, nil, nil, err 194 } 195 196 var ( 197 ok bool 198 ) 199 200 switch t := inputs[0].(type) { 201 case *dualValue: 202 if in, ok = t.Value.(*tensor.Dense); !ok { 203 return nil, nil, nil, errors.Errorf("input should be a tensor.Tensor, got %T", inputs[0]) 204 } 205 case *tensor.Dense: 206 in = t 207 default: 208 return nil, nil, nil, errors.Errorf("input type is not supported, got %T", inputs[0]) 209 } 210 211 switch t := inputs[2].(type) { 212 case *dualValue: 213 if gradient, ok = t.Value.(*tensor.Dense); !ok { 214 return nil, nil, nil, errors.Errorf("gradient should be a tensor, got %T", inputs[2]) 215 } 216 case *tensor.Dense: 217 gradient = t 218 default: 219 return nil, nil, nil, errors.Errorf("gradient type is not supported, got %T", inputs[2]) 220 } 221 222 switch t := inputs[1].(type) { 223 case *tensor.Dense: 224 indices = t 225 default: 226 return nil, nil, nil, errors.Errorf("indices type %T is not supported", inputs[1]) 227 } 228 229 return in, indices, gradient, nil 230 } 231 232 func (op *byIndicesOpDiffOp) Do(inputs ...Value) (Value, error) { 233 inputTensor, gradTensor, indices, err := op.checkInput(inputs...) 234 if err != nil { 235 return nil, fmt.Errorf("Can't check ByIndicesOpDiff input: %w", err) 236 } 237 238 output, err := tensor.ByIndicesB(inputTensor, gradTensor, indices, op.axis) 239 if err != nil { 240 return nil, err 241 } 242 243 return output, nil 244 } 245 246 // ensure it complies with the Op interface 247 var ( 248 _ Op = &byIndicesOpDiffOp{} 249 250 _ Op = &byIndicesOp{} 251 _ SDOp = &byIndicesOp{} 252 _ ADOp = &byIndicesOp{} 253 )