gorgonia.org/gorgonia@v0.9.17/op_nondiff_test.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 6 "gorgonia.org/tensor" 7 ) 8 9 func ExampleDiagFlat() { 10 g := NewGraph() 11 12 // 2 dimensional 13 aV := tensor.New(tensor.WithShape(2, 2), tensor.WithBacking([]float64{1, 2, 3, 4})) 14 a := NodeFromAny(g, aV) 15 b, err := DiagFlat(a) 16 if err != nil { 17 fmt.Println(err) 18 return 19 } 20 m := NewTapeMachine(g) 21 if err := m.RunAll(); err != nil { 22 fmt.Println(err) 23 return 24 } 25 fmt.Printf("a:\n%v\n", a.Value()) 26 fmt.Printf("b:\n%v\n", b.Value()) 27 28 // 3 dimensional 29 aV = tensor.New(tensor.WithShape(2, 3, 2), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})) 30 a = NodeFromAny(g, aV, WithName("a'")) 31 b2, err := DiagFlat(a) 32 if err != nil { 33 fmt.Println(err) 34 return 35 } 36 m = NewTapeMachine(g) 37 if err := m.RunAll(); err != nil { 38 fmt.Println(err) 39 } 40 41 fmt.Printf("a:\n%v", a.Value()) 42 fmt.Printf("b:\n%v\n", b2.Value()) 43 44 // 1 dimensional 45 aV = tensor.New(tensor.WithShape(2), tensor.WithBacking([]float64{1, 2})) 46 a = NodeFromAny(g, aV, WithName("a''")) 47 b3, err := DiagFlat(a) 48 if err != nil { 49 fmt.Println(err) 50 return 51 } 52 m = NewTapeMachine(g) 53 if err := m.RunAll(); err != nil { 54 fmt.Println(err) 55 } 56 57 fmt.Printf("a:\n%v\n", a.Value()) 58 fmt.Printf("b:\n%v\n", b3.Value()) 59 60 // Scalars 61 62 a = NodeFromAny(g, 100.0, WithName("aScalar")) 63 _, err = DiagFlat(a) 64 fmt.Println(err) 65 66 // Output: 67 // a: 68 // ⎡1 2⎤ 69 // ⎣3 4⎦ 70 // 71 // b: 72 // ⎡1 0 0 0⎤ 73 // ⎢0 2 0 0⎥ 74 // ⎢0 0 3 0⎥ 75 // ⎣0 0 0 4⎦ 76 // 77 // a: 78 // ⎡ 1 2⎤ 79 // ⎢ 3 4⎥ 80 // ⎣ 5 6⎦ 81 // 82 // ⎡ 7 8⎤ 83 // ⎢ 9 10⎥ 84 // ⎣11 12⎦ 85 // 86 // 87 // b: 88 // ⎡ 1 0 0 0 ... 0 0 0 0⎤ 89 // ⎢ 0 2 0 0 ... 0 0 0 0⎥ 90 // ⎢ 0 0 3 0 ... 0 0 0 0⎥ 91 // ⎢ 0 0 0 4 ... 0 0 0 0⎥ 92 // . 93 // . 94 // . 95 // ⎢ 0 0 0 0 ... 9 0 0 0⎥ 96 // ⎢ 0 0 0 0 ... 0 10 0 0⎥ 97 // ⎢ 0 0 0 0 ... 0 0 11 0⎥ 98 // ⎣ 0 0 0 0 ... 0 0 0 12⎦ 99 // 100 // a: 101 // [1 2] 102 // b: 103 // ⎡1 0⎤ 104 // ⎣0 2⎦ 105 // 106 // Cannot perform DiagFlat on a scalar equivalent node 107 108 }