github.com/qiaogw/arrgo@v0.0.8/shape.go (about) 1 package arrgo 2 3 import "fmt" 4 5 //改变原始多维数组的形状,并返回改变后的多维数组的指引引用。 6 //不会创建新的数据副本。 7 //如果新的Shape的大小和原来多维数组的大小不同,则抛出异常。 8 func (a *Arrf) ReShape(Shape ...int) *Arrf { 9 if a.Length() != ProductIntSlice(Shape) { 10 fmt.Println("new Shape length does not equal to original array length.") 11 panic(SHAPE_ERROR) 12 } 13 14 internalShape := make([]int, len(Shape)) 15 copy(internalShape, Shape) 16 a.Shape = internalShape 17 18 a.Strides = make([]int, len(a.Shape)+1) 19 a.Strides[len(a.Shape)] = 1 20 for i := len(a.Shape) - 1; i >= 0; i-- { 21 a.Strides[i] = a.Strides[i+1] * a.Shape[i] 22 } 23 24 return a 25 } 26 27 //两个多维数组形状相同,则返回true, 否则返回false。 28 func (a *Arrf) SameShapeTo(b *Arrf) bool { 29 return SameIntSlice(a.Shape, b.Shape) 30 } 31 32 //将多个两维数组在垂直方向上组合起来,形成新的多维数组。 33 //不影响原多维数组。 34 func Vstack(arrs ...*Arrf) *Arrf { 35 for i := range arrs { 36 if arrs[i].Ndims() > 2 { 37 fmt.Println("in Vstack function, array dimension cannot bigger than 2.") 38 panic(SHAPE_ERROR) 39 } 40 } 41 if len(arrs) == 0 { 42 return nil 43 } 44 if len(arrs) == 1 { 45 return arrs[0].Copy() 46 } 47 48 return Concat(0, arrs...) 49 // 50 //var vlenSum int = 0 51 // 52 //var hlen int 53 //if arrs[0].Ndims() == 1 { 54 // hlen = arrs[0].Shape[0] 55 // vlenSum += 1 56 //} else { 57 // hlen = arrs[0].Shape[1] 58 // vlenSum += arrs[0].Shape[0] 59 //} 60 //for i := 1; i < len(arrs); i++ { 61 // var nextHen int 62 // if arrs[i].Ndims() == 1 { 63 // nextHen = arrs[i].Shape[0] 64 // vlenSum += 1 65 // } else { 66 // nextHen = arrs[i].Shape[1] 67 // vlenSum += arrs[i].Shape[0] 68 // } 69 // if hlen != nextHen { 70 // panic(SHAPE_ERROR) 71 // } 72 //} 73 // 74 //Data := make([]float64, vlenSum*hlen) 75 //var offset = 0 76 //for i := range arrs { 77 // copy(Data[offset:], arrs[i].Data) 78 // offset += len(arrs[i].Data) 79 //} 80 // 81 //return Array(Data, vlenSum, hlen) 82 } 83 84 //将多个两维数组在水平方向上组合起来,形成新的多维数组。 85 //不影响原多维数组。 86 func Hstack(arrs ...*Arrf) *Arrf { 87 for i := range arrs { 88 if arrs[i].Ndims() > 2 { 89 panic(SHAPE_ERROR) 90 } 91 } 92 if len(arrs) == 0 { 93 return nil 94 } 95 if len(arrs) == 1 { 96 return arrs[0].Copy() 97 } 98 99 return Concat(1, arrs...) 100 101 //var hlenSum int = 0 102 //var hBlockLens = make([]int, len(arrs)) 103 //var vlen int 104 //if arrs[0].Ndims() == 1 { 105 // vlen = 1 106 // hlenSum += arrs[0].Shape[0] 107 // hBlockLens[0] = arrs[0].Shape[0] 108 //} else { 109 // vlen = arrs[0].Shape[0] 110 // hlenSum += arrs[0].Shape[1] 111 // hBlockLens[0] = arrs[0].Shape[1] 112 //} 113 //for i := 1; i < len(arrs); i++ { 114 // var nextVlen int 115 // if arrs[i].Ndims() == 1 { 116 // nextVlen = 1 117 // hlenSum += arrs[i].Shape[0] 118 // hBlockLens[i] = arrs[i].Shape[0] 119 // } else { 120 // nextVlen = arrs[i].Shape[0] 121 // hlenSum += arrs[i].Shape[1] 122 // hBlockLens[i] = arrs[i].Shape[1] 123 // } 124 // if vlen != nextVlen { 125 // panic(SHAPE_ERROR) 126 // } 127 //} 128 // 129 //Data := make([]float64, hlenSum*vlen) 130 //for i := 0; i < vlen; i++ { 131 // var curPos = 0 132 // for j := 0; j < len(arrs); j++ { 133 // copy(Data[curPos+i*hlenSum:curPos+i*hlenSum+hBlockLens[j]], arrs[j].Data[i*hBlockLens[j]:(i+1)*hBlockLens[j]]) 134 // curPos += hBlockLens[j] 135 // } 136 //} 137 // 138 //return Array(Data, vlen, hlenSum) 139 } 140 141 //将多个多维数组在指定的轴上组合起来。 142 //一维数组默认扩充为2维,参考AtLeast2D函数。 143 func Concat(axis int, arrs ...*Arrf) *Arrf { 144 if len(arrs) == 0 { 145 return nil 146 } 147 if len(arrs) == 1 { 148 return arrs[0].Copy() 149 } 150 151 for i := range arrs { 152 AtLeast2D(arrs[i]) 153 } 154 155 if axis >= arrs[0].Ndims() { 156 fmt.Println("axis is bigger than dimensions num.") 157 panic(PARAMETER_ERROR) 158 } 159 160 var newShape = make([]int, arrs[0].Ndims()) 161 for index, firstL := range arrs[0].Shape { 162 if index == axis { 163 newShape[index] += firstL 164 for j := 1; j < len(arrs); j++ { 165 newShape[index] += arrs[j].Shape[index] 166 } 167 } else { 168 newShape[index] = firstL 169 for j := 1; j < len(arrs); j++ { 170 if firstL != arrs[j].Shape[index] { 171 panic(SHAPE_ERROR) 172 } 173 } 174 } 175 } 176 177 var times = 0 178 if axis == 0 { 179 times = 1 180 } else { 181 times = ProductIntSlice(arrs[0].Shape[0:axis]) 182 } 183 184 var Data = make([]float64, ProductIntSlice(newShape)) 185 186 var curPos = 0 187 for i := 0; i < times; i++ { 188 for j := 0; j < len(arrs); j++ { 189 var l = ProductIntSlice(arrs[j].Shape[axis:]) 190 copy(Data[curPos:curPos+l], arrs[j].Data[i*l:(i+1)*l]) 191 curPos += l 192 } 193 } 194 195 return Array(Data, newShape...) 196 } 197 198 //将一维数组扩充为二维 199 func AtLeast2D(a *Arrf) *Arrf { 200 if a == nil { 201 return nil 202 } else if a.Ndims() >= 2 { 203 return a 204 } else { 205 newShpae := make([]int, 2) 206 newShpae[0] = 1 207 newShpae[1] = a.Shape[0] 208 a.Shape = newShpae 209 return a 210 } 211 } 212 213 //将数组内部的元素铺平返回,创建新的数据副本。 214 func (a *Arrf) Flatten() *Arrf { 215 ra := make([]float64, len(a.Data)) 216 copy(ra, a.Data) 217 return Array(ra, len(a.Data)) 218 }