github.com/qiaogw/arrgo@v0.0.8/numeric_arrf.go (about) 1 package arrgo 2 3 import ( 4 "fmt" 5 "math" 6 "strings" 7 ) 8 9 type Arrf struct { 10 Shape []int 11 Strides []int 12 Data []float64 13 } 14 15 //通过[]float64,形状来创建多维数组。 16 //输入参数1:data []float64,以·C· 顺序存储,作为多维数组的输入数据,内部复制一份新的internalData,不改变data。 17 //输入参数2:Shape ...int,指定多维数组的形状,多维,类似numpy中的Shape。 18 // 如果某一个(仅支持一个维度)维度为负数,则根据len(Data)推断该维度的大小。 19 //情况1:如果不指定Shape,而且data为nil,则创建一个空的*Arrf。 20 //情况2:如果不指定Shape,而且data不为nil,则创建一个len(Data)大小的一维*Arrf。 21 //情况3:如果指定Shape,而且data不为nil,则根据data大小创建多维数组,如果len(Data)不等于Shape,或者len(Data)不能整除Shape,抛出异常。 22 //情况4:如果指定Shape,而且data为nil,则创建Shape大小的全为0.0的多维数组。 23 func Array(Data []float64, Shape ...int) *Arrf { 24 if len(Shape) == 0 && Data == nil { 25 return &Arrf{ 26 Shape: []int{0}, 27 Strides: []int{0, 1}, 28 Data: []float64{}, 29 } 30 } 31 32 if len(Shape) == 0 && Data != nil { 33 internalData := make([]float64, len(Data)) //复制data,不影响输入的值。 34 copy(internalData, Data) 35 return &Arrf{ 36 Shape: []int{len(Data)}, 37 Strides: []int{len(Data), 1}, 38 Data: internalData, 39 } 40 } 41 42 if Data == nil { 43 for _, v := range Shape { 44 if v <= 0 { 45 fmt.Println("Shape should be positive when Data is nill") 46 panic(SHAPE_ERROR) 47 } 48 } 49 length := ProductIntSlice(Shape) 50 internalShape := make([]int, len(Shape)) 51 copy(internalShape, Shape) 52 Strides := make([]int, len(Shape)+1) 53 Strides[len(Shape)] = 1 54 for i := len(Shape) - 1; i >= 0; i-- { 55 Strides[i] = Strides[i+1] * internalShape[i] 56 } 57 58 return &Arrf{ 59 Shape: internalShape, 60 Strides: Strides, 61 Data: make([]float64, length), 62 } 63 } 64 65 var dataLength = len(Data) 66 negativeIndex := -1 67 internalShape := make([]int, len(Shape)) 68 copy(internalShape, Shape) 69 for k, v := range Shape { 70 if v < 0 { 71 if negativeIndex < 0 { 72 negativeIndex = k 73 internalShape[k] = 1 74 } else { 75 fmt.Println("Shape can only have one negative demention.") 76 panic(SHAPE_ERROR) 77 } 78 } 79 } 80 ShapeLength := ProductIntSlice(internalShape) 81 82 if dataLength < ShapeLength { 83 fmt.Println("Data length is shorter than Shape length.") 84 panic(SHAPE_ERROR) 85 } 86 if (dataLength % ShapeLength) != 0 { 87 fmt.Println("Data length cannot divided by Shape length") 88 panic(SHAPE_ERROR) 89 } 90 91 if negativeIndex >= 0 { 92 internalShape[negativeIndex] = dataLength / ShapeLength 93 } 94 95 Strides := make([]int, len(internalShape)+1) 96 Strides[len(internalShape)] = 1 97 for i := len(internalShape) - 1; i >= 0; i-- { 98 Strides[i] = Strides[i+1] * internalShape[i] 99 } 100 101 internalData := make([]float64, len(Data)) 102 copy(internalData, Data) 103 104 return &Arrf{ 105 Shape: internalShape, 106 Strides: Strides, 107 Data: internalData, 108 } 109 } 110 111 // 通过指定起始、终止和步进量来创建一维Array。 112 // 输入参数: vals,可以有三种情况,详见下面描述。 113 // 情况1:Arange(stop): 以0开始的序列,创建Array [0, 0+(-)1, ..., stop),不包括stop,stop符号决定升降序。 114 // 情况2:Arange(start, stop):创建Array [start, start +(-)1, ..., stop),如果start小于start则递增,否则递减。 115 // 情况3:Arange(start, stop, step):创建Array [start, start + step, ..., stop),step符号决定升降序。 116 // 输入参数多于三个的都会被忽略。 117 // 输出序列为“整型数”序列。 118 func Arange(vals ...int) *Arrf { 119 var start, stop, step int = 0, 0, 1 120 121 switch len(vals) { 122 case 0: 123 fmt.Println("range function should have range") 124 panic(PARAMETER_ERROR) 125 case 1: 126 if vals[0] <= 0 { 127 step = -1 128 stop = vals[0] + 1 129 } else { 130 stop = vals[0] - 1 131 } 132 case 2: 133 if vals[1] < vals[0] { 134 step = -1 135 stop = vals[1] + 1 136 } else { 137 stop = vals[1] - 1 138 } 139 start = vals[0] 140 default: 141 if vals[1] < vals[0] { 142 if vals[2] >= 0 { 143 fmt.Println("increment should be negative.") 144 panic(PARAMETER_ERROR) 145 } 146 stop = vals[1] + 1 147 } else { 148 if vals[2] <= 0 { 149 fmt.Println("increment should be positive.") 150 panic(PARAMETER_ERROR) 151 } 152 stop = vals[1] - 1 153 } 154 start, step = vals[0], vals[2] 155 } 156 157 a := Array(nil, int(math.Abs(float64((stop-start)/step)))+1) 158 for i, v := 0, start; i < len(a.Data); i, v = i+1, v+step { 159 a.Data[i] = float64(v) 160 } 161 return a 162 } 163 164 //判断Arrf是否为空数组。 165 //如果内部的data长度为0或者为nil,返回true,否则位false。 166 func (a *Arrf) IsEmpty() bool { 167 return len(a.Data) == 0 || a.Data == nil 168 } 169 170 //创建Shape形状的多维数组,全部填充为fillvalue。 171 //必须指定Shape,否则抛出异常。 172 func Fill(fillValue float64, Shape ...int) *Arrf { 173 if len(Shape) == 0 { 174 fmt.Println("Shape is empty!") 175 panic(SHAPE_ERROR) 176 } 177 arr := Array(nil, Shape...) 178 for i := range arr.Data { 179 arr.Data[i] = fillValue 180 } 181 182 return arr 183 } 184 185 //根据Shape创建全为1.0的多维数组。 186 func Ones(Shape ...int) *Arrf { 187 return Fill(1, Shape...) 188 } 189 190 //根据输入的多维数组的形状创建全1的多维数组。 191 func OnesLike(a *Arrf) *Arrf { 192 return Ones(a.Shape...) 193 } 194 195 //根据Shape创建全为0的多维数组。 196 func Zeros(Shape ...int) *Arrf { 197 return Fill(0, Shape...) 198 } 199 200 //根据输入的多维数组的形状创建全0的多维数组。 201 func ZerosLike(a *Arrf) *Arrf { 202 return Zeros(a.Shape...) 203 } 204 205 // String Satisfies the Stringer interface for fmt package 206 func (a *Arrf) String() (s string) { 207 switch { 208 case a == nil: 209 return "<nil>" 210 case a.Data == nil || a.Shape == nil || a.Strides == nil: 211 return "<nil>" 212 case a.Strides[0] == 0: 213 return "[]" 214 case len(a.Shape) == 1: 215 return fmt.Sprint(a.Data) 216 //strs := make([]string, len(a.Data)) 217 //for i := range a.Data { 218 // strs[i] = string(strconv.FormatFloat(a.Data[i], 'f', -1, 64)) 219 // 220 //} 221 //return strings.Join(strs, ", ") 222 } 223 224 stride := a.Shape[len(a.Shape)-1] 225 226 for i, k := 0, 0; i+stride <= len(a.Data); i, k = i+stride, k+1 { 227 228 t := "" 229 for j, v := range a.Strides { 230 if i%v == 0 && j < len(a.Strides)-2 { 231 t += "[" 232 } 233 } 234 235 s += strings.Repeat(" ", len(a.Shape)-len(t)-1) + t 236 s += fmt.Sprint(a.Data[i : i+stride]) 237 238 t = "" 239 for j, v := range a.Strides { 240 if (i+stride)%v == 0 && j < len(a.Strides)-2 { 241 t += "]" 242 } 243 } 244 245 s += t + strings.Repeat(" ", len(a.Shape)-len(t)-1) 246 if i+stride != len(a.Data) { 247 s += "\n" 248 if len(t) > 0 { 249 s += "\n" 250 } 251 } 252 } 253 return 254 } 255 256 //获取index指定位置的元素。 257 //index必须在Shape规定的范围内,否则会抛出异常。 258 //index的长度必须小于等于维度的个数,否则会抛出异常。 259 //如果index的个数小于维度个数,则会取后面的第一个值。 260 func (a *Arrf) At(index ...int) float64 { 261 idx := a.valIndex(index...) 262 return a.Data[idx] 263 } 264 265 //详见At函数。 266 func (a *Arrf) Get(index ...int) float64 { 267 return a.At(index...) 268 } 269 270 //At函数的内部实现,返回index指定的元素在切片中的位置,如果有错误,则返回error。 271 func (a *Arrf) valIndex(index ...int) int { 272 idx := 0 273 if len(index) > len(a.Shape) { 274 fmt.Println("index len should not longer than Shape.") 275 panic(INDEX_ERROR) 276 } 277 for i, v := range index { 278 if v >= a.Shape[i] || v < 0 { 279 fmt.Println("index value out of range.") 280 panic(INDEX_ERROR) 281 } 282 idx += v * a.Strides[i+1] 283 } 284 return idx 285 } 286 287 //获取多维数组元素的个数。 288 func (a *Arrf) Length() int { 289 return len(a.Data) 290 } 291 292 //创建一个n X n 的2维单位矩阵(数组)。 293 func Eye(n int) *Arrf { 294 arr := Zeros(n, n) 295 for i := 0; i < n; i++ { 296 arr.Set(1, i, i) 297 } 298 return arr 299 } 300 301 //Eye的另一种称呼,详见Eye函数。 302 func Identity(n int) *Arrf { 303 return Eye(n) 304 } 305 306 //指定位置的元素被新值替换。 307 //如果index的超出范围则会抛出异常。 308 //返回当前数组的指引,方便后续的连续操作。 309 func (a *Arrf) Set(value float64, index ...int) *Arrf { 310 idx := a.valIndex(index...) 311 312 a.Data[idx] = value 313 return a 314 } 315 316 //返回多维数组的内部数组元素。 317 //对返回值的操作会影响多维数组,一定谨慎操作。 318 func (a *Arrf) Values() []float64 { 319 return a.Data 320 } 321 322 //根据[start, stop]指定的区间,创建包含num个元素的一维数组。 323 func Linspace(start, stop float64, num int) *Arrf { 324 var Data = make([]float64, num) 325 var startF, stopF = start, stop 326 if startF <= stopF { 327 var step = (stopF - startF) / (float64(num - 1.0)) 328 for i := range Data { 329 Data[i] = startF + float64(i)*step 330 } 331 return Array(Data, num) 332 } else { 333 var step = (startF - stopF) / (float64(num - 1.0)) 334 for i := range Data { 335 Data[i] = startF - float64(i)*step 336 } 337 return Array(Data, num) 338 } 339 } 340 341 //复制一个形状一样,但是数据被深度复制的多维数组。 342 func (a *Arrf) Copy() *Arrf { 343 b := ZerosLike(a) 344 copy(b.Data, a.Data) 345 return b 346 } 347 348 //返回多维数组的维度数目。 349 func (a *Arrf) Ndims() int { 350 return len(a.Shape) 351 } 352 353 //Returns ta view of the array with axes transposed. 354 //根据指定的轴顺序,生成一个新的调整后的多维数组。 355 //如果是1维数组,则没有任何变化。 356 //如果是2维数组,则行列交换。 357 //如果是n维数组,则根据指定的顺序调整,生成新的多维数组。 358 //输入参数1:如果不指定输入参数,则轴顺序全部反序;如果指定参数则个数必须和轴个数相同,否则抛出异常。 359 //fixme 这里的实现效率不高,后面有时间需要提升一下。 360 func (a *Arrf) Transpose(axes ...int) *Arrf { 361 var n = a.Ndims() 362 var permutation []int 363 var nShape []int 364 365 switch len(axes) { 366 case 0: 367 permutation = make([]int, n) 368 nShape = make([]int, n) 369 for i := range permutation { 370 permutation[i] = n - i 371 } 372 for i := 0; i < n; i++ { 373 permutation[i] = n - 1 - i 374 nShape[i] = a.Shape[permutation[i]] 375 } 376 377 case n: 378 permutation = axes 379 nShape = make([]int, n) 380 for i := range nShape { 381 nShape[i] = a.Shape[permutation[i]] 382 } 383 384 default: 385 fmt.Println("axis number wrong.") 386 panic(DIMENTION_ERROR) 387 } 388 389 var totalIndexSize = 1 390 for i := range a.Shape { 391 totalIndexSize *= a.Shape[i] 392 } 393 394 var indexsSrc = make([][]int, totalIndexSize) 395 var indexsDst = make([][]int, totalIndexSize) 396 397 var b = Zeros(nShape...) 398 var index = make([]int, n) 399 for i := 0; i < totalIndexSize; i++ { 400 tindexSrc := make([]int, n) 401 copy(tindexSrc, index) 402 indexsSrc[i] = tindexSrc 403 var tindexDst = make([]int, n) 404 for j := range tindexDst { 405 tindexDst[j] = index[permutation[j]] 406 } 407 indexsDst[i] = tindexDst 408 409 var j = n - 1 410 index[j]++ 411 for { 412 if j > 0 && index[j] >= a.Shape[j] { 413 index[j-1]++ 414 index[j] = 0 415 j-- 416 } else { 417 break 418 } 419 } 420 } 421 for i := range indexsSrc { 422 b.Set(a.Get(indexsSrc[i]...), indexsDst[i]...) 423 } 424 return b 425 }