gonum.org/v1/gonum@v0.14.0/stat/distmv/normal_test.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distmv 6 7 import ( 8 "math" 9 "testing" 10 11 "golang.org/x/exp/rand" 12 13 "gonum.org/v1/gonum/diff/fd" 14 "gonum.org/v1/gonum/floats" 15 "gonum.org/v1/gonum/mat" 16 "gonum.org/v1/gonum/stat" 17 ) 18 19 func TestNormProbs(t *testing.T) { 20 dist1, ok := NewNormal([]float64{0, 0}, mat.NewSymDense(2, []float64{1, 0, 0, 1}), nil) 21 if !ok { 22 t.Errorf("bad test") 23 } 24 dist2, ok := NewNormal([]float64{6, 7}, mat.NewSymDense(2, []float64{8, 2, 0, 4}), nil) 25 if !ok { 26 t.Errorf("bad test") 27 } 28 testProbability(t, []probCase{ 29 { 30 dist: dist1, 31 loc: []float64{0, 0}, 32 logProb: -1.837877066409345, 33 }, 34 { 35 dist: dist2, 36 loc: []float64{6, 7}, 37 logProb: -3.503979321496947, 38 }, 39 { 40 dist: dist2, 41 loc: []float64{1, 2}, 42 logProb: -7.075407892925519, 43 }, 44 }) 45 } 46 47 func TestNewNormalChol(t *testing.T) { 48 for _, test := range []struct { 49 mean []float64 50 cov *mat.SymDense 51 }{ 52 { 53 mean: []float64{2, 3}, 54 cov: mat.NewSymDense(2, []float64{1, 0.1, 0.1, 1}), 55 }, 56 } { 57 var chol mat.Cholesky 58 ok := chol.Factorize(test.cov) 59 if !ok { 60 panic("bad test") 61 } 62 n := NewNormalChol(test.mean, &chol, nil) 63 // Generate a random number and calculate probability to ensure things 64 // have been set properly. See issue #426. 65 x := n.Rand(nil) 66 _ = n.Prob(x) 67 } 68 } 69 70 func TestNormRand(t *testing.T) { 71 for _, test := range []struct { 72 mean []float64 73 cov []float64 74 }{ 75 { 76 mean: []float64{0, 0}, 77 cov: []float64{ 78 1, 0, 79 0, 1, 80 }, 81 }, 82 { 83 mean: []float64{0, 0}, 84 cov: []float64{ 85 1, 0.9, 86 0.9, 1, 87 }, 88 }, 89 { 90 mean: []float64{6, 7}, 91 cov: []float64{ 92 5, 0.9, 93 0.9, 2, 94 }, 95 }, 96 } { 97 dim := len(test.mean) 98 cov := mat.NewSymDense(dim, test.cov) 99 n, ok := NewNormal(test.mean, cov, nil) 100 if !ok { 101 t.Errorf("bad covariance matrix") 102 } 103 104 nSamples := 1000000 105 samps := mat.NewDense(nSamples, dim, nil) 106 for i := 0; i < nSamples; i++ { 107 n.Rand(samps.RawRowView(i)) 108 } 109 estMean := make([]float64, dim) 110 for i := range estMean { 111 estMean[i] = stat.Mean(mat.Col(nil, i, samps), nil) 112 } 113 if !floats.EqualApprox(estMean, test.mean, 1e-2) { 114 t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean) 115 } 116 var estCov mat.SymDense 117 stat.CovarianceMatrix(&estCov, samps, nil) 118 if !mat.EqualApprox(&estCov, cov, 1e-2) { 119 t.Errorf("Cov mismatch: want: %v, got %v", cov, &estCov) 120 } 121 } 122 } 123 124 func TestNormalQuantile(t *testing.T) { 125 for _, test := range []struct { 126 mean []float64 127 cov []float64 128 }{ 129 { 130 mean: []float64{6, 7}, 131 cov: []float64{ 132 5, 0.9, 133 0.9, 2, 134 }, 135 }, 136 } { 137 dim := len(test.mean) 138 cov := mat.NewSymDense(dim, test.cov) 139 n, ok := NewNormal(test.mean, cov, nil) 140 if !ok { 141 t.Errorf("bad covariance matrix") 142 } 143 144 nSamples := 1000000 145 rnd := rand.New(rand.NewSource(1)) 146 samps := mat.NewDense(nSamples, dim, nil) 147 tmp := make([]float64, dim) 148 for i := 0; i < nSamples; i++ { 149 for j := range tmp { 150 tmp[j] = rnd.Float64() 151 } 152 n.Quantile(samps.RawRowView(i), tmp) 153 } 154 estMean := make([]float64, dim) 155 for i := range estMean { 156 estMean[i] = stat.Mean(mat.Col(nil, i, samps), nil) 157 } 158 if !floats.EqualApprox(estMean, test.mean, 1e-2) { 159 t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean) 160 } 161 var estCov mat.SymDense 162 stat.CovarianceMatrix(&estCov, samps, nil) 163 if !mat.EqualApprox(&estCov, cov, 1e-2) { 164 t.Errorf("Cov mismatch: want: %v, got %v", cov, &estCov) 165 } 166 } 167 } 168 169 func TestConditionNormal(t *testing.T) { 170 // Uncorrelated values shouldn't influence the updated values. 171 for _, test := range []struct { 172 mu []float64 173 sigma *mat.SymDense 174 observed []int 175 values []float64 176 177 newMu []float64 178 newSigma *mat.SymDense 179 }{ 180 { 181 mu: []float64{2, 3}, 182 sigma: mat.NewSymDense(2, []float64{2, 0, 0, 5}), 183 observed: []int{0}, 184 values: []float64{10}, 185 186 newMu: []float64{3}, 187 newSigma: mat.NewSymDense(1, []float64{5}), 188 }, 189 { 190 mu: []float64{2, 3}, 191 sigma: mat.NewSymDense(2, []float64{2, 0, 0, 5}), 192 observed: []int{1}, 193 values: []float64{10}, 194 195 newMu: []float64{2}, 196 newSigma: mat.NewSymDense(1, []float64{2}), 197 }, 198 { 199 mu: []float64{2, 3, 4}, 200 sigma: mat.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}), 201 observed: []int{1}, 202 values: []float64{10}, 203 204 newMu: []float64{2, 4}, 205 newSigma: mat.NewSymDense(2, []float64{2, 0, 0, 10}), 206 }, 207 { 208 mu: []float64{2, 3, 4}, 209 sigma: mat.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}), 210 observed: []int{0, 1}, 211 values: []float64{10, 15}, 212 213 newMu: []float64{4}, 214 newSigma: mat.NewSymDense(1, []float64{10}), 215 }, 216 { 217 mu: []float64{2, 3, 4, 5}, 218 sigma: mat.NewSymDense(4, []float64{2, 0.5, 0, 0, 0.5, 5, 0, 0, 0, 0, 10, 2, 0, 0, 2, 3}), 219 observed: []int{0, 1}, 220 values: []float64{10, 15}, 221 222 newMu: []float64{4, 5}, 223 newSigma: mat.NewSymDense(2, []float64{10, 2, 2, 3}), 224 }, 225 } { 226 normal, ok := NewNormal(test.mu, test.sigma, nil) 227 if !ok { 228 t.Fatalf("Bad test, original sigma not positive definite") 229 } 230 newNormal, ok := normal.ConditionNormal(test.observed, test.values, nil) 231 if !ok { 232 t.Fatalf("Bad test, update failure") 233 } 234 235 if !floats.EqualApprox(test.newMu, newNormal.mu, 1e-12) { 236 t.Errorf("Updated mean mismatch. Want %v, got %v.", test.newMu, newNormal.mu) 237 } 238 239 var sigma mat.SymDense 240 newNormal.chol.ToSym(&sigma) 241 if !mat.EqualApprox(test.newSigma, &sigma, 1e-12) { 242 t.Errorf("Updated sigma mismatch\n.Want:\n% v\nGot:\n% v\n", test.newSigma, sigma) 243 } 244 } 245 246 // Test bivariate case where the update rule is analytic 247 for _, test := range []struct { 248 mu []float64 249 std []float64 250 rho float64 251 value float64 252 }{ 253 { 254 mu: []float64{2, 3}, 255 std: []float64{3, 5}, 256 rho: 0.9, 257 value: 1000, 258 }, 259 { 260 mu: []float64{2, 3}, 261 std: []float64{3, 5}, 262 rho: -0.9, 263 value: 1000, 264 }, 265 } { 266 std := test.std 267 rho := test.rho 268 sigma := mat.NewSymDense(2, []float64{std[0] * std[0], std[0] * std[1] * rho, std[0] * std[1] * rho, std[1] * std[1]}) 269 normal, ok := NewNormal(test.mu, sigma, nil) 270 if !ok { 271 t.Fatalf("Bad test, original sigma not positive definite") 272 } 273 newNormal, ok := normal.ConditionNormal([]int{1}, []float64{test.value}, nil) 274 if !ok { 275 t.Fatalf("Bad test, update failed") 276 } 277 var newSigma mat.SymDense 278 newNormal.chol.ToSym(&newSigma) 279 trueMean := test.mu[0] + rho*(std[0]/std[1])*(test.value-test.mu[1]) 280 if math.Abs(trueMean-newNormal.mu[0]) > 1e-14 { 281 t.Errorf("Mean mismatch. Want %v, got %v", trueMean, newNormal.mu[0]) 282 } 283 trueVar := (1 - rho*rho) * std[0] * std[0] 284 if math.Abs(trueVar-newSigma.At(0, 0)) > 1e-14 { 285 t.Errorf("Std mismatch. Want %v, got %v", trueMean, newNormal.mu[0]) 286 } 287 } 288 289 // Test via sampling. 290 for _, test := range []struct { 291 mu []float64 292 sigma *mat.SymDense 293 observed []int 294 unobserved []int 295 value []float64 296 }{ 297 // The indices in unobserved must be in ascending order for this test. 298 { 299 mu: []float64{2, 3, 4}, 300 sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}), 301 302 observed: []int{0}, 303 unobserved: []int{1, 2}, 304 value: []float64{1.9}, 305 }, 306 { 307 mu: []float64{2, 3, 4, 5}, 308 sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}), 309 310 observed: []int{0, 3}, 311 unobserved: []int{1, 2}, 312 value: []float64{1.9, 2.9}, 313 }, 314 } { 315 totalSamp := 4000000 316 var nSamp int 317 samples := mat.NewDense(totalSamp, len(test.mu), nil) 318 normal, ok := NewNormal(test.mu, test.sigma, nil) 319 if !ok { 320 t.Errorf("bad test") 321 } 322 sample := make([]float64, len(test.mu)) 323 for i := 0; i < totalSamp; i++ { 324 normal.Rand(sample) 325 isClose := true 326 for i, v := range test.observed { 327 if math.Abs(sample[v]-test.value[i]) > 1e-1 { 328 isClose = false 329 break 330 } 331 } 332 if isClose { 333 samples.SetRow(nSamp, sample) 334 nSamp++ 335 } 336 } 337 338 if nSamp < 100 { 339 t.Errorf("bad test, not enough samples") 340 continue 341 } 342 samples = samples.Slice(0, nSamp, 0, len(test.mu)).(*mat.Dense) 343 344 // Compute mean and covariance matrix. 345 estMean := make([]float64, len(test.mu)) 346 for i := range estMean { 347 estMean[i] = stat.Mean(mat.Col(nil, i, samples), nil) 348 } 349 var estCov mat.SymDense 350 stat.CovarianceMatrix(&estCov, samples, nil) 351 352 // Compute update rule. 353 newNormal, ok := normal.ConditionNormal(test.observed, test.value, nil) 354 if !ok { 355 t.Fatalf("Bad test, update failure") 356 } 357 358 var subEstMean []float64 359 for _, v := range test.unobserved { 360 361 subEstMean = append(subEstMean, estMean[v]) 362 } 363 subEstCov := mat.NewSymDense(len(test.unobserved), nil) 364 for i := 0; i < len(test.unobserved); i++ { 365 for j := i; j < len(test.unobserved); j++ { 366 subEstCov.SetSym(i, j, estCov.At(test.unobserved[i], test.unobserved[j])) 367 } 368 } 369 370 for i, v := range subEstMean { 371 if math.Abs(newNormal.mu[i]-v) > 5e-2 { 372 t.Errorf("Mean mismatch. Want %v, got %v.", newNormal.mu[i], v) 373 } 374 } 375 var sigma mat.SymDense 376 newNormal.chol.ToSym(&sigma) 377 if !mat.EqualApprox(&sigma, subEstCov, 1e-1) { 378 t.Errorf("Covariance mismatch. Want:\n%0.8v\nGot:\n%0.8v\n", subEstCov, sigma) 379 } 380 } 381 } 382 383 func TestCovarianceMatrix(t *testing.T) { 384 for _, test := range []struct { 385 mu []float64 386 sigma *mat.SymDense 387 }{ 388 { 389 mu: []float64{2, 3, 4}, 390 sigma: mat.NewSymDense(3, []float64{1, 0.5, 3, 0.5, 8, -1, 3, -1, 15}), 391 }, 392 } { 393 normal, ok := NewNormal(test.mu, test.sigma, nil) 394 if !ok { 395 t.Fatalf("Bad test, covariance matrix not positive definite") 396 } 397 var cov mat.SymDense 398 normal.CovarianceMatrix(&cov) 399 if !mat.EqualApprox(&cov, test.sigma, 1e-14) { 400 t.Errorf("Covariance mismatch with nil input") 401 } 402 dim := test.sigma.SymmetricDim() 403 cov = *mat.NewSymDense(dim, nil) 404 normal.CovarianceMatrix(&cov) 405 if !mat.EqualApprox(&cov, test.sigma, 1e-14) { 406 t.Errorf("Covariance mismatch with supplied input") 407 } 408 } 409 } 410 411 func TestMarginal(t *testing.T) { 412 for _, test := range []struct { 413 mu []float64 414 sigma *mat.SymDense 415 marginal []int 416 }{ 417 { 418 mu: []float64{2, 3, 4}, 419 sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}), 420 marginal: []int{0}, 421 }, 422 { 423 mu: []float64{2, 3, 4}, 424 sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}), 425 marginal: []int{0, 2}, 426 }, 427 { 428 mu: []float64{2, 3, 4, 5}, 429 sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}), 430 431 marginal: []int{0, 3}, 432 }, 433 } { 434 normal, ok := NewNormal(test.mu, test.sigma, nil) 435 if !ok { 436 t.Fatalf("Bad test, covariance matrix not positive definite") 437 } 438 marginal, ok := normal.MarginalNormal(test.marginal, nil) 439 if !ok { 440 t.Fatalf("Bad test, marginal matrix not positive definite") 441 } 442 dim := normal.Dim() 443 nSamples := 1000000 444 samps := mat.NewDense(nSamples, dim, nil) 445 for i := 0; i < nSamples; i++ { 446 normal.Rand(samps.RawRowView(i)) 447 } 448 estMean := make([]float64, dim) 449 for i := range estMean { 450 estMean[i] = stat.Mean(mat.Col(nil, i, samps), nil) 451 } 452 for i, v := range test.marginal { 453 if math.Abs(marginal.mu[i]-estMean[v]) > 1e-2 { 454 t.Errorf("Mean mismatch: want: %v, got %v", estMean[v], marginal.mu[i]) 455 } 456 } 457 458 var marginalCov mat.SymDense 459 marginal.CovarianceMatrix(&marginalCov) 460 var estCov mat.SymDense 461 stat.CovarianceMatrix(&estCov, samps, nil) 462 for i, v1 := range test.marginal { 463 for j, v2 := range test.marginal { 464 c := marginalCov.At(i, j) 465 ec := estCov.At(v1, v2) 466 if math.Abs(c-ec) > 5e-2 { 467 t.Errorf("Cov mismatch element i = %d, j = %d: want: %v, got %v", i, j, c, ec) 468 } 469 } 470 } 471 } 472 } 473 474 func TestMarginalSingle(t *testing.T) { 475 for _, test := range []struct { 476 mu []float64 477 sigma *mat.SymDense 478 }{ 479 { 480 mu: []float64{2, 3, 4}, 481 sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}), 482 }, 483 { 484 mu: []float64{2, 3, 4, 5}, 485 sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}), 486 }, 487 } { 488 normal, ok := NewNormal(test.mu, test.sigma, nil) 489 if !ok { 490 t.Fatalf("Bad test, covariance matrix not positive definite") 491 } 492 for i, mean := range test.mu { 493 norm := normal.MarginalNormalSingle(i, nil) 494 if norm.Mean() != mean { 495 t.Errorf("Mean mismatch nil Sigma, idx %v: want %v, got %v.", i, mean, norm.Mean()) 496 } 497 std := math.Sqrt(test.sigma.At(i, i)) 498 if math.Abs(norm.StdDev()-std) > 1e-14 { 499 t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev()) 500 } 501 } 502 } 503 504 // Test matching with TestMarginal. 505 rnd := rand.New(rand.NewSource(1)) 506 for cas := 0; cas < 10; cas++ { 507 dim := rnd.Intn(10) + 1 508 mu := make([]float64, dim) 509 for i := range mu { 510 mu[i] = rnd.Float64() 511 } 512 x := make([]float64, dim*dim) 513 for i := range x { 514 x[i] = rnd.Float64() 515 } 516 matrix := mat.NewDense(dim, dim, x) 517 var sigma mat.SymDense 518 sigma.SymOuterK(1, matrix) 519 520 normal, ok := NewNormal(mu, &sigma, nil) 521 if !ok { 522 t.Fatal("bad test") 523 } 524 for i := 0; i < dim; i++ { 525 single := normal.MarginalNormalSingle(i, nil) 526 mult, ok := normal.MarginalNormal([]int{i}, nil) 527 if !ok { 528 t.Fatal("bad test") 529 } 530 if math.Abs(single.Mean()-mult.Mean(nil)[0]) > 1e-14 { 531 t.Errorf("Mean mismatch") 532 } 533 var cov mat.SymDense 534 mult.CovarianceMatrix(&cov) 535 if math.Abs(single.Variance()-cov.At(0, 0)) > 1e-14 { 536 t.Errorf("Variance mismatch") 537 } 538 } 539 } 540 } 541 542 func TestNormalScoreInput(t *testing.T) { 543 for cas, test := range []struct { 544 mu []float64 545 sigma *mat.SymDense 546 x []float64 547 }{ 548 { 549 mu: []float64{2, 3, 4}, 550 sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}), 551 x: []float64{1, 3.1, -2}, 552 }, 553 { 554 mu: []float64{2, 3, 4, 5}, 555 sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}), 556 x: []float64{1, 3.1, -2, 5}, 557 }, 558 } { 559 normal, ok := NewNormal(test.mu, test.sigma, nil) 560 if !ok { 561 t.Fatalf("Bad test, covariance matrix not positive definite") 562 } 563 x := make([]float64, len(test.x)) 564 copy(x, test.x) 565 score := normal.ScoreInput(nil, x) 566 if !floats.Equal(x, test.x) { 567 t.Errorf("x modified during call to ScoreInput") 568 } 569 scoreFD := fd.Gradient(nil, normal.LogProb, x, nil) 570 if !floats.EqualApprox(score, scoreFD, 1e-4) { 571 t.Errorf("Case %d: derivative mismatch. Got %v, want %v", cas, score, scoreFD) 572 } 573 } 574 }