github.com/qiaogw/arrgo@v0.0.8/shape.go (about)

     1  package arrgo
     2  
     3  import "fmt"
     4  
     5  //改变原始多维数组的形状,并返回改变后的多维数组的指引引用。
     6  //不会创建新的数据副本。
     7  //如果新的Shape的大小和原来多维数组的大小不同,则抛出异常。
     8  func (a *Arrf) ReShape(Shape ...int) *Arrf {
     9  	if a.Length() != ProductIntSlice(Shape) {
    10  		fmt.Println("new Shape length does not equal to original array length.")
    11  		panic(SHAPE_ERROR)
    12  	}
    13  
    14  	internalShape := make([]int, len(Shape))
    15  	copy(internalShape, Shape)
    16  	a.Shape = internalShape
    17  
    18  	a.Strides = make([]int, len(a.Shape)+1)
    19  	a.Strides[len(a.Shape)] = 1
    20  	for i := len(a.Shape) - 1; i >= 0; i-- {
    21  		a.Strides[i] = a.Strides[i+1] * a.Shape[i]
    22  	}
    23  
    24  	return a
    25  }
    26  
    27  //两个多维数组形状相同,则返回true, 否则返回false。
    28  func (a *Arrf) SameShapeTo(b *Arrf) bool {
    29  	return SameIntSlice(a.Shape, b.Shape)
    30  }
    31  
    32  //将多个两维数组在垂直方向上组合起来,形成新的多维数组。
    33  //不影响原多维数组。
    34  func Vstack(arrs ...*Arrf) *Arrf {
    35  	for i := range arrs {
    36  		if arrs[i].Ndims() > 2 {
    37  			fmt.Println("in Vstack function, array dimension cannot bigger than 2.")
    38  			panic(SHAPE_ERROR)
    39  		}
    40  	}
    41  	if len(arrs) == 0 {
    42  		return nil
    43  	}
    44  	if len(arrs) == 1 {
    45  		return arrs[0].Copy()
    46  	}
    47  
    48  	return Concat(0, arrs...)
    49  	//
    50  	//var vlenSum int = 0
    51  	//
    52  	//var hlen int
    53  	//if arrs[0].Ndims() == 1 {
    54  	//	hlen = arrs[0].Shape[0]
    55  	//	vlenSum += 1
    56  	//} else {
    57  	//	hlen = arrs[0].Shape[1]
    58  	//	vlenSum += arrs[0].Shape[0]
    59  	//}
    60  	//for i := 1; i < len(arrs); i++ {
    61  	//	var nextHen int
    62  	//	if arrs[i].Ndims() == 1 {
    63  	//		nextHen = arrs[i].Shape[0]
    64  	//		vlenSum += 1
    65  	//	} else {
    66  	//		nextHen = arrs[i].Shape[1]
    67  	//		vlenSum += arrs[i].Shape[0]
    68  	//	}
    69  	//	if hlen != nextHen {
    70  	//		panic(SHAPE_ERROR)
    71  	//	}
    72  	//}
    73  	//
    74  	//Data := make([]float64, vlenSum*hlen)
    75  	//var offset = 0
    76  	//for i := range arrs {
    77  	//	copy(Data[offset:], arrs[i].Data)
    78  	//	offset += len(arrs[i].Data)
    79  	//}
    80  	//
    81  	//return Array(Data, vlenSum, hlen)
    82  }
    83  
    84  //将多个两维数组在水平方向上组合起来,形成新的多维数组。
    85  //不影响原多维数组。
    86  func Hstack(arrs ...*Arrf) *Arrf {
    87  	for i := range arrs {
    88  		if arrs[i].Ndims() > 2 {
    89  			panic(SHAPE_ERROR)
    90  		}
    91  	}
    92  	if len(arrs) == 0 {
    93  		return nil
    94  	}
    95  	if len(arrs) == 1 {
    96  		return arrs[0].Copy()
    97  	}
    98  
    99  	return Concat(1, arrs...)
   100  
   101  	//var hlenSum int = 0
   102  	//var hBlockLens = make([]int, len(arrs))
   103  	//var vlen int
   104  	//if arrs[0].Ndims() == 1 {
   105  	//	vlen = 1
   106  	//	hlenSum += arrs[0].Shape[0]
   107  	//	hBlockLens[0] = arrs[0].Shape[0]
   108  	//} else {
   109  	//	vlen = arrs[0].Shape[0]
   110  	//	hlenSum += arrs[0].Shape[1]
   111  	//	hBlockLens[0] = arrs[0].Shape[1]
   112  	//}
   113  	//for i := 1; i < len(arrs); i++ {
   114  	//	var nextVlen int
   115  	//	if arrs[i].Ndims() == 1 {
   116  	//		nextVlen = 1
   117  	//		hlenSum += arrs[i].Shape[0]
   118  	//		hBlockLens[i] = arrs[i].Shape[0]
   119  	//	} else {
   120  	//		nextVlen = arrs[i].Shape[0]
   121  	//		hlenSum += arrs[i].Shape[1]
   122  	//		hBlockLens[i] = arrs[i].Shape[1]
   123  	//	}
   124  	//	if vlen != nextVlen {
   125  	//		panic(SHAPE_ERROR)
   126  	//	}
   127  	//}
   128  	//
   129  	//Data := make([]float64, hlenSum*vlen)
   130  	//for i := 0; i < vlen; i++ {
   131  	//	var curPos = 0
   132  	//	for j := 0; j < len(arrs); j++ {
   133  	//		copy(Data[curPos+i*hlenSum:curPos+i*hlenSum+hBlockLens[j]], arrs[j].Data[i*hBlockLens[j]:(i+1)*hBlockLens[j]])
   134  	//		curPos += hBlockLens[j]
   135  	//	}
   136  	//}
   137  	//
   138  	//return Array(Data, vlen, hlenSum)
   139  }
   140  
   141  //将多个多维数组在指定的轴上组合起来。
   142  //一维数组默认扩充为2维,参考AtLeast2D函数。
   143  func Concat(axis int, arrs ...*Arrf) *Arrf {
   144  	if len(arrs) == 0 {
   145  		return nil
   146  	}
   147  	if len(arrs) == 1 {
   148  		return arrs[0].Copy()
   149  	}
   150  
   151  	for i := range arrs {
   152  		AtLeast2D(arrs[i])
   153  	}
   154  
   155  	if axis >= arrs[0].Ndims() {
   156  		fmt.Println("axis is bigger than dimensions num.")
   157  		panic(PARAMETER_ERROR)
   158  	}
   159  
   160  	var newShape = make([]int, arrs[0].Ndims())
   161  	for index, firstL := range arrs[0].Shape {
   162  		if index == axis {
   163  			newShape[index] += firstL
   164  			for j := 1; j < len(arrs); j++ {
   165  				newShape[index] += arrs[j].Shape[index]
   166  			}
   167  		} else {
   168  			newShape[index] = firstL
   169  			for j := 1; j < len(arrs); j++ {
   170  				if firstL != arrs[j].Shape[index] {
   171  					panic(SHAPE_ERROR)
   172  				}
   173  			}
   174  		}
   175  	}
   176  
   177  	var times = 0
   178  	if axis == 0 {
   179  		times = 1
   180  	} else {
   181  		times = ProductIntSlice(arrs[0].Shape[0:axis])
   182  	}
   183  
   184  	var Data = make([]float64, ProductIntSlice(newShape))
   185  
   186  	var curPos = 0
   187  	for i := 0; i < times; i++ {
   188  		for j := 0; j < len(arrs); j++ {
   189  			var l = ProductIntSlice(arrs[j].Shape[axis:])
   190  			copy(Data[curPos:curPos+l], arrs[j].Data[i*l:(i+1)*l])
   191  			curPos += l
   192  		}
   193  	}
   194  
   195  	return Array(Data, newShape...)
   196  }
   197  
   198  //将一维数组扩充为二维
   199  func AtLeast2D(a *Arrf) *Arrf {
   200  	if a == nil {
   201  		return nil
   202  	} else if a.Ndims() >= 2 {
   203  		return a
   204  	} else {
   205  		newShpae := make([]int, 2)
   206  		newShpae[0] = 1
   207  		newShpae[1] = a.Shape[0]
   208  		a.Shape = newShpae
   209  		return a
   210  	}
   211  }
   212  
   213  //将数组内部的元素铺平返回,创建新的数据副本。
   214  func (a *Arrf) Flatten() *Arrf {
   215  	ra := make([]float64, len(a.Data))
   216  	copy(ra, a.Data)
   217  	return Array(ra, len(a.Data))
   218  }