gorgonia.org/gorgonia@v0.9.17/testsetup_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  	"math/rand"
     7  	"reflect"
     8  	"runtime"
     9  
    10  	"github.com/chewxy/hm"
    11  	"github.com/pkg/errors"
    12  	"github.com/stretchr/testify/assert"
    13  	"gorgonia.org/dawson"
    14  	"gorgonia.org/tensor"
    15  
    16  	"testing"
    17  )
    18  
    19  type errorStacker interface {
    20  	ErrorStack() string
    21  }
    22  
    23  func floatsEqual64(a, b []float64) bool {
    24  	if len(a) != len(b) {
    25  		return false
    26  	}
    27  
    28  	for i, v := range a {
    29  		if !dawson.CloseF64(v, b[i]) {
    30  			return false
    31  		}
    32  	}
    33  	return true
    34  }
    35  
    36  func floatsEqual32(a, b []float32) bool {
    37  	if len(a) != len(b) {
    38  		return false
    39  	}
    40  
    41  	for i, v := range a {
    42  		if !dawson.CloseF32(v, b[i]) {
    43  			return false
    44  		}
    45  	}
    46  	return true
    47  }
    48  
    49  func extractF64s(v Value) []float64 {
    50  	return v.Data().([]float64)
    51  }
    52  
    53  func extractF64(v Value) float64 {
    54  	switch vt := v.(type) {
    55  	case *F64:
    56  		return float64(*vt)
    57  	case tensor.Tensor:
    58  		if !vt.IsScalar() {
    59  			panic("Got a non scalar result!")
    60  		}
    61  		pc, _, _, _ := runtime.Caller(1)
    62  		log.Printf("Better watch it: %v called with a Scalar tensor", runtime.FuncForPC(pc).Name())
    63  		return vt.ScalarValue().(float64)
    64  	}
    65  	panic(fmt.Sprintf("Unhandled types! Got %v of %T instead", v, v))
    66  }
    67  
    68  func extractF32s(v Value) []float32 {
    69  	return v.Data().([]float32)
    70  }
    71  
    72  func extractF32(v Value) float32 {
    73  	switch vt := v.(type) {
    74  	case *F32:
    75  		return float32(*vt)
    76  	case tensor.Tensor:
    77  		if !vt.IsScalar() {
    78  			panic("Got a non scalar result!")
    79  		}
    80  		pc, _, _, _ := runtime.Caller(1)
    81  		log.Printf("Better watch it: %v called with a Scalar tensor", runtime.FuncForPC(pc).Name())
    82  		return vt.ScalarValue().(float32)
    83  	}
    84  	panic(fmt.Sprintf("Unhandled types! Got %v of %T instead", v, v))
    85  }
    86  
    87  func f64sTof32s(f []float64) []float32 {
    88  	retVal := make([]float32, len(f))
    89  	for i, v := range f {
    90  		retVal[i] = float32(v)
    91  	}
    92  	return retVal
    93  }
    94  
    95  func simpleMatEqn() (g *ExprGraph, x, y, z *Node) {
    96  	g = NewGraph()
    97  	x = NewMatrix(g, Float64, WithName("x"), WithShape(2, 2))
    98  	y = NewMatrix(g, Float64, WithName("y"), WithShape(2, 2))
    99  	z = Must(Add(x, y))
   100  	return
   101  }
   102  
   103  func simpleVecEqn() (g *ExprGraph, x, y, z *Node) {
   104  	g = NewGraph()
   105  	x = NewVector(g, Float64, WithName("x"), WithShape(2))
   106  	y = NewVector(g, Float64, WithName("y"), WithShape(2))
   107  	z = Must(Add(x, y))
   108  	return
   109  }
   110  
   111  func simpleEqn() (g *ExprGraph, x, y, z *Node) {
   112  	g = NewGraph()
   113  	x = NewScalar(g, Float64, WithName("x"))
   114  	y = NewScalar(g, Float64, WithName("y"))
   115  	z = Must(Add(x, y))
   116  	return
   117  }
   118  
   119  func simpleUnaryEqn() (g *ExprGraph, x, y *Node) {
   120  	g = NewGraph()
   121  	x = NewScalar(g, Float64, WithName("x"))
   122  	y = Must(Square(x))
   123  	return
   124  }
   125  
   126  func simpleUnaryVecEqn() (g *ExprGraph, x, y *Node) {
   127  	g = NewGraph()
   128  	x = NewVector(g, Float64, WithName("x"), WithShape(2))
   129  	y = Must(Square(x))
   130  	return
   131  }
   132  
   133  type malformed struct{}
   134  
   135  func (t malformed) Name() string                   { return "malformed" }
   136  func (t malformed) Format(state fmt.State, c rune) { fmt.Fprintf(state, "malformed") }
   137  func (t malformed) String() string                 { return "malformed" }
   138  func (t malformed) Apply(hm.Subs) hm.Substitutable { return t }
   139  func (t malformed) FreeTypeVar() hm.TypeVarSet     { return nil }
   140  func (t malformed) Eq(hm.Type) bool                { return false }
   141  func (t malformed) Types() hm.Types                { return nil }
   142  func (t malformed) Normalize(a, b hm.TypeVarSet) (hm.Type, error) {
   143  	return nil, errors.Errorf("cannot normalize malformed")
   144  }
   145  
   146  type assertState struct {
   147  	*assert.Assertions
   148  	cont bool
   149  }
   150  
   151  func newAssertState(a *assert.Assertions) *assertState { return &assertState{a, true} }
   152  
   153  func (a *assertState) Equal(expected interface{}, actual interface{}, msgAndArgs ...interface{}) {
   154  	if !a.cont {
   155  		return
   156  	}
   157  	a.cont = a.Assertions.Equal(expected, actual, msgAndArgs...)
   158  }
   159  
   160  func (a *assertState) True(value bool, msgAndArgs ...interface{}) {
   161  	if !a.cont {
   162  		return
   163  	}
   164  	a.cont = a.Assertions.True(value, msgAndArgs...)
   165  }
   166  
   167  func checkErr(t *testing.T, expected bool, err error, name string, id interface{}) (cont bool) {
   168  	switch {
   169  	case expected:
   170  		if err == nil {
   171  			t.Errorf("Expected error in test %v (%v)", name, id)
   172  		}
   173  		return true
   174  	case !expected && err != nil:
   175  		t.Errorf("Test %v (%v) errored: %+v", name, id, err)
   176  		return true
   177  	}
   178  	return false
   179  }
   180  
   181  func deepNodeEq(a, b *Node) bool {
   182  	if a == b {
   183  		return true
   184  	}
   185  
   186  	if a.isInput() {
   187  		if !b.isInput() {
   188  			return false
   189  		}
   190  
   191  		if a.name != b.name {
   192  			return false
   193  		}
   194  		if !ValueEq(a.boundTo, b.boundTo) {
   195  			return false
   196  		}
   197  		return true
   198  	}
   199  
   200  	if b.isInput() {
   201  		return false
   202  	}
   203  
   204  	if a.name != b.name {
   205  		return false
   206  	}
   207  
   208  	if a.group != b.group {
   209  		return false
   210  	}
   211  
   212  	if a.id != b.id {
   213  		return false
   214  	}
   215  
   216  	if a.hash != b.hash {
   217  		return false
   218  	}
   219  
   220  	if a.hashed != b.hashed {
   221  		return false
   222  	}
   223  
   224  	if a.inferredShape != b.inferredShape {
   225  		return false
   226  	}
   227  
   228  	if a.unchanged != b.unchanged {
   229  		return false
   230  	}
   231  
   232  	if a.isStmt != b.isStmt {
   233  		return false
   234  	}
   235  
   236  	if a.ofInterest != b.ofInterest {
   237  		return false
   238  	}
   239  
   240  	if a.dataOn != b.dataOn {
   241  		return false
   242  	}
   243  
   244  	if !a.t.Eq(b.t) {
   245  		return false
   246  	}
   247  	if !a.shape.Eq(b.shape) {
   248  		return false
   249  	}
   250  
   251  	if a.op.Hashcode() != b.op.Hashcode() {
   252  		return false
   253  	}
   254  
   255  	if !ValueEq(a.boundTo, b.boundTo) {
   256  		return false
   257  	}
   258  
   259  	if len(a.children) != len(b.children) {
   260  		return false
   261  	}
   262  
   263  	if len(a.derivOf) != len(b.derivOf) {
   264  		return false
   265  	}
   266  
   267  	if a.deriv != nil {
   268  		if b.deriv == nil {
   269  			return false
   270  		}
   271  		if a.deriv.Hashcode() != b.deriv.Hashcode() {
   272  			return false
   273  		}
   274  	}
   275  
   276  	for i, c := range a.children {
   277  		if c.Hashcode() != b.children[i].Hashcode() {
   278  			return false
   279  		}
   280  	}
   281  
   282  	for i, c := range a.derivOf {
   283  		if c.Hashcode() != b.derivOf[i].Hashcode() {
   284  			return false
   285  		}
   286  	}
   287  	return true
   288  }
   289  
   290  // TensorGenerator only generates Dense tensors for now
   291  type TensorGenerator struct {
   292  	ShapeConstraint tensor.Shape // [0, 6, 0] implies that the second dimension is the constraint. 0 is any.
   293  	DtypeConstraint tensor.Dtype
   294  }
   295  
   296  func (g TensorGenerator) Generate(r *rand.Rand, size int) reflect.Value {
   297  	// shape := g.ShapeConstraint
   298  	// of := g.DtypeConstraint
   299  
   300  	// if g.ShapeConstraint == nil {
   301  	// 	// generate
   302  	// } else {
   303  	// 	// generate for 0s in constraints
   304  	// }
   305  
   306  	// if g.DtypeConstraint == (tensor.Dtype{}) {
   307  	// 	of = g.DtypeConstraint
   308  	// }
   309  	var retVal Value
   310  
   311  	return reflect.ValueOf(retVal)
   312  }
   313  
   314  type ValueGenerator struct {
   315  	ShapeConstraint tensor.Shape // [0, 6, 0] implies that the second dimension is the constraint. 0 is any.
   316  	DtypeConstraint tensor.Dtype
   317  }
   318  
   319  func (g ValueGenerator) Generate(r *rand.Rand, size int) reflect.Value {
   320  	// generate scalar or tensor
   321  	ri := r.Intn(2)
   322  	if ri == 0 {
   323  		gen := TensorGenerator{
   324  			ShapeConstraint: g.ShapeConstraint,
   325  			DtypeConstraint: g.DtypeConstraint,
   326  		}
   327  		return gen.Generate(r, size)
   328  
   329  	}
   330  	var retVal Value
   331  	// of := acceptableDtypes[r.Intn(len(acceptableDtypes))]
   332  
   333  	return reflect.ValueOf(retVal)
   334  }
   335  
   336  type NodeGenerator struct{}
   337  
   338  func (g NodeGenerator) Generate(r *rand.Rand, size int) reflect.Value {
   339  	var n *Node
   340  	return reflect.ValueOf(n)
   341  }