gorgonia.org/tensor@v0.9.24/example_extension_matop_test.go (about)

     1  package tensor_test
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"gorgonia.org/tensor"
     7  )
     8  
     9  // In this example, we want to handle basic tensor operations for arbitray types (slicing, stacking, transposing)
    10  
    11  // LongStruct is a type that is an arbitrarily long struct
    12  type LongStruct struct {
    13  	a, b, c, d, e uint64
    14  }
    15  
    16  // Format implements fmt.Formatter for easier-to-read output of data
    17  func (ls LongStruct) Format(s fmt.State, c rune) {
    18  	fmt.Fprintf(s, "{a: %d, b: %d, c: %d, d: %d, e: %d}", ls.a, ls.b, ls.c, ls.d, ls.e)
    19  }
    20  
    21  type s int
    22  
    23  func (ss s) Start() int { return int(ss) }
    24  func (ss s) End() int   { return int(ss) + 1 }
    25  func (ss s) Step() int  { return 1 }
    26  
    27  func ExampleTranspose_extension() {
    28  	// For documentation if you're reading this on godoc:
    29  	//
    30  	// type LongStruct struct {
    31  	// 		a, b, c, d, e uint64
    32  	// }
    33  
    34  	T := tensor.New(tensor.WithShape(2, 2),
    35  		tensor.WithBacking([]LongStruct{
    36  			LongStruct{0, 0, 0, 0, 0},
    37  			LongStruct{1, 1, 1, 1, 1},
    38  			LongStruct{2, 2, 2, 2, 2},
    39  			LongStruct{3, 3, 3, 3, 3},
    40  		}),
    41  	)
    42  
    43  	fmt.Printf("Before:\n%v\n", T)
    44  	retVal, _ := tensor.Transpose(T) // an alternative would be to use T.T(); T.Transpose()
    45  	fmt.Printf("After:\n%v\n", retVal)
    46  
    47  	// Output:
    48  	// Before:
    49  	// ⎡{a: 0, b: 0, c: 0, d: 0, e: 0}  {a: 1, b: 1, c: 1, d: 1, e: 1}⎤
    50  	// ⎣{a: 2, b: 2, c: 2, d: 2, e: 2}  {a: 3, b: 3, c: 3, d: 3, e: 3}⎦
    51  	//
    52  	// After:
    53  	// ⎡{a: 0, b: 0, c: 0, d: 0, e: 0}  {a: 2, b: 2, c: 2, d: 2, e: 2}⎤
    54  	// ⎣{a: 1, b: 1, c: 1, d: 1, e: 1}  {a: 3, b: 3, c: 3, d: 3, e: 3}⎦
    55  }
    56  
    57  func Example_stackExtension() {
    58  	// For documentation if you're reading this on godoc:
    59  	//
    60  	// type LongStruct struct {
    61  	// a, b, c, d, e uint64
    62  	// }
    63  
    64  	T := tensor.New(tensor.WithShape(2, 2),
    65  		tensor.WithBacking([]LongStruct{
    66  			LongStruct{0, 0, 0, 0, 0},
    67  			LongStruct{1, 1, 1, 1, 1},
    68  			LongStruct{2, 2, 2, 2, 2},
    69  			LongStruct{3, 3, 3, 3, 3},
    70  		}),
    71  	)
    72  	S, _ := T.Slice(nil, s(1)) // s is a type that implements tensor.Slice
    73  	T2 := tensor.New(tensor.WithShape(2, 2),
    74  		tensor.WithBacking([]LongStruct{
    75  			LongStruct{10, 10, 10, 10, 10},
    76  			LongStruct{11, 11, 11, 11, 11},
    77  			LongStruct{12, 12, 12, 12, 12},
    78  			LongStruct{13, 13, 13, 13, 13},
    79  		}),
    80  	)
    81  	S2, _ := T2.Slice(nil, s(0))
    82  
    83  	// an alternative would be something like this
    84  	// T3, _ := S.(*tensor.Dense).Stack(1, S2.(*tensor.Dense))
    85  	T3, _ := tensor.Stack(1, S, S2)
    86  	fmt.Printf("Stacked:\n%v", T3)
    87  
    88  	// Output:
    89  	// Stacked:
    90  	// ⎡     {a: 1, b: 1, c: 1, d: 1, e: 1}  {a: 10, b: 10, c: 10, d: 10, e: 10}⎤
    91  	// ⎣     {a: 3, b: 3, c: 3, d: 3, e: 3}  {a: 12, b: 12, c: 12, d: 12, e: 12}⎦
    92  }