github.com/wzzhu/tensor@v0.9.24/dense_mask_inspection_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/stretchr/testify/assert"
     7  )
     8  
     9  func TestMaskedInspection(t *testing.T) {
    10  	assert := assert.New(t)
    11  
    12  	var retT *Dense
    13  
    14  	//vector case
    15  	T := New(Of(Bool), WithShape(1, 12))
    16  	T.ResetMask(false)
    17  	assert.False(T.MaskedAny().(bool))
    18  	for i := 0; i < 12; i += 2 {
    19  		T.mask[i] = true
    20  	}
    21  	assert.True(T.MaskedAny().(bool))
    22  	assert.True(T.MaskedAny(0).(bool))
    23  	assert.False(T.MaskedAll().(bool))
    24  	assert.False(T.MaskedAll(0).(bool))
    25  	assert.Equal(6, T.MaskedCount())
    26  	assert.Equal(6, T.MaskedCount(0))
    27  	assert.Equal(6, T.NonMaskedCount())
    28  	assert.Equal(6, T.NonMaskedCount(0))
    29  
    30  	//contiguous mask case
    31  	/*equivalent python code
    32  	  ---------
    33  	  import numpy.ma as ma
    34  	  a = ma.arange(12).reshape((2, 3, 2))
    35  	  a[0,0,0]=ma.masked
    36  	  a[0,2,0]=ma.masked
    37  	  print(ma.getmask(a).all())
    38  	  print(ma.getmask(a).any())
    39  	  print(ma.count_masked(a))
    40  	  print(ma.count(a))
    41  	  print(ma.getmask(a).all(0))
    42  	  print(ma.getmask(a).any(0))
    43  	  print(ma.count_masked(a,0))
    44  	  print(ma.count(a,0))
    45  	  print(ma.getmask(a).all(1))
    46  	  print(ma.getmask(a).any(1))
    47  	  print(ma.count_masked(a,1))
    48  	  print(ma.count(a,1))
    49  	  print(ma.getmask(a).all(2))
    50  	  print(ma.getmask(a).any(2))
    51  	  print(ma.count_masked(a,2))
    52  	  print(ma.count(a,2))
    53  	  -----------
    54  	*/
    55  	T = New(Of(Bool), WithShape(2, 3, 2))
    56  	T.ResetMask(false)
    57  
    58  	for i := 0; i < 2; i += 2 {
    59  		for j := 0; j < 3; j += 2 {
    60  			for k := 0; k < 2; k += 2 {
    61  				a, b, c := T.strides[0], T.strides[1], T.strides[2]
    62  				T.mask[i*a+b*j+c*k] = true
    63  			}
    64  		}
    65  	}
    66  
    67  	assert.Equal([]bool{true, false, false, false, true, false,
    68  		false, false, false, false, false, false}, T.mask)
    69  
    70  	assert.Equal(false, T.MaskedAll())
    71  	assert.Equal(true, T.MaskedAny())
    72  	assert.Equal(2, T.MaskedCount())
    73  	assert.Equal(10, T.NonMaskedCount())
    74  
    75  	retT = T.MaskedAll(0).(*Dense)
    76  	assert.Equal([]int{3, 2}, []int(retT.shape))
    77  	assert.Equal([]bool{false, false, false, false, false, false}, retT.Bools())
    78  	retT = T.MaskedAny(0).(*Dense)
    79  	assert.Equal([]int{3, 2}, []int(retT.shape))
    80  	assert.Equal([]bool{true, false, false, false, true, false}, retT.Bools())
    81  	retT = T.MaskedCount(0).(*Dense)
    82  	assert.Equal([]int{3, 2}, []int(retT.shape))
    83  	assert.Equal([]int{1, 0, 0, 0, 1, 0}, retT.Ints())
    84  	retT = T.NonMaskedCount(0).(*Dense)
    85  	assert.Equal([]int{1, 2, 2, 2, 1, 2}, retT.Ints())
    86  
    87  	retT = T.MaskedAll(1).(*Dense)
    88  	assert.Equal([]int{2, 2}, []int(retT.shape))
    89  	assert.Equal([]bool{false, false, false, false}, retT.Bools())
    90  	retT = T.MaskedAny(1).(*Dense)
    91  	assert.Equal([]int{2, 2}, []int(retT.shape))
    92  	assert.Equal([]bool{true, false, false, false}, retT.Bools())
    93  	retT = T.MaskedCount(1).(*Dense)
    94  	assert.Equal([]int{2, 2}, []int(retT.shape))
    95  	assert.Equal([]int{2, 0, 0, 0}, retT.Ints())
    96  	retT = T.NonMaskedCount(1).(*Dense)
    97  	assert.Equal([]int{1, 3, 3, 3}, retT.Ints())
    98  
    99  	retT = T.MaskedAll(2).(*Dense)
   100  	assert.Equal([]int{2, 3}, []int(retT.shape))
   101  	assert.Equal([]bool{false, false, false, false, false, false}, retT.Bools())
   102  	retT = T.MaskedAny(2).(*Dense)
   103  	assert.Equal([]int{2, 3}, []int(retT.shape))
   104  	assert.Equal([]bool{true, false, true, false, false, false}, retT.Bools())
   105  	retT = T.MaskedCount(2).(*Dense)
   106  	assert.Equal([]int{2, 3}, []int(retT.shape))
   107  	assert.Equal([]int{1, 0, 1, 0, 0, 0}, retT.Ints())
   108  	retT = T.NonMaskedCount(2).(*Dense)
   109  	assert.Equal([]int{1, 2, 1, 2, 2, 2}, retT.Ints())
   110  
   111  }
   112  
   113  func TestMaskedFindContiguous(t *testing.T) {
   114  	assert := assert.New(t)
   115  	T := NewDense(Int, []int{1, 100})
   116  	T.ResetMask(false)
   117  	retSL := T.FlatNotMaskedContiguous()
   118  	assert.Equal(1, len(retSL))
   119  	assert.Equal(rs{0, 100, 1}, retSL[0].(rs))
   120  
   121  	// test ability to find unmasked regions
   122  	sliceList := make([]Slice, 0, 4)
   123  	sliceList = append(sliceList, makeRS(3, 9), makeRS(14, 27), makeRS(51, 72), makeRS(93, 100))
   124  	T.ResetMask(true)
   125  	for i := range sliceList {
   126  		tt, _ := T.Slice(nil, sliceList[i])
   127  		ts := tt.(*Dense)
   128  		ts.ResetMask(false)
   129  	}
   130  	retSL = T.FlatNotMaskedContiguous()
   131  	assert.Equal(sliceList, retSL)
   132  
   133  	retSL = T.ClumpUnmasked()
   134  	assert.Equal(sliceList, retSL)
   135  
   136  	// test ability to find masked regions
   137  	T.ResetMask(false)
   138  	for i := range sliceList {
   139  		tt, _ := T.Slice(nil, sliceList[i])
   140  		ts := tt.(*Dense)
   141  		ts.ResetMask(true)
   142  	}
   143  	retSL = T.FlatMaskedContiguous()
   144  	assert.Equal(sliceList, retSL)
   145  
   146  	retSL = T.ClumpMasked()
   147  	assert.Equal(sliceList, retSL)
   148  }
   149  
   150  func TestMaskedFindEdges(t *testing.T) {
   151  	assert := assert.New(t)
   152  	T := NewDense(Int, []int{1, 100})
   153  
   154  	sliceList := make([]Slice, 0, 4)
   155  	sliceList = append(sliceList, makeRS(0, 9), makeRS(14, 27), makeRS(51, 72), makeRS(93, 100))
   156  
   157  	// test ability to find unmasked edges
   158  	T.ResetMask(false)
   159  	for i := range sliceList {
   160  		tt, _ := T.Slice(nil, sliceList[i])
   161  		ts := tt.(*Dense)
   162  		ts.ResetMask(true)
   163  	}
   164  	start, end := T.FlatNotMaskedEdges()
   165  	assert.Equal(9, start)
   166  	assert.Equal(92, end)
   167  
   168  	// test ability to find masked edges
   169  	T.ResetMask(true)
   170  	for i := range sliceList {
   171  		tt, _ := T.Slice(nil, sliceList[i])
   172  		ts := tt.(*Dense)
   173  		ts.ResetMask(false)
   174  	}
   175  	start, end = T.FlatMaskedEdges()
   176  	assert.Equal(9, start)
   177  	assert.Equal(92, end)
   178  }