gorgonia.org/gorgonia@v0.9.17/x/vm/node_test.go (about) 1 package xvm 2 3 import ( 4 "context" 5 "errors" 6 "reflect" 7 "testing" 8 9 "gorgonia.org/gorgonia" 10 ) 11 12 func Test_receiveInput(t *testing.T) { 13 cancelCtx, cancel := context.WithCancel(context.Background()) 14 inputC := make(chan ioValue, 0) 15 type args struct { 16 ctx context.Context 17 o *node 18 fn func() 19 } 20 tests := []struct { 21 name string 22 args args 23 want stateFn 24 }{ 25 { 26 "context cancelation", 27 args{ 28 cancelCtx, 29 &node{ 30 inputC: make(chan ioValue, 0), 31 }, 32 nil, 33 }, 34 nil, 35 }, 36 { 37 "bad input value position", 38 args{ 39 context.Background(), 40 &node{ 41 inputC: inputC, 42 inputValues: make([]gorgonia.Value, 1), 43 }, 44 func() { 45 inputC <- struct { 46 pos int 47 v gorgonia.Value 48 }{ 49 pos: 1, 50 v: nil, 51 } 52 }, 53 }, 54 nil, 55 }, 56 { 57 "more value to receive", 58 args{ 59 context.Background(), 60 &node{ 61 inputC: inputC, 62 inputValues: make([]gorgonia.Value, 2), 63 }, 64 func() { 65 inputC <- struct { 66 pos int 67 v gorgonia.Value 68 }{ 69 pos: 0, 70 v: nil, 71 } 72 }, 73 }, 74 receiveInput, 75 }, 76 { 77 "no input chan go to conpute", 78 args{ 79 context.Background(), 80 &node{ 81 inputValues: make([]gorgonia.Value, 1), 82 }, 83 nil, 84 }, 85 computeFwd, 86 }, 87 { 88 "all done go to compute", 89 args{ 90 context.Background(), 91 &node{ 92 inputC: inputC, 93 inputValues: make([]gorgonia.Value, 1), 94 }, 95 func() { 96 inputC <- struct { 97 pos int 98 v gorgonia.Value 99 }{ 100 pos: 0, 101 v: nil, 102 } 103 }, 104 }, 105 computeFwd, 106 }, 107 // TODO: Add test cases. 108 } 109 cancel() 110 for _, tt := range tests { 111 t.Run(tt.name, func(t *testing.T) { 112 if tt.args.fn != nil { 113 go tt.args.fn() 114 } 115 got := receiveInput(tt.args.ctx, tt.args.o) 116 gotPrt := reflect.ValueOf(got).Pointer() 117 wantPtr := reflect.ValueOf(tt.want).Pointer() 118 if gotPrt != wantPtr { 119 t.Errorf("receiveInput() = %v, want %v", got, tt.want) 120 } 121 }) 122 } 123 } 124 125 func Test_computeFwd(t *testing.T) { 126 type args struct { 127 in0 context.Context 128 n *node 129 } 130 tests := []struct { 131 name string 132 args args 133 want stateFn 134 }{ 135 { 136 "simple no error", 137 args{ 138 nil, 139 &node{ 140 op: &noOpTest{}, 141 inputValues: []gorgonia.Value{nil}, 142 }, 143 }, 144 emitOutput, 145 }, 146 { 147 "simple with error", 148 args{ 149 nil, 150 &node{ 151 op: &noOpTest{err: errors.New("")}, 152 inputValues: []gorgonia.Value{nil}, 153 }, 154 }, 155 nil, 156 }, 157 // TODO: Add test cases. 158 } 159 for _, tt := range tests { 160 t.Run(tt.name, func(t *testing.T) { 161 got := computeFwd(tt.args.in0, tt.args.n) 162 gotPrt := reflect.ValueOf(got).Pointer() 163 wantPtr := reflect.ValueOf(tt.want).Pointer() 164 if gotPrt != wantPtr { 165 t.Errorf("computeFwd() = %v, want %v", got, tt.want) 166 } 167 }) 168 } 169 } 170 171 func Test_node_ComputeForward(t *testing.T) { 172 type fields struct { 173 op gorgonia.Op 174 output gorgonia.Value 175 outputC chan gorgonia.Value 176 receivedValues int 177 err error 178 inputValues []gorgonia.Value 179 inputC chan ioValue 180 } 181 type args struct { 182 ctx context.Context 183 } 184 tests := []struct { 185 name string 186 fields fields 187 args args 188 wantErr bool 189 }{ 190 { 191 "simple", 192 fields{ 193 op: nil, 194 }, 195 args{ 196 nil, 197 }, 198 false, 199 }, 200 // TODO: Add test cases. 201 } 202 for _, tt := range tests { 203 t.Run(tt.name, func(t *testing.T) { 204 n := &node{ 205 op: tt.fields.op, 206 output: tt.fields.output, 207 outputC: tt.fields.outputC, 208 receivedValues: tt.fields.receivedValues, 209 err: tt.fields.err, 210 inputValues: tt.fields.inputValues, 211 inputC: tt.fields.inputC, 212 } 213 if err := n.Compute(tt.args.ctx); (err != nil) != tt.wantErr { 214 t.Errorf("node.ComputeForward() error = %v, wantErr %v", err, tt.wantErr) 215 } 216 }) 217 } 218 } 219 220 type errorOP struct{} 221 222 func (*errorOP) Do(v ...gorgonia.Value) (gorgonia.Value, error) { 223 return nil, errors.New("error") 224 } 225 226 type sumF32 struct{} 227 228 func (*sumF32) Do(v ...gorgonia.Value) (gorgonia.Value, error) { 229 val := v[0].Data().(float32) + v[1].Data().(float32) 230 value := gorgonia.F32(val) 231 return &value, nil 232 } 233 234 func Test_emitOutput(t *testing.T) { 235 cancelCtx, cancel := context.WithCancel(context.Background()) 236 outputC1 := make(chan gorgonia.Value, 0) 237 outputC2 := make(chan gorgonia.Value, 1) 238 type args struct { 239 ctx context.Context 240 n *node 241 } 242 tests := []struct { 243 name string 244 args args 245 want stateFn 246 }{ 247 { 248 "nil node", 249 args{nil, nil}, 250 nil, 251 }, 252 { 253 "context cancelation", 254 args{ 255 cancelCtx, 256 &node{ 257 outputC: outputC1, 258 }, 259 }, 260 nil, 261 }, 262 { 263 "emit output", 264 args{ 265 context.Background(), 266 &node{ 267 outputC: outputC2, 268 }, 269 }, 270 nil, 271 }, 272 } 273 cancel() 274 for _, tt := range tests { 275 t.Run(tt.name, func(t *testing.T) { 276 got := emitOutput(tt.args.ctx, tt.args.n) 277 gotPrt := reflect.ValueOf(got).Pointer() 278 wantPtr := reflect.ValueOf(tt.want).Pointer() 279 if gotPrt != wantPtr { 280 t.Errorf("emitOutput() = %v, want %v", got, tt.want) 281 } 282 }) 283 } 284 } 285 286 func Test_computeBackward(t *testing.T) { 287 type args struct { 288 in0 context.Context 289 in1 *node 290 } 291 tests := []struct { 292 name string 293 args args 294 want stateFn 295 }{ 296 { 297 "simple", 298 args{ 299 nil, 300 nil, 301 }, 302 nil, 303 }, 304 } 305 for _, tt := range tests { 306 t.Run(tt.name, func(t *testing.T) { 307 if got := computeBackward(tt.args.in0, tt.args.in1); !reflect.DeepEqual(got, tt.want) { 308 t.Errorf("computeBackward() = %v, want %v", got, tt.want) 309 } 310 }) 311 } 312 } 313 314 func Test_newOp(t *testing.T) { 315 g := gorgonia.NewGraph() 316 fortyTwo := gorgonia.F32(42.0) 317 n1 := gorgonia.NodeFromAny(g, fortyTwo) 318 n2 := gorgonia.NodeFromAny(g, fortyTwo) 319 addOp, err := gorgonia.Add(n1, n2) 320 if err != nil { 321 t.Fatal(err) 322 } 323 type args struct { 324 n *gorgonia.Node 325 hasOutputChan bool 326 } 327 tests := []struct { 328 name string 329 args args 330 want *node 331 }{ 332 { 333 "no op", 334 args{nil, false}, 335 nil, 336 }, 337 { 338 "add with outputChan", 339 args{addOp, true}, 340 &node{ 341 id: addOp.ID(), 342 op: addOp.Op(), 343 inputC: make(chan ioValue, 0), 344 outputC: make(chan gorgonia.Value, 0), 345 inputValues: make([]gorgonia.Value, 2), 346 }, 347 }, 348 { 349 "add without outputChan", 350 args{addOp, false}, 351 &node{ 352 id: addOp.ID(), 353 op: addOp.Op(), 354 inputC: make(chan ioValue, 0), 355 outputC: nil, 356 inputValues: make([]gorgonia.Value, 2), 357 }, 358 }, 359 } 360 for _, tt := range tests { 361 t.Run(tt.name, func(t *testing.T) { 362 got := newOp(tt.args.n, tt.args.hasOutputChan) 363 if got == tt.want { 364 return 365 } 366 if got.id != tt.want.id { 367 t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want) 368 } 369 if !reflect.DeepEqual(got.op, tt.want.op) { 370 t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want) 371 } 372 if !reflect.DeepEqual(got.inputValues, tt.want.inputValues) { 373 t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want) 374 } 375 if got.receivedValues != tt.want.receivedValues { 376 t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want) 377 } 378 if got.err != tt.want.err { 379 t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want) 380 } 381 if (got.inputC == nil && tt.want.inputC != nil) || 382 (got.inputC != nil && tt.want.inputC == nil) { 383 t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want) 384 } 385 if (got.outputC == nil && tt.want.outputC != nil) || 386 (got.outputC != nil && tt.want.outputC == nil) { 387 t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want) 388 } 389 if cap(got.outputC) != cap(tt.want.outputC) { 390 t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want) 391 } 392 if len(got.outputC) != len(tt.want.outputC) { 393 t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want) 394 } 395 if cap(got.inputC) != cap(tt.want.inputC) { 396 t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want) 397 } 398 if len(got.inputC) != len(tt.want.inputC) { 399 t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want) 400 } 401 402 }) 403 } 404 } 405 406 func Test_newInput(t *testing.T) { 407 g := gorgonia.NewGraph() 408 fortyTwo := gorgonia.F32(42.0) 409 n1 := gorgonia.NodeFromAny(g, &fortyTwo) 410 type args struct { 411 n *gorgonia.Node 412 } 413 tests := []struct { 414 name string 415 args args 416 want *node 417 }{ 418 { 419 "nil", 420 args{nil}, 421 nil, 422 }, 423 { 424 "simple", 425 args{n1}, 426 &node{ 427 outputC: make(chan gorgonia.Value, 0), 428 output: &fortyTwo, 429 }, 430 }, 431 } 432 for _, tt := range tests { 433 t.Run(tt.name, func(t *testing.T) { 434 got := newInput(tt.args.n) 435 if got == tt.want { 436 return 437 } 438 compareNodes(t, got, tt.want) 439 }) 440 } 441 } 442 443 func compareNodes(t *testing.T, got, want *node) { 444 if got.id != want.id { 445 t.Errorf("nodes ID are different = \n%#v, want \n%#v", got.id, want.id) 446 } 447 if !reflect.DeepEqual(got.op, want.op) { 448 t.Errorf("nodes OP are different = \n%#v, want \n%#v", got.op, want.op) 449 } 450 if !reflect.DeepEqual(got.inputValues, want.inputValues) { 451 t.Errorf("nodes inputValues are different = \n%#v, want \n%#v", got.inputValues, want.inputValues) 452 } 453 if got.receivedValues != want.receivedValues { 454 t.Errorf("nodes receivedValues are different = \n%#v, want \n%#v", got.receivedValues, want.receivedValues) 455 } 456 if got.err != want.err { 457 t.Errorf("nodes errors are different = \n%#v, want \n%#v", got.err, want.err) 458 } 459 if (got.inputC == nil && want.inputC != nil) || 460 (got.inputC != nil && want.inputC == nil) { 461 t.Errorf("newInput() = \n%#v, want \n%#v", got, want) 462 } 463 if (got.outputC == nil && want.outputC != nil) || 464 (got.outputC != nil && want.outputC == nil) { 465 t.Errorf("newInput() = \n%#v, want \n%#v", got, want) 466 } 467 if cap(got.outputC) != cap(want.outputC) { 468 t.Errorf("newInput() = \n%#v, want \n%#v", got, want) 469 } 470 if len(got.outputC) != len(want.outputC) { 471 t.Errorf("newInput() = \n%#v, want \n%#v", got, want) 472 } 473 if cap(got.inputC) != cap(want.inputC) { 474 t.Errorf("newInput() = \n%#v, want \n%#v", got, want) 475 } 476 if len(got.inputC) != len(want.inputC) { 477 t.Errorf("newInput() = \n%#v, want \n%#v", got, want) 478 } 479 480 }