gorgonia.org/gorgonia@v0.9.17/x/vm/bench_test.go (about) 1 package xvm 2 3 import ( 4 "context" 5 "testing" 6 7 "gorgonia.org/gorgonia" 8 "gorgonia.org/tensor" 9 ) 10 11 func BenchmarkMachine_Run(b *testing.B) { 12 g := gorgonia.NewGraph() 13 xV := tensor.New(tensor.WithShape(1, 1, 5, 5), tensor.WithBacking([]float32{ 14 0, 0, 0, 0, 0, 15 1, 1, 1, 1, 1, 16 2, 2, 2, 2, 2, 17 3, 3, 3, 3, 3, 18 4, 4, 4, 4, 4, 19 })) 20 kernelV := tensor.New(tensor.WithShape(1, 1, 3, 3), tensor.WithBacking([]float32{ 21 1, 1, 1, 22 1, 1, 1, 23 1, 1, 1, 24 })) 25 26 x := gorgonia.NewTensor(g, gorgonia.Float32, 4, gorgonia.WithShape(1, 1, 5, 5), gorgonia.WithValue(xV), gorgonia.WithName("x")) 27 w := gorgonia.NewTensor(g, gorgonia.Float32, 4, gorgonia.WithShape(1, 1, 3, 3), gorgonia.WithValue(kernelV), gorgonia.WithName("w")) 28 29 _, err := gorgonia.Conv2d(x, w, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}) 30 if err != nil { 31 b.Fatal(err) 32 } 33 // logger := log.New(os.Stderr, "", 0) 34 // vm := NewTapeMachine(g, WithLogger(logger), WithWatchlist(), WithValueFmt("%#v")) 35 36 vm := NewMachine(g) 37 ctx := context.Background() 38 b.ResetTimer() 39 for i := 0; i < b.N; i++ { 40 //ctx, cancel := context.WithTimeout(context.Background(), 900*time.Millisecond) 41 if err := vm.Run(ctx); err != nil { 42 b.Fatal(err) 43 } 44 } 45 vm.Close() 46 } 47 func BenchmarkMachine_RunTapeMachine(b *testing.B) { 48 g := gorgonia.NewGraph() 49 xV := tensor.New(tensor.WithShape(1, 1, 5, 5), tensor.WithBacking([]float32{ 50 0, 0, 0, 0, 0, 51 1, 1, 1, 1, 1, 52 2, 2, 2, 2, 2, 53 3, 3, 3, 3, 3, 54 4, 4, 4, 4, 4, 55 })) 56 kernelV := tensor.New(tensor.WithShape(1, 1, 3, 3), tensor.WithBacking([]float32{ 57 1, 1, 1, 58 1, 1, 1, 59 1, 1, 1, 60 })) 61 62 x := gorgonia.NewTensor(g, gorgonia.Float32, 4, gorgonia.WithShape(1, 1, 5, 5), gorgonia.WithValue(xV), gorgonia.WithName("x")) 63 w := gorgonia.NewTensor(g, gorgonia.Float32, 4, gorgonia.WithShape(1, 1, 3, 3), gorgonia.WithValue(kernelV), gorgonia.WithName("w")) 64 65 _, err := gorgonia.Conv2d(x, w, tensor.Shape{3, 3}, []int{1, 1}, []int{1, 1}, []int{1, 1}) 66 if err != nil { 67 b.Fatal(err) 68 } 69 // logger := log.New(os.Stderr, "", 0) 70 // vm := NewTapeMachine(g, WithLogger(logger), WithWatchlist(), WithValueFmt("%#v")) 71 72 vm := gorgonia.NewTapeMachine(g) 73 b.ResetTimer() 74 75 for i := 0; i < b.N; i++ { 76 if err := vm.RunAll(); err != nil { 77 b.Fatal(err) 78 } 79 vm.Reset() 80 } 81 vm.Close() 82 }