gorgonia.org/gorgonia@v0.9.17/x/vm/bench_test.go (about)

     1  package xvm
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  
     7  	"gorgonia.org/gorgonia"
     8  	"gorgonia.org/tensor"
     9  )
    10  
    11  func BenchmarkMachine_Run(b *testing.B) {
    12  	g := gorgonia.NewGraph()
    13  	xV := tensor.New(tensor.WithShape(1, 1, 5, 5), tensor.WithBacking([]float32{
    14  		0, 0, 0, 0, 0,
    15  		1, 1, 1, 1, 1,
    16  		2, 2, 2, 2, 2,
    17  		3, 3, 3, 3, 3,
    18  		4, 4, 4, 4, 4,
    19  	}))
    20  	kernelV := tensor.New(tensor.WithShape(1, 1, 3, 3), tensor.WithBacking([]float32{
    21  		1, 1, 1,
    22  		1, 1, 1,
    23  		1, 1, 1,
    24  	}))
    25  
    26  	x := gorgonia.NewTensor(g, gorgonia.Float32, 4, gorgonia.WithShape(1, 1, 5, 5), gorgonia.WithValue(xV), gorgonia.WithName("x"))
    27  	w := gorgonia.NewTensor(g, gorgonia.Float32, 4, gorgonia.WithShape(1, 1, 3, 3), gorgonia.WithValue(kernelV), gorgonia.WithName("w"))
    28  
    29  	_, err := gorgonia.Conv2d(x, w, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1})
    30  	if err != nil {
    31  		b.Fatal(err)
    32  	}
    33  	// logger := log.New(os.Stderr, "", 0)
    34  	// vm := NewTapeMachine(g, WithLogger(logger), WithWatchlist(), WithValueFmt("%#v"))
    35  
    36  	vm := NewMachine(g)
    37  	ctx := context.Background()
    38  	b.ResetTimer()
    39  	for i := 0; i < b.N; i++ {
    40  		//ctx, cancel := context.WithTimeout(context.Background(), 900*time.Millisecond)
    41  		if err := vm.Run(ctx); err != nil {
    42  			b.Fatal(err)
    43  		}
    44  	}
    45  	vm.Close()
    46  }
    47  func BenchmarkMachine_RunTapeMachine(b *testing.B) {
    48  	g := gorgonia.NewGraph()
    49  	xV := tensor.New(tensor.WithShape(1, 1, 5, 5), tensor.WithBacking([]float32{
    50  		0, 0, 0, 0, 0,
    51  		1, 1, 1, 1, 1,
    52  		2, 2, 2, 2, 2,
    53  		3, 3, 3, 3, 3,
    54  		4, 4, 4, 4, 4,
    55  	}))
    56  	kernelV := tensor.New(tensor.WithShape(1, 1, 3, 3), tensor.WithBacking([]float32{
    57  		1, 1, 1,
    58  		1, 1, 1,
    59  		1, 1, 1,
    60  	}))
    61  
    62  	x := gorgonia.NewTensor(g, gorgonia.Float32, 4, gorgonia.WithShape(1, 1, 5, 5), gorgonia.WithValue(xV), gorgonia.WithName("x"))
    63  	w := gorgonia.NewTensor(g, gorgonia.Float32, 4, gorgonia.WithShape(1, 1, 3, 3), gorgonia.WithValue(kernelV), gorgonia.WithName("w"))
    64  
    65  	_, err := gorgonia.Conv2d(x, w, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1})
    66  	if err != nil {
    67  		b.Fatal(err)
    68  	}
    69  	// logger := log.New(os.Stderr, "", 0)
    70  	// vm := NewTapeMachine(g, WithLogger(logger), WithWatchlist(), WithValueFmt("%#v"))
    71  
    72  	vm := gorgonia.NewTapeMachine(g)
    73  	b.ResetTimer()
    74  
    75  	for i := 0; i < b.N; i++ {
    76  		if err := vm.RunAll(); err != nil {
    77  			b.Fatal(err)
    78  		}
    79  		vm.Reset()
    80  	}
    81  	vm.Close()
    82  }