github.com/qiaogw/arrgo@v0.0.8/stats.go (about)

     1  package arrgo
     2  
     3  import (
     4  	"sort"
     5  
     6  	asm "github.com/qiaogw/arrgo/internal"
     7  	//"github.com/ledao/arrgo/internal"
     8  )
     9  
    10  func (a *Arrf) Sum(axis ...int) *Arrf {
    11  	if len(axis) == 0 || len(axis) >= a.Ndims() {
    12  		tot := float64(0)
    13  		for _, v := range a.Data {
    14  			tot += v
    15  		}
    16  		return Fill(tot, 1)
    17  	}
    18  
    19  	//对axis进行排序,按照从大到小的顺序进行规约
    20  	sort.IntSlice(axis).Sort()
    21  	//规约后的数组的形状
    22  	restAxis := make([]int, len(a.Shape)-len(axis))
    23  	//对a进行复制,所有的操作都作用于临时变量ta中,最后将ta返回
    24  	ta := a.Copy()
    25  
    26  axisR:
    27  	for i, t := 0, 0; i < len(ta.Shape); i++ {
    28  		for _, w := range axis {
    29  			if i == w {
    30  				continue axisR
    31  			}
    32  		}
    33  		restAxis[t] = ta.Shape[i]
    34  		t++
    35  	}
    36  
    37  	//数组的元素的个数保存到ln中
    38  	ln := ta.Strides[0]
    39  	//对每个指定的轴,顺寻进行规约
    40  	for k := 0; k < len(axis); k++ {
    41  		//如果轴大小为1,则不需要任何操作
    42  		if ta.Shape[axis[k]] == 1 {
    43  			continue
    44  		}
    45  		//获取当前轴的大小v,当前轴的跨度wd,以及下一个轴的跨度st
    46  		v, wd, st := ta.Shape[axis[k]], ta.Strides[axis[k]], ta.Strides[axis[k]+1]
    47  		//如果下一个轴st的跨度为1,则说明当前轴为最后一个轴,只需要每wd个跨度进行一个规约即可
    48  		if st == 1 {
    49  			//每wd个数据进行一次规约,结果依次放到开始的位置
    50  			asm.Hadd(uint64(wd), ta.Data)
    51  			ln /= v
    52  			ta.Data = ta.Data[:ln]
    53  			continue
    54  		}
    55  		//如果不是最后一个轴,则在该轴上进行规约
    56  		for w := 0; w < ln; w += wd {
    57  			t := ta.Data[w/wd*st : (w/wd+1)*st]
    58  			copy(t, ta.Data[w:w+st])
    59  			for i := 1; i*st+1 < wd; i++ {
    60  				asm.Vadd(t, ta.Data[w+(i)*st:w+(i+1)*st])
    61  			}
    62  		}
    63  		ln /= v
    64  		ta.Data = ta.Data[:ln]
    65  	}
    66  	ta.Shape = restAxis
    67  
    68  	tmp := 1
    69  	for i := len(restAxis); i > 0; i-- {
    70  		ta.Strides[i] = tmp
    71  		tmp *= restAxis[i-1]
    72  	}
    73  	ta.Strides[0] = tmp
    74  	ta.Data = ta.Data[:tmp]
    75  	ta.Strides = ta.Strides[:len(restAxis)+1]
    76  	return ta
    77  }
    78  
    79  func Sum(a *Arrf, axis ...int) *Arrf {
    80  	return a.Sum(axis...)
    81  }
    82  
    83  func (a *Arrf) Mean(axis ...int) *Arrf {
    84  	if len(axis) == 0 || len(axis) >= a.Ndims() {
    85  		tot := float64(0)
    86  		for _, v := range a.Data {
    87  			tot += v
    88  		}
    89  		return Fill(tot/float64(a.Strides[0]), 1)
    90  	}
    91  
    92  	sort.IntSlice(axis).Sort()
    93  	selectShape := make([]int, len(axis))
    94  	for i := range selectShape {
    95  		selectShape[i] = a.Shape[axis[i]]
    96  	}
    97  	N := ProductIntSlice(selectShape)
    98  
    99  	ta := a.Sum(axis...)
   100  
   101  	return ta.DivC(float64(N))
   102  }
   103  
   104  func Mean(a *Arrf, axis ...int) *Arrf {
   105  	return a.Mean(axis...)
   106  }
   107  
   108  func (a *Arrf) Var(axis ...int) *Arrf {
   109  	a2 := a.Mul(a).Sum(axis...)
   110  	m := a.Mean(axis...)
   111  	var N int
   112  	if len(axis) == 0 || len(axis) >= a.Ndims() {
   113  		N = ProductIntSlice(a.Shape)
   114  	} else {
   115  		selectShape := make([]int, len(axis))
   116  		for i := range selectShape {
   117  			selectShape[i] = a.Shape[axis[i]]
   118  		}
   119  		N = ProductIntSlice(selectShape)
   120  	}
   121  
   122  	m2 := m.Mul(m).MulC(float64(N))
   123  	a_m_2 := a.Sum(axis...).Mul(m).MulC(2)
   124  	return a2.Sub(a_m_2).Add(m2).DivC(float64(N))
   125  }
   126  
   127  func Var(a *Arrf, axis ...int) *Arrf {
   128  	return a.Var(axis...)
   129  }
   130  
   131  func (a *Arrf) Std(axis ...int) *Arrf {
   132  	return Sqrt(a.Var(axis...))
   133  }
   134  
   135  func Std(a *Arrf, axis ...int) *Arrf {
   136  	return a.Std(axis...)
   137  }
   138  
   139  func (a *Arrf) Min(axis ...int) *Arrf {
   140  	if len(axis) == 0 || len(axis) >= a.Ndims() {
   141  		minValue := a.Data[0]
   142  		for _, v := range a.Data {
   143  			if minValue > v {
   144  				minValue = v
   145  			}
   146  		}
   147  		return Fill(minValue, 1)
   148  	}
   149  
   150  	sort.IntSlice(axis).Sort()
   151  	restAxis := make([]int, len(a.Shape)-len(axis))
   152  	ta := a.Copy()
   153  axisR:
   154  	for i, t := 0, 0; i < len(ta.Shape); i++ {
   155  		for _, w := range axis {
   156  			if i == w {
   157  				continue axisR
   158  			}
   159  		}
   160  		restAxis[t] = ta.Shape[i]
   161  		t++
   162  	}
   163  
   164  	//数组的元素的个数保存到ln中
   165  	ln := ta.Strides[0]
   166  	//对每个指定的轴,顺寻进行规约
   167  	for k := 0; k < len(axis); k++ {
   168  		//如果轴大小为1,则不需要任何操作
   169  		if ta.Shape[axis[k]] == 1 {
   170  			continue
   171  		}
   172  		//获取当前轴的大小v,当前轴的跨度wd,以及下一个轴的跨度st
   173  		v, wd, st := ta.Shape[axis[k]], ta.Strides[axis[k]], ta.Strides[axis[k]+1]
   174  		//如果下一个轴st的跨度为1,则说明当前轴为最后一个轴,只需要每wd个跨度进行一个规约即可
   175  		if st == 1 {
   176  			//每wd个数据进行一次规约,结果依次放到开始的位置
   177  			Hmin(wd, ta.Data)
   178  			ln /= v
   179  			ta.Data = ta.Data[:ln]
   180  			continue
   181  		}
   182  		//如果不是最后一个轴,则在该轴上进行规约
   183  		for w := 0; w < ln; w += wd {
   184  			t := ta.Data[w/wd*st : (w/wd+1)*st]
   185  			copy(t, ta.Data[w:w+st])
   186  			for i := 1; i*st+1 < wd; i++ {
   187  				Vmin(t, ta.Data[w+(i)*st:w+(i+1)*st])
   188  			}
   189  		}
   190  		ln /= v
   191  		ta.Data = ta.Data[:ln]
   192  	}
   193  
   194  	ta.Shape = restAxis
   195  
   196  	tmp := 1
   197  	for i := len(restAxis); i > 0; i-- {
   198  		ta.Strides[i] = tmp
   199  		tmp *= restAxis[i-1]
   200  	}
   201  	ta.Strides[0] = tmp
   202  	ta.Strides = ta.Strides[:len(restAxis)+1]
   203  	return ta
   204  }
   205  
   206  func Min(a *Arrf, axis ...int) *Arrf {
   207  	return a.Min(axis...)
   208  }
   209  
   210  func (a *Arrf) Max(axis ...int) *Arrf {
   211  	if len(axis) == 0 || len(axis) >= a.Ndims() {
   212  		maxValue := a.Data[0]
   213  		for _, v := range a.Data {
   214  			if maxValue < v {
   215  				maxValue = v
   216  			}
   217  		}
   218  		return Fill(maxValue, 1)
   219  	}
   220  
   221  	sort.IntSlice(axis).Sort()
   222  	restAxis := make([]int, len(a.Shape)-len(axis))
   223  	ta := a.Copy()
   224  axisR:
   225  	for i, t := 0, 0; i < len(ta.Shape); i++ {
   226  		for _, w := range axis {
   227  			if i == w {
   228  				continue axisR
   229  			}
   230  		}
   231  		restAxis[t] = ta.Shape[i]
   232  		t++
   233  	}
   234  
   235  	//数组的元素的个数保存到ln中
   236  	ln := ta.Strides[0]
   237  	//对每个指定的轴,顺寻进行规约
   238  	for k := 0; k < len(axis); k++ {
   239  		//如果轴大小为1,则不需要任何操作
   240  		if ta.Shape[axis[k]] == 1 {
   241  			continue
   242  		}
   243  		//获取当前轴的大小v,当前轴的跨度wd,以及下一个轴的跨度st
   244  		v, wd, st := ta.Shape[axis[k]], ta.Strides[axis[k]], ta.Strides[axis[k]+1]
   245  		//如果下一个轴st的跨度为1,则说明当前轴为最后一个轴,只需要每wd个跨度进行一个规约即可
   246  		if st == 1 {
   247  			//每wd个数据进行一次规约,结果依次放到开始的位置
   248  			Hmax(wd, ta.Data)
   249  			ln /= v
   250  			ta.Data = ta.Data[:ln]
   251  			continue
   252  		}
   253  		//如果不是最后一个轴,则在该轴上进行规约
   254  		for w := 0; w < ln; w += wd {
   255  			t := ta.Data[w/wd*st : (w/wd+1)*st]
   256  			copy(t, ta.Data[w:w+st])
   257  			for i := 1; i*st+1 < wd; i++ {
   258  				Vmax(t, ta.Data[w+(i)*st:w+(i+1)*st])
   259  			}
   260  		}
   261  		ln /= v
   262  		ta.Data = ta.Data[:ln]
   263  	}
   264  
   265  	ta.Shape = restAxis
   266  
   267  	tmp := 1
   268  	for i := len(restAxis); i > 0; i-- {
   269  		ta.Strides[i] = tmp
   270  		tmp *= restAxis[i-1]
   271  	}
   272  	ta.Strides[0] = tmp
   273  	ta.Strides = ta.Strides[:len(restAxis)+1]
   274  	return ta
   275  }
   276  
   277  func Max(a *Arrf, axis ...int) *Arrf {
   278  	return a.Max(axis...)
   279  }
   280  
   281  func (a *Arrf) ArgMax(axis int) *Arrf {
   282  	if axis < 0 {
   283  		axis = axis + len(a.Shape)
   284  	}
   285  	restAxis := make([]int, len(a.Shape)-1)
   286  	ta := a.Copy()
   287  	for i, t := 0, 0; i < len(ta.Shape); i++ {
   288  		if i == axis {
   289  			continue
   290  		}
   291  		restAxis[t] = ta.Shape[i]
   292  		t++
   293  	}
   294  
   295  	//数组的元素的个数保存到ln中
   296  	ln := ta.Strides[0]
   297  
   298  	//获取当前轴的大小v,当前轴的跨度wd,以及下一个轴的跨度st
   299  	v, wd, st := ta.Shape[axis], ta.Strides[axis], ta.Strides[axis+1]
   300  	//如果下一个轴st的跨度为1,则说明当前轴为最后一个轴,只需要每wd个跨度进行一个规约即可
   301  	if st == 1 {
   302  		//每wd个数据进行一次规约,结果依次放到开始的位置
   303  		Hargmax(wd, ta.Data)
   304  		ln /= v
   305  		ta.Data = ta.Data[:ln]
   306  	} else {
   307  		//如果不是最后一个轴,则在该轴上进行规约
   308  		td := make([]float64, 0, ln/wd)
   309  		for w := 0; w < ln; w += wd {
   310  			Vargmax(st, ta.Data[w:w+wd])
   311  			td = append(td, ta.Data[w : w+wd][:st]...)
   312  		}
   313  		ln /= v
   314  		ta.Data = td
   315  	}
   316  
   317  	ta.Shape = restAxis
   318  
   319  	tmp := 1
   320  	for i := len(restAxis); i > 0; i-- {
   321  		ta.Strides[i] = tmp
   322  		tmp *= restAxis[i-1]
   323  	}
   324  	ta.Strides[0] = tmp
   325  	ta.Strides = ta.Strides[:len(restAxis)+1]
   326  	return ta
   327  }
   328  
   329  func ArgMax(a *Arrf, axis int) *Arrf {
   330  	return a.ArgMax(axis)
   331  }
   332  
   333  //fixme has bug
   334  func (a *Arrf) ArgMin(axis int) *Arrf {
   335  	if axis < 0 {
   336  		axis = axis + len(a.Shape)
   337  	}
   338  	restAxis := make([]int, len(a.Shape)-1)
   339  	ta := a.Copy()
   340  	for i, t := 0, 0; i < len(ta.Shape); i++ {
   341  		if i == axis {
   342  			continue
   343  		}
   344  		restAxis[t] = ta.Shape[i]
   345  		t++
   346  	}
   347  
   348  	//数组的元素的个数保存到ln中
   349  	ln := ta.Strides[0]
   350  
   351  	//获取当前轴的大小v,当前轴的跨度wd,以及下一个轴的跨度st
   352  	v, wd, st := ta.Shape[axis], ta.Strides[axis], ta.Strides[axis+1]
   353  	//如果下一个轴st的跨度为1,则说明当前轴为最后一个轴,只需要每wd个跨度进行一个规约即可
   354  	if st == 1 {
   355  		//每wd个数据进行一次规约,结果依次放到开始的位置
   356  		Hargmin(wd, ta.Data)
   357  		ln /= v
   358  		ta.Data = ta.Data[:ln]
   359  	} else {
   360  		//如果不是最后一个轴,则在该轴上进行规约
   361  		td := make([]float64, 0, ln/wd)
   362  		for w := 0; w < ln; w += wd {
   363  			Vargmin(st, ta.Data[w:w+wd])
   364  			td = append(td, ta.Data[w : w+wd][:st]...)
   365  		}
   366  		ln /= v
   367  		ta.Data = td
   368  	}
   369  
   370  	ta.Shape = restAxis
   371  
   372  	tmp := 1
   373  	for i := len(restAxis); i > 0; i-- {
   374  		ta.Strides[i] = tmp
   375  		tmp *= restAxis[i-1]
   376  	}
   377  	ta.Strides[0] = tmp
   378  	ta.Strides = ta.Strides[:len(restAxis)+1]
   379  	return ta
   380  }
   381  
   382  func ArgMin(a *Arrf, axis int) *Arrf {
   383  	return a.ArgMin(axis)
   384  }