gorgonia.org/gorgonia@v0.9.17/benchmark_operations_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/assert"
     7  	"gorgonia.org/tensor"
     8  )
     9  
    10  func BenchmarkReshape_Dense(b *testing.B) {
    11  	for _, rst := range reshapeTests {
    12  		b.Run(rst.testName, func(b *testing.B) {
    13  			g := NewGraph()
    14  			tT := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(rst.input.Clone()...))
    15  			T := NodeFromAny(g, tT)
    16  			for i := 0; i < b.N; i++ {
    17  				T2, err := Reshape(T, rst.to.Clone())
    18  				switch {
    19  				case rst.err && err == nil:
    20  					b.Fatalf("Expected Error when testing %v", rst)
    21  				case rst.err:
    22  					continue
    23  				case err != nil:
    24  					b.Fatal(err)
    25  				default:
    26  					assert.True(b, rst.output.Eq(T2.Shape()), "expected both to be the same")
    27  				}
    28  			}
    29  			m := NewTapeMachine(g)
    30  			if err := m.RunAll(); err != nil {
    31  				b.Fatal(err)
    32  			}
    33  
    34  		})
    35  	}
    36  }