github.com/wzzhu/tensor@v0.9.24/api_cmp_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/assert"
     7  )
     8  
     9  // This file contains the tests for API functions that aren't generated by genlib
    10  
    11  func TestLtScalarScalar(t *testing.T) {
    12  	// scalar-scalar
    13  	a := New(WithBacking([]float64{6}))
    14  	b := New(WithBacking([]float64{2}))
    15  	var correct interface{} = false
    16  
    17  	res, err := Lt(a, b)
    18  	if err != nil {
    19  		t.Fatalf("Error: %v", err)
    20  	}
    21  	assert.Equal(t, correct, res.Data())
    22  
    23  	// scalar-tensor
    24  	a = New(WithBacking([]float64{1, 4}))
    25  	b = New(WithBacking([]float64{2}))
    26  	correct = []bool{true, false}
    27  
    28  	res, err = Lt(a, b)
    29  	if err != nil {
    30  		t.Fatalf("Error: %v", err)
    31  	}
    32  	assert.Equal(t, correct, res.Data())
    33  
    34  	// tensor-scalar
    35  	a = New(WithBacking([]float64{3}))
    36  	b = New(WithBacking([]float64{6, 2}))
    37  	correct = []bool{true, false}
    38  
    39  	res, err = Lt(a, b)
    40  	if err != nil {
    41  		t.Fatalf("Error: %v", err)
    42  	}
    43  	assert.Equal(t, correct, res.Data())
    44  
    45  	// tensor - tensor
    46  	a = New(WithBacking([]float64{21, 2}))
    47  	b = New(WithBacking([]float64{7, 10}))
    48  	correct = []bool{false, true}
    49  
    50  	res, err = Lt(a, b)
    51  	if err != nil {
    52  		t.Fatalf("Error: %v", err)
    53  	}
    54  	assert.Equal(t, correct, res.Data())
    55  }
    56  
    57  func TestGtScalarScalar(t *testing.T) {
    58  	// scalar-scalar
    59  	a := New(WithBacking([]float64{6}))
    60  	b := New(WithBacking([]float64{2}))
    61  	var correct interface{} = true
    62  
    63  	res, err := Gt(a, b)
    64  	if err != nil {
    65  		t.Fatalf("Error: %v", err)
    66  	}
    67  	assert.Equal(t, correct, res.Data())
    68  
    69  	// scalar-tensor
    70  	a = New(WithBacking([]float64{1, 4}))
    71  	b = New(WithBacking([]float64{2}))
    72  	correct = []bool{false, true}
    73  
    74  	res, err = Gt(a, b)
    75  	if err != nil {
    76  		t.Fatalf("Error: %v", err)
    77  	}
    78  	assert.Equal(t, correct, res.Data())
    79  
    80  	// tensor-scalar
    81  	a = New(WithBacking([]float64{3}))
    82  	b = New(WithBacking([]float64{6, 2}))
    83  	correct = []bool{false, true}
    84  
    85  	res, err = Gt(a, b)
    86  	if err != nil {
    87  		t.Fatalf("Error: %v", err)
    88  	}
    89  	assert.Equal(t, correct, res.Data())
    90  
    91  	// tensor - tensor
    92  	a = New(WithBacking([]float64{21, 2}))
    93  	b = New(WithBacking([]float64{7, 10}))
    94  	correct = []bool{true, false}
    95  
    96  	res, err = Gt(a, b)
    97  	if err != nil {
    98  		t.Fatalf("Error: %v", err)
    99  	}
   100  	assert.Equal(t, correct, res.Data())
   101  }
   102  
   103  func TestLteScalarScalar(t *testing.T) {
   104  	// scalar-scalar
   105  	a := New(WithBacking([]float64{6}))
   106  	b := New(WithBacking([]float64{2}))
   107  	var correct interface{} = false
   108  
   109  	res, err := Lte(a, b)
   110  	if err != nil {
   111  		t.Fatalf("Error: %v", err)
   112  	}
   113  	assert.Equal(t, correct, res.Data())
   114  
   115  	// scalar-tensor
   116  	a = New(WithBacking([]float64{1, 2, 4}))
   117  	b = New(WithBacking([]float64{2}))
   118  	correct = []bool{true, true, false}
   119  
   120  	res, err = Lte(a, b)
   121  	if err != nil {
   122  		t.Fatalf("Error: %v", err)
   123  	}
   124  	assert.Equal(t, correct, res.Data())
   125  
   126  	// tensor-scalar
   127  	a = New(WithBacking([]float64{3}))
   128  	b = New(WithBacking([]float64{6, 2}))
   129  	correct = []bool{true, false}
   130  
   131  	res, err = Lte(a, b)
   132  	if err != nil {
   133  		t.Fatalf("Error: %v", err)
   134  	}
   135  	assert.Equal(t, correct, res.Data())
   136  
   137  	// tensor - tensor
   138  	a = New(WithBacking([]float64{21, 2}))
   139  	b = New(WithBacking([]float64{7, 10}))
   140  	correct = []bool{false, true}
   141  
   142  	res, err = Lte(a, b)
   143  	if err != nil {
   144  		t.Fatalf("Error: %v", err)
   145  	}
   146  	assert.Equal(t, correct, res.Data())
   147  }
   148  
   149  func TestGteScalarScalar(t *testing.T) {
   150  	// scalar-scalar
   151  	a := New(WithBacking([]float64{6}))
   152  	b := New(WithBacking([]float64{2}))
   153  	var correct interface{} = true
   154  
   155  	res, err := Gte(a, b)
   156  	if err != nil {
   157  		t.Fatalf("Error: %v", err)
   158  	}
   159  	assert.Equal(t, correct, res.Data())
   160  
   161  	// scalar-tensor
   162  	a = New(WithBacking([]float64{1, 2, 4}))
   163  	b = New(WithBacking([]float64{2}))
   164  	correct = []bool{false, true, true}
   165  
   166  	res, err = Gte(a, b)
   167  	if err != nil {
   168  		t.Fatalf("Error: %v", err)
   169  	}
   170  	assert.Equal(t, correct, res.Data())
   171  
   172  	// tensor-scalar
   173  	a = New(WithBacking([]float64{3}))
   174  	b = New(WithBacking([]float64{6, 3, 2}))
   175  	correct = []bool{false, true, true}
   176  
   177  	res, err = Gte(a, b)
   178  	if err != nil {
   179  		t.Fatalf("Error: %v", err)
   180  	}
   181  	assert.Equal(t, correct, res.Data())
   182  
   183  	// tensor - tensor
   184  	a = New(WithBacking([]float64{21, 31, 2}))
   185  	b = New(WithBacking([]float64{7, 31, 10}))
   186  	correct = []bool{true, true, false}
   187  
   188  	res, err = Gte(a, b)
   189  	if err != nil {
   190  		t.Fatalf("Error: %v", err)
   191  	}
   192  	assert.Equal(t, correct, res.Data())
   193  }
   194  
   195  func TestElEqScalarScalar(t *testing.T) {
   196  	// scalar-scalar
   197  	a := New(WithBacking([]float64{6}))
   198  	b := New(WithBacking([]float64{2}))
   199  	var correct interface{} = false
   200  
   201  	res, err := ElEq(a, b)
   202  	if err != nil {
   203  		t.Fatalf("Error: %v", err)
   204  	}
   205  	assert.Equal(t, correct, res.Data())
   206  
   207  	// scalar-tensor
   208  	a = New(WithBacking([]float64{1, 2, 4}))
   209  	b = New(WithBacking([]float64{2}))
   210  	correct = []bool{false, true, false}
   211  
   212  	res, err = ElEq(a, b)
   213  	if err != nil {
   214  		t.Fatalf("Error: %v", err)
   215  	}
   216  	assert.Equal(t, correct, res.Data())
   217  
   218  	// tensor-scalar
   219  	a = New(WithBacking([]float64{3}))
   220  	b = New(WithBacking([]float64{6, 3, 2}))
   221  	correct = []bool{false, true, false}
   222  
   223  	res, err = ElEq(a, b)
   224  	if err != nil {
   225  		t.Fatalf("Error: %v", err)
   226  	}
   227  	assert.Equal(t, correct, res.Data())
   228  
   229  	// tensor - tensor
   230  	a = New(WithBacking([]float64{21, 10}))
   231  	b = New(WithBacking([]float64{7, 10}))
   232  	correct = []bool{false, true}
   233  
   234  	res, err = ElEq(a, b)
   235  	if err != nil {
   236  		t.Fatalf("Error: %v", err)
   237  	}
   238  	assert.Equal(t, correct, res.Data())
   239  }