github.com/wzzhu/tensor@v0.9.24/dense_mask_inspection.go (about)

     1  package tensor
     2  
     3  type maskedReduceFn func(Tensor) interface{}
     4  
     5  // MaskedReduce applies a reduction function of type maskedReduceFn to mask, and returns
     6  // either an int, or another array
     7  func MaskedReduce(t *Dense, retType Dtype, fn maskedReduceFn, axis ...int) interface{} {
     8  	if len(axis) == 0 || t.IsVector() {
     9  		return fn(t)
    10  	}
    11  	ax := axis[0]
    12  	if ax >= t.Dims() {
    13  		return -1
    14  	}
    15  	// create object to be used for slicing
    16  	slices := make([]Slice, t.Dims())
    17  
    18  	// calculate shape of tensor to be returned
    19  	slices[ax] = makeRS(0, 0)
    20  	tt, _ := t.Slice(slices...)
    21  	ts := tt.(*Dense)
    22  	retVal := NewDense(retType, ts.shape) //retVal is array to be returned
    23  
    24  	it := NewIterator(retVal.Info())
    25  
    26  	// iterate through retVal
    27  	slices[ax] = makeRS(0, t.shape[ax])
    28  	for _, err := it.Next(); err == nil; _, err = it.Next() {
    29  		coord := it.Coord()
    30  		k := 0
    31  		for d := range slices {
    32  			if d != ax {
    33  				slices[d] = makeRS(coord[k], coord[k]+1)
    34  				k++
    35  			} else {
    36  				slices[d] = nil
    37  			}
    38  		}
    39  		tt, _ = t.Slice(slices...)
    40  		ts = tt.(*Dense)
    41  		retVal.SetAt(fn(ts), coord...)
    42  
    43  	}
    44  	return retVal
    45  }
    46  
    47  // MaskedAny returns True if any mask elements evaluate to True.
    48  // If object is not masked, returns false
    49  // !!! Not the same as numpy's, which looks at data elements and not at mask
    50  // Instead, equivalent to numpy ma.getmask(t).any(axis)
    51  func (t *Dense) MaskedAny(axis ...int) interface{} {
    52  	return MaskedReduce(t, Bool, doMaskAny, axis...)
    53  }
    54  
    55  // MaskedAll returns True if all mask elements evaluate to True.
    56  // If object is not masked, returns false
    57  // !!! Not the same as numpy's, which looks at data elements and not at mask
    58  // Instead, equivalent to numpy ma.getmask(t).all(axis)
    59  func (t *Dense) MaskedAll(axis ...int) interface{} {
    60  	return MaskedReduce(t, Bool, doMaskAll, axis...)
    61  }
    62  
    63  // MaskedCount counts the masked elements of the array (optionally along the given axis)
    64  // returns -1 if axis out of bounds
    65  func (t *Dense) MaskedCount(axis ...int) interface{} {
    66  	return MaskedReduce(t, Int, doMaskCt, axis...)
    67  }
    68  
    69  // NonMaskedCount counts the non-masked elements of the array (optionally along the given axis)
    70  // returns -1 if axis out of bounds
    71  // MaskedCount counts the masked elements of the array (optionally along the given axis)
    72  // returns -1 if axis out of bounds
    73  func (t *Dense) NonMaskedCount(axis ...int) interface{} {
    74  	return MaskedReduce(t, Int, doNonMaskCt, axis...)
    75  }
    76  
    77  func doMaskAll(T Tensor) interface{} {
    78  	switch t := T.(type) {
    79  	case *Dense:
    80  		if !t.IsMasked() {
    81  			return false
    82  		}
    83  		m := t.mask
    84  		if len(t.mask) == t.Size() {
    85  			for _, v := range m {
    86  				if !v {
    87  					return false
    88  				}
    89  			}
    90  		} else {
    91  			it := IteratorFromDense(t)
    92  			i, _, _ := it.NextValid()
    93  			if i != -1 {
    94  				return false
    95  			}
    96  		}
    97  		return true
    98  
    99  	default:
   100  		panic("Incompatible type")
   101  	}
   102  }
   103  
   104  func doMaskAny(T Tensor) interface{} {
   105  	switch t := T.(type) {
   106  	case *Dense:
   107  		if !t.IsMasked() {
   108  			return false
   109  		}
   110  		m := t.mask
   111  		if len(t.mask) == t.Size() {
   112  			for _, v := range m {
   113  				if v {
   114  					return true
   115  				}
   116  			}
   117  		} else {
   118  			it := IteratorFromDense(t)
   119  			i, _, _ := it.NextInvalid()
   120  			if i != -1 {
   121  				return true
   122  			}
   123  		}
   124  		return false
   125  
   126  	default:
   127  		panic("Incompatible type")
   128  	}
   129  }
   130  
   131  func doMaskCt(T Tensor) interface{} {
   132  	switch t := T.(type) {
   133  	case *Dense:
   134  		// non masked case
   135  		if !t.IsMasked() {
   136  			return 0
   137  		}
   138  
   139  		count := 0
   140  		m := t.mask
   141  		if len(t.mask) == t.Size() {
   142  			for _, v := range m {
   143  				if v {
   144  					count++
   145  				}
   146  			}
   147  		} else {
   148  			it := IteratorFromDense(t)
   149  			for _, _, err := it.NextInvalid(); err == nil; _, _, err = it.NextInvalid() {
   150  				count++
   151  			}
   152  		}
   153  		return count
   154  	default:
   155  		panic("Incompatible type")
   156  	}
   157  }
   158  
   159  func doNonMaskCt(T Tensor) interface{} {
   160  	switch t := T.(type) {
   161  	case *Dense:
   162  		if !t.IsMasked() {
   163  			return t.Size()
   164  		}
   165  		return t.Size() - doMaskCt(t).(int)
   166  	default:
   167  		panic("Incompatible type")
   168  	}
   169  }
   170  
   171  /* -----------
   172  ************ Finding masked data
   173  ----------*/
   174  
   175  // FlatNotMaskedContiguous is used to find contiguous unmasked data in a masked array.
   176  // Applies to a flattened version of the array.
   177  // Returns:A sorted sequence of slices (start index, end index).
   178  func (t *Dense) FlatNotMaskedContiguous() []Slice {
   179  	sliceList := make([]Slice, 0, 4)
   180  
   181  	it := IteratorFromDense(t)
   182  
   183  	for start, _, err := it.NextValid(); err == nil; start, _, err = it.NextValid() {
   184  		end, _, _ := it.NextInvalid()
   185  		if end == -1 {
   186  			end = t.Size()
   187  		}
   188  		sliceList = append(sliceList, makeRS(start, end))
   189  	}
   190  
   191  	return sliceList
   192  }
   193  
   194  // FlatMaskedContiguous is used to find contiguous masked data in a masked array.
   195  // Applies to a flattened version of the array.
   196  // Returns:A sorted sequence of slices (start index, end index).
   197  func (t *Dense) FlatMaskedContiguous() []Slice {
   198  	sliceList := make([]Slice, 0, 4)
   199  
   200  	it := IteratorFromDense(t)
   201  
   202  	for start, _, err := it.NextInvalid(); err == nil; start, _, err = it.NextInvalid() {
   203  		end, _, _ := it.NextValid()
   204  		if end == -1 {
   205  			end = t.Size()
   206  		}
   207  		sliceList = append(sliceList, makeRS(start, end))
   208  	}
   209  	return sliceList
   210  }
   211  
   212  // FlatNotMaskedEdges is used to find the indices of the first and last unmasked values
   213  // Applies to a flattened version of the array.
   214  // Returns: A pair of ints. -1 if all values are masked.
   215  func (t *Dense) FlatNotMaskedEdges() (int, int) {
   216  	if !t.IsMasked() {
   217  		return 0, t.Size() - 1
   218  	}
   219  
   220  	var start, end int
   221  	it := IteratorFromDense(t)
   222  
   223  	it.SetForward()
   224  	start, _, err := it.NextValid()
   225  	if err != nil {
   226  		return -1, -1
   227  	}
   228  
   229  	it.SetReverse()
   230  	end, _, _ = it.NextValid()
   231  
   232  	return start, end
   233  }
   234  
   235  // FlatMaskedEdges is used to find the indices of the first and last masked values
   236  // Applies to a flattened version of the array.
   237  // Returns: A pair of ints. -1 if all values are unmasked.
   238  func (t *Dense) FlatMaskedEdges() (int, int) {
   239  	if !t.IsMasked() {
   240  		return 0, t.Size() - 1
   241  	}
   242  	var start, end int
   243  	it := IteratorFromDense(t)
   244  
   245  	it.SetForward()
   246  	start, _, err := it.NextInvalid()
   247  	if err != nil {
   248  		return -1, -1
   249  	}
   250  
   251  	it.SetReverse()
   252  	end, _, _ = it.NextInvalid()
   253  
   254  	return start, end
   255  }
   256  
   257  // ClumpMasked returns a list of slices corresponding to the masked clumps of a 1-D array
   258  // Added to match numpy function names
   259  func (t *Dense) ClumpMasked() []Slice {
   260  	return t.FlatMaskedContiguous()
   261  }
   262  
   263  // ClumpUnmasked returns a list of slices corresponding to the unmasked clumps of a 1-D array
   264  // Added to match numpy function names
   265  func (t *Dense) ClumpUnmasked() []Slice {
   266  	return t.FlatNotMaskedContiguous()
   267  }