gorgonia.org/gorgonia@v0.9.17/ops/nn/conv_test.go (about)

     1  package nnops_test
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  	"runtime"
     7  
     8  	"gorgonia.org/gorgonia"
     9  	nnops "gorgonia.org/gorgonia/ops/nn"
    10  	"gorgonia.org/tensor"
    11  )
    12  
    13  func ExampleConv2d() {
    14  	g := gorgonia.NewGraph()
    15  	x := gorgonia.NodeFromAny(g, tensor.New(
    16  		tensor.WithShape(1, 1, 7, 5),
    17  		tensor.WithBacking([]float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34})))
    18  	filter := gorgonia.NodeFromAny(g, tensor.New(
    19  		tensor.WithShape(1, 1, 3, 3),
    20  		tensor.WithBacking([]float32{1, 1, 1, 1, 1, 1, 1, 1, 1})))
    21  	y := gorgonia.Must(nnops.Conv2d(x, filter, []int{3, 3}, []int{0, 0}, []int{2, 2}, []int{1, 1}))
    22  	m := gorgonia.NewTapeMachine(g)
    23  	runtime.LockOSThread()
    24  	for i := 0; i < 1000; i++ {
    25  		if err := m.RunAll(); err != nil {
    26  			log.Fatalf("iteration: %d. Err: %v", i, err)
    27  		}
    28  	}
    29  	runtime.UnlockOSThread()
    30  
    31  	fmt.Printf("%1.1f", y.Value())
    32  	// output:
    33  	// ⎡ 54.0   72.0⎤
    34  	// ⎢144.0  162.0⎥
    35  	// ⎣234.0  252.0⎦
    36  }