gorgonia.org/tensor@v0.9.24/example_byindices_test.go (about)

     1  package tensor
     2  
     3  import "fmt"
     4  
     5  func ExampleByIndices() {
     6  	a := New(WithShape(2, 2), WithBacking([]float64{
     7  		100, 200,
     8  		300, 400,
     9  	}))
    10  	indices := New(WithBacking([]int{1, 1, 1, 0, 1}))
    11  	b, err := ByIndices(a, indices, 0) // we select rows 1, 1, 1, 0, 1
    12  	if err != nil {
    13  		fmt.Println(err)
    14  		return
    15  	}
    16  
    17  	fmt.Printf("a:\n%v\nindices: %v\nb:\n%v\n", a, indices, b)
    18  
    19  	// Output:
    20  	// a:
    21  	// ⎡100  200⎤
    22  	// ⎣300  400⎦
    23  	//
    24  	// indices: [1  1  1  0  1]
    25  	// b:
    26  	// ⎡300  400⎤
    27  	// ⎢300  400⎥
    28  	// ⎢300  400⎥
    29  	// ⎢100  200⎥
    30  	// ⎣300  400⎦
    31  
    32  }
    33  
    34  func ExampleByIndicesB() {
    35  	a := New(WithShape(2, 2), WithBacking([]float64{
    36  		100, 200,
    37  		300, 400,
    38  	}))
    39  	indices := New(WithBacking([]int{1, 1, 1, 0, 1}))
    40  	b, err := ByIndices(a, indices, 0) // we select rows 1, 1, 1, 0, 1
    41  	if err != nil {
    42  		fmt.Println(err)
    43  		return
    44  	}
    45  
    46  	outGrad := b.Clone().(*Dense)
    47  	outGrad.Memset(1.0)
    48  
    49  	grad, err := ByIndicesB(a, outGrad, indices, 0)
    50  	if err != nil {
    51  		fmt.Println(err)
    52  		return
    53  	}
    54  
    55  	fmt.Printf("a:\n%v\nindices: %v\nb:\n%v\ngrad:\n%v", a, indices, b, grad)
    56  
    57  	// Output:
    58  	// a:
    59  	// ⎡100  200⎤
    60  	// ⎣300  400⎦
    61  	//
    62  	// indices: [1  1  1  0  1]
    63  	// b:
    64  	// ⎡300  400⎤
    65  	// ⎢300  400⎥
    66  	// ⎢300  400⎥
    67  	// ⎢100  200⎥
    68  	// ⎣300  400⎦
    69  	//
    70  	// grad:
    71  	// ⎡1  1⎤
    72  	// ⎣4  4⎦
    73  
    74  }