github.com/wzzhu/tensor@v0.9.24/dense_mask_inspection_test.go (about) 1 package tensor 2 3 import ( 4 "testing" 5 6 "github.com/stretchr/testify/assert" 7 ) 8 9 func TestMaskedInspection(t *testing.T) { 10 assert := assert.New(t) 11 12 var retT *Dense 13 14 //vector case 15 T := New(Of(Bool), WithShape(1, 12)) 16 T.ResetMask(false) 17 assert.False(T.MaskedAny().(bool)) 18 for i := 0; i < 12; i += 2 { 19 T.mask[i] = true 20 } 21 assert.True(T.MaskedAny().(bool)) 22 assert.True(T.MaskedAny(0).(bool)) 23 assert.False(T.MaskedAll().(bool)) 24 assert.False(T.MaskedAll(0).(bool)) 25 assert.Equal(6, T.MaskedCount()) 26 assert.Equal(6, T.MaskedCount(0)) 27 assert.Equal(6, T.NonMaskedCount()) 28 assert.Equal(6, T.NonMaskedCount(0)) 29 30 //contiguous mask case 31 /*equivalent python code 32 --------- 33 import numpy.ma as ma 34 a = ma.arange(12).reshape((2, 3, 2)) 35 a[0,0,0]=ma.masked 36 a[0,2,0]=ma.masked 37 print(ma.getmask(a).all()) 38 print(ma.getmask(a).any()) 39 print(ma.count_masked(a)) 40 print(ma.count(a)) 41 print(ma.getmask(a).all(0)) 42 print(ma.getmask(a).any(0)) 43 print(ma.count_masked(a,0)) 44 print(ma.count(a,0)) 45 print(ma.getmask(a).all(1)) 46 print(ma.getmask(a).any(1)) 47 print(ma.count_masked(a,1)) 48 print(ma.count(a,1)) 49 print(ma.getmask(a).all(2)) 50 print(ma.getmask(a).any(2)) 51 print(ma.count_masked(a,2)) 52 print(ma.count(a,2)) 53 ----------- 54 */ 55 T = New(Of(Bool), WithShape(2, 3, 2)) 56 T.ResetMask(false) 57 58 for i := 0; i < 2; i += 2 { 59 for j := 0; j < 3; j += 2 { 60 for k := 0; k < 2; k += 2 { 61 a, b, c := T.strides[0], T.strides[1], T.strides[2] 62 T.mask[i*a+b*j+c*k] = true 63 } 64 } 65 } 66 67 assert.Equal([]bool{true, false, false, false, true, false, 68 false, false, false, false, false, false}, T.mask) 69 70 assert.Equal(false, T.MaskedAll()) 71 assert.Equal(true, T.MaskedAny()) 72 assert.Equal(2, T.MaskedCount()) 73 assert.Equal(10, T.NonMaskedCount()) 74 75 retT = T.MaskedAll(0).(*Dense) 76 assert.Equal([]int{3, 2}, []int(retT.shape)) 77 assert.Equal([]bool{false, false, false, false, false, false}, retT.Bools()) 78 retT = T.MaskedAny(0).(*Dense) 79 assert.Equal([]int{3, 2}, []int(retT.shape)) 80 assert.Equal([]bool{true, false, false, false, true, false}, retT.Bools()) 81 retT = T.MaskedCount(0).(*Dense) 82 assert.Equal([]int{3, 2}, []int(retT.shape)) 83 assert.Equal([]int{1, 0, 0, 0, 1, 0}, retT.Ints()) 84 retT = T.NonMaskedCount(0).(*Dense) 85 assert.Equal([]int{1, 2, 2, 2, 1, 2}, retT.Ints()) 86 87 retT = T.MaskedAll(1).(*Dense) 88 assert.Equal([]int{2, 2}, []int(retT.shape)) 89 assert.Equal([]bool{false, false, false, false}, retT.Bools()) 90 retT = T.MaskedAny(1).(*Dense) 91 assert.Equal([]int{2, 2}, []int(retT.shape)) 92 assert.Equal([]bool{true, false, false, false}, retT.Bools()) 93 retT = T.MaskedCount(1).(*Dense) 94 assert.Equal([]int{2, 2}, []int(retT.shape)) 95 assert.Equal([]int{2, 0, 0, 0}, retT.Ints()) 96 retT = T.NonMaskedCount(1).(*Dense) 97 assert.Equal([]int{1, 3, 3, 3}, retT.Ints()) 98 99 retT = T.MaskedAll(2).(*Dense) 100 assert.Equal([]int{2, 3}, []int(retT.shape)) 101 assert.Equal([]bool{false, false, false, false, false, false}, retT.Bools()) 102 retT = T.MaskedAny(2).(*Dense) 103 assert.Equal([]int{2, 3}, []int(retT.shape)) 104 assert.Equal([]bool{true, false, true, false, false, false}, retT.Bools()) 105 retT = T.MaskedCount(2).(*Dense) 106 assert.Equal([]int{2, 3}, []int(retT.shape)) 107 assert.Equal([]int{1, 0, 1, 0, 0, 0}, retT.Ints()) 108 retT = T.NonMaskedCount(2).(*Dense) 109 assert.Equal([]int{1, 2, 1, 2, 2, 2}, retT.Ints()) 110 111 } 112 113 func TestMaskedFindContiguous(t *testing.T) { 114 assert := assert.New(t) 115 T := NewDense(Int, []int{1, 100}) 116 T.ResetMask(false) 117 retSL := T.FlatNotMaskedContiguous() 118 assert.Equal(1, len(retSL)) 119 assert.Equal(rs{0, 100, 1}, retSL[0].(rs)) 120 121 // test ability to find unmasked regions 122 sliceList := make([]Slice, 0, 4) 123 sliceList = append(sliceList, makeRS(3, 9), makeRS(14, 27), makeRS(51, 72), makeRS(93, 100)) 124 T.ResetMask(true) 125 for i := range sliceList { 126 tt, _ := T.Slice(nil, sliceList[i]) 127 ts := tt.(*Dense) 128 ts.ResetMask(false) 129 } 130 retSL = T.FlatNotMaskedContiguous() 131 assert.Equal(sliceList, retSL) 132 133 retSL = T.ClumpUnmasked() 134 assert.Equal(sliceList, retSL) 135 136 // test ability to find masked regions 137 T.ResetMask(false) 138 for i := range sliceList { 139 tt, _ := T.Slice(nil, sliceList[i]) 140 ts := tt.(*Dense) 141 ts.ResetMask(true) 142 } 143 retSL = T.FlatMaskedContiguous() 144 assert.Equal(sliceList, retSL) 145 146 retSL = T.ClumpMasked() 147 assert.Equal(sliceList, retSL) 148 } 149 150 func TestMaskedFindEdges(t *testing.T) { 151 assert := assert.New(t) 152 T := NewDense(Int, []int{1, 100}) 153 154 sliceList := make([]Slice, 0, 4) 155 sliceList = append(sliceList, makeRS(0, 9), makeRS(14, 27), makeRS(51, 72), makeRS(93, 100)) 156 157 // test ability to find unmasked edges 158 T.ResetMask(false) 159 for i := range sliceList { 160 tt, _ := T.Slice(nil, sliceList[i]) 161 ts := tt.(*Dense) 162 ts.ResetMask(true) 163 } 164 start, end := T.FlatNotMaskedEdges() 165 assert.Equal(9, start) 166 assert.Equal(92, end) 167 168 // test ability to find masked edges 169 T.ResetMask(true) 170 for i := range sliceList { 171 tt, _ := T.Slice(nil, sliceList[i]) 172 ts := tt.(*Dense) 173 ts.ResetMask(false) 174 } 175 start, end = T.FlatMaskedEdges() 176 assert.Equal(9, start) 177 assert.Equal(92, end) 178 }