github.com/qiaogw/arrgo@v0.0.8/numeric_arrf.go (about)

     1  package arrgo
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"strings"
     7  )
     8  
     9  type Arrf struct {
    10  	Shape   []int
    11  	Strides []int
    12  	Data    []float64
    13  }
    14  
    15  //通过[]float64,形状来创建多维数组。
    16  //输入参数1:data []float64,以·C· 顺序存储,作为多维数组的输入数据,内部复制一份新的internalData,不改变data。
    17  //输入参数2:Shape ...int,指定多维数组的形状,多维,类似numpy中的Shape。
    18  //	如果某一个(仅支持一个维度)维度为负数,则根据len(Data)推断该维度的大小。
    19  //情况1:如果不指定Shape,而且data为nil,则创建一个空的*Arrf。
    20  //情况2:如果不指定Shape,而且data不为nil,则创建一个len(Data)大小的一维*Arrf。
    21  //情况3:如果指定Shape,而且data不为nil,则根据data大小创建多维数组,如果len(Data)不等于Shape,或者len(Data)不能整除Shape,抛出异常。
    22  //情况4:如果指定Shape,而且data为nil,则创建Shape大小的全为0.0的多维数组。
    23  func Array(Data []float64, Shape ...int) *Arrf {
    24  	if len(Shape) == 0 && Data == nil {
    25  		return &Arrf{
    26  			Shape:   []int{0},
    27  			Strides: []int{0, 1},
    28  			Data:    []float64{},
    29  		}
    30  	}
    31  
    32  	if len(Shape) == 0 && Data != nil {
    33  		internalData := make([]float64, len(Data)) //复制data,不影响输入的值。
    34  		copy(internalData, Data)
    35  		return &Arrf{
    36  			Shape:   []int{len(Data)},
    37  			Strides: []int{len(Data), 1},
    38  			Data:    internalData,
    39  		}
    40  	}
    41  
    42  	if Data == nil {
    43  		for _, v := range Shape {
    44  			if v <= 0 {
    45  				fmt.Println("Shape should be positive when Data is nill")
    46  				panic(SHAPE_ERROR)
    47  			}
    48  		}
    49  		length := ProductIntSlice(Shape)
    50  		internalShape := make([]int, len(Shape))
    51  		copy(internalShape, Shape)
    52  		Strides := make([]int, len(Shape)+1)
    53  		Strides[len(Shape)] = 1
    54  		for i := len(Shape) - 1; i >= 0; i-- {
    55  			Strides[i] = Strides[i+1] * internalShape[i]
    56  		}
    57  
    58  		return &Arrf{
    59  			Shape:   internalShape,
    60  			Strides: Strides,
    61  			Data:    make([]float64, length),
    62  		}
    63  	}
    64  
    65  	var dataLength = len(Data)
    66  	negativeIndex := -1
    67  	internalShape := make([]int, len(Shape))
    68  	copy(internalShape, Shape)
    69  	for k, v := range Shape {
    70  		if v < 0 {
    71  			if negativeIndex < 0 {
    72  				negativeIndex = k
    73  				internalShape[k] = 1
    74  			} else {
    75  				fmt.Println("Shape can only have one negative demention.")
    76  				panic(SHAPE_ERROR)
    77  			}
    78  		}
    79  	}
    80  	ShapeLength := ProductIntSlice(internalShape)
    81  
    82  	if dataLength < ShapeLength {
    83  		fmt.Println("Data length is shorter than Shape length.")
    84  		panic(SHAPE_ERROR)
    85  	}
    86  	if (dataLength % ShapeLength) != 0 {
    87  		fmt.Println("Data length cannot divided by Shape length")
    88  		panic(SHAPE_ERROR)
    89  	}
    90  
    91  	if negativeIndex >= 0 {
    92  		internalShape[negativeIndex] = dataLength / ShapeLength
    93  	}
    94  
    95  	Strides := make([]int, len(internalShape)+1)
    96  	Strides[len(internalShape)] = 1
    97  	for i := len(internalShape) - 1; i >= 0; i-- {
    98  		Strides[i] = Strides[i+1] * internalShape[i]
    99  	}
   100  
   101  	internalData := make([]float64, len(Data))
   102  	copy(internalData, Data)
   103  
   104  	return &Arrf{
   105  		Shape:   internalShape,
   106  		Strides: Strides,
   107  		Data:    internalData,
   108  	}
   109  }
   110  
   111  // 通过指定起始、终止和步进量来创建一维Array。
   112  // 输入参数: vals,可以有三种情况,详见下面描述。
   113  // 情况1:Arange(stop): 以0开始的序列,创建Array [0, 0+(-)1, ..., stop),不包括stop,stop符号决定升降序。
   114  // 情况2:Arange(start, stop):创建Array [start, start +(-)1, ..., stop),如果start小于start则递增,否则递减。
   115  // 情况3:Arange(start, stop, step):创建Array [start, start + step, ..., stop),step符号决定升降序。
   116  // 输入参数多于三个的都会被忽略。
   117  // 输出序列为“整型数”序列。
   118  func Arange(vals ...int) *Arrf {
   119  	var start, stop, step int = 0, 0, 1
   120  
   121  	switch len(vals) {
   122  	case 0:
   123  		fmt.Println("range function should have range")
   124  		panic(PARAMETER_ERROR)
   125  	case 1:
   126  		if vals[0] <= 0 {
   127  			step = -1
   128  			stop = vals[0] + 1
   129  		} else {
   130  			stop = vals[0] - 1
   131  		}
   132  	case 2:
   133  		if vals[1] < vals[0] {
   134  			step = -1
   135  			stop = vals[1] + 1
   136  		} else {
   137  			stop = vals[1] - 1
   138  		}
   139  		start = vals[0]
   140  	default:
   141  		if vals[1] < vals[0] {
   142  			if vals[2] >= 0 {
   143  				fmt.Println("increment should be negative.")
   144  				panic(PARAMETER_ERROR)
   145  			}
   146  			stop = vals[1] + 1
   147  		} else {
   148  			if vals[2] <= 0 {
   149  				fmt.Println("increment should be positive.")
   150  				panic(PARAMETER_ERROR)
   151  			}
   152  			stop = vals[1] - 1
   153  		}
   154  		start, step = vals[0], vals[2]
   155  	}
   156  
   157  	a := Array(nil, int(math.Abs(float64((stop-start)/step)))+1)
   158  	for i, v := 0, start; i < len(a.Data); i, v = i+1, v+step {
   159  		a.Data[i] = float64(v)
   160  	}
   161  	return a
   162  }
   163  
   164  //判断Arrf是否为空数组。
   165  //如果内部的data长度为0或者为nil,返回true,否则位false。
   166  func (a *Arrf) IsEmpty() bool {
   167  	return len(a.Data) == 0 || a.Data == nil
   168  }
   169  
   170  //创建Shape形状的多维数组,全部填充为fillvalue。
   171  //必须指定Shape,否则抛出异常。
   172  func Fill(fillValue float64, Shape ...int) *Arrf {
   173  	if len(Shape) == 0 {
   174  		fmt.Println("Shape is empty!")
   175  		panic(SHAPE_ERROR)
   176  	}
   177  	arr := Array(nil, Shape...)
   178  	for i := range arr.Data {
   179  		arr.Data[i] = fillValue
   180  	}
   181  
   182  	return arr
   183  }
   184  
   185  //根据Shape创建全为1.0的多维数组。
   186  func Ones(Shape ...int) *Arrf {
   187  	return Fill(1, Shape...)
   188  }
   189  
   190  //根据输入的多维数组的形状创建全1的多维数组。
   191  func OnesLike(a *Arrf) *Arrf {
   192  	return Ones(a.Shape...)
   193  }
   194  
   195  //根据Shape创建全为0的多维数组。
   196  func Zeros(Shape ...int) *Arrf {
   197  	return Fill(0, Shape...)
   198  }
   199  
   200  //根据输入的多维数组的形状创建全0的多维数组。
   201  func ZerosLike(a *Arrf) *Arrf {
   202  	return Zeros(a.Shape...)
   203  }
   204  
   205  // String Satisfies the Stringer interface for fmt package
   206  func (a *Arrf) String() (s string) {
   207  	switch {
   208  	case a == nil:
   209  		return "<nil>"
   210  	case a.Data == nil || a.Shape == nil || a.Strides == nil:
   211  		return "<nil>"
   212  	case a.Strides[0] == 0:
   213  		return "[]"
   214  	case len(a.Shape) == 1:
   215  		return fmt.Sprint(a.Data)
   216  		//strs := make([]string, len(a.Data))
   217  		//for i := range a.Data {
   218  		//	strs[i] = string(strconv.FormatFloat(a.Data[i], 'f', -1, 64))
   219  		//
   220  		//}
   221  		//return strings.Join(strs, ", ")
   222  	}
   223  
   224  	stride := a.Shape[len(a.Shape)-1]
   225  
   226  	for i, k := 0, 0; i+stride <= len(a.Data); i, k = i+stride, k+1 {
   227  
   228  		t := ""
   229  		for j, v := range a.Strides {
   230  			if i%v == 0 && j < len(a.Strides)-2 {
   231  				t += "["
   232  			}
   233  		}
   234  
   235  		s += strings.Repeat(" ", len(a.Shape)-len(t)-1) + t
   236  		s += fmt.Sprint(a.Data[i : i+stride])
   237  
   238  		t = ""
   239  		for j, v := range a.Strides {
   240  			if (i+stride)%v == 0 && j < len(a.Strides)-2 {
   241  				t += "]"
   242  			}
   243  		}
   244  
   245  		s += t + strings.Repeat(" ", len(a.Shape)-len(t)-1)
   246  		if i+stride != len(a.Data) {
   247  			s += "\n"
   248  			if len(t) > 0 {
   249  				s += "\n"
   250  			}
   251  		}
   252  	}
   253  	return
   254  }
   255  
   256  //获取index指定位置的元素。
   257  //index必须在Shape规定的范围内,否则会抛出异常。
   258  //index的长度必须小于等于维度的个数,否则会抛出异常。
   259  //如果index的个数小于维度个数,则会取后面的第一个值。
   260  func (a *Arrf) At(index ...int) float64 {
   261  	idx := a.valIndex(index...)
   262  	return a.Data[idx]
   263  }
   264  
   265  //详见At函数。
   266  func (a *Arrf) Get(index ...int) float64 {
   267  	return a.At(index...)
   268  }
   269  
   270  //At函数的内部实现,返回index指定的元素在切片中的位置,如果有错误,则返回error。
   271  func (a *Arrf) valIndex(index ...int) int {
   272  	idx := 0
   273  	if len(index) > len(a.Shape) {
   274  		fmt.Println("index len should not longer than Shape.")
   275  		panic(INDEX_ERROR)
   276  	}
   277  	for i, v := range index {
   278  		if v >= a.Shape[i] || v < 0 {
   279  			fmt.Println("index value out of range.")
   280  			panic(INDEX_ERROR)
   281  		}
   282  		idx += v * a.Strides[i+1]
   283  	}
   284  	return idx
   285  }
   286  
   287  //获取多维数组元素的个数。
   288  func (a *Arrf) Length() int {
   289  	return len(a.Data)
   290  }
   291  
   292  //创建一个n X n 的2维单位矩阵(数组)。
   293  func Eye(n int) *Arrf {
   294  	arr := Zeros(n, n)
   295  	for i := 0; i < n; i++ {
   296  		arr.Set(1, i, i)
   297  	}
   298  	return arr
   299  }
   300  
   301  //Eye的另一种称呼,详见Eye函数。
   302  func Identity(n int) *Arrf {
   303  	return Eye(n)
   304  }
   305  
   306  //指定位置的元素被新值替换。
   307  //如果index的超出范围则会抛出异常。
   308  //返回当前数组的指引,方便后续的连续操作。
   309  func (a *Arrf) Set(value float64, index ...int) *Arrf {
   310  	idx := a.valIndex(index...)
   311  
   312  	a.Data[idx] = value
   313  	return a
   314  }
   315  
   316  //返回多维数组的内部数组元素。
   317  //对返回值的操作会影响多维数组,一定谨慎操作。
   318  func (a *Arrf) Values() []float64 {
   319  	return a.Data
   320  }
   321  
   322  //根据[start, stop]指定的区间,创建包含num个元素的一维数组。
   323  func Linspace(start, stop float64, num int) *Arrf {
   324  	var Data = make([]float64, num)
   325  	var startF, stopF = start, stop
   326  	if startF <= stopF {
   327  		var step = (stopF - startF) / (float64(num - 1.0))
   328  		for i := range Data {
   329  			Data[i] = startF + float64(i)*step
   330  		}
   331  		return Array(Data, num)
   332  	} else {
   333  		var step = (startF - stopF) / (float64(num - 1.0))
   334  		for i := range Data {
   335  			Data[i] = startF - float64(i)*step
   336  		}
   337  		return Array(Data, num)
   338  	}
   339  }
   340  
   341  //复制一个形状一样,但是数据被深度复制的多维数组。
   342  func (a *Arrf) Copy() *Arrf {
   343  	b := ZerosLike(a)
   344  	copy(b.Data, a.Data)
   345  	return b
   346  }
   347  
   348  //返回多维数组的维度数目。
   349  func (a *Arrf) Ndims() int {
   350  	return len(a.Shape)
   351  }
   352  
   353  //Returns ta view of the array with axes transposed.
   354  //根据指定的轴顺序,生成一个新的调整后的多维数组。
   355  //如果是1维数组,则没有任何变化。
   356  //如果是2维数组,则行列交换。
   357  //如果是n维数组,则根据指定的顺序调整,生成新的多维数组。
   358  //输入参数1:如果不指定输入参数,则轴顺序全部反序;如果指定参数则个数必须和轴个数相同,否则抛出异常。
   359  //fixme 这里的实现效率不高,后面有时间需要提升一下。
   360  func (a *Arrf) Transpose(axes ...int) *Arrf {
   361  	var n = a.Ndims()
   362  	var permutation []int
   363  	var nShape []int
   364  
   365  	switch len(axes) {
   366  	case 0:
   367  		permutation = make([]int, n)
   368  		nShape = make([]int, n)
   369  		for i := range permutation {
   370  			permutation[i] = n - i
   371  		}
   372  		for i := 0; i < n; i++ {
   373  			permutation[i] = n - 1 - i
   374  			nShape[i] = a.Shape[permutation[i]]
   375  		}
   376  
   377  	case n:
   378  		permutation = axes
   379  		nShape = make([]int, n)
   380  		for i := range nShape {
   381  			nShape[i] = a.Shape[permutation[i]]
   382  		}
   383  
   384  	default:
   385  		fmt.Println("axis number wrong.")
   386  		panic(DIMENTION_ERROR)
   387  	}
   388  
   389  	var totalIndexSize = 1
   390  	for i := range a.Shape {
   391  		totalIndexSize *= a.Shape[i]
   392  	}
   393  
   394  	var indexsSrc = make([][]int, totalIndexSize)
   395  	var indexsDst = make([][]int, totalIndexSize)
   396  
   397  	var b = Zeros(nShape...)
   398  	var index = make([]int, n)
   399  	for i := 0; i < totalIndexSize; i++ {
   400  		tindexSrc := make([]int, n)
   401  		copy(tindexSrc, index)
   402  		indexsSrc[i] = tindexSrc
   403  		var tindexDst = make([]int, n)
   404  		for j := range tindexDst {
   405  			tindexDst[j] = index[permutation[j]]
   406  		}
   407  		indexsDst[i] = tindexDst
   408  
   409  		var j = n - 1
   410  		index[j]++
   411  		for {
   412  			if j > 0 && index[j] >= a.Shape[j] {
   413  				index[j-1]++
   414  				index[j] = 0
   415  				j--
   416  			} else {
   417  				break
   418  			}
   419  		}
   420  	}
   421  	for i := range indexsSrc {
   422  		b.Set(a.Get(indexsSrc[i]...), indexsDst[i]...)
   423  	}
   424  	return b
   425  }