github.com/wzzhu/tensor@v0.9.24/sparse_test.go (about) 1 package tensor 2 3 import ( 4 "testing" 5 6 "github.com/stretchr/testify/assert" 7 ) 8 9 func TestCS_Basics(t *testing.T) { 10 assert := assert.New(t) 11 xs0 := []int{1, 2, 6, 8} 12 ys0 := []int{1, 2, 1, 6} 13 xs1 := []int{1, 2, 6, 8} 14 ys1 := []int{1, 2, 1, 6} 15 vals0 := []float64{3, 1, 4, 1} 16 vals1 := []float64{3, 1, 4, 1} 17 18 var T0, T1 *CS 19 var d0, d1 *Dense 20 var dp0, dp1 *Dense 21 var err error 22 fails := func() { 23 CSCFromCoord(Shape{7, 6}, xs0, ys0, vals0) 24 } 25 assert.Panics(fails) 26 27 // Test CSC 28 T0 = CSCFromCoord(Shape{9, 7}, xs0, ys0, vals0) 29 d0 = T0.Dense() 30 T0.T() 31 dp0 = T0.Dense() 32 T0.UT() // untranspose as Materialize() will be called below 33 34 // Test CSR 35 fails = func() { 36 CSRFromCoord(Shape{7, 6}, xs1, ys1, vals1) 37 } 38 T1 = CSRFromCoord(Shape{9, 7}, xs1, ys1, vals1) 39 d1 = T1.Dense() 40 T1.T() 41 dp1 = T1.Dense() 42 T1.UT() 43 44 t.Logf("%v %v", T0.indptr, T0.indices) 45 t.Logf("%v %v", T1.indptr, T1.indices) 46 47 assert.True(d0.Eq(d1), "%+#v\n %+#v\n", d0, d1) 48 assert.True(dp0.Eq(dp1)) 49 assert.True(T1.Eq(T1)) 50 assert.False(T0.Eq(T1)) 51 52 // At 53 var got interface{} 54 correct := float64(3.0) 55 if got, err = T0.At(1, 1); err != nil { 56 t.Error(err) 57 } 58 if got.(float64) != correct { 59 t.Errorf("Expected %v. Got %v - T0[1,1]", correct, got) 60 } 61 if got, err = T1.At(1, 1); err != nil { 62 t.Error(err) 63 } 64 if got.(float64) != correct { 65 t.Errorf("Expected %v. Got %v - T1[1,1]", correct, got) 66 } 67 68 correct = 0.0 69 if got, err = T0.At(3, 3); err != nil { 70 t.Error(err) 71 } 72 if got.(float64) != correct { 73 t.Errorf("Expected %v. Got %v - T0[3,3]", correct, got) 74 } 75 76 if got, err = T1.At(3, 3); err != nil { 77 t.Error(err) 78 } 79 if got.(float64) != correct { 80 t.Errorf("Expected %v. Got %v - T1[3,3]", correct, got) 81 } 82 83 // Test clone 84 T2 := T0.Clone() 85 assert.True(T0.Eq(T2)) 86 87 // Scalar representation 88 assert.False(T0.IsScalar()) 89 fails = func() { 90 T0.ScalarValue() 91 } 92 assert.Panics(fails) 93 assert.Equal(len(vals0), T0.NonZeroes()) 94 95 // Sparse Iterator 96 it := T0.Iterator() 97 var valids []int 98 correctValids := []int{0, 2, 1, 3} 99 for i, valid, err := it.NextValidity(); err == nil; i, valid, err = it.NextValidity() { 100 if valid { 101 valids = append(valids, i) 102 } 103 } 104 assert.Equal(correctValids, valids) 105 }