gorgonia.org/gorgonia@v0.9.17/testsetup_test.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "log" 6 "math/rand" 7 "reflect" 8 "runtime" 9 10 "github.com/chewxy/hm" 11 "github.com/pkg/errors" 12 "github.com/stretchr/testify/assert" 13 "gorgonia.org/dawson" 14 "gorgonia.org/tensor" 15 16 "testing" 17 ) 18 19 type errorStacker interface { 20 ErrorStack() string 21 } 22 23 func floatsEqual64(a, b []float64) bool { 24 if len(a) != len(b) { 25 return false 26 } 27 28 for i, v := range a { 29 if !dawson.CloseF64(v, b[i]) { 30 return false 31 } 32 } 33 return true 34 } 35 36 func floatsEqual32(a, b []float32) bool { 37 if len(a) != len(b) { 38 return false 39 } 40 41 for i, v := range a { 42 if !dawson.CloseF32(v, b[i]) { 43 return false 44 } 45 } 46 return true 47 } 48 49 func extractF64s(v Value) []float64 { 50 return v.Data().([]float64) 51 } 52 53 func extractF64(v Value) float64 { 54 switch vt := v.(type) { 55 case *F64: 56 return float64(*vt) 57 case tensor.Tensor: 58 if !vt.IsScalar() { 59 panic("Got a non scalar result!") 60 } 61 pc, _, _, _ := runtime.Caller(1) 62 log.Printf("Better watch it: %v called with a Scalar tensor", runtime.FuncForPC(pc).Name()) 63 return vt.ScalarValue().(float64) 64 } 65 panic(fmt.Sprintf("Unhandled types! Got %v of %T instead", v, v)) 66 } 67 68 func extractF32s(v Value) []float32 { 69 return v.Data().([]float32) 70 } 71 72 func extractF32(v Value) float32 { 73 switch vt := v.(type) { 74 case *F32: 75 return float32(*vt) 76 case tensor.Tensor: 77 if !vt.IsScalar() { 78 panic("Got a non scalar result!") 79 } 80 pc, _, _, _ := runtime.Caller(1) 81 log.Printf("Better watch it: %v called with a Scalar tensor", runtime.FuncForPC(pc).Name()) 82 return vt.ScalarValue().(float32) 83 } 84 panic(fmt.Sprintf("Unhandled types! Got %v of %T instead", v, v)) 85 } 86 87 func f64sTof32s(f []float64) []float32 { 88 retVal := make([]float32, len(f)) 89 for i, v := range f { 90 retVal[i] = float32(v) 91 } 92 return retVal 93 } 94 95 func simpleMatEqn() (g *ExprGraph, x, y, z *Node) { 96 g = NewGraph() 97 x = NewMatrix(g, Float64, WithName("x"), WithShape(2, 2)) 98 y = NewMatrix(g, Float64, WithName("y"), WithShape(2, 2)) 99 z = Must(Add(x, y)) 100 return 101 } 102 103 func simpleVecEqn() (g *ExprGraph, x, y, z *Node) { 104 g = NewGraph() 105 x = NewVector(g, Float64, WithName("x"), WithShape(2)) 106 y = NewVector(g, Float64, WithName("y"), WithShape(2)) 107 z = Must(Add(x, y)) 108 return 109 } 110 111 func simpleEqn() (g *ExprGraph, x, y, z *Node) { 112 g = NewGraph() 113 x = NewScalar(g, Float64, WithName("x")) 114 y = NewScalar(g, Float64, WithName("y")) 115 z = Must(Add(x, y)) 116 return 117 } 118 119 func simpleUnaryEqn() (g *ExprGraph, x, y *Node) { 120 g = NewGraph() 121 x = NewScalar(g, Float64, WithName("x")) 122 y = Must(Square(x)) 123 return 124 } 125 126 func simpleUnaryVecEqn() (g *ExprGraph, x, y *Node) { 127 g = NewGraph() 128 x = NewVector(g, Float64, WithName("x"), WithShape(2)) 129 y = Must(Square(x)) 130 return 131 } 132 133 type malformed struct{} 134 135 func (t malformed) Name() string { return "malformed" } 136 func (t malformed) Format(state fmt.State, c rune) { fmt.Fprintf(state, "malformed") } 137 func (t malformed) String() string { return "malformed" } 138 func (t malformed) Apply(hm.Subs) hm.Substitutable { return t } 139 func (t malformed) FreeTypeVar() hm.TypeVarSet { return nil } 140 func (t malformed) Eq(hm.Type) bool { return false } 141 func (t malformed) Types() hm.Types { return nil } 142 func (t malformed) Normalize(a, b hm.TypeVarSet) (hm.Type, error) { 143 return nil, errors.Errorf("cannot normalize malformed") 144 } 145 146 type assertState struct { 147 *assert.Assertions 148 cont bool 149 } 150 151 func newAssertState(a *assert.Assertions) *assertState { return &assertState{a, true} } 152 153 func (a *assertState) Equal(expected interface{}, actual interface{}, msgAndArgs ...interface{}) { 154 if !a.cont { 155 return 156 } 157 a.cont = a.Assertions.Equal(expected, actual, msgAndArgs...) 158 } 159 160 func (a *assertState) True(value bool, msgAndArgs ...interface{}) { 161 if !a.cont { 162 return 163 } 164 a.cont = a.Assertions.True(value, msgAndArgs...) 165 } 166 167 func checkErr(t *testing.T, expected bool, err error, name string, id interface{}) (cont bool) { 168 switch { 169 case expected: 170 if err == nil { 171 t.Errorf("Expected error in test %v (%v)", name, id) 172 } 173 return true 174 case !expected && err != nil: 175 t.Errorf("Test %v (%v) errored: %+v", name, id, err) 176 return true 177 } 178 return false 179 } 180 181 func deepNodeEq(a, b *Node) bool { 182 if a == b { 183 return true 184 } 185 186 if a.isInput() { 187 if !b.isInput() { 188 return false 189 } 190 191 if a.name != b.name { 192 return false 193 } 194 if !ValueEq(a.boundTo, b.boundTo) { 195 return false 196 } 197 return true 198 } 199 200 if b.isInput() { 201 return false 202 } 203 204 if a.name != b.name { 205 return false 206 } 207 208 if a.group != b.group { 209 return false 210 } 211 212 if a.id != b.id { 213 return false 214 } 215 216 if a.hash != b.hash { 217 return false 218 } 219 220 if a.hashed != b.hashed { 221 return false 222 } 223 224 if a.inferredShape != b.inferredShape { 225 return false 226 } 227 228 if a.unchanged != b.unchanged { 229 return false 230 } 231 232 if a.isStmt != b.isStmt { 233 return false 234 } 235 236 if a.ofInterest != b.ofInterest { 237 return false 238 } 239 240 if a.dataOn != b.dataOn { 241 return false 242 } 243 244 if !a.t.Eq(b.t) { 245 return false 246 } 247 if !a.shape.Eq(b.shape) { 248 return false 249 } 250 251 if a.op.Hashcode() != b.op.Hashcode() { 252 return false 253 } 254 255 if !ValueEq(a.boundTo, b.boundTo) { 256 return false 257 } 258 259 if len(a.children) != len(b.children) { 260 return false 261 } 262 263 if len(a.derivOf) != len(b.derivOf) { 264 return false 265 } 266 267 if a.deriv != nil { 268 if b.deriv == nil { 269 return false 270 } 271 if a.deriv.Hashcode() != b.deriv.Hashcode() { 272 return false 273 } 274 } 275 276 for i, c := range a.children { 277 if c.Hashcode() != b.children[i].Hashcode() { 278 return false 279 } 280 } 281 282 for i, c := range a.derivOf { 283 if c.Hashcode() != b.derivOf[i].Hashcode() { 284 return false 285 } 286 } 287 return true 288 } 289 290 // TensorGenerator only generates Dense tensors for now 291 type TensorGenerator struct { 292 ShapeConstraint tensor.Shape // [0, 6, 0] implies that the second dimension is the constraint. 0 is any. 293 DtypeConstraint tensor.Dtype 294 } 295 296 func (g TensorGenerator) Generate(r *rand.Rand, size int) reflect.Value { 297 // shape := g.ShapeConstraint 298 // of := g.DtypeConstraint 299 300 // if g.ShapeConstraint == nil { 301 // // generate 302 // } else { 303 // // generate for 0s in constraints 304 // } 305 306 // if g.DtypeConstraint == (tensor.Dtype{}) { 307 // of = g.DtypeConstraint 308 // } 309 var retVal Value 310 311 return reflect.ValueOf(retVal) 312 } 313 314 type ValueGenerator struct { 315 ShapeConstraint tensor.Shape // [0, 6, 0] implies that the second dimension is the constraint. 0 is any. 316 DtypeConstraint tensor.Dtype 317 } 318 319 func (g ValueGenerator) Generate(r *rand.Rand, size int) reflect.Value { 320 // generate scalar or tensor 321 ri := r.Intn(2) 322 if ri == 0 { 323 gen := TensorGenerator{ 324 ShapeConstraint: g.ShapeConstraint, 325 DtypeConstraint: g.DtypeConstraint, 326 } 327 return gen.Generate(r, size) 328 329 } 330 var retVal Value 331 // of := acceptableDtypes[r.Intn(len(acceptableDtypes))] 332 333 return reflect.ValueOf(retVal) 334 } 335 336 type NodeGenerator struct{} 337 338 func (g NodeGenerator) Generate(r *rand.Rand, size int) reflect.Value { 339 var n *Node 340 return reflect.ValueOf(n) 341 }