github.com/qiaogw/arrgo@v0.0.8/numeric_arrb.go (about) 1 package arrgo 2 3 import ( 4 "fmt" 5 "strings" 6 ) 7 8 type Arrb struct { 9 Shape []int 10 Strides []int 11 Data []bool 12 } 13 14 //通过[]bool,形状来创建多维数组。 15 //输入参数1:Data []bool,以·C· 顺序存储,作为多维数组的输入数据,内部复制一份新的internalData,不改变data。 16 //输入参数2:Shape ...int,指定多维数组的形状,多维,类似numpy中的Shape。 17 // 如果某一个(仅支持一个维度)维度为负数,则根据len(Data)推断该维度的大小。 18 //情况1:如果不指定Shape,而且data为nil,则创建一个空的*Arrb。 19 //情况2:如果不指定Shape,而且data不为nil,则创建一个len(Data)大小的一维*Arrb。 20 //情况3:如果指定Shape,而且data不为nil,则根据data大小创建多维数组,如果len(Data)不等于Shape,或者len(Data)不能整除Shape,抛出异常。 21 //情况4:如果指定Shape,而且data为nil,则创建Shape大小的全为false的多维数组。 22 func ArrayB(Data []bool, Shape ...int) *Arrb { 23 if len(Shape) == 0 && Data == nil { 24 return &Arrb{ 25 Shape: []int{0}, 26 Strides: []int{0, 1}, 27 Data: []bool{}, 28 } 29 } 30 31 if len(Shape) == 0 && Data != nil { 32 internalData := make([]bool, len(Data)) //复制data,不影响输入的值。 33 copy(internalData, Data) 34 return &Arrb{ 35 Shape: []int{len(Data)}, 36 Strides: []int{len(Data), 1}, 37 Data: internalData, 38 } 39 } 40 41 if Data == nil { 42 for _, v := range Shape { 43 if v <= 0 { 44 fmt.Println("Shape should be positive when Data is nill") 45 panic(SHAPE_ERROR) 46 } 47 } 48 length := ProductIntSlice(Shape) 49 internalShape := make([]int, len(Shape)) 50 copy(internalShape, Shape) 51 Strides := make([]int, len(Shape)+1) 52 Strides[len(Shape)] = 1 53 for i := len(Shape) - 1; i >= 0; i-- { 54 Strides[i] = Strides[i+1] * internalShape[i] 55 } 56 57 return &Arrb{ 58 Shape: internalShape, 59 Strides: Strides, 60 Data: make([]bool, length), 61 } 62 } 63 64 var dataLength = len(Data) 65 negativeIndex := -1 66 internalShape := make([]int, len(Shape)) 67 copy(internalShape, Shape) 68 for k, v := range Shape { 69 if v < 0 { 70 if negativeIndex < 0 { 71 negativeIndex = k 72 internalShape[k] = 1 73 } else { 74 fmt.Println("Shape can only have one negative demention.") 75 panic(SHAPE_ERROR) 76 } 77 } 78 } 79 ShapeLength := ProductIntSlice(internalShape) 80 81 if dataLength < ShapeLength { 82 fmt.Println("Data length is shorter than Shape length.") 83 panic(SHAPE_ERROR) 84 } 85 if (dataLength % ShapeLength) != 0 { 86 fmt.Println("Data length cannot divided by Shape length") 87 panic(SHAPE_ERROR) 88 } 89 90 if negativeIndex >= 0 { 91 internalShape[negativeIndex] = dataLength / ShapeLength 92 } 93 94 Strides := make([]int, len(internalShape)+1) 95 Strides[len(internalShape)] = 1 96 for i := len(internalShape) - 1; i >= 0; i-- { 97 Strides[i] = Strides[i+1] * internalShape[i] 98 } 99 100 internalData := make([]bool, len(Data)) 101 copy(internalData, Data) 102 103 return &Arrb{ 104 Shape: internalShape, 105 Strides: Strides, 106 Data: internalData, 107 } 108 } 109 110 //创建Shape形状的多维布尔数组,全部填充为fillvalue。 111 //必须指定Shape,否则抛出异常。 112 func FillB(fullValue bool, Shape ...int) *Arrb { 113 if len(Shape) == 0 { 114 fmt.Println("Shape is empty!") 115 panic(SHAPE_ERROR) 116 } 117 arr := ArrayB(nil, Shape...) 118 for i := range arr.Data { 119 arr.Data[i] = fullValue 120 } 121 122 return arr 123 } 124 125 //创建全为false,形状位Shape的多维布尔数组 126 func EmptyB(Shape ...int) (a *Arrb) { 127 a = FillB(false, Shape...) 128 return 129 } 130 131 func (a *Arrb) String() (s string) { 132 switch { 133 case a == nil: 134 return "<nil>" 135 case a.Shape == nil || a.Strides == nil || a.Data == nil: 136 return "<nil>" 137 case a.Strides[0] == 0: 138 return "[]" 139 } 140 141 stride := a.Strides[len(a.Strides)-2] 142 for i, k := 0, 0; i+stride <= len(a.Data); i, k = i+stride, k+1 { 143 144 t := "" 145 for j, v := range a.Strides { 146 if i%v == 0 && j < len(a.Strides)-2 { 147 t += "[" 148 } 149 } 150 151 s += strings.Repeat(" ", len(a.Shape)-len(t)-1) + t 152 s += fmt.Sprint(a.Data[i : i+stride]) 153 154 t = "" 155 for j, v := range a.Strides { 156 if (i+stride)%v == 0 && j < len(a.Strides)-2 { 157 t += "]" 158 } 159 } 160 161 s += t + strings.Repeat(" ", len(a.Shape)-len(t)-1) 162 if i+stride != len(a.Data) { 163 s += "\n" 164 if len(t) > 0 { 165 s += "\n" 166 } 167 } 168 } 169 return 170 } 171 172 //如果多维布尔数组元素都为真,返回true,否则返回false。 173 func (ab *Arrb) AllTrues() bool { 174 if len(ab.Data) == 0 { 175 return false 176 } 177 for _, v := range ab.Data { 178 if v == false { 179 return false 180 } 181 } 182 return true 183 } 184 185 //如果多维布尔数组元素都为假,返回false,否则返回true。 186 func (ab *Arrb) AnyTrue() bool { 187 if len(ab.Data) == 0 { 188 return false 189 } 190 for _, v := range ab.Data { 191 if v == true { 192 return true 193 } 194 } 195 return false 196 } 197 198 //返回多维数组中真值的个数。 199 func (a *Arrb) Sum() int { 200 sum := 0 201 for _, v := range a.Data { 202 if v { 203 sum++ 204 } 205 } 206 return sum 207 }