gorgonia.org/gorgonia@v0.9.17/x/vm/node_test.go (about)

     1  package xvm
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"reflect"
     7  	"testing"
     8  
     9  	"gorgonia.org/gorgonia"
    10  )
    11  
    12  func Test_receiveInput(t *testing.T) {
    13  	cancelCtx, cancel := context.WithCancel(context.Background())
    14  	inputC := make(chan ioValue, 0)
    15  	type args struct {
    16  		ctx context.Context
    17  		o   *node
    18  		fn  func()
    19  	}
    20  	tests := []struct {
    21  		name string
    22  		args args
    23  		want stateFn
    24  	}{
    25  		{
    26  			"context cancelation",
    27  			args{
    28  				cancelCtx,
    29  				&node{
    30  					inputC: make(chan ioValue, 0),
    31  				},
    32  				nil,
    33  			},
    34  			nil,
    35  		},
    36  		{
    37  			"bad input value position",
    38  			args{
    39  				context.Background(),
    40  				&node{
    41  					inputC:      inputC,
    42  					inputValues: make([]gorgonia.Value, 1),
    43  				},
    44  				func() {
    45  					inputC <- struct {
    46  						pos int
    47  						v   gorgonia.Value
    48  					}{
    49  						pos: 1,
    50  						v:   nil,
    51  					}
    52  				},
    53  			},
    54  			nil,
    55  		},
    56  		{
    57  			"more value to receive",
    58  			args{
    59  				context.Background(),
    60  				&node{
    61  					inputC:      inputC,
    62  					inputValues: make([]gorgonia.Value, 2),
    63  				},
    64  				func() {
    65  					inputC <- struct {
    66  						pos int
    67  						v   gorgonia.Value
    68  					}{
    69  						pos: 0,
    70  						v:   nil,
    71  					}
    72  				},
    73  			},
    74  			receiveInput,
    75  		},
    76  		{
    77  			"no input chan go to conpute",
    78  			args{
    79  				context.Background(),
    80  				&node{
    81  					inputValues: make([]gorgonia.Value, 1),
    82  				},
    83  				nil,
    84  			},
    85  			computeFwd,
    86  		},
    87  		{
    88  			"all done go to compute",
    89  			args{
    90  				context.Background(),
    91  				&node{
    92  					inputC:      inputC,
    93  					inputValues: make([]gorgonia.Value, 1),
    94  				},
    95  				func() {
    96  					inputC <- struct {
    97  						pos int
    98  						v   gorgonia.Value
    99  					}{
   100  						pos: 0,
   101  						v:   nil,
   102  					}
   103  				},
   104  			},
   105  			computeFwd,
   106  		},
   107  		// TODO: Add test cases.
   108  	}
   109  	cancel()
   110  	for _, tt := range tests {
   111  		t.Run(tt.name, func(t *testing.T) {
   112  			if tt.args.fn != nil {
   113  				go tt.args.fn()
   114  			}
   115  			got := receiveInput(tt.args.ctx, tt.args.o)
   116  			gotPrt := reflect.ValueOf(got).Pointer()
   117  			wantPtr := reflect.ValueOf(tt.want).Pointer()
   118  			if gotPrt != wantPtr {
   119  				t.Errorf("receiveInput() = %v, want %v", got, tt.want)
   120  			}
   121  		})
   122  	}
   123  }
   124  
   125  func Test_computeFwd(t *testing.T) {
   126  	type args struct {
   127  		in0 context.Context
   128  		n   *node
   129  	}
   130  	tests := []struct {
   131  		name string
   132  		args args
   133  		want stateFn
   134  	}{
   135  		{
   136  			"simple no error",
   137  			args{
   138  				nil,
   139  				&node{
   140  					op:          &noOpTest{},
   141  					inputValues: []gorgonia.Value{nil},
   142  				},
   143  			},
   144  			emitOutput,
   145  		},
   146  		{
   147  			"simple with error",
   148  			args{
   149  				nil,
   150  				&node{
   151  					op:          &noOpTest{err: errors.New("")},
   152  					inputValues: []gorgonia.Value{nil},
   153  				},
   154  			},
   155  			nil,
   156  		},
   157  		// TODO: Add test cases.
   158  	}
   159  	for _, tt := range tests {
   160  		t.Run(tt.name, func(t *testing.T) {
   161  			got := computeFwd(tt.args.in0, tt.args.n)
   162  			gotPrt := reflect.ValueOf(got).Pointer()
   163  			wantPtr := reflect.ValueOf(tt.want).Pointer()
   164  			if gotPrt != wantPtr {
   165  				t.Errorf("computeFwd() = %v, want %v", got, tt.want)
   166  			}
   167  		})
   168  	}
   169  }
   170  
   171  func Test_node_ComputeForward(t *testing.T) {
   172  	type fields struct {
   173  		op             gorgonia.Op
   174  		output         gorgonia.Value
   175  		outputC        chan gorgonia.Value
   176  		receivedValues int
   177  		err            error
   178  		inputValues    []gorgonia.Value
   179  		inputC         chan ioValue
   180  	}
   181  	type args struct {
   182  		ctx context.Context
   183  	}
   184  	tests := []struct {
   185  		name    string
   186  		fields  fields
   187  		args    args
   188  		wantErr bool
   189  	}{
   190  		{
   191  			"simple",
   192  			fields{
   193  				op: nil,
   194  			},
   195  			args{
   196  				nil,
   197  			},
   198  			false,
   199  		},
   200  		// TODO: Add test cases.
   201  	}
   202  	for _, tt := range tests {
   203  		t.Run(tt.name, func(t *testing.T) {
   204  			n := &node{
   205  				op:             tt.fields.op,
   206  				output:         tt.fields.output,
   207  				outputC:        tt.fields.outputC,
   208  				receivedValues: tt.fields.receivedValues,
   209  				err:            tt.fields.err,
   210  				inputValues:    tt.fields.inputValues,
   211  				inputC:         tt.fields.inputC,
   212  			}
   213  			if err := n.Compute(tt.args.ctx); (err != nil) != tt.wantErr {
   214  				t.Errorf("node.ComputeForward() error = %v, wantErr %v", err, tt.wantErr)
   215  			}
   216  		})
   217  	}
   218  }
   219  
   220  type errorOP struct{}
   221  
   222  func (*errorOP) Do(v ...gorgonia.Value) (gorgonia.Value, error) {
   223  	return nil, errors.New("error")
   224  }
   225  
   226  type sumF32 struct{}
   227  
   228  func (*sumF32) Do(v ...gorgonia.Value) (gorgonia.Value, error) {
   229  	val := v[0].Data().(float32) + v[1].Data().(float32)
   230  	value := gorgonia.F32(val)
   231  	return &value, nil
   232  }
   233  
   234  func Test_emitOutput(t *testing.T) {
   235  	cancelCtx, cancel := context.WithCancel(context.Background())
   236  	outputC1 := make(chan gorgonia.Value, 0)
   237  	outputC2 := make(chan gorgonia.Value, 1)
   238  	type args struct {
   239  		ctx context.Context
   240  		n   *node
   241  	}
   242  	tests := []struct {
   243  		name string
   244  		args args
   245  		want stateFn
   246  	}{
   247  		{
   248  			"nil node",
   249  			args{nil, nil},
   250  			nil,
   251  		},
   252  		{
   253  			"context cancelation",
   254  			args{
   255  				cancelCtx,
   256  				&node{
   257  					outputC: outputC1,
   258  				},
   259  			},
   260  			nil,
   261  		},
   262  		{
   263  			"emit output",
   264  			args{
   265  				context.Background(),
   266  				&node{
   267  					outputC: outputC2,
   268  				},
   269  			},
   270  			nil,
   271  		},
   272  	}
   273  	cancel()
   274  	for _, tt := range tests {
   275  		t.Run(tt.name, func(t *testing.T) {
   276  			got := emitOutput(tt.args.ctx, tt.args.n)
   277  			gotPrt := reflect.ValueOf(got).Pointer()
   278  			wantPtr := reflect.ValueOf(tt.want).Pointer()
   279  			if gotPrt != wantPtr {
   280  				t.Errorf("emitOutput() = %v, want %v", got, tt.want)
   281  			}
   282  		})
   283  	}
   284  }
   285  
   286  func Test_computeBackward(t *testing.T) {
   287  	type args struct {
   288  		in0 context.Context
   289  		in1 *node
   290  	}
   291  	tests := []struct {
   292  		name string
   293  		args args
   294  		want stateFn
   295  	}{
   296  		{
   297  			"simple",
   298  			args{
   299  				nil,
   300  				nil,
   301  			},
   302  			nil,
   303  		},
   304  	}
   305  	for _, tt := range tests {
   306  		t.Run(tt.name, func(t *testing.T) {
   307  			if got := computeBackward(tt.args.in0, tt.args.in1); !reflect.DeepEqual(got, tt.want) {
   308  				t.Errorf("computeBackward() = %v, want %v", got, tt.want)
   309  			}
   310  		})
   311  	}
   312  }
   313  
   314  func Test_newOp(t *testing.T) {
   315  	g := gorgonia.NewGraph()
   316  	fortyTwo := gorgonia.F32(42.0)
   317  	n1 := gorgonia.NodeFromAny(g, fortyTwo)
   318  	n2 := gorgonia.NodeFromAny(g, fortyTwo)
   319  	addOp, err := gorgonia.Add(n1, n2)
   320  	if err != nil {
   321  		t.Fatal(err)
   322  	}
   323  	type args struct {
   324  		n             *gorgonia.Node
   325  		hasOutputChan bool
   326  	}
   327  	tests := []struct {
   328  		name string
   329  		args args
   330  		want *node
   331  	}{
   332  		{
   333  			"no op",
   334  			args{nil, false},
   335  			nil,
   336  		},
   337  		{
   338  			"add with outputChan",
   339  			args{addOp, true},
   340  			&node{
   341  				id:          addOp.ID(),
   342  				op:          addOp.Op(),
   343  				inputC:      make(chan ioValue, 0),
   344  				outputC:     make(chan gorgonia.Value, 0),
   345  				inputValues: make([]gorgonia.Value, 2),
   346  			},
   347  		},
   348  		{
   349  			"add without outputChan",
   350  			args{addOp, false},
   351  			&node{
   352  				id:          addOp.ID(),
   353  				op:          addOp.Op(),
   354  				inputC:      make(chan ioValue, 0),
   355  				outputC:     nil,
   356  				inputValues: make([]gorgonia.Value, 2),
   357  			},
   358  		},
   359  	}
   360  	for _, tt := range tests {
   361  		t.Run(tt.name, func(t *testing.T) {
   362  			got := newOp(tt.args.n, tt.args.hasOutputChan)
   363  			if got == tt.want {
   364  				return
   365  			}
   366  			if got.id != tt.want.id {
   367  				t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want)
   368  			}
   369  			if !reflect.DeepEqual(got.op, tt.want.op) {
   370  				t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want)
   371  			}
   372  			if !reflect.DeepEqual(got.inputValues, tt.want.inputValues) {
   373  				t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want)
   374  			}
   375  			if got.receivedValues != tt.want.receivedValues {
   376  				t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want)
   377  			}
   378  			if got.err != tt.want.err {
   379  				t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want)
   380  			}
   381  			if (got.inputC == nil && tt.want.inputC != nil) ||
   382  				(got.inputC != nil && tt.want.inputC == nil) {
   383  				t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want)
   384  			}
   385  			if (got.outputC == nil && tt.want.outputC != nil) ||
   386  				(got.outputC != nil && tt.want.outputC == nil) {
   387  				t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want)
   388  			}
   389  			if cap(got.outputC) != cap(tt.want.outputC) {
   390  				t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want)
   391  			}
   392  			if len(got.outputC) != len(tt.want.outputC) {
   393  				t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want)
   394  			}
   395  			if cap(got.inputC) != cap(tt.want.inputC) {
   396  				t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want)
   397  			}
   398  			if len(got.inputC) != len(tt.want.inputC) {
   399  				t.Errorf("newOp() = \n%#v, want \n%#v", got, tt.want)
   400  			}
   401  
   402  		})
   403  	}
   404  }
   405  
   406  func Test_newInput(t *testing.T) {
   407  	g := gorgonia.NewGraph()
   408  	fortyTwo := gorgonia.F32(42.0)
   409  	n1 := gorgonia.NodeFromAny(g, &fortyTwo)
   410  	type args struct {
   411  		n *gorgonia.Node
   412  	}
   413  	tests := []struct {
   414  		name string
   415  		args args
   416  		want *node
   417  	}{
   418  		{
   419  			"nil",
   420  			args{nil},
   421  			nil,
   422  		},
   423  		{
   424  			"simple",
   425  			args{n1},
   426  			&node{
   427  				outputC: make(chan gorgonia.Value, 0),
   428  				output:  &fortyTwo,
   429  			},
   430  		},
   431  	}
   432  	for _, tt := range tests {
   433  		t.Run(tt.name, func(t *testing.T) {
   434  			got := newInput(tt.args.n)
   435  			if got == tt.want {
   436  				return
   437  			}
   438  			compareNodes(t, got, tt.want)
   439  		})
   440  	}
   441  }
   442  
   443  func compareNodes(t *testing.T, got, want *node) {
   444  	if got.id != want.id {
   445  		t.Errorf("nodes ID are different = \n%#v, want \n%#v", got.id, want.id)
   446  	}
   447  	if !reflect.DeepEqual(got.op, want.op) {
   448  		t.Errorf("nodes OP are different = \n%#v, want \n%#v", got.op, want.op)
   449  	}
   450  	if !reflect.DeepEqual(got.inputValues, want.inputValues) {
   451  		t.Errorf("nodes inputValues are different = \n%#v, want \n%#v", got.inputValues, want.inputValues)
   452  	}
   453  	if got.receivedValues != want.receivedValues {
   454  		t.Errorf("nodes receivedValues are different = \n%#v, want \n%#v", got.receivedValues, want.receivedValues)
   455  	}
   456  	if got.err != want.err {
   457  		t.Errorf("nodes errors are different = \n%#v, want \n%#v", got.err, want.err)
   458  	}
   459  	if (got.inputC == nil && want.inputC != nil) ||
   460  		(got.inputC != nil && want.inputC == nil) {
   461  		t.Errorf("newInput() = \n%#v, want \n%#v", got, want)
   462  	}
   463  	if (got.outputC == nil && want.outputC != nil) ||
   464  		(got.outputC != nil && want.outputC == nil) {
   465  		t.Errorf("newInput() = \n%#v, want \n%#v", got, want)
   466  	}
   467  	if cap(got.outputC) != cap(want.outputC) {
   468  		t.Errorf("newInput() = \n%#v, want \n%#v", got, want)
   469  	}
   470  	if len(got.outputC) != len(want.outputC) {
   471  		t.Errorf("newInput() = \n%#v, want \n%#v", got, want)
   472  	}
   473  	if cap(got.inputC) != cap(want.inputC) {
   474  		t.Errorf("newInput() = \n%#v, want \n%#v", got, want)
   475  	}
   476  	if len(got.inputC) != len(want.inputC) {
   477  		t.Errorf("newInput() = \n%#v, want \n%#v", got, want)
   478  	}
   479  
   480  }