gonum.org/v1/gonum@v0.14.0/mat/solve_test.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"testing"
     9  
    10  	"golang.org/x/exp/rand"
    11  )
    12  
    13  func TestSolve(t *testing.T) {
    14  	t.Parallel()
    15  	rnd := rand.New(rand.NewSource(1))
    16  	// Hand-coded cases.
    17  	for _, test := range []struct {
    18  		a         [][]float64
    19  		b         [][]float64
    20  		ans       [][]float64
    21  		shouldErr bool
    22  	}{
    23  		{
    24  			a:         [][]float64{{6}},
    25  			b:         [][]float64{{3}},
    26  			ans:       [][]float64{{0.5}},
    27  			shouldErr: false,
    28  		},
    29  		{
    30  			a: [][]float64{
    31  				{1, 0, 0},
    32  				{0, 1, 0},
    33  				{0, 0, 1},
    34  			},
    35  			b: [][]float64{
    36  				{3},
    37  				{2},
    38  				{1},
    39  			},
    40  			ans: [][]float64{
    41  				{3},
    42  				{2},
    43  				{1},
    44  			},
    45  			shouldErr: false,
    46  		},
    47  		{
    48  			a: [][]float64{
    49  				{0.8147, 0.9134, 0.5528},
    50  				{0.9058, 0.6324, 0.8723},
    51  				{0.1270, 0.0975, 0.7612},
    52  			},
    53  			b: [][]float64{
    54  				{0.278},
    55  				{0.547},
    56  				{0.958},
    57  			},
    58  			ans: [][]float64{
    59  				{-0.932687281002860},
    60  				{0.303963920182067},
    61  				{1.375216503507109},
    62  			},
    63  			shouldErr: false,
    64  		},
    65  		{
    66  			a: [][]float64{
    67  				{0.8147, 0.9134, 0.5528},
    68  				{0.9058, 0.6324, 0.8723},
    69  			},
    70  			b: [][]float64{
    71  				{0.278},
    72  				{0.547},
    73  			},
    74  			ans: [][]float64{
    75  				{0.25919787248965376},
    76  				{-0.25560256266441034},
    77  				{0.5432324059702451},
    78  			},
    79  			shouldErr: false,
    80  		},
    81  		{
    82  			a: [][]float64{
    83  				{0.8147, 0.9134, 0.9},
    84  				{0.9058, 0.6324, 0.9},
    85  				{0.1270, 0.0975, 0.1},
    86  				{1.6, 2.8, -3.5},
    87  			},
    88  			b: [][]float64{
    89  				{0.278},
    90  				{0.547},
    91  				{-0.958},
    92  				{1.452},
    93  			},
    94  			ans: [][]float64{
    95  				{0.820970340787782},
    96  				{-0.218604626527306},
    97  				{-0.212938815234215},
    98  			},
    99  			shouldErr: false,
   100  		},
   101  		{
   102  			a: [][]float64{
   103  				{0.8147, 0.9134, 0.231, -1.65},
   104  				{0.9058, 0.6324, 0.9, 0.72},
   105  				{0.1270, 0.0975, 0.1, 1.723},
   106  				{1.6, 2.8, -3.5, 0.987},
   107  				{7.231, 9.154, 1.823, 0.9},
   108  			},
   109  			b: [][]float64{
   110  				{0.278, 8.635},
   111  				{0.547, 9.125},
   112  				{-0.958, -0.762},
   113  				{1.452, 1.444},
   114  				{1.999, -7.234},
   115  			},
   116  			ans: [][]float64{
   117  				{1.863006789511373, 44.467887791812750},
   118  				{-1.127270935407224, -34.073794226035126},
   119  				{-0.527926457947330, -8.032133759788573},
   120  				{-0.248621916204897, -2.366366415805275},
   121  			},
   122  			shouldErr: false,
   123  		},
   124  		{
   125  			a: [][]float64{
   126  				{0, 0},
   127  				{0, 0},
   128  			},
   129  			b: [][]float64{
   130  				{3},
   131  				{2},
   132  			},
   133  			ans:       nil,
   134  			shouldErr: true,
   135  		},
   136  		{
   137  			a: [][]float64{
   138  				{0, 0},
   139  				{0, 0},
   140  				{0, 0},
   141  			},
   142  			b: [][]float64{
   143  				{3},
   144  				{2},
   145  				{1},
   146  			},
   147  			ans:       nil,
   148  			shouldErr: true,
   149  		},
   150  		{
   151  			a: [][]float64{
   152  				{0, 0, 0},
   153  				{0, 0, 0},
   154  			},
   155  			b: [][]float64{
   156  				{3},
   157  				{2},
   158  			},
   159  			ans:       nil,
   160  			shouldErr: true,
   161  		},
   162  	} {
   163  		a := NewDense(flatten(test.a))
   164  		b := NewDense(flatten(test.b))
   165  
   166  		var ans *Dense
   167  		if test.ans != nil {
   168  			ans = NewDense(flatten(test.ans))
   169  		}
   170  
   171  		var x Dense
   172  		err := x.Solve(a, b)
   173  		if err != nil {
   174  			if !test.shouldErr {
   175  				t.Errorf("Unexpected solve error: %s", err)
   176  			}
   177  			continue
   178  		}
   179  		if err == nil && test.shouldErr {
   180  			t.Errorf("Did not error during solve.")
   181  			continue
   182  		}
   183  		if !EqualApprox(&x, ans, 1e-12) {
   184  			t.Errorf("Solve answer mismatch. Want %v, got %v", ans, x)
   185  		}
   186  	}
   187  
   188  	// Random Cases.
   189  	for _, test := range []struct {
   190  		m, n, bc int
   191  	}{
   192  		{5, 5, 1},
   193  		{5, 10, 1},
   194  		{10, 5, 1},
   195  		{5, 5, 7},
   196  		{5, 10, 7},
   197  		{10, 5, 7},
   198  		{5, 5, 12},
   199  		{5, 10, 12},
   200  		{10, 5, 12},
   201  	} {
   202  		m := test.m
   203  		n := test.n
   204  		bc := test.bc
   205  		a := NewDense(m, n, nil)
   206  		for i := 0; i < m; i++ {
   207  			for j := 0; j < n; j++ {
   208  				a.Set(i, j, rnd.Float64())
   209  			}
   210  		}
   211  		br := m
   212  		b := NewDense(br, bc, nil)
   213  		for i := 0; i < br; i++ {
   214  			for j := 0; j < bc; j++ {
   215  				b.Set(i, j, rnd.Float64())
   216  			}
   217  		}
   218  		var x Dense
   219  		err := x.Solve(a, b)
   220  		if err != nil {
   221  			t.Errorf("unexpected error from dense solve: %v", err)
   222  		}
   223  
   224  		// Test that the normal equations hold.
   225  		// Aᵀ * A * x = Aᵀ * b
   226  		var tmp, lhs, rhs Dense
   227  		tmp.Mul(a.T(), a)
   228  		lhs.Mul(&tmp, &x)
   229  		rhs.Mul(a.T(), b)
   230  		if !EqualApprox(&lhs, &rhs, 1e-10) {
   231  			t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs)
   232  		}
   233  	}
   234  
   235  	// Use testTwoInput.
   236  	method := func(receiver, a, b Matrix) {
   237  		type Solver interface {
   238  			Solve(a, b Matrix) error
   239  		}
   240  		rd := receiver.(Solver)
   241  		_ = rd.Solve(a, b)
   242  	}
   243  	denseComparison := func(receiver, a, b *Dense) {
   244  		_ = receiver.Solve(a, b)
   245  	}
   246  	testTwoInput(t, "Solve", &Dense{}, method, denseComparison, legalTypesAll, legalSizeSolve, 1e-7)
   247  }
   248  
   249  func TestSolveVec(t *testing.T) {
   250  	t.Parallel()
   251  	rnd := rand.New(rand.NewSource(1))
   252  	for _, test := range []struct {
   253  		m, n int
   254  	}{
   255  		{5, 5},
   256  		{5, 10},
   257  		{10, 5},
   258  		{5, 5},
   259  		{5, 10},
   260  		{10, 5},
   261  		{5, 5},
   262  		{5, 10},
   263  		{10, 5},
   264  	} {
   265  		m := test.m
   266  		n := test.n
   267  		a := NewDense(m, n, nil)
   268  		for i := 0; i < m; i++ {
   269  			for j := 0; j < n; j++ {
   270  				a.Set(i, j, rnd.Float64())
   271  			}
   272  		}
   273  		br := m
   274  		b := NewVecDense(br, nil)
   275  		for i := 0; i < br; i++ {
   276  			b.SetVec(i, rnd.Float64())
   277  		}
   278  		var x VecDense
   279  		err := x.SolveVec(a, b)
   280  		if err != nil {
   281  			t.Errorf("unexpected error from dense vector solve: %v", err)
   282  		}
   283  
   284  		// Test that the normal equations hold.
   285  		// Aᵀ * A * x = Aᵀ * b
   286  		var tmp, lhs, rhs Dense
   287  		tmp.Mul(a.T(), a)
   288  		lhs.Mul(&tmp, &x)
   289  		rhs.Mul(a.T(), b)
   290  		if !EqualApprox(&lhs, &rhs, 1e-10) {
   291  			t.Errorf("Normal equations do not hold.\nLHS: %v\n, RHS: %v\n", lhs, rhs)
   292  		}
   293  	}
   294  
   295  	// Use testTwoInput
   296  	method := func(receiver, a, b Matrix) {
   297  		type SolveVecer interface {
   298  			SolveVec(a Matrix, b Vector) error
   299  		}
   300  		rd := receiver.(SolveVecer)
   301  		_ = rd.SolveVec(a, b.(Vector))
   302  	}
   303  	denseComparison := func(receiver, a, b *Dense) {
   304  		_ = receiver.Solve(a, b)
   305  	}
   306  	testTwoInput(t, "SolveVec", &VecDense{}, method, denseComparison, legalTypesMatrixVector, legalSizeSolve, 1e-12)
   307  }