gorgonia.org/gorgonia@v0.9.17/values.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"unsafe"
     6  
     7  	"github.com/chewxy/hm"
     8  	"github.com/pkg/errors"
     9  	"gorgonia.org/tensor"
    10  )
    11  
    12  // Value represents a value that Gorgonia accepts. At this point it is implemented by:
    13  //		- all scalar value types (F64, F32... etc)
    14  // 		- *tensor.Dense
    15  // 		- *dualValue
    16  //
    17  // A Value is essentially any thing that knows its own type and shape.
    18  // Most importantly though, a Value is a pointer - and can be converted into a tensor.Memory.
    19  // This is done for the sake of interoperability with external devices like cgo or CUDA or OpenCL.
    20  // This also means for the most part most Values will be allocated on the heap.
    21  // There are some performance tradeoffs made in this decision, but ultimately this is better than having to manually manage blocks of memory
    22  type Value interface {
    23  	Shape() tensor.Shape // Shape  returns the shape of the Value. Scalar values return ScalarShape()
    24  	Size() int           // Size represents the number of elements in the Value. Note that in cases such as a *tensor.Dense, the underlying slice MAY have more elements than the Size() reports. This is correct.
    25  	Data() interface{}   // Data returns the original representation of the Value
    26  	Dtype() tensor.Dtype // Dtype returns the Dtype of the value
    27  
    28  	tensor.Memory
    29  	fmt.Formatter
    30  }
    31  
    32  // Valuer is any type that can return a Value
    33  type Valuer interface {
    34  	Value() Value
    35  }
    36  
    37  // Zeroer is a Value that can zero itself
    38  type Zeroer interface {
    39  	Value
    40  	Zero()
    41  }
    42  
    43  // ZeroValuer is a a Value that can provide the zero-value of its type
    44  type ZeroValuer interface {
    45  	Value
    46  	ZeroValue() Value
    47  }
    48  
    49  // Dtyper represents any type (typically a Value) that knows its own Dtype
    50  type Dtyper interface {
    51  	Dtype() tensor.Dtype
    52  }
    53  
    54  // Typer represents any type (typically a Op) that knows its own Type
    55  type Typer interface {
    56  	Type() hm.Type
    57  }
    58  
    59  // ValueEqualer represents any type that can perform a equal value check
    60  type ValueEqualer interface {
    61  	ValueEq(Value) bool
    62  }
    63  
    64  // ValueCloser represents any type that can perform a close-value check
    65  type ValueCloser interface {
    66  	ValueClose(interface{}) bool
    67  }
    68  
    69  // Cloner represents any type that can clone itself.
    70  type Cloner interface {
    71  	Clone() interface{}
    72  }
    73  
    74  // CloneErrorer represents any type that can clone itself and return an error if necessary
    75  type CloneErrorer interface {
    76  	Clone() (interface{}, error)
    77  }
    78  
    79  // CopierTo represents any type that can copy data to the destination.
    80  type CopierTo interface {
    81  	CopyTo(dest interface{}) error
    82  }
    83  
    84  // CopierFrom represents any type that can copy data from the source provided.
    85  type CopierFrom interface {
    86  	CopyFrom(src interface{}) error
    87  }
    88  
    89  // Setter is a any value that can Memset itself to the provided value
    90  // type Setter interface {
    91  // 	SetAll(interface{}) error
    92  // }
    93  
    94  // makeValue creates a value given a type and shape. The default value is the zero value of the type.
    95  func makeValue(t hm.Type, s tensor.Shape) (retVal Value, err error) {
    96  	var dt tensor.Dtype
    97  	if dt, err = dtypeOf(t); err != nil {
    98  		return
    99  	}
   100  
   101  	if s.IsScalar() {
   102  		switch dt {
   103  		case tensor.Float64:
   104  			return NewF64(0), nil
   105  		case tensor.Float32:
   106  			return NewF32(0), nil
   107  		case tensor.Int:
   108  			return NewI(0), nil
   109  		case tensor.Int64:
   110  			return NewI64(0), nil
   111  		case tensor.Int32:
   112  			return NewI32(0), nil
   113  		case tensor.Byte:
   114  			return NewU8(0), nil
   115  		case tensor.Bool:
   116  			return NewB(false), nil
   117  		}
   118  	}
   119  
   120  	switch tt := t.(type) {
   121  	case TensorType:
   122  		return tensor.New(tensor.Of(dt), tensor.WithShape(s...)), nil
   123  	default:
   124  		err = errors.Errorf(nyiTypeFail, "MakeValue", tt)
   125  		return
   126  	}
   127  }
   128  
   129  func makeValueFromMem(t hm.Type, s tensor.Shape, mem tensor.Memory) (retVal Value, err error) {
   130  	var dt tensor.Dtype
   131  	if dt, err = dtypeOf(t); err != nil {
   132  		return
   133  	}
   134  	if s.IsScalar() {
   135  		return makeScalarFromMem(dt, mem)
   136  	}
   137  
   138  	switch tt := t.(type) {
   139  	case TensorType:
   140  		memsize := calcMemSize(dt, s)
   141  		return tensor.New(tensor.Of(dt), tensor.WithShape(s...), tensor.FromMemory(mem.Uintptr(), uintptr(memsize))), nil
   142  	case tensor.Dtype:
   143  		return makeScalarFromMem(tt, mem)
   144  	default:
   145  		err = errors.Errorf(nyiTypeFail, "MakeValue", tt)
   146  		return
   147  	}
   148  }
   149  
   150  func makeScalarFromMem(dt tensor.Dtype, mem tensor.Memory) (retVal Value, err error) {
   151  	switch dt {
   152  	case tensor.Float64:
   153  		retVal = (*F64)(unsafe.Pointer(mem.Uintptr()))
   154  	case tensor.Float32:
   155  		retVal = (*F32)(unsafe.Pointer(mem.Uintptr()))
   156  	case tensor.Int:
   157  		retVal = (*I)(unsafe.Pointer(mem.Uintptr()))
   158  	case tensor.Int64:
   159  		retVal = (*I64)(unsafe.Pointer(mem.Uintptr()))
   160  	case tensor.Int32:
   161  		retVal = (*I32)(unsafe.Pointer(mem.Uintptr()))
   162  	case tensor.Byte:
   163  		retVal = (*U8)(unsafe.Pointer(mem.Uintptr()))
   164  	case tensor.Bool:
   165  		retVal = (*B)(unsafe.Pointer(mem.Uintptr()))
   166  	default:
   167  		err = errors.Errorf(nyiTypeFail, "makeScalarFromMem", dt)
   168  	}
   169  	return
   170  }
   171  
   172  func logicalSize(s tensor.Shape) int {
   173  	if s.IsScalar() {
   174  		return 1
   175  	}
   176  	return s.TotalSize()
   177  }
   178  
   179  func calcMemSize(dt tensor.Dtype, s tensor.Shape) int64 {
   180  	var elemSize int64
   181  	if s.IsScalar() {
   182  		elemSize = 1
   183  	} else {
   184  		elemSize = int64(s.TotalSize())
   185  	}
   186  	dtSize := int64(dt.Size())
   187  	return elemSize * dtSize
   188  }
   189  
   190  // ScalarAsTensor returns the tensor representation of a scalar. It is particularly useful as a "reshape" of tensors of sorts
   191  //
   192  // The Value passed in are either Scalar, tensor.Tensor, or *dualValue. Anything else will panic.
   193  func ScalarAsTensor(v Value, dims int, e tensor.Engine) Value {
   194  	switch a := v.(type) {
   195  	case Scalar:
   196  		sh := make(tensor.Shape, dims)
   197  		for i := range sh {
   198  			sh[i] = 1
   199  		}
   200  		return tensor.New(tensor.WithShape(sh...), tensor.Of(a.Dtype()), tensor.FromMemory(a.Uintptr(), a.MemSize()), tensor.WithEngine(e))
   201  	case tensor.Tensor:
   202  		return a
   203  	case *dualValue:
   204  		b := new(dualValue)
   205  		b.Value = ScalarAsTensor(a.Value, dims, e)
   206  		b.d = ScalarAsTensor(a.d, dims, e)
   207  		return b
   208  	case nil:
   209  		return nil
   210  	default:
   211  		panic(fmt.Sprintf("Unable to convert %v to Tensor", v))
   212  	}
   213  }