gorgonia.org/tensor@v0.9.24/tensor.go (about) 1 // Package tensor is a package that provides efficient, generic n-dimensional arrays in Go. 2 // Also in this package are functions and methods that are used commonly in arithmetic, comparison and linear algebra operations. 3 package tensor // import "gorgonia.org/tensor" 4 5 import ( 6 "encoding/gob" 7 "fmt" 8 "io" 9 10 "github.com/pkg/errors" 11 ) 12 13 var ( 14 _ Tensor = &Dense{} 15 _ Tensor = &CS{} 16 _ View = &Dense{} 17 ) 18 19 func init() { 20 gob.Register(&Dense{}) 21 gob.Register(&CS{}) 22 } 23 24 // Tensor represents a variety of n-dimensional arrays. The most commonly used tensor is the Dense tensor. 25 // It can be used to represent a vector, matrix, 3D matrix and n-dimensional tensors. 26 type Tensor interface { 27 // info about the ndarray 28 Shape() Shape 29 Strides() []int 30 Dtype() Dtype 31 Dims() int 32 Size() int 33 DataSize() int 34 35 // Data access related 36 RequiresIterator() bool 37 Iterator() Iterator 38 DataOrder() DataOrder 39 40 // ops 41 Slicer 42 At(...int) (interface{}, error) 43 SetAt(v interface{}, coord ...int) error 44 Reshape(...int) error 45 T(axes ...int) error 46 UT() 47 Transpose() error // Transpose actually moves the data 48 Apply(fn interface{}, opts ...FuncOpt) (Tensor, error) 49 50 // data related interface 51 Zeroer 52 MemSetter 53 Dataer 54 Eq 55 Cloner 56 57 // type overloading methods 58 IsScalar() bool 59 ScalarValue() interface{} 60 61 // engine/memory related stuff 62 // all Tensors should be able to be expressed of as a slab of memory 63 // Note: the size of each element can be acquired by T.Dtype().Size() 64 Memory // Tensors all implement Memory 65 Engine() Engine // Engine can be nil 66 IsNativelyAccessible() bool // Can Go access the memory 67 IsManuallyManaged() bool // Must Go manage the memory 68 69 // formatters 70 fmt.Formatter 71 fmt.Stringer 72 73 // all Tensors are serializable to these formats 74 WriteNpy(io.Writer) error 75 ReadNpy(io.Reader) error 76 gob.GobEncoder 77 gob.GobDecoder 78 79 standardEngine() standardEngine 80 headerer 81 arrayer 82 } 83 84 // New creates a new Dense Tensor. For sparse arrays use their relevant construction function 85 func New(opts ...ConsOpt) *Dense { 86 d := borrowDense() 87 for _, opt := range opts { 88 opt(d) 89 } 90 d.fix() 91 if err := d.sanity(); err != nil { 92 panic(err) 93 } 94 95 return d 96 } 97 98 func assertDense(t Tensor) (*Dense, error) { 99 if t == nil { 100 return nil, errors.New("nil is not a *Dense") 101 } 102 if retVal, ok := t.(*Dense); ok { 103 return retVal, nil 104 } 105 if retVal, ok := t.(Densor); ok { 106 return retVal.Dense(), nil 107 } 108 return nil, errors.Errorf("%T is not *Dense", t) 109 } 110 111 func getDenseTensor(t Tensor) (DenseTensor, error) { 112 switch tt := t.(type) { 113 case DenseTensor: 114 return tt, nil 115 case Densor: 116 return tt.Dense(), nil 117 default: 118 return nil, errors.Errorf("Tensor %T is not a DenseTensor", t) 119 } 120 } 121 122 // getFloatDense extracts a *Dense from a Tensor and ensures that the .data is a Array that implements Float 123 func getFloatDenseTensor(t Tensor) (retVal DenseTensor, err error) { 124 if t == nil { 125 return 126 } 127 if err = typeclassCheck(t.Dtype(), floatTypes); err != nil { 128 err = errors.Wrapf(err, "getFloatDense only handles floats. Got %v instead", t.Dtype()) 129 return 130 } 131 132 if retVal, err = getDenseTensor(t); err != nil { 133 err = errors.Wrapf(err, opFail, "getFloatDense") 134 return 135 } 136 if retVal == nil { 137 return 138 } 139 140 return 141 } 142 143 // getFloatDense extracts a *Dense from a Tensor and ensures that the .data is a Array that implements Float 144 func getFloatComplexDenseTensor(t Tensor) (retVal DenseTensor, err error) { 145 if t == nil { 146 return 147 } 148 if err = typeclassCheck(t.Dtype(), floatcmplxTypes); err != nil { 149 err = errors.Wrapf(err, "getFloatDense only handles floats and complex. Got %v instead", t.Dtype()) 150 return 151 } 152 153 if retVal, err = getDenseTensor(t); err != nil { 154 err = errors.Wrapf(err, opFail, "getFloatDense") 155 return 156 } 157 if retVal == nil { 158 return 159 } 160 161 return 162 } 163 164 func sliceDense(t *Dense, slices ...Slice) (retVal *Dense, err error) { 165 var sliced Tensor 166 if sliced, err = t.Slice(slices...); err != nil { 167 return nil, err 168 } 169 return sliced.(*Dense), nil 170 }