gorgonia.org/tensor@v0.9.24/example_extension_test.go (about)

     1  package tensor_test
     2  
     3  import (
     4  	//"errors"
     5  	"fmt"
     6  	"reflect"
     7  
     8  	"github.com/pkg/errors"
     9  	"gorgonia.org/tensor"
    10  )
    11  
    12  // In this example, we want to create and handle a tensor of *MyType
    13  
    14  // First, define MyType
    15  
    16  // MyType is defined
    17  type MyType struct {
    18  	x, y int
    19  }
    20  
    21  func (T MyType) Format(s fmt.State, c rune) { fmt.Fprintf(s, "(%d, %d)", T.x, T.y) }
    22  
    23  // MyDtype this the dtype of MyType. This value is populated in the init() function below
    24  var MyDtype tensor.Dtype
    25  
    26  // MyEngine supports additions of MyType, as well as other Dtypes
    27  type MyEngine struct {
    28  	tensor.StdEng
    29  }
    30  
    31  // For simplicity's sake, we'd only want to handle MyType-MyType or MyType-Int interactions
    32  // Also, we only expect Dense tensors
    33  // You're of course free to define your own rules
    34  
    35  // Add adds two tensors
    36  func (e MyEngine) Add(a, b tensor.Tensor, opts ...tensor.FuncOpt) (retVal tensor.Tensor, err error) {
    37  	switch a.Dtype() {
    38  	case MyDtype:
    39  		switch b.Dtype() {
    40  		case MyDtype:
    41  			data := a.Data().([]*MyType)
    42  			datb := b.Data().([]*MyType)
    43  			for i, v := range data {
    44  				v.x += datb[i].x
    45  				v.y += datb[i].y
    46  			}
    47  			return a, nil
    48  		case tensor.Int:
    49  			data := a.Data().([]*MyType)
    50  			datb := b.Data().([]int)
    51  			for i, v := range data {
    52  				v.x += datb[i]
    53  				v.y += datb[i]
    54  			}
    55  			return a, nil
    56  		}
    57  	case tensor.Int:
    58  		switch b.Dtype() {
    59  		case MyDtype:
    60  			data := a.Data().([]int)
    61  			datb := b.Data().([]*MyType)
    62  			for i, v := range datb {
    63  				v.x += data[i]
    64  				v.y += data[i]
    65  			}
    66  		default:
    67  			return e.StdEng.Add(a, b, opts...)
    68  		}
    69  	default:
    70  		return e.StdEng.Add(a, b, opts...)
    71  	}
    72  	return nil, errors.New("Unreachable")
    73  }
    74  
    75  func init() {
    76  	MyDtype = tensor.Dtype{reflect.TypeOf(&MyType{})}
    77  }
    78  
    79  func Example_extension() {
    80  	T := tensor.New(tensor.WithEngine(MyEngine{}),
    81  		tensor.WithShape(2, 2),
    82  		tensor.WithBacking([]*MyType{
    83  			&MyType{0, 0}, &MyType{0, 1},
    84  			&MyType{1, 0}, &MyType{1, 1},
    85  		}))
    86  	ones := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]int{1, 1, 1, 1}), tensor.WithEngine(MyEngine{}))
    87  	T2, _ := T.Add(ones)
    88  
    89  	fmt.Printf("T:\n%+v", T)
    90  	fmt.Printf("T2:\n%+v", T2)
    91  
    92  	// output:
    93  	//T:
    94  	// Matrix (2, 2) [2 1]
    95  	// ⎡(1, 1)  (1, 2)⎤
    96  	// ⎣(2, 1)  (2, 2)⎦
    97  	// T2:
    98  	// Matrix (2, 2) [2 1]
    99  	// ⎡(1, 1)  (1, 2)⎤
   100  	// ⎣(2, 1)  (2, 2)⎦
   101  }