gorgonia.org/gorgonia@v0.9.17/op_nondiff_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"gorgonia.org/tensor"
     7  )
     8  
     9  func ExampleDiagFlat() {
    10  	g := NewGraph()
    11  
    12  	// 2 dimensional
    13  	aV := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4}))
    14  	a := NodeFromAny(g, aV)
    15  	b, err := DiagFlat(a)
    16  	if err != nil {
    17  		fmt.Println(err)
    18  		return
    19  	}
    20  	m := NewTapeMachine(g)
    21  	if err := m.RunAll(); err != nil {
    22  		fmt.Println(err)
    23  		return
    24  	}
    25  	fmt.Printf("a:\n%v\n", a.Value())
    26  	fmt.Printf("b:\n%v\n", b.Value())
    27  
    28  	// 3 dimensional
    29  	aV = tensor.New(tensor.WithShape(2, 3, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}))
    30  	a = NodeFromAny(g, aV, WithName("a'"))
    31  	b2, err := DiagFlat(a)
    32  	if err != nil {
    33  		fmt.Println(err)
    34  		return
    35  	}
    36  	m = NewTapeMachine(g)
    37  	if err := m.RunAll(); err != nil {
    38  		fmt.Println(err)
    39  	}
    40  
    41  	fmt.Printf("a:\n%v", a.Value())
    42  	fmt.Printf("b:\n%v\n", b2.Value())
    43  
    44  	// 1 dimensional
    45  	aV = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{1, 2}))
    46  	a = NodeFromAny(g, aV, WithName("a''"))
    47  	b3, err := DiagFlat(a)
    48  	if err != nil {
    49  		fmt.Println(err)
    50  		return
    51  	}
    52  	m = NewTapeMachine(g)
    53  	if err := m.RunAll(); err != nil {
    54  		fmt.Println(err)
    55  	}
    56  
    57  	fmt.Printf("a:\n%v\n", a.Value())
    58  	fmt.Printf("b:\n%v\n", b3.Value())
    59  
    60  	// Scalars
    61  
    62  	a = NodeFromAny(g, 100.0, WithName("aScalar"))
    63  	_, err = DiagFlat(a)
    64  	fmt.Println(err)
    65  
    66  	// Output:
    67  	// a:
    68  	// ⎡1  2⎤
    69  	// ⎣3  4⎦
    70  	//
    71  	// b:
    72  	// ⎡1  0  0  0⎤
    73  	// ⎢0  2  0  0⎥
    74  	// ⎢0  0  3  0⎥
    75  	// ⎣0  0  0  4⎦
    76  	//
    77  	// a:
    78  	// ⎡ 1   2⎤
    79  	// ⎢ 3   4⎥
    80  	// ⎣ 5   6⎦
    81  	//
    82  	// ⎡ 7   8⎤
    83  	// ⎢ 9  10⎥
    84  	// ⎣11  12⎦
    85  	//
    86  	//
    87  	// b:
    88  	// ⎡ 1   0   0   0  ...  0   0   0   0⎤
    89  	// ⎢ 0   2   0   0  ...  0   0   0   0⎥
    90  	// ⎢ 0   0   3   0  ...  0   0   0   0⎥
    91  	// ⎢ 0   0   0   4  ...  0   0   0   0⎥
    92  	// .
    93  	// .
    94  	// .
    95  	// ⎢ 0   0   0   0  ...  9   0   0   0⎥
    96  	// ⎢ 0   0   0   0  ...  0  10   0   0⎥
    97  	// ⎢ 0   0   0   0  ...  0   0  11   0⎥
    98  	// ⎣ 0   0   0   0  ...  0   0   0  12⎦
    99  	//
   100  	// a:
   101  	// [1  2]
   102  	// b:
   103  	// ⎡1  0⎤
   104  	// ⎣0  2⎦
   105  	//
   106  	// Cannot perform DiagFlat on a scalar equivalent node
   107  
   108  }