github.com/qiaogw/arrgo@v0.0.8/numeric_arrb.go (about)

     1  package arrgo
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  )
     7  
     8  type Arrb struct {
     9  	Shape   []int
    10  	Strides []int
    11  	Data    []bool
    12  }
    13  
    14  //通过[]bool,形状来创建多维数组。
    15  //输入参数1:Data []bool,以·C· 顺序存储,作为多维数组的输入数据,内部复制一份新的internalData,不改变data。
    16  //输入参数2:Shape ...int,指定多维数组的形状,多维,类似numpy中的Shape。
    17  //	如果某一个(仅支持一个维度)维度为负数,则根据len(Data)推断该维度的大小。
    18  //情况1:如果不指定Shape,而且data为nil,则创建一个空的*Arrb。
    19  //情况2:如果不指定Shape,而且data不为nil,则创建一个len(Data)大小的一维*Arrb。
    20  //情况3:如果指定Shape,而且data不为nil,则根据data大小创建多维数组,如果len(Data)不等于Shape,或者len(Data)不能整除Shape,抛出异常。
    21  //情况4:如果指定Shape,而且data为nil,则创建Shape大小的全为false的多维数组。
    22  func ArrayB(Data []bool, Shape ...int) *Arrb {
    23  	if len(Shape) == 0 && Data == nil {
    24  		return &Arrb{
    25  			Shape:   []int{0},
    26  			Strides: []int{0, 1},
    27  			Data:    []bool{},
    28  		}
    29  	}
    30  
    31  	if len(Shape) == 0 && Data != nil {
    32  		internalData := make([]bool, len(Data)) //复制data,不影响输入的值。
    33  		copy(internalData, Data)
    34  		return &Arrb{
    35  			Shape:   []int{len(Data)},
    36  			Strides: []int{len(Data), 1},
    37  			Data:    internalData,
    38  		}
    39  	}
    40  
    41  	if Data == nil {
    42  		for _, v := range Shape {
    43  			if v <= 0 {
    44  				fmt.Println("Shape should be positive when Data is nill")
    45  				panic(SHAPE_ERROR)
    46  			}
    47  		}
    48  		length := ProductIntSlice(Shape)
    49  		internalShape := make([]int, len(Shape))
    50  		copy(internalShape, Shape)
    51  		Strides := make([]int, len(Shape)+1)
    52  		Strides[len(Shape)] = 1
    53  		for i := len(Shape) - 1; i >= 0; i-- {
    54  			Strides[i] = Strides[i+1] * internalShape[i]
    55  		}
    56  
    57  		return &Arrb{
    58  			Shape:   internalShape,
    59  			Strides: Strides,
    60  			Data:    make([]bool, length),
    61  		}
    62  	}
    63  
    64  	var dataLength = len(Data)
    65  	negativeIndex := -1
    66  	internalShape := make([]int, len(Shape))
    67  	copy(internalShape, Shape)
    68  	for k, v := range Shape {
    69  		if v < 0 {
    70  			if negativeIndex < 0 {
    71  				negativeIndex = k
    72  				internalShape[k] = 1
    73  			} else {
    74  				fmt.Println("Shape can only have one negative demention.")
    75  				panic(SHAPE_ERROR)
    76  			}
    77  		}
    78  	}
    79  	ShapeLength := ProductIntSlice(internalShape)
    80  
    81  	if dataLength < ShapeLength {
    82  		fmt.Println("Data length is shorter than Shape length.")
    83  		panic(SHAPE_ERROR)
    84  	}
    85  	if (dataLength % ShapeLength) != 0 {
    86  		fmt.Println("Data length cannot divided by Shape length")
    87  		panic(SHAPE_ERROR)
    88  	}
    89  
    90  	if negativeIndex >= 0 {
    91  		internalShape[negativeIndex] = dataLength / ShapeLength
    92  	}
    93  
    94  	Strides := make([]int, len(internalShape)+1)
    95  	Strides[len(internalShape)] = 1
    96  	for i := len(internalShape) - 1; i >= 0; i-- {
    97  		Strides[i] = Strides[i+1] * internalShape[i]
    98  	}
    99  
   100  	internalData := make([]bool, len(Data))
   101  	copy(internalData, Data)
   102  
   103  	return &Arrb{
   104  		Shape:   internalShape,
   105  		Strides: Strides,
   106  		Data:    internalData,
   107  	}
   108  }
   109  
   110  //创建Shape形状的多维布尔数组,全部填充为fillvalue。
   111  //必须指定Shape,否则抛出异常。
   112  func FillB(fullValue bool, Shape ...int) *Arrb {
   113  	if len(Shape) == 0 {
   114  		fmt.Println("Shape is empty!")
   115  		panic(SHAPE_ERROR)
   116  	}
   117  	arr := ArrayB(nil, Shape...)
   118  	for i := range arr.Data {
   119  		arr.Data[i] = fullValue
   120  	}
   121  
   122  	return arr
   123  }
   124  
   125  //创建全为false,形状位Shape的多维布尔数组
   126  func EmptyB(Shape ...int) (a *Arrb) {
   127  	a = FillB(false, Shape...)
   128  	return
   129  }
   130  
   131  func (a *Arrb) String() (s string) {
   132  	switch {
   133  	case a == nil:
   134  		return "<nil>"
   135  	case a.Shape == nil || a.Strides == nil || a.Data == nil:
   136  		return "<nil>"
   137  	case a.Strides[0] == 0:
   138  		return "[]"
   139  	}
   140  
   141  	stride := a.Strides[len(a.Strides)-2]
   142  	for i, k := 0, 0; i+stride <= len(a.Data); i, k = i+stride, k+1 {
   143  
   144  		t := ""
   145  		for j, v := range a.Strides {
   146  			if i%v == 0 && j < len(a.Strides)-2 {
   147  				t += "["
   148  			}
   149  		}
   150  
   151  		s += strings.Repeat(" ", len(a.Shape)-len(t)-1) + t
   152  		s += fmt.Sprint(a.Data[i : i+stride])
   153  
   154  		t = ""
   155  		for j, v := range a.Strides {
   156  			if (i+stride)%v == 0 && j < len(a.Strides)-2 {
   157  				t += "]"
   158  			}
   159  		}
   160  
   161  		s += t + strings.Repeat(" ", len(a.Shape)-len(t)-1)
   162  		if i+stride != len(a.Data) {
   163  			s += "\n"
   164  			if len(t) > 0 {
   165  				s += "\n"
   166  			}
   167  		}
   168  	}
   169  	return
   170  }
   171  
   172  //如果多维布尔数组元素都为真,返回true,否则返回false。
   173  func (ab *Arrb) AllTrues() bool {
   174  	if len(ab.Data) == 0 {
   175  		return false
   176  	}
   177  	for _, v := range ab.Data {
   178  		if v == false {
   179  			return false
   180  		}
   181  	}
   182  	return true
   183  }
   184  
   185  //如果多维布尔数组元素都为假,返回false,否则返回true。
   186  func (ab *Arrb) AnyTrue() bool {
   187  	if len(ab.Data) == 0 {
   188  		return false
   189  	}
   190  	for _, v := range ab.Data {
   191  		if v == true {
   192  			return true
   193  		}
   194  	}
   195  	return false
   196  }
   197  
   198  //返回多维数组中真值的个数。
   199  func (a *Arrb) Sum() int {
   200  	sum := 0
   201  	for _, v := range a.Data {
   202  		if v {
   203  			sum++
   204  		}
   205  	}
   206  	return sum
   207  }