gorgonia.org/tensor@v0.9.24/example_byindices_test.go (about) 1 package tensor 2 3 import "fmt" 4 5 func ExampleByIndices() { 6 a := New(WithShape(2, 2), WithBacking([]float64{ 7 100, 200, 8 300, 400, 9 })) 10 indices := New(WithBacking([]int{1, 1, 1, 0, 1})) 11 b, err := ByIndices(a, indices, 0) // we select rows 1, 1, 1, 0, 1 12 if err != nil { 13 fmt.Println(err) 14 return 15 } 16 17 fmt.Printf("a:\n%v\nindices: %v\nb:\n%v\n", a, indices, b) 18 19 // Output: 20 // a: 21 // ⎡100 200⎤ 22 // ⎣300 400⎦ 23 // 24 // indices: [1 1 1 0 1] 25 // b: 26 // ⎡300 400⎤ 27 // ⎢300 400⎥ 28 // ⎢300 400⎥ 29 // ⎢100 200⎥ 30 // ⎣300 400⎦ 31 32 } 33 34 func ExampleByIndicesB() { 35 a := New(WithShape(2, 2), WithBacking([]float64{ 36 100, 200, 37 300, 400, 38 })) 39 indices := New(WithBacking([]int{1, 1, 1, 0, 1})) 40 b, err := ByIndices(a, indices, 0) // we select rows 1, 1, 1, 0, 1 41 if err != nil { 42 fmt.Println(err) 43 return 44 } 45 46 outGrad := b.Clone().(*Dense) 47 outGrad.Memset(1.0) 48 49 grad, err := ByIndicesB(a, outGrad, indices, 0) 50 if err != nil { 51 fmt.Println(err) 52 return 53 } 54 55 fmt.Printf("a:\n%v\nindices: %v\nb:\n%v\ngrad:\n%v", a, indices, b, grad) 56 57 // Output: 58 // a: 59 // ⎡100 200⎤ 60 // ⎣300 400⎦ 61 // 62 // indices: [1 1 1 0 1] 63 // b: 64 // ⎡300 400⎤ 65 // ⎢300 400⎥ 66 // ⎢300 400⎥ 67 // ⎢100 200⎥ 68 // ⎣300 400⎦ 69 // 70 // grad: 71 // ⎡1 1⎤ 72 // ⎣4 4⎦ 73 74 }