github.com/wzzhu/tensor@v0.9.24/engine.go (about)

     1  package tensor
     2  
     3  // Memory is a representation of memory of the value.
     4  //
     5  // The main reason for requiring both Uintptr() and Pointer() methods is because while Go currently does not have a compacting
     6  // garbage collector, from the docs of `unsafe`:
     7  //		Even if a uintptr holds the address of some object, the garbage collector, will not update that uintptr's value if the object moves,
     8  //		nor will that uintptr keep the object from being reclaimed.
     9  type Memory interface {
    10  	Uintptr() uintptr
    11  	MemSize() uintptr
    12  }
    13  
    14  // Engine is a representation of an execution engine.
    15  // While different execution engines can have different capabilities, all execution engines must be able to allocate and free memory
    16  type Engine interface {
    17  	AllocAccessible() bool                    // AllocAccessible returns true if the engine return Go-accessible memory pointers?
    18  	Alloc(size int64) (Memory, error)         // Alloc allocates memory
    19  	Free(mem Memory, size int64) error        // Free rees memory
    20  	Memset(mem Memory, val interface{}) error // Memset - duh
    21  	Memclr(mem Memory)                        // Memclr - duh
    22  	Memcpy(dst, src Memory) error             // Memcpy - duh
    23  	Accessible(mem Memory) (Memory, error)    // Accessible returns Go-accesible memory pointers, or errors, if it cannot be done
    24  	WorksWith(order DataOrder) bool           // WorksWith returns true if the data order can be directly worked with
    25  }
    26  
    27  type standardEngine interface {
    28  	Engine
    29  
    30  	Adder
    31  	Suber
    32  	Muler
    33  	Diver
    34  	Power
    35  	Moder
    36  	FMAer
    37  	MatMuler
    38  	MatVecMuler
    39  	OuterProder
    40  	Dotter
    41  	SVDer
    42  	Lter
    43  	Lteer
    44  	Gter
    45  	Gteer
    46  	ElEqer
    47  	MinBetweener
    48  	MaxBetweener
    49  
    50  	// Anything that returns interface{} cannot be added here because they will likely have additional
    51  	// optimized versions of the functions for types.
    52  	// For example: Tracer and InnerProder both have optimized interfaces for Float32 and Float64 which returns those types specifically.
    53  }
    54  
    55  type arrayMaker interface {
    56  	makeArray(arr *array, t Dtype, size int)
    57  }
    58  
    59  // NonStdEngine are any engines that do not allocate using the default built in allocator
    60  type NonStdEngine interface {
    61  	NonStdAlloc() // noop
    62  }
    63  
    64  /* Data Agnostic Execution Engine Methods */
    65  
    66  // Transposer is any engine that can perform an unsafe transpose of a tensor.
    67  type Transposer interface {
    68  	Transpose(t Tensor, expStrides []int) error
    69  }
    70  
    71  // Concater is any enegine that can concatenate multiple Tensors together
    72  type Concater interface {
    73  	Concat(t Tensor, axis int, others ...Tensor) (Tensor, error)
    74  }
    75  
    76  // Stacker is any engine that can stack multiple Tenosrs along an axis
    77  type Stacker interface {
    78  	Stack(t Tensor, axis int, others ...Tensor) (Tensor, error)
    79  }
    80  
    81  // DenseStacker is any engine that can stack DenseTensors along an axis. This is a specialization of Stacker.
    82  type DenseStacker interface {
    83  	StackDense(t DenseTensor, axis int, others ...DenseTensor) (retVal DenseTensor, err error)
    84  }
    85  
    86  // Repeater is any engine that can repeat values along the given axis.
    87  type Repeater interface {
    88  	Repeat(t Tensor, axis int, repeats ...int) (Tensor, error)
    89  	RepeatReuse(t Tensor, reuse Tensor, axis int, repeeats ...int) (Tensor, error)
    90  }
    91  
    92  // Diager is any engine that can return a tensor that only contains the diagonal values of the input
    93  type Diager interface {
    94  	Diag(a Tensor) (Tensor, error)
    95  }
    96  
    97  /* NUMBER INTERFACES
    98  All these are expected to be unsafe on the first tensor
    99  */
   100  
   101  // Adder is any engine that can perform elementwise addition.
   102  type Adder interface {
   103  	// Add performs a + b
   104  	Add(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   105  
   106  	// AddScalar adds a scalar to the tensor. leftTensor indicates if the tensor is the left operand.
   107  	// Whether or not the input tensor is clobbered is left to the implementation
   108  	AddScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   109  }
   110  
   111  // Suber is any engine that can perform elementwise subtraction.
   112  type Suber interface {
   113  	// Sub performs a - b
   114  	Sub(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   115  
   116  	// SubScalar subtracts a scalar from/to the tensor. leftTensor indicates if the tensor is the left operand.
   117  	// Whether or not the input tensor is clobbered is left to the implementation
   118  	SubScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   119  }
   120  
   121  // Muler is any engine that can perform elementwise multiplication.
   122  // For matrix multiplication, an engine should implement MatMul() or MatVecMul() or Inner()
   123  type Muler interface {
   124  	// Mul performs a * b
   125  	Mul(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   126  
   127  	// MulScalar multiplies a scalar to the tensor. leftTensor indicates if the tensor is the left operand.
   128  	// Whether or not the input tensor is clobbered is left to the implementation
   129  	MulScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   130  }
   131  
   132  // Diver is any engine that can perform elementwise division.
   133  type Diver interface {
   134  	// Div performs a / b
   135  	Div(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   136  
   137  	// DivScalar divides a scalar from/to the tensor. leftTensor indicates if the tensor is the left operand.
   138  	// Whether or not the input tensor is clobbered is left to the implementation
   139  	DivScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   140  }
   141  
   142  // Power is any engine that can perform elementwise Pow()
   143  type Power interface {
   144  	// Pow performs a ^ b
   145  	Pow(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   146  
   147  	// PowScalar exponentiates a scalar from/to the tensor. leftTensor indicates if the tensor is the left operand.
   148  	// Whether or not the input tensor is clobbered is left to the implementation
   149  	PowScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   150  }
   151  
   152  // Moder is any engine that can perform elementwise Mod()
   153  type Moder interface {
   154  	// Mod performs a % b
   155  	Mod(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   156  
   157  	// ModScalar performs a % b where one of the operands is scalar. leftTensor indicates if the tensor is the left operand.
   158  	// Whether or not hte input tensor is clobbered is left to the implementation
   159  	ModScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   160  }
   161  
   162  // MinBetweener is any engine that can perform an elementwise min=between.
   163  type MinBetweener interface {
   164  	MinBetween(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   165  
   166  	MinBetweenScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   167  }
   168  
   169  // MaxBetweener is any engine that can perform an elementwise ma<x-between.
   170  type MaxBetweener interface {
   171  	MaxBetween(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   172  
   173  	MaxBetweenScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   174  }
   175  
   176  /* LINEAR ALGEBRA INTERFACES */
   177  
   178  // Tracer is any engine that can return the trace (aka the sum of the diagonal elements).
   179  type Tracer interface {
   180  	Trace(a Tensor) (interface{}, error)
   181  }
   182  
   183  // FMAer is any engine that can perform fused multiply add functions: A * X + Y. Also known as Axpy.
   184  type FMAer interface {
   185  	FMA(a, x, y Tensor) (Tensor, error)
   186  	FMAScalar(a Tensor, x interface{}, y Tensor) (Tensor, error)
   187  }
   188  
   189  // MatMuler is any engine that can perform matrix multiplication
   190  type MatMuler interface {
   191  	MatMul(a, b, preallocated Tensor) error
   192  }
   193  
   194  // MatVecMuler is any engine that can perform matrix vector multiplication
   195  type MatVecMuler interface {
   196  	MatVecMul(a, b, preallocated Tensor) error
   197  }
   198  
   199  // InnerProder is any engine that can perform inner product multiplication
   200  type InnerProder interface {
   201  	Inner(a, b Tensor) (interface{}, error) // Inner always returns a scalar value
   202  }
   203  
   204  // InnerProderF32 is an optimization for float32 - results are returned as float32.
   205  type InnerProderF32 interface {
   206  	Inner(a, b Tensor) (float32, error)
   207  }
   208  
   209  // InnerProderF64 is an optimization for float64 - results are returned as float64
   210  type InnerProderF64 interface {
   211  	Inner(a, b Tensor) (float64, error)
   212  }
   213  
   214  // OuterProder is any engine that can perform outer product (kronecker) multiplication
   215  type OuterProder interface {
   216  	Outer(a, b, preallocated Tensor) error
   217  }
   218  
   219  // Dotter is used to implement sparse matrices
   220  type Dotter interface {
   221  	Dot(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   222  }
   223  
   224  // SVDer is any engine that can perform SVD
   225  type SVDer interface {
   226  	SVD(a Tensor, uv, full bool) (s, u, v Tensor, err error)
   227  }
   228  
   229  /* ORD INTERFACES */
   230  
   231  // Lter is any engine that can perform the Lt operation.
   232  type Lter interface {
   233  	Lt(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   234  	LtScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   235  }
   236  
   237  // Lteer is any engine that can perform the Lte operation.
   238  type Lteer interface {
   239  	Lte(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   240  	LteScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   241  }
   242  
   243  // Gter is any engine that can perform the Gt operation.
   244  type Gter interface {
   245  	Gt(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   246  	GtScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   247  }
   248  
   249  // Gteer is any engine that can perform the Gte operation.
   250  type Gteer interface {
   251  	Gte(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   252  	GteScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   253  }
   254  
   255  /* EQ INTERFACES */
   256  
   257  // ElEqer is any engine that can perform the elementwise equality comparison operation.
   258  type ElEqer interface {
   259  	ElEq(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   260  	EqScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   261  
   262  	ElNe(a, b Tensor, opts ...FuncOpt) (Tensor, error)
   263  	NeScalar(a Tensor, b interface{}, leftTensor bool, opts ...FuncOpt) (Tensor, error)
   264  }
   265  
   266  /* Unary Operators for Numbers */
   267  
   268  // Mapper is any engine that can map a function onto the values of a tensor.
   269  type Mapper interface {
   270  	Map(fn interface{}, a Tensor, opts ...FuncOpt) (Tensor, error)
   271  }
   272  
   273  // Neger is any engine that can negate the sign of the values in the tensor.d
   274  type Neger interface {
   275  	Neg(a Tensor, opts ...FuncOpt) (Tensor, error)
   276  }
   277  
   278  // Inver is any engine that can perform 1/x for each element in the Tensor.
   279  type Inver interface {
   280  	Inv(a Tensor, opts ...FuncOpt) (Tensor, error)
   281  }
   282  
   283  // Squarer is any engine that can square the values elementwise in a Tensor.
   284  type Squarer interface {
   285  	Square(a Tensor, opts ...FuncOpt) (Tensor, error)
   286  }
   287  
   288  // Cuber is any engine that can cube the values elementwise in a Tensor.
   289  type Cuber interface {
   290  	Cube(a Tensor, opts ...FuncOpt) (Tensor, error)
   291  }
   292  
   293  // Exper is any engine that can perform elementwise natural exponentiation on the values in a Tensor.
   294  type Exper interface {
   295  	Exp(a Tensor, opts ...FuncOpt) (Tensor, error)
   296  }
   297  
   298  // Tanher is any engine that can perform elementwise Tanh on the values in a Tensor.
   299  type Tanher interface {
   300  	Tanh(a Tensor, opts ...FuncOpt) (Tensor, error)
   301  }
   302  
   303  // Loger is any engine that can perform natural log on the values in a Tensor.
   304  type Loger interface {
   305  	Log(a Tensor, opt ...FuncOpt) (Tensor, error)
   306  }
   307  
   308  // Log2 is any engine that can perform base-2 logarithm on the values in a Tensor.
   309  type Log2er interface {
   310  	Log2(a Tensor, opt ...FuncOpt) (Tensor, error)
   311  }
   312  
   313  // Log10er is any engine that can perform base-10 logarithm on the values in a Tensor.
   314  type Log10er interface {
   315  	Log10(a Tensor, opt ...FuncOpt) (Tensor, error)
   316  }
   317  
   318  // Sqrter is any engine that can perform square root on the values in a Tensor.
   319  type Sqrter interface {
   320  	Sqrt(a Tensor, opt ...FuncOpt) (Tensor, error)
   321  }
   322  
   323  // Cbrter is any engine that can perform cube root on the values in a Tensor.
   324  type Cbrter interface {
   325  	Cbrt(a Tensor, opt ...FuncOpt) (Tensor, error)
   326  }
   327  
   328  // InvSqrter is any engine that can perform 1/sqrt(x) on the values of a Tensor.
   329  type InvSqrter interface {
   330  	InvSqrt(a Tensor, opts ...FuncOpt) (Tensor, error)
   331  }
   332  
   333  // Signer is any engine that can perform a sign function on the values of a Tensor.
   334  type Signer interface {
   335  	Sign(a Tensor, opts ...FuncOpt) (Tensor, error)
   336  }
   337  
   338  // Abser is any engine that can perform Abs on the values of a Tensor.
   339  type Abser interface {
   340  	Abs(a Tensor, opts ...FuncOpt) (Tensor, error)
   341  }
   342  
   343  // Clamper is any engine that can clamp the values in a tensor to between min and max.
   344  type Clamper interface {
   345  	Clamp(a Tensor, min, max interface{}, opts ...FuncOpt) (Tensor, error)
   346  }
   347  
   348  /* Reduction */
   349  
   350  // Reducer is any engine that can perform a reduction function.
   351  type Reducer interface {
   352  	Reduce(fn interface{}, a Tensor, axis int, defaultValue interface{}, opts ...FuncOpt) (Tensor, error)
   353  }
   354  
   355  // OptimizedReducer is any engine that can perform a reduction function with optimizations for the first dimension, last dimension and dimensions in between.
   356  type OptimizedReducer interface {
   357  	OptimizedReduce(a Tensor, axis int, firstFn, lastFn, defaultFn, defaultValue interface{}, opts ...FuncOpt) (Tensor, error)
   358  }
   359  
   360  // Sumer is any engine that can perform summation along an axis of a Tensor.
   361  type Sumer interface {
   362  	Sum(a Tensor, along ...int) (Tensor, error)
   363  }
   364  
   365  // Proder is any engine that can perform product along an axis of a Tensor.
   366  type Proder interface {
   367  	Prod(a Tensor, along ...int) (Tensor, error)
   368  }
   369  
   370  // Miner is any engine that can find the minimum value along an axis of a Tensor.
   371  type Miner interface {
   372  	Min(a Tensor, along ...int) (Tensor, error)
   373  }
   374  
   375  // Maxer is any engine that can find the maximum value along an axis of a Tensor.
   376  type Maxer interface {
   377  	Max(a Tensor, along ...int) (Tensor, error)
   378  }
   379  
   380  /* Arg methods */
   381  
   382  // Argmaxer is any engine that can find the indices of the maximum values along an axis.
   383  // By convention the returned Tensor has Dtype of Int.
   384  type Argmaxer interface {
   385  	Argmax(t Tensor, axis int) (Tensor, error)
   386  }
   387  
   388  // Argmaxer is any engine that can find the indices of the minimum values along an axis.
   389  // By convention the returned Tensor has Dtype of Int.
   390  type Argminer interface {
   391  	Argmin(t Tensor, axis int) (Tensor, error)
   392  }
   393  
   394  // NaNChecker checks that the tensor contains a NaN
   395  // Errors are to be returned if the concept of NaN does not apply to the data type.
   396  // Other errors may also occur. See specific implementations for details
   397  type NaNChecker interface {
   398  	HasNaN(t Tensor) (bool, error)
   399  }
   400  
   401  // InfChecker checks that the tensor contains a Inf.
   402  // Errors are to be returned if the concept of Inf does not apply to the data type.
   403  // Other errors may also occur. See specific implementations for details
   404  type InfChecker interface {
   405  	HasInf(t Tensor) (bool, error)
   406  }
   407  
   408  /* Advanced Indexing */
   409  
   410  // ByIndiceser allows for values in tensor `a` to be selected by the indices listed in the `indices` tensor.
   411  type ByIndiceser interface {
   412  	SelectByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error)
   413  	SelectByIndicesB(input, outGrad, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error)
   414  }
   415  
   416  /* Internal interfaces for faster shit */
   417  
   418  type denseArgmaxer interface {
   419  	argmaxDenseTensor(t DenseTensor, axis int) (*Dense, error)
   420  }
   421  
   422  type denseArgminer interface {
   423  	argminDenseTensor(t DenseTensor, axis int) (*Dense, error)
   424  }
   425  
   426  type SoftMaxer interface {
   427  	LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error)
   428  	LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error)
   429  
   430  	SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error)
   431  	SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error)
   432  }