github.com/wzzhu/tensor@v0.9.24/sparse_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/assert"
     7  )
     8  
     9  func TestCS_Basics(t *testing.T) {
    10  	assert := assert.New(t)
    11  	xs0 := []int{1, 2, 6, 8}
    12  	ys0 := []int{1, 2, 1, 6}
    13  	xs1 := []int{1, 2, 6, 8}
    14  	ys1 := []int{1, 2, 1, 6}
    15  	vals0 := []float64{3, 1, 4, 1}
    16  	vals1 := []float64{3, 1, 4, 1}
    17  
    18  	var T0, T1 *CS
    19  	var d0, d1 *Dense
    20  	var dp0, dp1 *Dense
    21  	var err error
    22  	fails := func() {
    23  		CSCFromCoord(Shape{7, 6}, xs0, ys0, vals0)
    24  	}
    25  	assert.Panics(fails)
    26  
    27  	// Test CSC
    28  	T0 = CSCFromCoord(Shape{9, 7}, xs0, ys0, vals0)
    29  	d0 = T0.Dense()
    30  	T0.T()
    31  	dp0 = T0.Dense()
    32  	T0.UT() // untranspose as Materialize() will be called below
    33  
    34  	// Test CSR
    35  	fails = func() {
    36  		CSRFromCoord(Shape{7, 6}, xs1, ys1, vals1)
    37  	}
    38  	T1 = CSRFromCoord(Shape{9, 7}, xs1, ys1, vals1)
    39  	d1 = T1.Dense()
    40  	T1.T()
    41  	dp1 = T1.Dense()
    42  	T1.UT()
    43  
    44  	t.Logf("%v %v", T0.indptr, T0.indices)
    45  	t.Logf("%v %v", T1.indptr, T1.indices)
    46  
    47  	assert.True(d0.Eq(d1), "%+#v\n %+#v\n", d0, d1)
    48  	assert.True(dp0.Eq(dp1))
    49  	assert.True(T1.Eq(T1))
    50  	assert.False(T0.Eq(T1))
    51  
    52  	// At
    53  	var got interface{}
    54  	correct := float64(3.0)
    55  	if got, err = T0.At(1, 1); err != nil {
    56  		t.Error(err)
    57  	}
    58  	if got.(float64) != correct {
    59  		t.Errorf("Expected %v. Got %v - T0[1,1]", correct, got)
    60  	}
    61  	if got, err = T1.At(1, 1); err != nil {
    62  		t.Error(err)
    63  	}
    64  	if got.(float64) != correct {
    65  		t.Errorf("Expected %v. Got %v - T1[1,1]", correct, got)
    66  	}
    67  
    68  	correct = 0.0
    69  	if got, err = T0.At(3, 3); err != nil {
    70  		t.Error(err)
    71  	}
    72  	if got.(float64) != correct {
    73  		t.Errorf("Expected %v. Got %v - T0[3,3]", correct, got)
    74  	}
    75  
    76  	if got, err = T1.At(3, 3); err != nil {
    77  		t.Error(err)
    78  	}
    79  	if got.(float64) != correct {
    80  		t.Errorf("Expected %v. Got %v - T1[3,3]", correct, got)
    81  	}
    82  
    83  	// Test clone
    84  	T2 := T0.Clone()
    85  	assert.True(T0.Eq(T2))
    86  
    87  	// Scalar representation
    88  	assert.False(T0.IsScalar())
    89  	fails = func() {
    90  		T0.ScalarValue()
    91  	}
    92  	assert.Panics(fails)
    93  	assert.Equal(len(vals0), T0.NonZeroes())
    94  
    95  	// Sparse Iterator
    96  	it := T0.Iterator()
    97  	var valids []int
    98  	correctValids := []int{0, 2, 1, 3}
    99  	for i, valid, err := it.NextValidity(); err == nil; i, valid, err = it.NextValidity() {
   100  		if valid {
   101  			valids = append(valids, i)
   102  		}
   103  	}
   104  	assert.Equal(correctValids, valids)
   105  }