gorgonia.org/tensor@v0.9.24/example_extension_test.go (about) 1 package tensor_test 2 3 import ( 4 //"errors" 5 "fmt" 6 "reflect" 7 8 "github.com/pkg/errors" 9 "gorgonia.org/tensor" 10 ) 11 12 // In this example, we want to create and handle a tensor of *MyType 13 14 // First, define MyType 15 16 // MyType is defined 17 type MyType struct { 18 x, y int 19 } 20 21 func (T MyType) Format(s fmt.State, c rune) { fmt.Fprintf(s, "(%d, %d)", T.x, T.y) } 22 23 // MyDtype this the dtype of MyType. This value is populated in the init() function below 24 var MyDtype tensor.Dtype 25 26 // MyEngine supports additions of MyType, as well as other Dtypes 27 type MyEngine struct { 28 tensor.StdEng 29 } 30 31 // For simplicity's sake, we'd only want to handle MyType-MyType or MyType-Int interactions 32 // Also, we only expect Dense tensors 33 // You're of course free to define your own rules 34 35 // Add adds two tensors 36 func (e MyEngine) Add(a, b tensor.Tensor, opts ...tensor.FuncOpt) (retVal tensor.Tensor, err error) { 37 switch a.Dtype() { 38 case MyDtype: 39 switch b.Dtype() { 40 case MyDtype: 41 data := a.Data().([]*MyType) 42 datb := b.Data().([]*MyType) 43 for i, v := range data { 44 v.x += datb[i].x 45 v.y += datb[i].y 46 } 47 return a, nil 48 case tensor.Int: 49 data := a.Data().([]*MyType) 50 datb := b.Data().([]int) 51 for i, v := range data { 52 v.x += datb[i] 53 v.y += datb[i] 54 } 55 return a, nil 56 } 57 case tensor.Int: 58 switch b.Dtype() { 59 case MyDtype: 60 data := a.Data().([]int) 61 datb := b.Data().([]*MyType) 62 for i, v := range datb { 63 v.x += data[i] 64 v.y += data[i] 65 } 66 default: 67 return e.StdEng.Add(a, b, opts...) 68 } 69 default: 70 return e.StdEng.Add(a, b, opts...) 71 } 72 return nil, errors.New("Unreachable") 73 } 74 75 func init() { 76 MyDtype = tensor.Dtype{reflect.TypeOf(&MyType{})} 77 } 78 79 func Example_extension() { 80 T := tensor.New(tensor.WithEngine(MyEngine{}), 81 tensor.WithShape(2, 2), 82 tensor.WithBacking([]*MyType{ 83 &MyType{0, 0}, &MyType{0, 1}, 84 &MyType{1, 0}, &MyType{1, 1}, 85 })) 86 ones := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]int{1, 1, 1, 1}), tensor.WithEngine(MyEngine{})) 87 T2, _ := T.Add(ones) 88 89 fmt.Printf("T:\n%+v", T) 90 fmt.Printf("T2:\n%+v", T2) 91 92 // output: 93 //T: 94 // Matrix (2, 2) [2 1] 95 // ⎡(1, 1) (1, 2)⎤ 96 // ⎣(2, 1) (2, 2)⎦ 97 // T2: 98 // Matrix (2, 2) [2 1] 99 // ⎡(1, 1) (1, 2)⎤ 100 // ⎣(2, 1) (2, 2)⎦ 101 }