gorgonia.org/gorgonia@v0.9.17/x/vm/machine_test.go (about)

     1  package xvm
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"log"
     8  	"reflect"
     9  	"testing"
    10  	"time"
    11  
    12  	"gorgonia.org/gorgonia"
    13  )
    14  
    15  func TestMachine_runAllNodes(t *testing.T) {
    16  	inputC1 := make(chan ioValue, 0)
    17  	outputC1 := make(chan gorgonia.Value, 1)
    18  	inputC2 := make(chan ioValue, 0)
    19  	outputC2 := make(chan gorgonia.Value, 1)
    20  
    21  	n1 := &node{
    22  		op:          &sumF32{},
    23  		inputValues: make([]gorgonia.Value, 2),
    24  		outputC:     outputC1,
    25  		inputC:      inputC1,
    26  	}
    27  	n2 := &node{
    28  		op:          &sumF32{},
    29  		inputValues: make([]gorgonia.Value, 2),
    30  		outputC:     outputC2,
    31  		inputC:      inputC2,
    32  	}
    33  	errNode1 := &node{
    34  		op:          &errorOP{},
    35  		inputValues: make([]gorgonia.Value, 2),
    36  		outputC:     outputC2,
    37  		inputC:      inputC2,
    38  	}
    39  	type fields struct {
    40  		nodes   []*node
    41  		pubsubs *pubsub
    42  	}
    43  	type args struct {
    44  		ctx context.Context
    45  	}
    46  	tests := []struct {
    47  		name    string
    48  		fields  fields
    49  		args    args
    50  		wantErr bool
    51  	}{
    52  		{
    53  			"simple",
    54  			fields{
    55  				nodes: []*node{n1, n2},
    56  			},
    57  			args{
    58  				context.Background(),
    59  			},
    60  			false,
    61  		},
    62  		{
    63  			"error",
    64  			fields{
    65  				nodes: []*node{n1, errNode1},
    66  			},
    67  			args{
    68  				context.Background(),
    69  			},
    70  			true,
    71  		},
    72  	}
    73  	for _, tt := range tests {
    74  		forty := gorgonia.F32(40.0)
    75  		fortyTwo := gorgonia.F32(42.0)
    76  		two := gorgonia.F32(2.0)
    77  		t.Run(tt.name, func(t *testing.T) {
    78  			m := &Machine{
    79  				nodes:  tt.fields.nodes,
    80  				pubsub: tt.fields.pubsubs,
    81  			}
    82  			go func() {
    83  				inputC1 <- struct {
    84  					pos int
    85  					v   gorgonia.Value
    86  				}{
    87  					0,
    88  					&forty,
    89  				}
    90  				inputC1 <- struct {
    91  					pos int
    92  					v   gorgonia.Value
    93  				}{
    94  					1,
    95  					&two,
    96  				}
    97  				inputC2 <- struct {
    98  					pos int
    99  					v   gorgonia.Value
   100  				}{
   101  					0,
   102  					&forty,
   103  				}
   104  				inputC2 <- struct {
   105  					pos int
   106  					v   gorgonia.Value
   107  				}{
   108  					1,
   109  					&two,
   110  				}
   111  			}()
   112  			if err := m.runAllNodes(tt.args.ctx); (err != nil) != tt.wantErr {
   113  				t.Errorf("Machine.runAllNodes() error = %v, wantErr %v", err, tt.wantErr)
   114  			}
   115  			if tt.wantErr {
   116  				return
   117  			}
   118  			out1 := <-outputC1
   119  			out2 := <-outputC2
   120  			if !reflect.DeepEqual(out1.Data(), fortyTwo.Data()) {
   121  				t.Errorf("out1: bad result, expected %v, got %v", fortyTwo, out1)
   122  			}
   123  			if !reflect.DeepEqual(out2.Data(), fortyTwo.Data()) {
   124  				t.Errorf("out2: bad result, expected %v, got %v", fortyTwo, out2)
   125  			}
   126  		})
   127  	}
   128  }
   129  
   130  func TestNewMachine(t *testing.T) {
   131  	g := gorgonia.NewGraph()
   132  	forty := gorgonia.F32(40.0)
   133  	//fortyTwo := gorgonia.F32(42.0)
   134  	two := gorgonia.F32(2.0)
   135  	n1 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&forty), gorgonia.WithName("n1"))
   136  	n2 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&two), gorgonia.WithName("n2"))
   137  
   138  	added, err := gorgonia.Add(n1, n2)
   139  	if err != nil {
   140  		t.Fatal(err)
   141  	}
   142  	i1 := newInput(n1)
   143  	i2 := newInput(n2)
   144  	op := newOp(added, false)
   145  	gg := gorgonia.NewGraph()
   146  	c1 := gorgonia.NewConstant(&forty)
   147  	ic1 := newInput(c1)
   148  	ic1.id = 0
   149  	gg.AddNode(c1)
   150  	type args struct {
   151  		g *gorgonia.ExprGraph
   152  	}
   153  	tests := []struct {
   154  		name string
   155  		args args
   156  		want *Machine
   157  	}{
   158  		{
   159  			"nil graph",
   160  			args{nil},
   161  			nil,
   162  		},
   163  		{
   164  			"simple graph WIP",
   165  			args{
   166  				g,
   167  			},
   168  			&Machine{
   169  				nodes: []*node{
   170  					i1, i2, op,
   171  				},
   172  			},
   173  		},
   174  		{
   175  			"constant (arity 0)",
   176  			args{
   177  				gg,
   178  			},
   179  			&Machine{
   180  				nodes: []*node{
   181  					ic1,
   182  				},
   183  			},
   184  		},
   185  	}
   186  	for _, tt := range tests {
   187  		t.Run(tt.name, func(t *testing.T) {
   188  			got := NewMachine(tt.args.g)
   189  			if got == nil && tt.want == nil {
   190  				return
   191  			}
   192  			if got == nil && tt.want != nil ||
   193  				got != nil && tt.want == nil {
   194  				t.Fatalf("NewMachine() = %v, want %v", got, tt.want)
   195  			}
   196  			if tt.want.nodes == nil && got.nodes != nil ||
   197  				tt.want.nodes != nil && got.nodes == nil {
   198  				t.Fatalf("NewMachine(nodes) = %v, want %v", got, tt.want)
   199  			}
   200  			if len(got.nodes) != len(tt.want.nodes) {
   201  				t.Fatalf("bad number of nodes, expecting %v, got %v", len(tt.want.nodes), len(got.nodes))
   202  			}
   203  			for i := 0; i < len(got.nodes); i++ {
   204  				compareNodes(t, got.nodes[i], tt.want.nodes[i])
   205  			}
   206  			/*
   207  				if tt.want.pubsubs == nil && got.pubsubs != nil ||
   208  					tt.want.pubsubs != nil && got.pubsubs == nil {
   209  					t.Fatalf("NewMachine(pubsubs) = %v, want %v", got, tt.want)
   210  				}
   211  				if !reflect.DeepEqual(got.pubsubs, tt.want.pubsubs) {
   212  					t.Fatalf("bad pubsubs, expecting %v, got %v", tt.want.pubsubs, got.pubsubs)
   213  				}
   214  			*/
   215  		})
   216  	}
   217  }
   218  
   219  func Test_createHub(t *testing.T) {
   220  	type args struct {
   221  		ns []*node
   222  		g  *gorgonia.ExprGraph
   223  	}
   224  	tests := []struct {
   225  		name string
   226  		args args
   227  		want []*pubsub
   228  	}{
   229  		// TODO: Add test cases.
   230  	}
   231  	for _, tt := range tests {
   232  		t.Run(tt.name, func(t *testing.T) {
   233  			if got := createNetwork(tt.args.ns, tt.args.g); !reflect.DeepEqual(got, tt.want) {
   234  				t.Errorf("createHub() = %v, want %v", got, tt.want)
   235  			}
   236  		})
   237  	}
   238  }
   239  
   240  func TestMachine_Close(t *testing.T) {
   241  	c0 := make(chan gorgonia.Value, 0)
   242  	c1 := make(chan gorgonia.Value, 0)
   243  	c2 := make(chan gorgonia.Value, 0)
   244  	c3 := make(chan gorgonia.Value, 0)
   245  	c4 := make(chan gorgonia.Value, 0)
   246  	c5 := make(chan gorgonia.Value, 0)
   247  	i0 := make(chan ioValue, 0)
   248  	i1 := make(chan ioValue, 0)
   249  	ps := &pubsub{
   250  		publishers: []*publisher{
   251  			{
   252  				publisher:   c0,
   253  				subscribers: []chan<- gorgonia.Value{c1, c2},
   254  			},
   255  			{
   256  				publisher:   c3,
   257  				subscribers: []chan<- gorgonia.Value{c1, c2},
   258  			},
   259  		},
   260  		subscribers: []*subscriber{
   261  			{
   262  				subscriber: i0,
   263  				publishers: []<-chan gorgonia.Value{c3, c2},
   264  			},
   265  			{
   266  				subscriber: i0,
   267  				publishers: []<-chan gorgonia.Value{c4, c5},
   268  			},
   269  			{
   270  				subscriber: i1,
   271  				publishers: []<-chan gorgonia.Value{c4, c5},
   272  			},
   273  		},
   274  	}
   275  	type fields struct {
   276  		nodes   []*node
   277  		pubsubs *pubsub
   278  	}
   279  	tests := []struct {
   280  		name   string
   281  		fields fields
   282  	}{
   283  		{
   284  			"simple",
   285  			fields{
   286  				pubsubs: ps,
   287  			},
   288  		},
   289  	}
   290  	for _, tt := range tests {
   291  		t.Run(tt.name, func(t *testing.T) {
   292  			m := &Machine{
   293  				nodes:  tt.fields.nodes,
   294  				pubsub: tt.fields.pubsubs,
   295  			}
   296  			m.Close()
   297  		})
   298  	}
   299  }
   300  
   301  func Test_createNetwork(t *testing.T) {
   302  	g := gorgonia.NewGraph()
   303  	forty := gorgonia.F32(40.0)
   304  	//fortyTwo := gorgonia.F32(42.0)
   305  	two := gorgonia.F32(2.0)
   306  	n1 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&forty), gorgonia.WithName("n1"))
   307  	n2 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&two), gorgonia.WithName("n2"))
   308  
   309  	added, err := gorgonia.Add(n1, n2)
   310  	if err != nil {
   311  		t.Fatal(err)
   312  	}
   313  	i1 := newInput(n1)
   314  	i2 := newInput(n2)
   315  	op := newOp(added, false)
   316  
   317  	type args struct {
   318  		ns []*node
   319  		g  *gorgonia.ExprGraph
   320  	}
   321  	tests := []struct {
   322  		name string
   323  		args args
   324  		want *pubsub
   325  	}{
   326  		{
   327  			"simple add operation",
   328  			args{
   329  				ns: []*node{i1, i2, op},
   330  				g:  g,
   331  			},
   332  			&pubsub{
   333  				publishers: []*publisher{
   334  					{
   335  						id: 0,
   336  					},
   337  					{
   338  						id: 1,
   339  					},
   340  				},
   341  				subscribers: []*subscriber{
   342  					{
   343  						id:         2,
   344  						subscriber: make(chan ioValue, 0),
   345  					},
   346  				},
   347  			},
   348  		},
   349  	}
   350  	for _, tt := range tests {
   351  		t.Run(tt.name, func(t *testing.T) {
   352  			got := createNetwork(tt.args.ns, tt.args.g)
   353  			if got == nil && tt.want != nil {
   354  				t.Fail()
   355  			}
   356  			if got != nil && tt.want == nil {
   357  				t.Fail()
   358  			}
   359  			if got == nil && tt.want == nil {
   360  				return
   361  			}
   362  			if got.publishers != nil && tt.want.publishers != nil {
   363  				if len(got.publishers) != len(tt.want.publishers) {
   364  					t.Errorf("bad number of publishers, expected %v, got %v", len(tt.want.publishers), len(got.publishers))
   365  				}
   366  			}
   367  			if got.subscribers != nil && tt.want.subscribers != nil {
   368  				if len(got.subscribers) != len(tt.want.subscribers) {
   369  					t.Errorf("bad number of subscribers, expected %v, got %v", len(tt.want.subscribers), len(got.subscribers))
   370  				}
   371  			}
   372  			for i := range tt.want.publishers {
   373  				want := tt.want.publishers[i]
   374  				got := got.publishers[i]
   375  				if want.id != got.id {
   376  					t.Errorf("bad subscriber id, expected %v, got %v", want.id, got.id)
   377  				}
   378  			}
   379  			for i := range tt.want.subscribers {
   380  				want := tt.want.subscribers[i]
   381  				got := got.subscribers[i]
   382  				if want.id != got.id {
   383  					t.Errorf("bad subscriber id, expected %v, got %v", want.id, got.id)
   384  				}
   385  			}
   386  		})
   387  	}
   388  }
   389  
   390  func ExampleMachine_Run() {
   391  	g := gorgonia.NewGraph()
   392  	forty := gorgonia.F32(40.0)
   393  	//fortyTwo := gorgonia.F32(42.0)
   394  	two := gorgonia.F32(2.0)
   395  	n1 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&forty), gorgonia.WithName("n1"))
   396  	n2 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&two), gorgonia.WithName("n2"))
   397  
   398  	added, err := gorgonia.Add(n1, n2)
   399  	if err != nil {
   400  		log.Fatal(err)
   401  	}
   402  	machine := NewMachine(g)
   403  	ctx, cancel := context.WithTimeout(context.Background(), 1000*time.Millisecond)
   404  	defer cancel()
   405  	defer machine.Close()
   406  	err = machine.Run(ctx)
   407  	if err != nil {
   408  		log.Fatal(err)
   409  	}
   410  	fmt.Println(machine.GetResult(added.ID()))
   411  	// output: 42
   412  }
   413  
   414  func TestMachine_Run(t *testing.T) {
   415  	g := gorgonia.NewGraph()
   416  	forty := gorgonia.F32(40.0)
   417  	//fortyTwo := gorgonia.F32(42.0)
   418  	two := gorgonia.F32(2.0)
   419  	n1 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&forty), gorgonia.WithName("n1"))
   420  	n2 := gorgonia.NewScalar(g, gorgonia.Float32, gorgonia.WithValue(&two), gorgonia.WithName("n2"))
   421  
   422  	added, err := gorgonia.Add(n1, n2)
   423  	if err != nil {
   424  		t.Fatal(err)
   425  	}
   426  	i1 := newInput(n1)
   427  	i2 := newInput(n2)
   428  	op := newOp(added, false)
   429  	c1 := make(chan gorgonia.Value, 0)
   430  	c2 := make(chan gorgonia.Value, 0)
   431  	type fields struct {
   432  		nodes   []*node
   433  		pubsubs *pubsub
   434  	}
   435  	type args struct {
   436  		ctx context.Context
   437  	}
   438  	tests := []struct {
   439  		name    string
   440  		fields  fields
   441  		args    args
   442  		wantErr bool
   443  	}{
   444  		{
   445  			"simple",
   446  			fields{
   447  				nodes: []*node{i1, i2, op},
   448  				pubsubs: &pubsub{
   449  					publishers: []*publisher{
   450  						{
   451  							id:        i1.id,
   452  							publisher: i1.outputC,
   453  							subscribers: []chan<- gorgonia.Value{
   454  								c1,
   455  							},
   456  						},
   457  						{
   458  							id:        i2.id,
   459  							publisher: i2.outputC,
   460  							subscribers: []chan<- gorgonia.Value{
   461  								c2,
   462  							},
   463  						},
   464  					},
   465  					subscribers: []*subscriber{
   466  						{
   467  							id: op.id,
   468  							publishers: []<-chan gorgonia.Value{
   469  								c1, c2,
   470  							},
   471  							subscriber: op.inputC,
   472  						},
   473  					},
   474  				},
   475  			},
   476  			args{
   477  				context.Background(),
   478  			},
   479  			false,
   480  		},
   481  		// TODO: Add test cases.
   482  	}
   483  	for _, tt := range tests {
   484  		t.Run(tt.name, func(t *testing.T) {
   485  			m := &Machine{
   486  				nodes:  tt.fields.nodes,
   487  				pubsub: tt.fields.pubsubs,
   488  			}
   489  			err := m.Run(tt.args.ctx)
   490  			if (err != nil) != tt.wantErr {
   491  				t.Errorf("Machine.Run() error = %v, wantErr %v", err, tt.wantErr)
   492  			}
   493  		})
   494  	}
   495  }
   496  
   497  func TestMachine_GetResult(t *testing.T) {
   498  	fortyTwo := gorgonia.F32(42.0)
   499  	type fields struct {
   500  		nodes   []*node
   501  		pubsubs *pubsub
   502  	}
   503  	type args struct {
   504  		id int64
   505  	}
   506  	tests := []struct {
   507  		name   string
   508  		fields fields
   509  		args   args
   510  		want   gorgonia.Value
   511  	}{
   512  		{
   513  			"nil",
   514  			fields{
   515  				nodes: []*node{
   516  					{
   517  						id:     1,
   518  						output: &fortyTwo,
   519  					},
   520  				},
   521  			},
   522  			args{
   523  				2,
   524  			},
   525  			nil,
   526  		},
   527  		{
   528  			"simple",
   529  			fields{
   530  				nodes: []*node{
   531  					{
   532  						id:     1,
   533  						output: &fortyTwo,
   534  					},
   535  				},
   536  			},
   537  			args{
   538  				1,
   539  			},
   540  			&fortyTwo,
   541  		},
   542  	}
   543  	for _, tt := range tests {
   544  		t.Run(tt.name, func(t *testing.T) {
   545  			m := &Machine{
   546  				nodes:  tt.fields.nodes,
   547  				pubsub: tt.fields.pubsubs,
   548  			}
   549  			if got := m.GetResult(tt.args.id); !reflect.DeepEqual(got, tt.want) {
   550  				t.Errorf("Machine.GetResult() = %v, want %v", got, tt.want)
   551  			}
   552  		})
   553  	}
   554  }
   555  
   556  func Test_nodeErrors_Error(t *testing.T) {
   557  	tests := []struct {
   558  		name string
   559  		e    nodeErrors
   560  		want string
   561  	}{
   562  		{
   563  			"simple",
   564  			[]nodeError{
   565  				{
   566  					id:  0,
   567  					err: errors.New("error"),
   568  				},
   569  			},
   570  			"0:error\n",
   571  		},
   572  	}
   573  	for _, tt := range tests {
   574  		t.Run(tt.name, func(t *testing.T) {
   575  			if got := tt.e.Error(); got != tt.want {
   576  				t.Errorf("nodeErrors.Error() = %v, want %v", got, tt.want)
   577  			}
   578  		})
   579  	}
   580  }