gorgonia.org/gorgonia@v0.9.17/shape_test.go (about) 1 package gorgonia_test 2 3 import ( 4 "fmt" 5 6 . "gorgonia.org/gorgonia" 7 "gorgonia.org/tensor" 8 ) 9 10 func Example_keepDims() { 11 g := NewGraph() 12 a := NodeFromAny(g, tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6}))) 13 m1, _ := Mean(a, 1) 14 m2, _ := KeepDims(a, false, func(a *Node) (*Node, error) { return Mean(a, 1) }) 15 m3, _ := Mean(a, 0) 16 m4, _ := KeepDims(a, true, func(a *Node) (*Node, error) { return Mean(a, 0) }) 17 m5, _ := KeepDims(a, true, func(a *Node) (*Node, error) { return Mean(a) }) 18 19 // these reads are necessary as the VM may feel free to clobber the underlying data. 20 // e.g. if m1.Value() is used in the print statement below, the answer will be wrong. 21 // This is because before the VM executes the operations, a check is done to see if unsafe 22 // operations may be done. Unsafe operations are useful in saving memory. 23 // In this example, Reshape can be unsafely done if no other node is "using" m1, 24 // so m1.Value() will have its shape clobbered. Thus if m1.Value() is read after the VM has run, 25 // there is no guarantee that the data is correct. The only way around this is to "use" m1, by the Read() function. 26 var m1v, m2v, m3v, m4v Value 27 Read(m1, &m1v) 28 Read(m2, &m2v) 29 Read(m3, &m3v) 30 Read(m4, &m4v) 31 32 vm := NewTapeMachine(g) 33 if err := vm.RunAll(); err != nil { 34 panic(err) 35 } 36 37 fmt.Printf("a:\n%v\n", a.Value()) 38 fmt.Printf("m1 (shape: %v):\n%v\n", m1.Value().Shape(), m1v) 39 fmt.Printf("m2 (shape: %v):\n%v\n", m2.Value().Shape(), m2v) 40 fmt.Printf("m3 (shape: %v):\n%v\n", m3.Value().Shape(), m3v) 41 fmt.Printf("m4 (shape: %v):\n%v\n", m4.Value().Shape(), m4v) 42 fmt.Printf("m5 (shape: %v):\n%v\n", m5.Value().Shape(), m5.Value()) 43 44 // Output: 45 // a: 46 // ⎡1 2 3⎤ 47 // ⎣4 5 6⎦ 48 // 49 // m1 (shape: (2)): 50 // [2 5] 51 // m2 (shape: (2, 1)): 52 // C[2 5] 53 // m3 (shape: (3)): 54 // [2.5 3.5 4.5] 55 // m4 (shape: (1, 3)): 56 // R[2.5 3.5 4.5] 57 // m5 (shape: (1, 1)): 58 // [[3.5]] 59 60 }