gonum.org/v1/gonum@v0.15.1-0.20240517103525-f853624cb1bb/stat/distmv/normal_test.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package distmv 6 7 import ( 8 "fmt" 9 "math" 10 "testing" 11 12 "golang.org/x/exp/rand" 13 14 "gonum.org/v1/gonum/diff/fd" 15 "gonum.org/v1/gonum/floats" 16 "gonum.org/v1/gonum/mat" 17 "gonum.org/v1/gonum/stat" 18 ) 19 20 func TestNormProbs(t *testing.T) { 21 dist1, ok := NewNormal([]float64{0, 0}, mat.NewSymDense(2, []float64{1, 0, 0, 1}), nil) 22 if !ok { 23 t.Errorf("bad test") 24 } 25 dist2, ok := NewNormal([]float64{6, 7}, mat.NewSymDense(2, []float64{8, 2, 0, 4}), nil) 26 if !ok { 27 t.Errorf("bad test") 28 } 29 testProbability(t, []probCase{ 30 { 31 dist: dist1, 32 loc: []float64{0, 0}, 33 logProb: -1.837877066409345, 34 }, 35 { 36 dist: dist2, 37 loc: []float64{6, 7}, 38 logProb: -3.503979321496947, 39 }, 40 { 41 dist: dist2, 42 loc: []float64{1, 2}, 43 logProb: -7.075407892925519, 44 }, 45 }) 46 } 47 48 func TestNewNormalChol(t *testing.T) { 49 for _, test := range []struct { 50 mean []float64 51 cov *mat.SymDense 52 }{ 53 { 54 mean: []float64{2, 3}, 55 cov: mat.NewSymDense(2, []float64{1, 0.1, 0.1, 1}), 56 }, 57 } { 58 var chol mat.Cholesky 59 ok := chol.Factorize(test.cov) 60 if !ok { 61 panic("bad test") 62 } 63 n := NewNormalChol(test.mean, &chol, nil) 64 // Generate a random number and calculate probability to ensure things 65 // have been set properly. See issue #426. 66 x := n.Rand(nil) 67 _ = n.Prob(x) 68 } 69 } 70 71 func TestNormRand(t *testing.T) { 72 for _, test := range []struct { 73 mean []float64 74 cov []float64 75 }{ 76 { 77 mean: []float64{0, 0}, 78 cov: []float64{ 79 1, 0, 80 0, 1, 81 }, 82 }, 83 { 84 mean: []float64{0, 0}, 85 cov: []float64{ 86 1, 0.9, 87 0.9, 1, 88 }, 89 }, 90 { 91 mean: []float64{6, 7}, 92 cov: []float64{ 93 5, 0.9, 94 0.9, 2, 95 }, 96 }, 97 } { 98 dim := len(test.mean) 99 cov := mat.NewSymDense(dim, test.cov) 100 n, ok := NewNormal(test.mean, cov, nil) 101 if !ok { 102 t.Errorf("bad covariance matrix") 103 } 104 105 nSamples := 1000000 106 samps := mat.NewDense(nSamples, dim, nil) 107 for i := 0; i < nSamples; i++ { 108 n.Rand(samps.RawRowView(i)) 109 } 110 estMean := make([]float64, dim) 111 for i := range estMean { 112 estMean[i] = stat.Mean(mat.Col(nil, i, samps), nil) 113 } 114 if !floats.EqualApprox(estMean, test.mean, 1e-2) { 115 t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean) 116 } 117 var estCov mat.SymDense 118 stat.CovarianceMatrix(&estCov, samps, nil) 119 if !mat.EqualApprox(&estCov, cov, 1e-2) { 120 t.Errorf("Cov mismatch: want: %v, got %v", cov, &estCov) 121 } 122 } 123 } 124 125 func TestNormalQuantile(t *testing.T) { 126 for _, test := range []struct { 127 mean []float64 128 cov []float64 129 }{ 130 { 131 mean: []float64{6, 7}, 132 cov: []float64{ 133 5, 0.9, 134 0.9, 2, 135 }, 136 }, 137 } { 138 dim := len(test.mean) 139 cov := mat.NewSymDense(dim, test.cov) 140 n, ok := NewNormal(test.mean, cov, nil) 141 if !ok { 142 t.Errorf("bad covariance matrix") 143 } 144 145 nSamples := 1000000 146 rnd := rand.New(rand.NewSource(1)) 147 samps := mat.NewDense(nSamples, dim, nil) 148 tmp := make([]float64, dim) 149 for i := 0; i < nSamples; i++ { 150 for j := range tmp { 151 tmp[j] = rnd.Float64() 152 } 153 n.Quantile(samps.RawRowView(i), tmp) 154 } 155 estMean := make([]float64, dim) 156 for i := range estMean { 157 estMean[i] = stat.Mean(mat.Col(nil, i, samps), nil) 158 } 159 if !floats.EqualApprox(estMean, test.mean, 1e-2) { 160 t.Errorf("Mean mismatch: want: %v, got %v", test.mean, estMean) 161 } 162 var estCov mat.SymDense 163 stat.CovarianceMatrix(&estCov, samps, nil) 164 if !mat.EqualApprox(&estCov, cov, 1e-2) { 165 t.Errorf("Cov mismatch: want: %v, got %v", cov, &estCov) 166 } 167 } 168 } 169 170 func TestConditionNormal(t *testing.T) { 171 // Uncorrelated values shouldn't influence the updated values. 172 for _, test := range []struct { 173 mu []float64 174 sigma *mat.SymDense 175 observed []int 176 values []float64 177 178 newMu []float64 179 newSigma *mat.SymDense 180 }{ 181 { 182 mu: []float64{2, 3}, 183 sigma: mat.NewSymDense(2, []float64{2, 0, 0, 5}), 184 observed: []int{0}, 185 values: []float64{10}, 186 187 newMu: []float64{3}, 188 newSigma: mat.NewSymDense(1, []float64{5}), 189 }, 190 { 191 mu: []float64{2, 3}, 192 sigma: mat.NewSymDense(2, []float64{2, 0, 0, 5}), 193 observed: []int{1}, 194 values: []float64{10}, 195 196 newMu: []float64{2}, 197 newSigma: mat.NewSymDense(1, []float64{2}), 198 }, 199 { 200 mu: []float64{2, 3, 4}, 201 sigma: mat.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}), 202 observed: []int{1}, 203 values: []float64{10}, 204 205 newMu: []float64{2, 4}, 206 newSigma: mat.NewSymDense(2, []float64{2, 0, 0, 10}), 207 }, 208 { 209 mu: []float64{2, 3, 4}, 210 sigma: mat.NewSymDense(3, []float64{2, 0, 0, 0, 5, 0, 0, 0, 10}), 211 observed: []int{0, 1}, 212 values: []float64{10, 15}, 213 214 newMu: []float64{4}, 215 newSigma: mat.NewSymDense(1, []float64{10}), 216 }, 217 { 218 mu: []float64{2, 3, 4, 5}, 219 sigma: mat.NewSymDense(4, []float64{2, 0.5, 0, 0, 0.5, 5, 0, 0, 0, 0, 10, 2, 0, 0, 2, 3}), 220 observed: []int{0, 1}, 221 values: []float64{10, 15}, 222 223 newMu: []float64{4, 5}, 224 newSigma: mat.NewSymDense(2, []float64{10, 2, 2, 3}), 225 }, 226 } { 227 normal, ok := NewNormal(test.mu, test.sigma, nil) 228 if !ok { 229 t.Fatalf("Bad test, original sigma not positive definite") 230 } 231 newNormal, ok := normal.ConditionNormal(test.observed, test.values, nil) 232 if !ok { 233 t.Fatalf("Bad test, update failure") 234 } 235 236 if !floats.EqualApprox(test.newMu, newNormal.mu, 1e-12) { 237 t.Errorf("Updated mean mismatch. Want %v, got %v.", test.newMu, newNormal.mu) 238 } 239 240 var sigma mat.SymDense 241 newNormal.chol.ToSym(&sigma) 242 if !mat.EqualApprox(test.newSigma, &sigma, 1e-12) { 243 t.Errorf("Updated sigma mismatch\n.Want:\n% v\nGot:\n% v\n", test.newSigma, sigma) 244 } 245 } 246 247 // Test bivariate case where the update rule is analytic 248 for _, test := range []struct { 249 mu []float64 250 std []float64 251 rho float64 252 value float64 253 }{ 254 { 255 mu: []float64{2, 3}, 256 std: []float64{3, 5}, 257 rho: 0.9, 258 value: 1000, 259 }, 260 { 261 mu: []float64{2, 3}, 262 std: []float64{3, 5}, 263 rho: -0.9, 264 value: 1000, 265 }, 266 } { 267 std := test.std 268 rho := test.rho 269 sigma := mat.NewSymDense(2, []float64{std[0] * std[0], std[0] * std[1] * rho, std[0] * std[1] * rho, std[1] * std[1]}) 270 normal, ok := NewNormal(test.mu, sigma, nil) 271 if !ok { 272 t.Fatalf("Bad test, original sigma not positive definite") 273 } 274 newNormal, ok := normal.ConditionNormal([]int{1}, []float64{test.value}, nil) 275 if !ok { 276 t.Fatalf("Bad test, update failed") 277 } 278 var newSigma mat.SymDense 279 newNormal.chol.ToSym(&newSigma) 280 trueMean := test.mu[0] + rho*(std[0]/std[1])*(test.value-test.mu[1]) 281 if math.Abs(trueMean-newNormal.mu[0]) > 1e-14 { 282 t.Errorf("Mean mismatch. Want %v, got %v", trueMean, newNormal.mu[0]) 283 } 284 trueVar := (1 - rho*rho) * std[0] * std[0] 285 if math.Abs(trueVar-newSigma.At(0, 0)) > 1e-14 { 286 t.Errorf("Std mismatch. Want %v, got %v", trueMean, newNormal.mu[0]) 287 } 288 } 289 290 // Test via sampling. 291 for _, test := range []struct { 292 mu []float64 293 sigma *mat.SymDense 294 observed []int 295 unobserved []int 296 value []float64 297 }{ 298 // The indices in unobserved must be in ascending order for this test. 299 { 300 mu: []float64{2, 3, 4}, 301 sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}), 302 303 observed: []int{0}, 304 unobserved: []int{1, 2}, 305 value: []float64{1.9}, 306 }, 307 { 308 mu: []float64{2, 3, 4, 5}, 309 sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}), 310 311 observed: []int{0, 3}, 312 unobserved: []int{1, 2}, 313 value: []float64{1.9, 2.9}, 314 }, 315 } { 316 totalSamp := 4000000 317 var nSamp int 318 samples := mat.NewDense(totalSamp, len(test.mu), nil) 319 normal, ok := NewNormal(test.mu, test.sigma, nil) 320 if !ok { 321 t.Errorf("bad test") 322 } 323 sample := make([]float64, len(test.mu)) 324 for i := 0; i < totalSamp; i++ { 325 normal.Rand(sample) 326 isClose := true 327 for i, v := range test.observed { 328 if math.Abs(sample[v]-test.value[i]) > 1e-1 { 329 isClose = false 330 break 331 } 332 } 333 if isClose { 334 samples.SetRow(nSamp, sample) 335 nSamp++ 336 } 337 } 338 339 if nSamp < 100 { 340 t.Errorf("bad test, not enough samples") 341 continue 342 } 343 samples = samples.Slice(0, nSamp, 0, len(test.mu)).(*mat.Dense) 344 345 // Compute mean and covariance matrix. 346 estMean := make([]float64, len(test.mu)) 347 for i := range estMean { 348 estMean[i] = stat.Mean(mat.Col(nil, i, samples), nil) 349 } 350 var estCov mat.SymDense 351 stat.CovarianceMatrix(&estCov, samples, nil) 352 353 // Compute update rule. 354 newNormal, ok := normal.ConditionNormal(test.observed, test.value, nil) 355 if !ok { 356 t.Fatalf("Bad test, update failure") 357 } 358 359 var subEstMean []float64 360 for _, v := range test.unobserved { 361 362 subEstMean = append(subEstMean, estMean[v]) 363 } 364 subEstCov := mat.NewSymDense(len(test.unobserved), nil) 365 for i := 0; i < len(test.unobserved); i++ { 366 for j := i; j < len(test.unobserved); j++ { 367 subEstCov.SetSym(i, j, estCov.At(test.unobserved[i], test.unobserved[j])) 368 } 369 } 370 371 for i, v := range subEstMean { 372 if math.Abs(newNormal.mu[i]-v) > 5e-2 { 373 t.Errorf("Mean mismatch. Want %v, got %v.", newNormal.mu[i], v) 374 } 375 } 376 var sigma mat.SymDense 377 newNormal.chol.ToSym(&sigma) 378 if !mat.EqualApprox(&sigma, subEstCov, 1e-1) { 379 t.Errorf("Covariance mismatch. Want:\n%0.8v\nGot:\n%0.8v\n", subEstCov, sigma) 380 } 381 } 382 } 383 384 func TestCovarianceMatrix(t *testing.T) { 385 for _, test := range []struct { 386 mu []float64 387 sigma *mat.SymDense 388 }{ 389 { 390 mu: []float64{2, 3, 4}, 391 sigma: mat.NewSymDense(3, []float64{1, 0.5, 3, 0.5, 8, -1, 3, -1, 15}), 392 }, 393 } { 394 normal, ok := NewNormal(test.mu, test.sigma, nil) 395 if !ok { 396 t.Fatalf("Bad test, covariance matrix not positive definite") 397 } 398 var cov mat.SymDense 399 normal.CovarianceMatrix(&cov) 400 if !mat.EqualApprox(&cov, test.sigma, 1e-14) { 401 t.Errorf("Covariance mismatch with nil input") 402 } 403 dim := test.sigma.SymmetricDim() 404 cov = *mat.NewSymDense(dim, nil) 405 normal.CovarianceMatrix(&cov) 406 if !mat.EqualApprox(&cov, test.sigma, 1e-14) { 407 t.Errorf("Covariance mismatch with supplied input") 408 } 409 } 410 } 411 412 func TestMarginal(t *testing.T) { 413 for _, test := range []struct { 414 mu []float64 415 sigma *mat.SymDense 416 marginal []int 417 }{ 418 { 419 mu: []float64{2, 3, 4}, 420 sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}), 421 marginal: []int{0}, 422 }, 423 { 424 mu: []float64{2, 3, 4}, 425 sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}), 426 marginal: []int{0, 2}, 427 }, 428 { 429 mu: []float64{2, 3, 4, 5}, 430 sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}), 431 432 marginal: []int{0, 3}, 433 }, 434 } { 435 normal, ok := NewNormal(test.mu, test.sigma, nil) 436 if !ok { 437 t.Fatalf("Bad test, covariance matrix not positive definite") 438 } 439 marginal, ok := normal.MarginalNormal(test.marginal, nil) 440 if !ok { 441 t.Fatalf("Bad test, marginal matrix not positive definite") 442 } 443 dim := normal.Dim() 444 nSamples := 1000000 445 samps := mat.NewDense(nSamples, dim, nil) 446 for i := 0; i < nSamples; i++ { 447 normal.Rand(samps.RawRowView(i)) 448 } 449 estMean := make([]float64, dim) 450 for i := range estMean { 451 estMean[i] = stat.Mean(mat.Col(nil, i, samps), nil) 452 } 453 for i, v := range test.marginal { 454 if math.Abs(marginal.mu[i]-estMean[v]) > 1e-2 { 455 t.Errorf("Mean mismatch: want: %v, got %v", estMean[v], marginal.mu[i]) 456 } 457 } 458 459 var marginalCov mat.SymDense 460 marginal.CovarianceMatrix(&marginalCov) 461 var estCov mat.SymDense 462 stat.CovarianceMatrix(&estCov, samps, nil) 463 for i, v1 := range test.marginal { 464 for j, v2 := range test.marginal { 465 c := marginalCov.At(i, j) 466 ec := estCov.At(v1, v2) 467 if math.Abs(c-ec) > 5e-2 { 468 t.Errorf("Cov mismatch element i = %d, j = %d: want: %v, got %v", i, j, c, ec) 469 } 470 } 471 } 472 } 473 } 474 475 func TestMarginalSingle(t *testing.T) { 476 for _, test := range []struct { 477 mu []float64 478 sigma *mat.SymDense 479 }{ 480 { 481 mu: []float64{2, 3, 4}, 482 sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}), 483 }, 484 { 485 mu: []float64{2, 3, 4, 5}, 486 sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}), 487 }, 488 } { 489 normal, ok := NewNormal(test.mu, test.sigma, nil) 490 if !ok { 491 t.Fatalf("Bad test, covariance matrix not positive definite") 492 } 493 for i, mean := range test.mu { 494 norm := normal.MarginalNormalSingle(i, nil) 495 if norm.Mean() != mean { 496 t.Errorf("Mean mismatch nil Sigma, idx %v: want %v, got %v.", i, mean, norm.Mean()) 497 } 498 std := math.Sqrt(test.sigma.At(i, i)) 499 if math.Abs(norm.StdDev()-std) > 1e-14 { 500 t.Errorf("StdDev mismatch nil Sigma, idx %v: want %v, got %v.", i, std, norm.StdDev()) 501 } 502 } 503 } 504 505 // Test matching with TestMarginal. 506 rnd := rand.New(rand.NewSource(1)) 507 for cas := 0; cas < 10; cas++ { 508 dim := rnd.Intn(10) + 1 509 mu := make([]float64, dim) 510 for i := range mu { 511 mu[i] = rnd.Float64() 512 } 513 x := make([]float64, dim*dim) 514 for i := range x { 515 x[i] = rnd.Float64() 516 } 517 matrix := mat.NewDense(dim, dim, x) 518 var sigma mat.SymDense 519 sigma.SymOuterK(1, matrix) 520 521 normal, ok := NewNormal(mu, &sigma, nil) 522 if !ok { 523 t.Fatal("bad test") 524 } 525 for i := 0; i < dim; i++ { 526 single := normal.MarginalNormalSingle(i, nil) 527 mult, ok := normal.MarginalNormal([]int{i}, nil) 528 if !ok { 529 t.Fatal("bad test") 530 } 531 if math.Abs(single.Mean()-mult.Mean(nil)[0]) > 1e-14 { 532 t.Errorf("Mean mismatch") 533 } 534 var cov mat.SymDense 535 mult.CovarianceMatrix(&cov) 536 if math.Abs(single.Variance()-cov.At(0, 0)) > 1e-14 { 537 t.Errorf("Variance mismatch") 538 } 539 } 540 } 541 } 542 543 func TestNormalScoreInput(t *testing.T) { 544 for cas, test := range []struct { 545 mu []float64 546 sigma *mat.SymDense 547 x []float64 548 }{ 549 { 550 mu: []float64{2, 3, 4}, 551 sigma: mat.NewSymDense(3, []float64{2, 0.5, 3, 0.5, 1, 0.6, 3, 0.6, 10}), 552 x: []float64{1, 3.1, -2}, 553 }, 554 { 555 mu: []float64{2, 3, 4, 5}, 556 sigma: mat.NewSymDense(4, []float64{2, 0.5, 3, 0.1, 0.5, 1, 0.6, 0.2, 3, 0.6, 10, 0.3, 0.1, 0.2, 0.3, 3}), 557 x: []float64{1, 3.1, -2, 5}, 558 }, 559 } { 560 normal, ok := NewNormal(test.mu, test.sigma, nil) 561 if !ok { 562 t.Fatalf("Bad test, covariance matrix not positive definite") 563 } 564 x := make([]float64, len(test.x)) 565 copy(x, test.x) 566 score := normal.ScoreInput(nil, x) 567 if !floats.Equal(x, test.x) { 568 t.Errorf("x modified during call to ScoreInput") 569 } 570 scoreFD := fd.Gradient(nil, normal.LogProb, x, nil) 571 if !floats.EqualApprox(score, scoreFD, 1e-4) { 572 t.Errorf("Case %d: derivative mismatch. Got %v, want %v", cas, score, scoreFD) 573 } 574 } 575 } 576 577 func TestNormalRandCov(t *testing.T) { 578 const numSamples = 1_000_000 579 const tol = 1e-2 580 581 for cas, test := range []struct { 582 mean []float64 583 cov []float64 584 semi bool 585 }{ 586 { 587 mean: []float64{0, 0}, 588 cov: []float64{ 589 1, 0, 590 0, 1, 591 }, 592 }, 593 { 594 mean: []float64{0, 0}, 595 cov: []float64{ 596 2, 0.9, 597 0.9, 1, 598 }, 599 }, 600 { 601 mean: []float64{6, 7}, 602 cov: []float64{ 603 2, 0.9, 604 0.9, 5, 605 }, 606 }, 607 { 608 mean: []float64{-1, 2, -3}, 609 cov: []float64{ 610 2, 1, 1, 611 1, 3, 0.1, 612 1, 0.1, 5, 613 }, 614 }, 615 { 616 mean: []float64{-1, 2, -3}, 617 cov: []float64{ 618 0, 0, 0, 619 0, 1, 0.1, 620 0, 0.1, 5, 621 }, 622 semi: true, 623 }, 624 } { 625 for _, covType := range []string{"Cholesky", "PivotedCholesky", "SymDense", "EigenSym"} { 626 if covType == "Cholesky" && test.semi { 627 continue 628 } 629 630 name := fmt.Sprintf("case %d with %s", cas, covType) 631 632 n := len(test.mean) 633 a := mat.NewSymDense(n, test.cov) 634 src := rand.NewSource(1) 635 636 var cov mat.Symmetric 637 switch covType { 638 case "Cholesky": 639 var chol mat.Cholesky 640 ok := chol.Factorize(a) 641 if !ok { 642 t.Fatalf("%s: Cholesky factorization failed, input covariance not positive definite?", name) 643 } 644 cov = &chol 645 case "PivotedCholesky": 646 var chol mat.PivotedCholesky 647 chol.Factorize(a, -1) 648 cov = &chol 649 case "SymDense": 650 cov = a 651 case "EigenSym": 652 var ed mat.EigenSym 653 ok := ed.Factorize(a, true) 654 if !ok { 655 t.Fatalf("%s: eigendecomposition failed", name) 656 } 657 cov = NewPositivePartEigenSym(&ed) 658 } 659 660 samples := mat.NewDense(numSamples, n, nil) 661 for i := 0; i < numSamples; i++ { 662 NormalRandCov(samples.RawRowView(i), test.mean, cov, src) 663 } 664 665 estMean := make([]float64, n) 666 for i := range estMean { 667 estMean[i] = stat.Mean(mat.Col(nil, i, samples), nil) 668 } 669 if !floats.EqualApprox(estMean, test.mean, tol) { 670 t.Errorf("%s: mean mismatch\nwant=%v\n got=%v", name, test.mean, estMean) 671 } 672 673 var estCov mat.SymDense 674 stat.CovarianceMatrix(&estCov, samples, nil) 675 if !mat.EqualApprox(&estCov, a, tol) { 676 t.Errorf("%s: covariance mismatch\nwant=%v\n got=%v", name, 677 mat.Formatted(a, mat.Prefix(" ")), mat.Formatted(&estCov, mat.Prefix(" "))) 678 } 679 } 680 } 681 }