github.com/wzzhu/tensor@v0.9.24/api_cmp_test.go (about) 1 package tensor 2 3 import ( 4 "testing" 5 6 "github.com/stretchr/testify/assert" 7 ) 8 9 // This file contains the tests for API functions that aren't generated by genlib 10 11 func TestLtScalarScalar(t *testing.T) { 12 // scalar-scalar 13 a := New(WithBacking([]float64{6})) 14 b := New(WithBacking([]float64{2})) 15 var correct interface{} = false 16 17 res, err := Lt(a, b) 18 if err != nil { 19 t.Fatalf("Error: %v", err) 20 } 21 assert.Equal(t, correct, res.Data()) 22 23 // scalar-tensor 24 a = New(WithBacking([]float64{1, 4})) 25 b = New(WithBacking([]float64{2})) 26 correct = []bool{true, false} 27 28 res, err = Lt(a, b) 29 if err != nil { 30 t.Fatalf("Error: %v", err) 31 } 32 assert.Equal(t, correct, res.Data()) 33 34 // tensor-scalar 35 a = New(WithBacking([]float64{3})) 36 b = New(WithBacking([]float64{6, 2})) 37 correct = []bool{true, false} 38 39 res, err = Lt(a, b) 40 if err != nil { 41 t.Fatalf("Error: %v", err) 42 } 43 assert.Equal(t, correct, res.Data()) 44 45 // tensor - tensor 46 a = New(WithBacking([]float64{21, 2})) 47 b = New(WithBacking([]float64{7, 10})) 48 correct = []bool{false, true} 49 50 res, err = Lt(a, b) 51 if err != nil { 52 t.Fatalf("Error: %v", err) 53 } 54 assert.Equal(t, correct, res.Data()) 55 } 56 57 func TestGtScalarScalar(t *testing.T) { 58 // scalar-scalar 59 a := New(WithBacking([]float64{6})) 60 b := New(WithBacking([]float64{2})) 61 var correct interface{} = true 62 63 res, err := Gt(a, b) 64 if err != nil { 65 t.Fatalf("Error: %v", err) 66 } 67 assert.Equal(t, correct, res.Data()) 68 69 // scalar-tensor 70 a = New(WithBacking([]float64{1, 4})) 71 b = New(WithBacking([]float64{2})) 72 correct = []bool{false, true} 73 74 res, err = Gt(a, b) 75 if err != nil { 76 t.Fatalf("Error: %v", err) 77 } 78 assert.Equal(t, correct, res.Data()) 79 80 // tensor-scalar 81 a = New(WithBacking([]float64{3})) 82 b = New(WithBacking([]float64{6, 2})) 83 correct = []bool{false, true} 84 85 res, err = Gt(a, b) 86 if err != nil { 87 t.Fatalf("Error: %v", err) 88 } 89 assert.Equal(t, correct, res.Data()) 90 91 // tensor - tensor 92 a = New(WithBacking([]float64{21, 2})) 93 b = New(WithBacking([]float64{7, 10})) 94 correct = []bool{true, false} 95 96 res, err = Gt(a, b) 97 if err != nil { 98 t.Fatalf("Error: %v", err) 99 } 100 assert.Equal(t, correct, res.Data()) 101 } 102 103 func TestLteScalarScalar(t *testing.T) { 104 // scalar-scalar 105 a := New(WithBacking([]float64{6})) 106 b := New(WithBacking([]float64{2})) 107 var correct interface{} = false 108 109 res, err := Lte(a, b) 110 if err != nil { 111 t.Fatalf("Error: %v", err) 112 } 113 assert.Equal(t, correct, res.Data()) 114 115 // scalar-tensor 116 a = New(WithBacking([]float64{1, 2, 4})) 117 b = New(WithBacking([]float64{2})) 118 correct = []bool{true, true, false} 119 120 res, err = Lte(a, b) 121 if err != nil { 122 t.Fatalf("Error: %v", err) 123 } 124 assert.Equal(t, correct, res.Data()) 125 126 // tensor-scalar 127 a = New(WithBacking([]float64{3})) 128 b = New(WithBacking([]float64{6, 2})) 129 correct = []bool{true, false} 130 131 res, err = Lte(a, b) 132 if err != nil { 133 t.Fatalf("Error: %v", err) 134 } 135 assert.Equal(t, correct, res.Data()) 136 137 // tensor - tensor 138 a = New(WithBacking([]float64{21, 2})) 139 b = New(WithBacking([]float64{7, 10})) 140 correct = []bool{false, true} 141 142 res, err = Lte(a, b) 143 if err != nil { 144 t.Fatalf("Error: %v", err) 145 } 146 assert.Equal(t, correct, res.Data()) 147 } 148 149 func TestGteScalarScalar(t *testing.T) { 150 // scalar-scalar 151 a := New(WithBacking([]float64{6})) 152 b := New(WithBacking([]float64{2})) 153 var correct interface{} = true 154 155 res, err := Gte(a, b) 156 if err != nil { 157 t.Fatalf("Error: %v", err) 158 } 159 assert.Equal(t, correct, res.Data()) 160 161 // scalar-tensor 162 a = New(WithBacking([]float64{1, 2, 4})) 163 b = New(WithBacking([]float64{2})) 164 correct = []bool{false, true, true} 165 166 res, err = Gte(a, b) 167 if err != nil { 168 t.Fatalf("Error: %v", err) 169 } 170 assert.Equal(t, correct, res.Data()) 171 172 // tensor-scalar 173 a = New(WithBacking([]float64{3})) 174 b = New(WithBacking([]float64{6, 3, 2})) 175 correct = []bool{false, true, true} 176 177 res, err = Gte(a, b) 178 if err != nil { 179 t.Fatalf("Error: %v", err) 180 } 181 assert.Equal(t, correct, res.Data()) 182 183 // tensor - tensor 184 a = New(WithBacking([]float64{21, 31, 2})) 185 b = New(WithBacking([]float64{7, 31, 10})) 186 correct = []bool{true, true, false} 187 188 res, err = Gte(a, b) 189 if err != nil { 190 t.Fatalf("Error: %v", err) 191 } 192 assert.Equal(t, correct, res.Data()) 193 } 194 195 func TestElEqScalarScalar(t *testing.T) { 196 // scalar-scalar 197 a := New(WithBacking([]float64{6})) 198 b := New(WithBacking([]float64{2})) 199 var correct interface{} = false 200 201 res, err := ElEq(a, b) 202 if err != nil { 203 t.Fatalf("Error: %v", err) 204 } 205 assert.Equal(t, correct, res.Data()) 206 207 // scalar-tensor 208 a = New(WithBacking([]float64{1, 2, 4})) 209 b = New(WithBacking([]float64{2})) 210 correct = []bool{false, true, false} 211 212 res, err = ElEq(a, b) 213 if err != nil { 214 t.Fatalf("Error: %v", err) 215 } 216 assert.Equal(t, correct, res.Data()) 217 218 // tensor-scalar 219 a = New(WithBacking([]float64{3})) 220 b = New(WithBacking([]float64{6, 3, 2})) 221 correct = []bool{false, true, false} 222 223 res, err = ElEq(a, b) 224 if err != nil { 225 t.Fatalf("Error: %v", err) 226 } 227 assert.Equal(t, correct, res.Data()) 228 229 // tensor - tensor 230 a = New(WithBacking([]float64{21, 10})) 231 b = New(WithBacking([]float64{7, 10})) 232 correct = []bool{false, true} 233 234 res, err = ElEq(a, b) 235 if err != nil { 236 t.Fatalf("Error: %v", err) 237 } 238 assert.Equal(t, correct, res.Data()) 239 }