github.com/jgbaldwinbrown/perf@v0.1.1/pkg/stats/udist_test.go (about)

     1  // Copyright 2015 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package stats
     6  
     7  import (
     8  	"fmt"
     9  	"math"
    10  	"testing"
    11  )
    12  
    13  func aeqTable(a, b [][]float64) bool {
    14  	if len(a) != len(b) {
    15  		return false
    16  	}
    17  	for i := range a {
    18  		if len(a[i]) != len(b[i]) {
    19  			return false
    20  		}
    21  		for j := range a[i] {
    22  			// "%f" precision
    23  			if math.Abs(a[i][j]-b[i][j]) >= 0.000001 {
    24  				return false
    25  			}
    26  		}
    27  	}
    28  	return true
    29  }
    30  
    31  // U distribution for N=3 up to U=5.
    32  var udist3 = [][]float64{
    33  	//    m=1         2         3
    34  	{0.250000, 0.100000, 0.050000}, // U=0
    35  	{0.500000, 0.200000, 0.100000}, // U=1
    36  	{0.750000, 0.400000, 0.200000}, // U=2
    37  	{1.000000, 0.600000, 0.350000}, // U=3
    38  	{1.000000, 0.800000, 0.500000}, // U=4
    39  	{1.000000, 0.900000, 0.650000}, // U=5
    40  }
    41  
    42  // U distribution for N=5 up to U=5.
    43  var udist5 = [][]float64{
    44  	//    m=1         2         3         4         5
    45  	{0.166667, 0.047619, 0.017857, 0.007937, 0.003968}, // U=0
    46  	{0.333333, 0.095238, 0.035714, 0.015873, 0.007937}, // U=1
    47  	{0.500000, 0.190476, 0.071429, 0.031746, 0.015873}, // U=2
    48  	{0.666667, 0.285714, 0.125000, 0.055556, 0.027778}, // U=3
    49  	{0.833333, 0.428571, 0.196429, 0.095238, 0.047619}, // U=4
    50  	{1.000000, 0.571429, 0.285714, 0.142857, 0.075397}, // U=5
    51  }
    52  
    53  func TestUDist(t *testing.T) {
    54  	makeTable := func(n int) [][]float64 {
    55  		out := make([][]float64, 6)
    56  		for U := 0; U < 6; U++ {
    57  			out[U] = make([]float64, n)
    58  			for m := 1; m <= n; m++ {
    59  				out[U][m-1] = UDist{N1: m, N2: n}.CDF(float64(U))
    60  			}
    61  		}
    62  		return out
    63  	}
    64  	fmtTable := func(a [][]float64) string {
    65  		out := fmt.Sprintf("%8s", "m=")
    66  		for m := 1; m <= len(a[0]); m++ {
    67  			out += fmt.Sprintf("%9d", m)
    68  		}
    69  		out += "\n"
    70  
    71  		for U, row := range a {
    72  			out += fmt.Sprintf("U=%-6d", U)
    73  			for m := 1; m <= len(a[0]); m++ {
    74  				out += fmt.Sprintf(" %f", row[m-1])
    75  			}
    76  			out += "\n"
    77  		}
    78  		return out
    79  	}
    80  
    81  	// Compare against tables given in Mann, Whitney (1947).
    82  	got3 := makeTable(3)
    83  	if !aeqTable(got3, udist3) {
    84  		t.Errorf("For n=3, want:\n%sgot:\n%s", fmtTable(udist3), fmtTable(got3))
    85  	}
    86  
    87  	got5 := makeTable(5)
    88  	if !aeqTable(got5, udist5) {
    89  		t.Errorf("For n=5, want:\n%sgot:\n%s", fmtTable(udist5), fmtTable(got5))
    90  	}
    91  }
    92  
    93  func BenchmarkUDist(b *testing.B) {
    94  	for i := 0; i < b.N; i++ {
    95  		// R uses the exact distribution up to N=50.
    96  		// N*M/2=1250 is the hardest point to get the CDF for.
    97  		UDist{N1: 50, N2: 50}.CDF(1250)
    98  	}
    99  }
   100  
   101  func TestUDistTies(t *testing.T) {
   102  	makeTable := func(m, N int, t []int, minx, maxx float64) [][]float64 {
   103  		out := [][]float64{}
   104  		dist := UDist{N1: m, N2: N - m, T: t}
   105  		for x := minx; x <= maxx; x += 0.5 {
   106  			// Convert x from uQt' to uQv'.
   107  			U := x - float64(m*m)/2
   108  			P := dist.CDF(U)
   109  			if len(out) == 0 || !aeq(out[len(out)-1][1], P) {
   110  				out = append(out, []float64{x, P})
   111  			}
   112  		}
   113  		return out
   114  	}
   115  	fmtTable := func(table [][]float64) string {
   116  		out := ""
   117  		for _, row := range table {
   118  			out += fmt.Sprintf("%5.1f %f\n", row[0], row[1])
   119  		}
   120  		return out
   121  	}
   122  
   123  	// Compare against Table 1 from Klotz (1966).
   124  	got := makeTable(5, 10, []int{1, 1, 2, 1, 1, 2, 1, 1}, 12.5, 19.5)
   125  	want := [][]float64{
   126  		{12.5, 0.003968}, {13.5, 0.007937},
   127  		{15.0, 0.023810}, {16.5, 0.047619},
   128  		{17.5, 0.071429}, {18.0, 0.087302},
   129  		{19.0, 0.134921}, {19.5, 0.138889},
   130  	}
   131  	if !aeqTable(got, want) {
   132  		t.Errorf("Want:\n%sgot:\n%s", fmtTable(want), fmtTable(got))
   133  	}
   134  
   135  	got = makeTable(10, 21, []int{6, 5, 4, 3, 2, 1}, 52, 87)
   136  	want = [][]float64{
   137  		{52.0, 0.000014}, {56.5, 0.000128},
   138  		{57.5, 0.000145}, {60.0, 0.000230},
   139  		{61.0, 0.000400}, {62.0, 0.000740},
   140  		{62.5, 0.000797}, {64.0, 0.000825},
   141  		{64.5, 0.001165}, {65.5, 0.001477},
   142  		{66.5, 0.002498}, {67.0, 0.002725},
   143  		{67.5, 0.002895}, {68.0, 0.003150},
   144  		{68.5, 0.003263}, {69.0, 0.003518},
   145  		{69.5, 0.003603}, {70.0, 0.005648},
   146  		{70.5, 0.005818}, {71.0, 0.006626},
   147  		{71.5, 0.006796}, {72.0, 0.008157},
   148  		{72.5, 0.009688}, {73.0, 0.009801},
   149  		{73.5, 0.010430}, {74.0, 0.011111},
   150  		{74.5, 0.014230}, {75.0, 0.014612},
   151  		{75.5, 0.017249}, {76.0, 0.018307},
   152  		{76.5, 0.020178}, {77.0, 0.022270},
   153  		{77.5, 0.023189}, {78.0, 0.026931},
   154  		{78.5, 0.028207}, {79.0, 0.029979},
   155  		{79.5, 0.030931}, {80.0, 0.038969},
   156  		{80.5, 0.043063}, {81.0, 0.044262},
   157  		{81.5, 0.046389}, {82.0, 0.049581},
   158  		{82.5, 0.056300}, {83.0, 0.058027},
   159  		{83.5, 0.063669}, {84.0, 0.067454},
   160  		{84.5, 0.074122}, {85.0, 0.077425},
   161  		{85.5, 0.083498}, {86.0, 0.094079},
   162  		{86.5, 0.096693}, {87.0, 0.101132},
   163  	}
   164  	if !aeqTable(got, want) {
   165  		t.Errorf("Want:\n%sgot:\n%s", fmtTable(want), fmtTable(got))
   166  	}
   167  
   168  	got = makeTable(8, 16, []int{2, 2, 2, 2, 2, 2, 2, 2}, 32, 54)
   169  	want = [][]float64{
   170  		{32.0, 0.000078}, {34.0, 0.000389},
   171  		{36.0, 0.001088}, {38.0, 0.002642},
   172  		{40.0, 0.005905}, {42.0, 0.011500},
   173  		{44.0, 0.021057}, {46.0, 0.035664},
   174  		{48.0, 0.057187}, {50.0, 0.086713},
   175  		{52.0, 0.126263}, {54.0, 0.175369},
   176  	}
   177  	if !aeqTable(got, want) {
   178  		t.Errorf("Want:\n%sgot:\n%s", fmtTable(want), fmtTable(got))
   179  	}
   180  
   181  	// Check remaining tables from Klotz against the reference
   182  	// implementation.
   183  	checkRef := func(n1 int, tie []int) {
   184  		wantPMF1, wantCDF1 := udistRef(n1, tie)
   185  
   186  		dist := UDist{N1: n1, N2: sumint(tie) - n1, T: tie}
   187  		gotPMF, wantPMF := [][]float64{}, [][]float64{}
   188  		gotCDF, wantCDF := [][]float64{}, [][]float64{}
   189  		N := sumint(tie)
   190  		for U := 0.0; U <= float64(n1*(N-n1)); U += 0.5 {
   191  			gotPMF = append(gotPMF, []float64{U, dist.PMF(U)})
   192  			gotCDF = append(gotCDF, []float64{U, dist.CDF(U)})
   193  			wantPMF = append(wantPMF, []float64{U, wantPMF1[int(U*2)]})
   194  			wantCDF = append(wantCDF, []float64{U, wantCDF1[int(U*2)]})
   195  		}
   196  		if !aeqTable(wantPMF, gotPMF) {
   197  			t.Errorf("For PMF of n1=%v, t=%v, want:\n%sgot:\n%s", n1, tie, fmtTable(wantPMF), fmtTable(gotPMF))
   198  		}
   199  		if !aeqTable(wantCDF, gotCDF) {
   200  			t.Errorf("For CDF of n1=%v, t=%v, want:\n%sgot:\n%s", n1, tie, fmtTable(wantCDF), fmtTable(gotCDF))
   201  		}
   202  	}
   203  	checkRef(5, []int{1, 1, 2, 1, 1, 2, 1, 1})
   204  	checkRef(5, []int{1, 1, 2, 1, 1, 1, 2, 1})
   205  	checkRef(5, []int{1, 3, 1, 2, 1, 1, 1})
   206  	checkRef(8, []int{1, 2, 1, 1, 1, 1, 2, 2, 1, 2})
   207  	checkRef(12, []int{3, 3, 4, 3, 4, 5})
   208  	checkRef(10, []int{1, 2, 3, 4, 5, 6})
   209  }
   210  
   211  func BenchmarkUDistTies(b *testing.B) {
   212  	// Worst case: just one tie.
   213  	n := 20
   214  	t := make([]int, 2*n-1)
   215  	for i := range t {
   216  		t[i] = 1
   217  	}
   218  	t[0] = 2
   219  
   220  	for i := 0; i < b.N; i++ {
   221  		UDist{N1: n, N2: n, T: t}.CDF(float64(n*n) / 2)
   222  	}
   223  }
   224  
   225  func XTestPrintUmemo(t *testing.T) {
   226  	// Reproduce table from Cheung, Klotz.
   227  	ties := []int{4, 5, 3, 4, 6}
   228  	printUmemo(makeUmemo(80, 10, ties), ties)
   229  }
   230  
   231  // udistRef computes the PMF and CDF of the U distribution for two
   232  // samples of sizes n1 and sum(t)-n1 with tie vector t. The returned
   233  // pmf and cdf are indexed by 2*U.
   234  //
   235  // This uses the "graphical method" of Klotz (1966). It is very slow
   236  // (Θ(∏ (t[i]+1)) = Ω(2^|t|)), but very correct, and hence useful as a
   237  // reference for testing faster implementations.
   238  func udistRef(n1 int, t []int) (pmf, cdf []float64) {
   239  	// Enumerate all u vectors for which 0 <= u_i <= t_i. Count
   240  	// the number of permutations of two samples of sizes n1 and
   241  	// sum(t)-n1 with tie vector t and accumulate these counts by
   242  	// their U statistics in count[2*U].
   243  	counts := make([]int, 1+2*n1*(sumint(t)-n1))
   244  
   245  	u := make([]int, len(t))
   246  	u[0] = -1 // Get enumeration started.
   247  enumu:
   248  	for {
   249  		// Compute the next u vector.
   250  		u[0]++
   251  		for i := 0; i < len(u) && u[i] > t[i]; i++ {
   252  			if i == len(u)-1 {
   253  				// All u vectors have been enumerated.
   254  				break enumu
   255  			}
   256  			// Carry.
   257  			u[i+1]++
   258  			u[i] = 0
   259  		}
   260  
   261  		// Is this a legal u vector?
   262  		if sumint(u) != n1 {
   263  			// Klotz (1966) has a method for directly
   264  			// enumerating legal u vectors, but the point
   265  			// of this is to be correct, not fast.
   266  			continue
   267  		}
   268  
   269  		// Compute 2*U statistic for this u vector.
   270  		twoU, vsum := 0, 0
   271  		for i, u_i := range u {
   272  			v_i := t[i] - u_i
   273  			// U = U + vsum*u_i + u_i*v_i/2
   274  			twoU += 2*vsum*u_i + u_i*v_i
   275  			vsum += v_i
   276  		}
   277  
   278  		// Compute Π choose(t_i, u_i). This is the number of
   279  		// ways of permuting the input sample under u.
   280  		prod := 1
   281  		for i, u_i := range u {
   282  			prod *= int(mathChoose(t[i], u_i) + 0.5)
   283  		}
   284  
   285  		// Accumulate the permutations on this u path.
   286  		counts[twoU] += prod
   287  
   288  		if false {
   289  			// Print a table in the form of Klotz's
   290  			// "direct enumeration" example.
   291  			//
   292  			// Convert 2U = 2UQV' to UQt' used in Klotz
   293  			// examples.
   294  			UQt := float64(twoU)/2 + float64(n1*n1)/2
   295  			fmt.Printf("%+v %f %-2d\n", u, UQt, prod)
   296  		}
   297  	}
   298  
   299  	// Convert counts into probabilities for PMF and CDF.
   300  	pmf = make([]float64, len(counts))
   301  	cdf = make([]float64, len(counts))
   302  	total := int(mathChoose(sumint(t), n1) + 0.5)
   303  	for i, count := range counts {
   304  		pmf[i] = float64(count) / float64(total)
   305  		if i > 0 {
   306  			cdf[i] = cdf[i-1]
   307  		}
   308  		cdf[i] += pmf[i]
   309  	}
   310  	return
   311  }
   312  
   313  // printUmemo prints the output of makeUmemo for debugging.
   314  func printUmemo(A []map[ukey]float64, t []int) {
   315  	fmt.Printf("K\tn1\t2*U\tpr\n")
   316  	for K := len(A) - 1; K >= 0; K-- {
   317  		for i, pr := range A[K] {
   318  			_, ref := udistRef(i.n1, t[:K])
   319  			fmt.Printf("%v\t%v\t%v\t%v\t%v\n", K, i.n1, i.twoU, pr, ref[i.twoU])
   320  		}
   321  	}
   322  }