gorgonia.org/gorgonia@v0.9.17/op_types.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"hash"
     6  
     7  	"github.com/chewxy/hm"
     8  	"github.com/pkg/errors"
     9  	"gorgonia.org/tensor"
    10  )
    11  
    12  // ConvType converts the type of the x Node from one type to other
    13  func ConvType(x *Node, from, to tensor.Dtype) (*Node, error) {
    14  	op := &dtConvOp{
    15  		inshape: x.Shape(),
    16  		from:    from,
    17  		to:      to,
    18  	}
    19  
    20  	return ApplyOp(op, x)
    21  }
    22  
    23  type dtConvOp struct {
    24  	inshape  tensor.Shape
    25  	from, to tensor.Dtype
    26  }
    27  
    28  /* Graph Building Related Methods */
    29  
    30  // Arity returns the number of inputs the Op expects. -1 indicates that it's n-ary and will be determined at runtime
    31  func (op *dtConvOp) Arity() int { return 1 }
    32  
    33  // Informs the type of the Op (not the node). This will be used by the type system to infer the final type of the node
    34  func (op *dtConvOp) Type() hm.Type {
    35  	if op.inshape.IsScalar() {
    36  		return hm.NewFnType(op.from, op.to)
    37  	}
    38  	t := makeTensorType(op.inshape.Dims(), op.from)
    39  	u := makeTensorType(op.inshape.Dims(), op.to)
    40  	return hm.NewFnType(t, u)
    41  }
    42  
    43  // returns the output shape as a function of the inputs
    44  func (op *dtConvOp) InferShape(_ ...DimSizer) (tensor.Shape, error) {
    45  	return op.inshape.Clone(), nil
    46  }
    47  
    48  // Do executes the op
    49  func (op *dtConvOp) Do(vals ...Value) (Value, error) {
    50  	retVal := tensor.New(tensor.Of(op.to), tensor.WithShape(op.inshape.Clone()...))
    51  	return op.UsePreallocDo(retVal, vals...)
    52  }
    53  
    54  /* Analysis Related Methods */ // indicates if the Op will return a pointer (allowing possible inplace edits) or by value
    55  // if it's false, the return value of the Op will be a copy of its input
    56  func (op *dtConvOp) ReturnsPtr() bool { return false }
    57  
    58  // Does this op potentially call external (cgo or cuda) functions (thereby requiring extra overhead for Go's trampolining thing)
    59  func (op *dtConvOp) CallsExtern() bool { return false }
    60  
    61  // overwriteInput() is a method which states which input the output will be overwriting.
    62  // This allows for some efficiency gains as the underlying arrays wouldn't have to be re-allocated.
    63  // The method returns an int instead of a bool because potentially different operations may be allowed
    64  // to overwrite certain inputs. For example, consider an operation to increment a value:
    65  // the IncrementOp would be a unary operator, and assuming we would like to overwrite the input,
    66  // the retVal of overwriteInput() will be 0 (inputs[0]).
    67  // -1 is returned if overwriting of input is disallowed
    68  func (op *dtConvOp) OverwritesInput() int { return -1 }
    69  
    70  /* Other methods */
    71  func (op *dtConvOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, "(%v)", op.Type()) }
    72  
    73  func (op *dtConvOp) Hashcode() uint32 { return simpleHash(op) }
    74  
    75  func (op *dtConvOp) String() string { return fmt.Sprintf("%v", op.Type()) }
    76  
    77  // DiffWRT indicates if the op is differentiable with regards to the given number of inputs
    78  // returns []bool to indicate which input it is differentiable to
    79  func (op *dtConvOp) DiffWRT(inputs int) []bool { return []bool{true} }
    80  
    81  // SymDiff symbolically differentiates the op
    82  func (op *dtConvOp) SymDiff(inputs Nodes, output *Node, grad *Node) (retVal Nodes, err error) {
    83  	diffOp := &dtConvOp{
    84  		inshape: grad.Shape().Clone(),
    85  		from:    op.to,
    86  		to:      op.from,
    87  	}
    88  	retVal = make(Nodes, op.Arity())
    89  	retVal[0], err = ApplyOp(diffOp, grad)
    90  	return retVal, err
    91  }
    92  
    93  // UsePreallocDo executes the Op with a preallocated value in the result.s
    94  func (op *dtConvOp) UsePreallocDo(prealloc Value, inputs ...Value) (Value, error) {
    95  	a := inputs[0]
    96  	retVal := prealloc
    97  	switch {
    98  	case op.from == tensor.Float64 && op.to == tensor.Int:
    99  		switch aData := a.Data().(type) {
   100  		case []float64:
   101  			retData := retVal.Data().([]int)
   102  			for i := range aData {
   103  				retData[i] = int(aData[i])
   104  			}
   105  		case float64:
   106  			retVal = tensor.New(
   107  				tensor.Of(tensor.Int),
   108  				tensor.WithShape(1),
   109  				tensor.WithBacking([]int{int(aData)}),
   110  			)
   111  		}
   112  	case op.from == tensor.Float32 && op.to == tensor.Int:
   113  		switch aData := a.Data().(type) {
   114  		case []float32:
   115  			retData := retVal.Data().([]int)
   116  			for i := range aData {
   117  				retData[i] = int(aData[i])
   118  			}
   119  		case float32:
   120  			retVal = tensor.New(
   121  				tensor.Of(tensor.Int),
   122  				tensor.WithShape(1),
   123  				tensor.WithBacking([]int{int(aData)}),
   124  			)
   125  		}
   126  	case op.from == tensor.Int && op.to == tensor.Float64:
   127  		switch aData := a.Data().(type) {
   128  		case []int:
   129  			retData := retVal.Data().([]float64)
   130  			for i := range aData {
   131  				retData[i] = float64(aData[i])
   132  			}
   133  		case int:
   134  			retVal = tensor.New(
   135  				tensor.Of(tensor.Float64),
   136  				tensor.WithShape(1),
   137  				tensor.WithBacking([]float64{float64(aData)}),
   138  			)
   139  		}
   140  	case op.from == tensor.Int && op.to == tensor.Float32:
   141  		switch aData := a.Data().(type) {
   142  		case []int:
   143  			retData := retVal.Data().([]float32)
   144  			for i := range aData {
   145  				retData[i] = float32(aData[i])
   146  			}
   147  		case int:
   148  			retVal = tensor.New(
   149  				tensor.Of(tensor.Float32),
   150  				tensor.WithShape(1),
   151  				tensor.WithBacking([]float32{float32(aData)}),
   152  			)
   153  		}
   154  	default:
   155  		return nil, errors.Errorf("Cannot do conversion %v", op.Type())
   156  		// TODO: other types
   157  	}
   158  
   159  	return retVal, nil
   160  }