github.com/wzzhu/tensor@v0.9.24/dense_test.go (about) 1 package tensor 2 3 import ( 4 "math/rand" 5 "testing" 6 "testing/quick" 7 "time" 8 9 "github.com/stretchr/testify/assert" 10 ) 11 12 func TestDense_ShallowClone(t *testing.T) { 13 T := New(Of(Float64), WithBacking([]float64{1, 2, 3, 4})) 14 T2 := T.ShallowClone() 15 T2.slice(0, 2) 16 T2.Float64s()[0] = 1000 17 18 assert.Equal(t, T.Data().([]float64)[0:2], T2.Data()) 19 assert.Equal(t, T.Engine(), T2.Engine()) 20 assert.Equal(t, T.oe, T2.oe) 21 assert.Equal(t, T.flag, T2.flag) 22 } 23 24 func TestDense_Clone(t *testing.T) { 25 assert := assert.New(t) 26 cloneChk := func(q *Dense) bool { 27 a := q.Clone().(*Dense) 28 if !q.Shape().Eq(a.Shape()) { 29 t.Errorf("Shape Difference: %v %v", q.Shape(), a.Shape()) 30 return false 31 } 32 if len(q.Strides()) != len(a.Strides()) { 33 t.Errorf("Stride Difference: %v %v", q.Strides(), a.Strides()) 34 return false 35 } 36 for i, s := range q.Strides() { 37 if a.Strides()[i] != s { 38 t.Errorf("Stride Difference: %v %v", q.Strides(), a.Strides()) 39 return false 40 } 41 } 42 if q.o != a.o { 43 t.Errorf("Data Order difference : %v %v", q.o, a.o) 44 return false 45 } 46 47 if q.Δ != a.Δ { 48 t.Errorf("Triangle Difference: %v %v", q.Δ, a.Δ) 49 return false 50 } 51 if q.flag != a.flag { 52 t.Errorf("Flag difference : %v %v", q.flag, a.flag) 53 return false 54 } 55 if q.e != a.e { 56 t.Errorf("Engine difference; %T %T", q.e, a.e) 57 return false 58 } 59 if q.oe != a.oe { 60 t.Errorf("Optimized Engine difference; %T %T", q.oe, a.oe) 61 return false 62 } 63 64 if len(q.transposeWith) != len(a.transposeWith) { 65 t.Errorf("TransposeWith difference: %v %v", q.transposeWith, a.transposeWith) 66 return false 67 } 68 69 assert.Equal(q.mask, a.mask, "mask difference") 70 assert.Equal(q.maskIsSoft, a.maskIsSoft, "mask is soft ") 71 return true 72 } 73 r := rand.New(rand.NewSource(time.Now().UnixNano())) 74 if err := quick.Check(cloneChk, &quick.Config{Rand: r}); err != nil { 75 t.Error(err) 76 } 77 } 78 79 func TestDenseMasked(t *testing.T) { 80 T := New(Of(Float64), WithShape(3, 2)) 81 T.ResetMask() 82 assert.Equal(t, []bool{false, false, false, false, false, false}, T.mask) 83 84 } 85 86 func TestFromScalar(t *testing.T) { 87 T := New(FromScalar(3.14)) 88 data := T.Float64s() 89 assert.Equal(t, []float64{3.14}, data) 90 } 91 92 func Test_recycledDense(t *testing.T) { 93 T := recycledDense(Float64, ScalarShape()) 94 assert.Equal(t, float64(0), T.Data()) 95 assert.Equal(t, StdEng{}, T.e) 96 assert.Equal(t, StdEng{}, T.oe) 97 } 98 99 func TestDense_unsqueeze(t *testing.T) { 100 assert := assert.New(t) 101 T := New(WithShape(3, 3, 2), WithBacking([]float64{ 102 1, 2, 3, 4, 5, 6, 103 60, 50, 40, 30, 20, 10, 104 100, 200, 300, 400, 500, 600, 105 })) 106 107 if err := T.unsqueeze(0); err != nil { 108 t.Fatal(err) 109 } 110 111 assert.True(T.Shape().Eq(Shape{1, 3, 3, 2})) 112 assert.Equal([]int{6, 6, 2, 1}, T.Strides()) // if you do shapes.CalcStrides() it'd be {18,6,2,1} 113 114 // reset 115 T.Reshape(3, 3, 2) 116 117 if err := T.unsqueeze(1); err != nil { 118 t.Fatal(err) 119 } 120 assert.True(T.Shape().Eq(Shape{3, 1, 3, 2})) 121 assert.Equal([]int{6, 2, 2, 1}, T.Strides()) 122 123 // reset 124 T.Reshape(3, 3, 2) 125 if err := T.unsqueeze(2); err != nil { 126 t.Fatal(err) 127 } 128 t.Logf("%v", T) 129 assert.True(T.Shape().Eq(Shape{3, 3, 1, 2})) 130 assert.Equal([]int{6, 2, 1, 1}, T.Strides()) 131 132 // reset 133 T.Reshape(3, 3, 2) 134 if err := T.unsqueeze(3); err != nil { 135 t.Fatal(err) 136 } 137 t.Logf("%v", T) 138 assert.True(T.Shape().Eq(Shape{3, 3, 2, 1})) 139 assert.Equal([]int{6, 2, 1, 1}, T.Strides()) 140 }