gorgonia.org/gorgonia@v0.9.17/ops/nn/conv_test.go (about) 1 package nnops_test 2 3 import ( 4 "fmt" 5 "log" 6 "runtime" 7 8 "gorgonia.org/gorgonia" 9 nnops "gorgonia.org/gorgonia/ops/nn" 10 "gorgonia.org/tensor" 11 ) 12 13 func ExampleConv2d() { 14 g := gorgonia.NewGraph() 15 x := gorgonia.NodeFromAny(g, tensor.New( 16 tensor.WithShape(1, 1, 7, 5), 17 tensor.WithBacking([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34}))) 18 filter := gorgonia.NodeFromAny(g, tensor.New( 19 tensor.WithShape(1, 1, 3, 3), 20 tensor.WithBacking([]float32{1, 1, 1, 1, 1, 1, 1, 1, 1}))) 21 y := gorgonia.Must(nnops.Conv2d(x, filter, []int{3, 3}, []int{0, 0}, []int{2, 2}, []int{1, 1})) 22 m := gorgonia.NewTapeMachine(g) 23 runtime.LockOSThread() 24 for i := 0; i < 1000; i++ { 25 if err := m.RunAll(); err != nil { 26 log.Fatalf("iteration: %d. Err: %v", i, err) 27 } 28 } 29 runtime.UnlockOSThread() 30 31 fmt.Printf("%1.1f", y.Value()) 32 // output: 33 // ⎡ 54.0 72.0⎤ 34 // ⎢144.0 162.0⎥ 35 // ⎣234.0 252.0⎦ 36 }