gorgonia.org/tensor@v0.9.24/tensor.go (about)

     1  // Package tensor is a package that provides efficient, generic n-dimensional arrays in Go.
     2  // Also in this package are functions and methods that are used commonly in arithmetic, comparison and linear algebra operations.
     3  package tensor // import "gorgonia.org/tensor"
     4  
     5  import (
     6  	"encoding/gob"
     7  	"fmt"
     8  	"io"
     9  
    10  	"github.com/pkg/errors"
    11  )
    12  
    13  var (
    14  	_ Tensor = &Dense{}
    15  	_ Tensor = &CS{}
    16  	_ View   = &Dense{}
    17  )
    18  
    19  func init() {
    20  	gob.Register(&Dense{})
    21  	gob.Register(&CS{})
    22  }
    23  
    24  // Tensor represents a variety of n-dimensional arrays. The most commonly used tensor is the Dense tensor.
    25  // It can be used to represent a vector, matrix, 3D matrix and n-dimensional tensors.
    26  type Tensor interface {
    27  	// info about the ndarray
    28  	Shape() Shape
    29  	Strides() []int
    30  	Dtype() Dtype
    31  	Dims() int
    32  	Size() int
    33  	DataSize() int
    34  
    35  	// Data access related
    36  	RequiresIterator() bool
    37  	Iterator() Iterator
    38  	DataOrder() DataOrder
    39  
    40  	// ops
    41  	Slicer
    42  	At(...int) (interface{}, error)
    43  	SetAt(v interface{}, coord ...int) error
    44  	Reshape(...int) error
    45  	T(axes ...int) error
    46  	UT()
    47  	Transpose() error // Transpose actually moves the data
    48  	Apply(fn interface{}, opts ...FuncOpt) (Tensor, error)
    49  
    50  	// data related interface
    51  	Zeroer
    52  	MemSetter
    53  	Dataer
    54  	Eq
    55  	Cloner
    56  
    57  	// type overloading methods
    58  	IsScalar() bool
    59  	ScalarValue() interface{}
    60  
    61  	// engine/memory related stuff
    62  	// all Tensors should be able to be expressed of as a slab of memory
    63  	// Note: the size of each element can be acquired by T.Dtype().Size()
    64  	Memory                      // Tensors all implement Memory
    65  	Engine() Engine             // Engine can be nil
    66  	IsNativelyAccessible() bool // Can Go access the memory
    67  	IsManuallyManaged() bool    // Must Go manage the memory
    68  
    69  	// formatters
    70  	fmt.Formatter
    71  	fmt.Stringer
    72  
    73  	// all Tensors are serializable to these formats
    74  	WriteNpy(io.Writer) error
    75  	ReadNpy(io.Reader) error
    76  	gob.GobEncoder
    77  	gob.GobDecoder
    78  
    79  	standardEngine() standardEngine
    80  	headerer
    81  	arrayer
    82  }
    83  
    84  // New creates a new Dense Tensor. For sparse arrays use their relevant construction function
    85  func New(opts ...ConsOpt) *Dense {
    86  	d := borrowDense()
    87  	for _, opt := range opts {
    88  		opt(d)
    89  	}
    90  	d.fix()
    91  	if err := d.sanity(); err != nil {
    92  		panic(err)
    93  	}
    94  
    95  	return d
    96  }
    97  
    98  func assertDense(t Tensor) (*Dense, error) {
    99  	if t == nil {
   100  		return nil, errors.New("nil is not a *Dense")
   101  	}
   102  	if retVal, ok := t.(*Dense); ok {
   103  		return retVal, nil
   104  	}
   105  	if retVal, ok := t.(Densor); ok {
   106  		return retVal.Dense(), nil
   107  	}
   108  	return nil, errors.Errorf("%T is not *Dense", t)
   109  }
   110  
   111  func getDenseTensor(t Tensor) (DenseTensor, error) {
   112  	switch tt := t.(type) {
   113  	case DenseTensor:
   114  		return tt, nil
   115  	case Densor:
   116  		return tt.Dense(), nil
   117  	default:
   118  		return nil, errors.Errorf("Tensor %T is not a DenseTensor", t)
   119  	}
   120  }
   121  
   122  // getFloatDense extracts a *Dense from a Tensor and ensures that the .data is a Array that implements Float
   123  func getFloatDenseTensor(t Tensor) (retVal DenseTensor, err error) {
   124  	if t == nil {
   125  		return
   126  	}
   127  	if err = typeclassCheck(t.Dtype(), floatTypes); err != nil {
   128  		err = errors.Wrapf(err, "getFloatDense only handles floats. Got %v instead", t.Dtype())
   129  		return
   130  	}
   131  
   132  	if retVal, err = getDenseTensor(t); err != nil {
   133  		err = errors.Wrapf(err, opFail, "getFloatDense")
   134  		return
   135  	}
   136  	if retVal == nil {
   137  		return
   138  	}
   139  
   140  	return
   141  }
   142  
   143  // getFloatDense extracts a *Dense from a Tensor and ensures that the .data is a Array that implements Float
   144  func getFloatComplexDenseTensor(t Tensor) (retVal DenseTensor, err error) {
   145  	if t == nil {
   146  		return
   147  	}
   148  	if err = typeclassCheck(t.Dtype(), floatcmplxTypes); err != nil {
   149  		err = errors.Wrapf(err, "getFloatDense only handles floats and complex. Got %v instead", t.Dtype())
   150  		return
   151  	}
   152  
   153  	if retVal, err = getDenseTensor(t); err != nil {
   154  		err = errors.Wrapf(err, opFail, "getFloatDense")
   155  		return
   156  	}
   157  	if retVal == nil {
   158  		return
   159  	}
   160  
   161  	return
   162  }
   163  
   164  func sliceDense(t *Dense, slices ...Slice) (retVal *Dense, err error) {
   165  	var sliced Tensor
   166  	if sliced, err = t.Slice(slices...); err != nil {
   167  		return nil, err
   168  	}
   169  	return sliced.(*Dense), nil
   170  }