gorgonia.org/gorgonia@v0.9.17/x/vm/machine_test.go (about) 1 package xvm 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "log" 8 "reflect" 9 "testing" 10 "time" 11 12 "gorgonia.org/gorgonia" 13 ) 14 15 func TestMachine_runAllNodes(t *testing.T) { 16 inputC1 := make(chan ioValue, 0) 17 outputC1 := make(chan gorgonia.Value, 1) 18 inputC2 := make(chan ioValue, 0) 19 outputC2 := make(chan gorgonia.Value, 1) 20 21 n1 := &node{ 22 op: &sumF32{}, 23 inputValues: make([]gorgonia.Value, 2), 24 outputC: outputC1, 25 inputC: inputC1, 26 } 27 n2 := &node{ 28 op: &sumF32{}, 29 inputValues: make([]gorgonia.Value, 2), 30 outputC: outputC2, 31 inputC: inputC2, 32 } 33 errNode1 := &node{ 34 op: &errorOP{}, 35 inputValues: make([]gorgonia.Value, 2), 36 outputC: outputC2, 37 inputC: inputC2, 38 } 39 type fields struct { 40 nodes []*node 41 pubsubs *pubsub 42 } 43 type args struct { 44 ctx context.Context 45 } 46 tests := []struct { 47 name string 48 fields fields 49 args args 50 wantErr bool 51 }{ 52 { 53 "simple", 54 fields{ 55 nodes: []*node{n1, n2}, 56 }, 57 args{ 58 context.Background(), 59 }, 60 false, 61 }, 62 { 63 "error", 64 fields{ 65 nodes: []*node{n1, errNode1}, 66 }, 67 args{ 68 context.Background(), 69 }, 70 true, 71 }, 72 } 73 for _, tt := range tests { 74 forty := gorgonia.F32(40.0) 75 fortyTwo := gorgonia.F32(42.0) 76 two := gorgonia.F32(2.0) 77 t.Run(tt.name, func(t *testing.T) { 78 m := &Machine{ 79 nodes: tt.fields.nodes, 80 pubsub: tt.fields.pubsubs, 81 } 82 go func() { 83 inputC1 <- struct { 84 pos int 85 v gorgonia.Value 86 }{ 87 0, 88 &forty, 89 } 90 inputC1 <- struct { 91 pos int 92 v gorgonia.Value 93 }{ 94 1, 95 &two, 96 } 97 inputC2 <- struct { 98 pos int 99 v gorgonia.Value 100 }{ 101 0, 102 &forty, 103 } 104 inputC2 <- struct { 105 pos int 106 v gorgonia.Value 107 }{ 108 1, 109 &two, 110 } 111 }() 112 if err := m.runAllNodes(tt.args.ctx); (err != nil) != tt.wantErr { 113 t.Errorf("Machine.runAllNodes() error = %v, wantErr %v", err, tt.wantErr) 114 } 115 if tt.wantErr { 116 return 117 } 118 out1 := <-outputC1 119 out2 := <-outputC2 120 if !reflect.DeepEqual(out1.Data(), fortyTwo.Data()) { 121 t.Errorf("out1: bad result, expected %v, got %v", fortyTwo, out1) 122 } 123 if !reflect.DeepEqual(out2.Data(), fortyTwo.Data()) { 124 t.Errorf("out2: bad result, expected %v, got %v", fortyTwo, out2) 125 } 126 }) 127 } 128 } 129 130 func TestNewMachine(t *testing.T) { 131 g := gorgonia.NewGraph() 132 forty := gorgonia.F32(40.0) 133 //fortyTwo := gorgonia.F32(42.0) 134 two := gorgonia.F32(2.0) 135 n1 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&forty), gorgonia.WithName("n1")) 136 n2 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&two), gorgonia.WithName("n2")) 137 138 added, err := gorgonia.Add(n1, n2) 139 if err != nil { 140 t.Fatal(err) 141 } 142 i1 := newInput(n1) 143 i2 := newInput(n2) 144 op := newOp(added, false) 145 gg := gorgonia.NewGraph() 146 c1 := gorgonia.NewConstant(&forty) 147 ic1 := newInput(c1) 148 ic1.id = 0 149 gg.AddNode(c1) 150 type args struct { 151 g *gorgonia.ExprGraph 152 } 153 tests := []struct { 154 name string 155 args args 156 want *Machine 157 }{ 158 { 159 "nil graph", 160 args{nil}, 161 nil, 162 }, 163 { 164 "simple graph WIP", 165 args{ 166 g, 167 }, 168 &Machine{ 169 nodes: []*node{ 170 i1, i2, op, 171 }, 172 }, 173 }, 174 { 175 "constant (arity 0)", 176 args{ 177 gg, 178 }, 179 &Machine{ 180 nodes: []*node{ 181 ic1, 182 }, 183 }, 184 }, 185 } 186 for _, tt := range tests { 187 t.Run(tt.name, func(t *testing.T) { 188 got := NewMachine(tt.args.g) 189 if got == nil && tt.want == nil { 190 return 191 } 192 if got == nil && tt.want != nil || 193 got != nil && tt.want == nil { 194 t.Fatalf("NewMachine() = %v, want %v", got, tt.want) 195 } 196 if tt.want.nodes == nil && got.nodes != nil || 197 tt.want.nodes != nil && got.nodes == nil { 198 t.Fatalf("NewMachine(nodes) = %v, want %v", got, tt.want) 199 } 200 if len(got.nodes) != len(tt.want.nodes) { 201 t.Fatalf("bad number of nodes, expecting %v, got %v", len(tt.want.nodes), len(got.nodes)) 202 } 203 for i := 0; i < len(got.nodes); i++ { 204 compareNodes(t, got.nodes[i], tt.want.nodes[i]) 205 } 206 /* 207 if tt.want.pubsubs == nil && got.pubsubs != nil || 208 tt.want.pubsubs != nil && got.pubsubs == nil { 209 t.Fatalf("NewMachine(pubsubs) = %v, want %v", got, tt.want) 210 } 211 if !reflect.DeepEqual(got.pubsubs, tt.want.pubsubs) { 212 t.Fatalf("bad pubsubs, expecting %v, got %v", tt.want.pubsubs, got.pubsubs) 213 } 214 */ 215 }) 216 } 217 } 218 219 func Test_createHub(t *testing.T) { 220 type args struct { 221 ns []*node 222 g *gorgonia.ExprGraph 223 } 224 tests := []struct { 225 name string 226 args args 227 want []*pubsub 228 }{ 229 // TODO: Add test cases. 230 } 231 for _, tt := range tests { 232 t.Run(tt.name, func(t *testing.T) { 233 if got := createNetwork(tt.args.ns, tt.args.g); !reflect.DeepEqual(got, tt.want) { 234 t.Errorf("createHub() = %v, want %v", got, tt.want) 235 } 236 }) 237 } 238 } 239 240 func TestMachine_Close(t *testing.T) { 241 c0 := make(chan gorgonia.Value, 0) 242 c1 := make(chan gorgonia.Value, 0) 243 c2 := make(chan gorgonia.Value, 0) 244 c3 := make(chan gorgonia.Value, 0) 245 c4 := make(chan gorgonia.Value, 0) 246 c5 := make(chan gorgonia.Value, 0) 247 i0 := make(chan ioValue, 0) 248 i1 := make(chan ioValue, 0) 249 ps := &pubsub{ 250 publishers: []*publisher{ 251 { 252 publisher: c0, 253 subscribers: []chan<- gorgonia.Value{c1, c2}, 254 }, 255 { 256 publisher: c3, 257 subscribers: []chan<- gorgonia.Value{c1, c2}, 258 }, 259 }, 260 subscribers: []*subscriber{ 261 { 262 subscriber: i0, 263 publishers: []<-chan gorgonia.Value{c3, c2}, 264 }, 265 { 266 subscriber: i0, 267 publishers: []<-chan gorgonia.Value{c4, c5}, 268 }, 269 { 270 subscriber: i1, 271 publishers: []<-chan gorgonia.Value{c4, c5}, 272 }, 273 }, 274 } 275 type fields struct { 276 nodes []*node 277 pubsubs *pubsub 278 } 279 tests := []struct { 280 name string 281 fields fields 282 }{ 283 { 284 "simple", 285 fields{ 286 pubsubs: ps, 287 }, 288 }, 289 } 290 for _, tt := range tests { 291 t.Run(tt.name, func(t *testing.T) { 292 m := &Machine{ 293 nodes: tt.fields.nodes, 294 pubsub: tt.fields.pubsubs, 295 } 296 m.Close() 297 }) 298 } 299 } 300 301 func Test_createNetwork(t *testing.T) { 302 g := gorgonia.NewGraph() 303 forty := gorgonia.F32(40.0) 304 //fortyTwo := gorgonia.F32(42.0) 305 two := gorgonia.F32(2.0) 306 n1 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&forty), gorgonia.WithName("n1")) 307 n2 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&two), gorgonia.WithName("n2")) 308 309 added, err := gorgonia.Add(n1, n2) 310 if err != nil { 311 t.Fatal(err) 312 } 313 i1 := newInput(n1) 314 i2 := newInput(n2) 315 op := newOp(added, false) 316 317 type args struct { 318 ns []*node 319 g *gorgonia.ExprGraph 320 } 321 tests := []struct { 322 name string 323 args args 324 want *pubsub 325 }{ 326 { 327 "simple add operation", 328 args{ 329 ns: []*node{i1, i2, op}, 330 g: g, 331 }, 332 &pubsub{ 333 publishers: []*publisher{ 334 { 335 id: 0, 336 }, 337 { 338 id: 1, 339 }, 340 }, 341 subscribers: []*subscriber{ 342 { 343 id: 2, 344 subscriber: make(chan ioValue, 0), 345 }, 346 }, 347 }, 348 }, 349 } 350 for _, tt := range tests { 351 t.Run(tt.name, func(t *testing.T) { 352 got := createNetwork(tt.args.ns, tt.args.g) 353 if got == nil && tt.want != nil { 354 t.Fail() 355 } 356 if got != nil && tt.want == nil { 357 t.Fail() 358 } 359 if got == nil && tt.want == nil { 360 return 361 } 362 if got.publishers != nil && tt.want.publishers != nil { 363 if len(got.publishers) != len(tt.want.publishers) { 364 t.Errorf("bad number of publishers, expected %v, got %v", len(tt.want.publishers), len(got.publishers)) 365 } 366 } 367 if got.subscribers != nil && tt.want.subscribers != nil { 368 if len(got.subscribers) != len(tt.want.subscribers) { 369 t.Errorf("bad number of subscribers, expected %v, got %v", len(tt.want.subscribers), len(got.subscribers)) 370 } 371 } 372 for i := range tt.want.publishers { 373 want := tt.want.publishers[i] 374 got := got.publishers[i] 375 if want.id != got.id { 376 t.Errorf("bad subscriber id, expected %v, got %v", want.id, got.id) 377 } 378 } 379 for i := range tt.want.subscribers { 380 want := tt.want.subscribers[i] 381 got := got.subscribers[i] 382 if want.id != got.id { 383 t.Errorf("bad subscriber id, expected %v, got %v", want.id, got.id) 384 } 385 } 386 }) 387 } 388 } 389 390 func ExampleMachine_Run() { 391 g := gorgonia.NewGraph() 392 forty := gorgonia.F32(40.0) 393 //fortyTwo := gorgonia.F32(42.0) 394 two := gorgonia.F32(2.0) 395 n1 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&forty), gorgonia.WithName("n1")) 396 n2 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&two), gorgonia.WithName("n2")) 397 398 added, err := gorgonia.Add(n1, n2) 399 if err != nil { 400 log.Fatal(err) 401 } 402 machine := NewMachine(g) 403 ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond) 404 defer cancel() 405 defer machine.Close() 406 err = machine.Run(ctx) 407 if err != nil { 408 log.Fatal(err) 409 } 410 fmt.Println(machine.GetResult(added.ID())) 411 // output: 42 412 } 413 414 func TestMachine_Run(t *testing.T) { 415 g := gorgonia.NewGraph() 416 forty := gorgonia.F32(40.0) 417 //fortyTwo := gorgonia.F32(42.0) 418 two := gorgonia.F32(2.0) 419 n1 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&forty), gorgonia.WithName("n1")) 420 n2 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&two), gorgonia.WithName("n2")) 421 422 added, err := gorgonia.Add(n1, n2) 423 if err != nil { 424 t.Fatal(err) 425 } 426 i1 := newInput(n1) 427 i2 := newInput(n2) 428 op := newOp(added, false) 429 c1 := make(chan gorgonia.Value, 0) 430 c2 := make(chan gorgonia.Value, 0) 431 type fields struct { 432 nodes []*node 433 pubsubs *pubsub 434 } 435 type args struct { 436 ctx context.Context 437 } 438 tests := []struct { 439 name string 440 fields fields 441 args args 442 wantErr bool 443 }{ 444 { 445 "simple", 446 fields{ 447 nodes: []*node{i1, i2, op}, 448 pubsubs: &pubsub{ 449 publishers: []*publisher{ 450 { 451 id: i1.id, 452 publisher: i1.outputC, 453 subscribers: []chan<- gorgonia.Value{ 454 c1, 455 }, 456 }, 457 { 458 id: i2.id, 459 publisher: i2.outputC, 460 subscribers: []chan<- gorgonia.Value{ 461 c2, 462 }, 463 }, 464 }, 465 subscribers: []*subscriber{ 466 { 467 id: op.id, 468 publishers: []<-chan gorgonia.Value{ 469 c1, c2, 470 }, 471 subscriber: op.inputC, 472 }, 473 }, 474 }, 475 }, 476 args{ 477 context.Background(), 478 }, 479 false, 480 }, 481 // TODO: Add test cases. 482 } 483 for _, tt := range tests { 484 t.Run(tt.name, func(t *testing.T) { 485 m := &Machine{ 486 nodes: tt.fields.nodes, 487 pubsub: tt.fields.pubsubs, 488 } 489 err := m.Run(tt.args.ctx) 490 if (err != nil) != tt.wantErr { 491 t.Errorf("Machine.Run() error = %v, wantErr %v", err, tt.wantErr) 492 } 493 }) 494 } 495 } 496 497 func TestMachine_GetResult(t *testing.T) { 498 fortyTwo := gorgonia.F32(42.0) 499 type fields struct { 500 nodes []*node 501 pubsubs *pubsub 502 } 503 type args struct { 504 id int64 505 } 506 tests := []struct { 507 name string 508 fields fields 509 args args 510 want gorgonia.Value 511 }{ 512 { 513 "nil", 514 fields{ 515 nodes: []*node{ 516 { 517 id: 1, 518 output: &fortyTwo, 519 }, 520 }, 521 }, 522 args{ 523 2, 524 }, 525 nil, 526 }, 527 { 528 "simple", 529 fields{ 530 nodes: []*node{ 531 { 532 id: 1, 533 output: &fortyTwo, 534 }, 535 }, 536 }, 537 args{ 538 1, 539 }, 540 &fortyTwo, 541 }, 542 } 543 for _, tt := range tests { 544 t.Run(tt.name, func(t *testing.T) { 545 m := &Machine{ 546 nodes: tt.fields.nodes, 547 pubsub: tt.fields.pubsubs, 548 } 549 if got := m.GetResult(tt.args.id); !reflect.DeepEqual(got, tt.want) { 550 t.Errorf("Machine.GetResult() = %v, want %v", got, tt.want) 551 } 552 }) 553 } 554 } 555 556 func Test_nodeErrors_Error(t *testing.T) { 557 tests := []struct { 558 name string 559 e nodeErrors 560 want string 561 }{ 562 { 563 "simple", 564 []nodeError{ 565 { 566 id: 0, 567 err: errors.New("error"), 568 }, 569 }, 570 "0:error\n", 571 }, 572 } 573 for _, tt := range tests { 574 t.Run(tt.name, func(t *testing.T) { 575 if got := tt.e.Error(); got != tt.want { 576 t.Errorf("nodeErrors.Error() = %v, want %v", got, tt.want) 577 } 578 }) 579 } 580 }