gorgonia.org/gorgonia@v0.9.17/op_by_indices.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"hash"
     6  
     7  	"github.com/chewxy/hm"
     8  	"github.com/pkg/errors"
     9  	"gorgonia.org/tensor"
    10  )
    11  
    12  type byIndicesOp struct {
    13  	axis int
    14  }
    15  
    16  func newByIndicesOp(axis int) *byIndicesOp {
    17  	if axis < 0 {
    18  		axis = 0
    19  	}
    20  
    21  	return &byIndicesOp{
    22  		axis: axis,
    23  	}
    24  }
    25  
    26  // ByIndices is an operation that takes the indices as input and return the selected values from those indices.
    27  // The default axis in 0
    28  func ByIndices(x *Node, indices *Node, axis int) (*Node, error) {
    29  	op := newByIndicesOp(axis)
    30  
    31  	return ApplyOp(op, x, indices)
    32  }
    33  
    34  func (op *byIndicesOp) Arity() int { return 2 }
    35  
    36  func (op *byIndicesOp) ReturnsPtr() bool { return false }
    37  
    38  func (op *byIndicesOp) CallsExtern() bool { return false }
    39  
    40  func (op *byIndicesOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, op.String()) }
    41  
    42  func (op *byIndicesOp) Hashcode() uint32 { return simpleHash(op) }
    43  
    44  func (op *byIndicesOp) String() string {
    45  	return fmt.Sprintf("ByIndicesOp{axis=%d}", op.axis)
    46  }
    47  
    48  func (op *byIndicesOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
    49  	s := inputs[0].(tensor.Shape).Clone()
    50  	i := inputs[1].(tensor.Shape).Clone()
    51  	if !i.IsVectorLike() {
    52  		return nil, errors.Errorf("Expected indices to be a vector-like. Got %v instead", i)
    53  	}
    54  
    55  	s[op.axis] = i.TotalSize()
    56  
    57  	return s, nil
    58  }
    59  
    60  func (op *byIndicesOp) Type() hm.Type {
    61  	a := hm.TypeVariable('a')
    62  	b := makeTensorType(1, tensor.Int)
    63  
    64  	return hm.NewFnType(a, b, a)
    65  }
    66  
    67  func (op *byIndicesOp) OverwritesInput() int { return -1 }
    68  
    69  func (op *byIndicesOp) checkInput(inputs ...Value) (x, indices tensor.Tensor, err error) {
    70  	if err := checkArity(op, len(inputs)); err != nil {
    71  		return nil, nil, err
    72  	}
    73  
    74  	var ok bool
    75  	if x, ok = inputs[0].(tensor.Tensor); !ok {
    76  		return nil, nil, errors.Errorf("Expected input to be a tensor, got %T", inputs[0])
    77  	}
    78  	if indices, ok = inputs[1].(tensor.Tensor); !ok {
    79  		return nil, nil, errors.Errorf("Expected indices to be a tensor. Got %T instead", inputs[1])
    80  	}
    81  
    82  	if indices.Dtype() != tensor.Int {
    83  		return nil, nil, errors.Errorf("Expected indices to have tensor.Int as a Dtype. Got %T instead", indices.Dtype())
    84  	}
    85  
    86  	return x, indices, nil
    87  }
    88  
    89  func (op *byIndicesOp) Do(inputs ...Value) (Value, error) {
    90  	inputTensor, indices, err := op.checkInput(inputs...)
    91  	if err != nil {
    92  		return nil, fmt.Errorf("Can't check ByIndicesOp input: %w", err)
    93  	}
    94  
    95  	return tensor.ByIndices(inputTensor, indices, op.axis)
    96  }
    97  
    98  // DoDiff calculates the diff and sets its value to the output node. Implementation for ADOp interface.
    99  func (op *byIndicesOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error {
   100  	if len(inputs) != 2 {
   101  		return fmt.Errorf("byIndicesOp.DoDiff needs 2 arguments")
   102  	}
   103  
   104  	odv := output.boundTo.(*dualValue)
   105  	odvd := odv.Value.(tensor.Tensor)
   106  
   107  	diffOp := &byIndicesOpDiffOp{op}
   108  
   109  	result, err := diffOp.Do(inputs[0].boundTo, inputs[1].boundTo)
   110  	if err != nil {
   111  		return err
   112  	}
   113  
   114  	err = result.(*tensor.Dense).Reshape(odvd.Shape()...)
   115  	if err != nil {
   116  		return err
   117  	}
   118  
   119  	sum, err := odvd.(*tensor.Dense).Add(result.(*tensor.Dense), tensor.UseUnsafe())
   120  	if err != nil {
   121  		return err
   122  	}
   123  
   124  	odv.d = sum
   125  
   126  	return nil
   127  }
   128  
   129  // SymDiff applies the diff op. Implementation for SDOp interface.
   130  func (op *byIndicesOp) SymDiff(inputs Nodes, output, grad *Node) (Nodes, error) {
   131  	err := checkArity(op, len(inputs))
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  
   136  	x := inputs[0]
   137  	indices := inputs[1]
   138  
   139  	diffOp := &byIndicesOpDiffOp{op}
   140  	nodes := make(Nodes, op.Arity())
   141  
   142  	nodes[0], err = ApplyOp(diffOp, x, grad, indices)
   143  
   144  	return nodes, err
   145  }
   146  
   147  // DiffWRT is an implementation for the SDOp interface
   148  func (op *byIndicesOp) DiffWRT(inputs int) []bool {
   149  	if inputs != op.Arity() {
   150  		panic(fmt.Sprintf("ByIndicesOp operator needs %d inputs, got %d instead", op.Arity(), inputs))
   151  	}
   152  
   153  	return []bool{true, false}
   154  }
   155  
   156  type byIndicesOpDiffOp struct {
   157  	*byIndicesOp
   158  }
   159  
   160  func (op *byIndicesOpDiffOp) Arity() int { return 3 }
   161  
   162  func (op *byIndicesOpDiffOp) ReturnsPtr() bool { return false }
   163  
   164  func (op *byIndicesOpDiffOp) CallsExtern() bool { return false }
   165  
   166  func (op *byIndicesOpDiffOp) WriteHash(h hash.Hash) {
   167  	fmt.Fprintf(h, op.String())
   168  }
   169  
   170  func (op *byIndicesOpDiffOp) Hashcode() uint32 { return simpleHash(op) }
   171  
   172  func (op *byIndicesOpDiffOp) String() string {
   173  	return fmt.Sprintf("ByIndicesOpDiff{}(%d)", op.axis)
   174  }
   175  
   176  func (op *byIndicesOpDiffOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
   177  	s := inputs[0].(tensor.Shape).Clone()
   178  
   179  	return s, nil
   180  }
   181  
   182  func (op *byIndicesOpDiffOp) Type() hm.Type {
   183  	a := hm.TypeVariable('a')
   184  	b := makeTensorType(1, tensor.Int)
   185  
   186  	return hm.NewFnType(a, a, b, a)
   187  }
   188  
   189  func (op *byIndicesOpDiffOp) OverwritesInput() int { return -1 }
   190  
   191  func (op *byIndicesOpDiffOp) checkInput(inputs ...Value) (in, indices, gradient *tensor.Dense, err error) {
   192  	if err := checkArity(op, len(inputs)); err != nil {
   193  		return nil, nil, nil, err
   194  	}
   195  
   196  	var (
   197  		ok bool
   198  	)
   199  
   200  	switch t := inputs[0].(type) {
   201  	case *dualValue:
   202  		if in, ok = t.Value.(*tensor.Dense); !ok {
   203  			return nil, nil, nil, errors.Errorf("input should be a tensor.Tensor, got %T", inputs[0])
   204  		}
   205  	case *tensor.Dense:
   206  		in = t
   207  	default:
   208  		return nil, nil, nil, errors.Errorf("input type is not supported, got %T", inputs[0])
   209  	}
   210  
   211  	switch t := inputs[2].(type) {
   212  	case *dualValue:
   213  		if gradient, ok = t.Value.(*tensor.Dense); !ok {
   214  			return nil, nil, nil, errors.Errorf("gradient should be a tensor, got %T", inputs[2])
   215  		}
   216  	case *tensor.Dense:
   217  		gradient = t
   218  	default:
   219  		return nil, nil, nil, errors.Errorf("gradient type is not supported, got %T", inputs[2])
   220  	}
   221  
   222  	switch t := inputs[1].(type) {
   223  	case *tensor.Dense:
   224  		indices = t
   225  	default:
   226  		return nil, nil, nil, errors.Errorf("indices type %T is not supported", inputs[1])
   227  	}
   228  
   229  	return in, indices, gradient, nil
   230  }
   231  
   232  func (op *byIndicesOpDiffOp) Do(inputs ...Value) (Value, error) {
   233  	inputTensor, gradTensor, indices, err := op.checkInput(inputs...)
   234  	if err != nil {
   235  		return nil, fmt.Errorf("Can't check ByIndicesOpDiff input: %w", err)
   236  	}
   237  
   238  	output, err := tensor.ByIndicesB(inputTensor, gradTensor, indices, op.axis)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  
   243  	return output, nil
   244  }
   245  
   246  // ensure it complies with the Op interface
   247  var (
   248  	_ Op = &byIndicesOpDiffOp{}
   249  
   250  	_ Op   = &byIndicesOp{}
   251  	_ SDOp = &byIndicesOp{}
   252  	_ ADOp = &byIndicesOp{}
   253  )