github.com/wzzhu/tensor@v0.9.24/dense_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"math/rand"
     5  	"testing"
     6  	"testing/quick"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/assert"
    10  )
    11  
    12  func TestDense_ShallowClone(t *testing.T) {
    13  	T := New(Of(Float64), WithBacking([]float64{1, 2, 3, 4}))
    14  	T2 := T.ShallowClone()
    15  	T2.slice(0, 2)
    16  	T2.Float64s()[0] = 1000
    17  
    18  	assert.Equal(t, T.Data().([]float64)[0:2], T2.Data())
    19  	assert.Equal(t, T.Engine(), T2.Engine())
    20  	assert.Equal(t, T.oe, T2.oe)
    21  	assert.Equal(t, T.flag, T2.flag)
    22  }
    23  
    24  func TestDense_Clone(t *testing.T) {
    25  	assert := assert.New(t)
    26  	cloneChk := func(q *Dense) bool {
    27  		a := q.Clone().(*Dense)
    28  		if !q.Shape().Eq(a.Shape()) {
    29  			t.Errorf("Shape Difference: %v %v", q.Shape(), a.Shape())
    30  			return false
    31  		}
    32  		if len(q.Strides()) != len(a.Strides()) {
    33  			t.Errorf("Stride Difference: %v %v", q.Strides(), a.Strides())
    34  			return false
    35  		}
    36  		for i, s := range q.Strides() {
    37  			if a.Strides()[i] != s {
    38  				t.Errorf("Stride Difference: %v %v", q.Strides(), a.Strides())
    39  				return false
    40  			}
    41  		}
    42  		if q.o != a.o {
    43  			t.Errorf("Data Order difference : %v %v", q.o, a.o)
    44  			return false
    45  		}
    46  
    47  		if q.Δ != a.Δ {
    48  			t.Errorf("Triangle Difference: %v  %v", q.Δ, a.Δ)
    49  			return false
    50  		}
    51  		if q.flag != a.flag {
    52  			t.Errorf("Flag difference : %v %v", q.flag, a.flag)
    53  			return false
    54  		}
    55  		if q.e != a.e {
    56  			t.Errorf("Engine difference; %T %T", q.e, a.e)
    57  			return false
    58  		}
    59  		if q.oe != a.oe {
    60  			t.Errorf("Optimized Engine difference; %T %T", q.oe, a.oe)
    61  			return false
    62  		}
    63  
    64  		if len(q.transposeWith) != len(a.transposeWith) {
    65  			t.Errorf("TransposeWith difference: %v %v", q.transposeWith, a.transposeWith)
    66  			return false
    67  		}
    68  
    69  		assert.Equal(q.mask, a.mask, "mask difference")
    70  		assert.Equal(q.maskIsSoft, a.maskIsSoft, "mask is soft ")
    71  		return true
    72  	}
    73  	r := rand.New(rand.NewSource(time.Now().UnixNano()))
    74  	if err := quick.Check(cloneChk, &quick.Config{Rand: r}); err != nil {
    75  		t.Error(err)
    76  	}
    77  }
    78  
    79  func TestDenseMasked(t *testing.T) {
    80  	T := New(Of(Float64), WithShape(3, 2))
    81  	T.ResetMask()
    82  	assert.Equal(t, []bool{false, false, false, false, false, false}, T.mask)
    83  
    84  }
    85  
    86  func TestFromScalar(t *testing.T) {
    87  	T := New(FromScalar(3.14))
    88  	data := T.Float64s()
    89  	assert.Equal(t, []float64{3.14}, data)
    90  }
    91  
    92  func Test_recycledDense(t *testing.T) {
    93  	T := recycledDense(Float64, ScalarShape())
    94  	assert.Equal(t, float64(0), T.Data())
    95  	assert.Equal(t, StdEng{}, T.e)
    96  	assert.Equal(t, StdEng{}, T.oe)
    97  }
    98  
    99  func TestDense_unsqueeze(t *testing.T) {
   100  	assert := assert.New(t)
   101  	T := New(WithShape(3, 3, 2), WithBacking([]float64{
   102  		1, 2, 3, 4, 5, 6,
   103  		60, 50, 40, 30, 20, 10,
   104  		100, 200, 300, 400, 500, 600,
   105  	}))
   106  
   107  	if err := T.unsqueeze(0); err != nil {
   108  		t.Fatal(err)
   109  	}
   110  
   111  	assert.True(T.Shape().Eq(Shape{1, 3, 3, 2}))
   112  	assert.Equal([]int{6, 6, 2, 1}, T.Strides()) // if you do shapes.CalcStrides() it'd be {18,6,2,1}
   113  
   114  	// reset
   115  	T.Reshape(3, 3, 2)
   116  
   117  	if err := T.unsqueeze(1); err != nil {
   118  		t.Fatal(err)
   119  	}
   120  	assert.True(T.Shape().Eq(Shape{3, 1, 3, 2}))
   121  	assert.Equal([]int{6, 2, 2, 1}, T.Strides())
   122  
   123  	// reset
   124  	T.Reshape(3, 3, 2)
   125  	if err := T.unsqueeze(2); err != nil {
   126  		t.Fatal(err)
   127  	}
   128  	t.Logf("%v", T)
   129  	assert.True(T.Shape().Eq(Shape{3, 3, 1, 2}))
   130  	assert.Equal([]int{6, 2, 1, 1}, T.Strides())
   131  
   132  	// reset
   133  	T.Reshape(3, 3, 2)
   134  	if err := T.unsqueeze(3); err != nil {
   135  		t.Fatal(err)
   136  	}
   137  	t.Logf("%v", T)
   138  	assert.True(T.Shape().Eq(Shape{3, 3, 2, 1}))
   139  	assert.Equal([]int{6, 2, 1, 1}, T.Strides())
   140  }