github.com/gonum/matrix@v0.0.0-20181209220409-c518dec07be9/mat64/triangular_test.go (about) 1 package mat64 2 3 import ( 4 "math" 5 "math/rand" 6 "reflect" 7 "testing" 8 9 "github.com/gonum/blas" 10 "github.com/gonum/blas/blas64" 11 "github.com/gonum/matrix" 12 ) 13 14 func TestNewTriangular(t *testing.T) { 15 for i, test := range []struct { 16 data []float64 17 n int 18 kind matrix.TriKind 19 mat *TriDense 20 }{ 21 { 22 data: []float64{ 23 1, 2, 3, 24 4, 5, 6, 25 7, 8, 9, 26 }, 27 n: 3, 28 kind: matrix.Upper, 29 mat: &TriDense{ 30 mat: blas64.Triangular{ 31 N: 3, 32 Stride: 3, 33 Uplo: blas.Upper, 34 Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}, 35 Diag: blas.NonUnit, 36 }, 37 cap: 3, 38 }, 39 }, 40 } { 41 tri := NewTriDense(test.n, test.kind, test.data) 42 rows, cols := tri.Dims() 43 44 if rows != test.n { 45 t.Errorf("unexpected number of rows for test %d: got: %d want: %d", i, rows, test.n) 46 } 47 if cols != test.n { 48 t.Errorf("unexpected number of cols for test %d: got: %d want: %d", i, cols, test.n) 49 } 50 if !reflect.DeepEqual(tri, test.mat) { 51 t.Errorf("unexpected data slice for test %d: got: %v want: %v", i, tri, test.mat) 52 } 53 } 54 55 for _, kind := range []matrix.TriKind{matrix.Lower, matrix.Upper} { 56 panicked, message := panics(func() { NewTriDense(3, kind, []float64{1, 2}) }) 57 if !panicked || message != matrix.ErrShape.Error() { 58 t.Errorf("expected panic for invalid data slice length for upper=%t", kind) 59 } 60 } 61 } 62 63 func TestTriAtSet(t *testing.T) { 64 tri := &TriDense{ 65 mat: blas64.Triangular{ 66 N: 3, 67 Stride: 3, 68 Uplo: blas.Upper, 69 Data: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9}, 70 Diag: blas.NonUnit, 71 }, 72 cap: 3, 73 } 74 75 rows, cols := tri.Dims() 76 77 // Check At out of bounds 78 for _, row := range []int{-1, rows, rows + 1} { 79 panicked, message := panics(func() { tri.At(row, 0) }) 80 if !panicked || message != matrix.ErrRowAccess.Error() { 81 t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row) 82 } 83 } 84 for _, col := range []int{-1, cols, cols + 1} { 85 panicked, message := panics(func() { tri.At(0, col) }) 86 if !panicked || message != matrix.ErrColAccess.Error() { 87 t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col) 88 } 89 } 90 91 // Check Set out of bounds 92 for _, row := range []int{-1, rows, rows + 1} { 93 panicked, message := panics(func() { tri.SetTri(row, 0, 1.2) }) 94 if !panicked || message != matrix.ErrRowAccess.Error() { 95 t.Errorf("expected panic for invalid row access N=%d r=%d", rows, row) 96 } 97 } 98 for _, col := range []int{-1, cols, cols + 1} { 99 panicked, message := panics(func() { tri.SetTri(0, col, 1.2) }) 100 if !panicked || message != matrix.ErrColAccess.Error() { 101 t.Errorf("expected panic for invalid column access N=%d c=%d", cols, col) 102 } 103 } 104 105 for _, st := range []struct { 106 row, col int 107 uplo blas.Uplo 108 }{ 109 {row: 2, col: 1, uplo: blas.Upper}, 110 {row: 1, col: 2, uplo: blas.Lower}, 111 } { 112 tri.mat.Uplo = st.uplo 113 panicked, message := panics(func() { tri.SetTri(st.row, st.col, 1.2) }) 114 if !panicked || message != matrix.ErrTriangleSet.Error() { 115 t.Errorf("expected panic for %+v", st) 116 } 117 } 118 119 for _, st := range []struct { 120 row, col int 121 uplo blas.Uplo 122 orig, new float64 123 }{ 124 {row: 2, col: 1, uplo: blas.Lower, orig: 8, new: 15}, 125 {row: 1, col: 2, uplo: blas.Upper, orig: 6, new: 15}, 126 } { 127 tri.mat.Uplo = st.uplo 128 if e := tri.At(st.row, st.col); e != st.orig { 129 t.Errorf("unexpected value for At(%d, %d): got: %v want: %v", st.row, st.col, e, st.orig) 130 } 131 tri.SetTri(st.row, st.col, st.new) 132 if e := tri.At(st.row, st.col); e != st.new { 133 t.Errorf("unexpected value for At(%d, %d) after SetTri(%[1]d, %d, %v): got: %v want: %[3]v", st.row, st.col, st.new, e) 134 } 135 } 136 } 137 138 func TestTriDenseCopy(t *testing.T) { 139 for i := 0; i < 100; i++ { 140 size := rand.Intn(100) 141 r, err := randDense(size, 0.9, rand.NormFloat64) 142 if size == 0 { 143 if err != matrix.ErrZeroLength { 144 t.Fatalf("expected error %v: got: %v", matrix.ErrZeroLength, err) 145 } 146 continue 147 } 148 if err != nil { 149 t.Fatalf("unexpected error: %v", err) 150 } 151 152 u := NewTriDense(size, true, nil) 153 l := NewTriDense(size, false, nil) 154 155 for _, typ := range []Matrix{r, (*basicMatrix)(r)} { 156 for j := range u.mat.Data { 157 u.mat.Data[j] = math.NaN() 158 l.mat.Data[j] = math.NaN() 159 } 160 u.Copy(typ) 161 l.Copy(typ) 162 for m := 0; m < size; m++ { 163 for n := 0; n < size; n++ { 164 want := typ.At(m, n) 165 switch { 166 case m < n: // Upper triangular matrix. 167 if got := u.At(m, n); got != want { 168 t.Errorf("unexpected upper value for At(%d, %d) for test %d: got: %v want: %v", m, n, i, got, want) 169 } 170 case m == n: // Diagonal matrix. 171 if got := u.At(m, n); got != want { 172 t.Errorf("unexpected upper value for At(%d, %d) for test %d: got: %v want: %v", m, n, i, got, want) 173 } 174 if got := l.At(m, n); got != want { 175 t.Errorf("unexpected diagonal value for At(%d, %d) for test %d: got: %v want: %v", m, n, i, got, want) 176 } 177 case m < n: // Lower triangular matrix. 178 if got := l.At(m, n); got != want { 179 t.Errorf("unexpected lower value for At(%d, %d) for test %d: got: %v want: %v", m, n, i, got, want) 180 } 181 } 182 } 183 } 184 } 185 } 186 } 187 188 func TestTriTriDenseCopy(t *testing.T) { 189 for i := 0; i < 100; i++ { 190 size := rand.Intn(100) 191 r, err := randDense(size, 1, rand.NormFloat64) 192 if size == 0 { 193 if err != matrix.ErrZeroLength { 194 t.Fatalf("expected error %v: got: %v", matrix.ErrZeroLength, err) 195 } 196 continue 197 } 198 if err != nil { 199 t.Fatalf("unexpected error: %v", err) 200 } 201 202 ur := NewTriDense(size, true, nil) 203 lr := NewTriDense(size, false, nil) 204 205 ur.Copy(r) 206 lr.Copy(r) 207 208 u := NewTriDense(size, true, nil) 209 u.Copy(ur) 210 if !equal(u, ur) { 211 t.Fatal("unexpected result for U triangle copy of U triangle: not equal") 212 } 213 214 l := NewTriDense(size, false, nil) 215 l.Copy(lr) 216 if !equal(l, lr) { 217 t.Fatal("unexpected result for L triangle copy of L triangle: not equal") 218 } 219 220 zero(u.mat.Data) 221 u.Copy(lr) 222 if !isDiagonal(u) { 223 t.Fatal("unexpected result for U triangle copy of L triangle: off diagonal non-zero element") 224 } 225 if !equalDiagonal(u, lr) { 226 t.Fatal("unexpected result for U triangle copy of L triangle: diagonal not equal") 227 } 228 229 zero(l.mat.Data) 230 l.Copy(ur) 231 if !isDiagonal(l) { 232 t.Fatal("unexpected result for L triangle copy of U triangle: off diagonal non-zero element") 233 } 234 if !equalDiagonal(l, ur) { 235 t.Fatal("unexpected result for L triangle copy of U triangle: diagonal not equal") 236 } 237 } 238 } 239 240 func TestTriInverse(t *testing.T) { 241 for _, kind := range []matrix.TriKind{matrix.Upper, matrix.Lower} { 242 for _, n := range []int{1, 3, 5, 9} { 243 data := make([]float64, n*n) 244 for i := range data { 245 data[i] = rand.NormFloat64() 246 } 247 a := NewTriDense(n, kind, data) 248 var tr TriDense 249 err := tr.InverseTri(a) 250 if err != nil { 251 t.Errorf("Bad test: %s", err) 252 } 253 var d Dense 254 d.Mul(a, &tr) 255 if !equalApprox(eye(n), &d, 1e-8, false) { 256 var diff Dense 257 diff.Sub(eye(n), &d) 258 t.Errorf("Tri times inverse is not identity. Norm of difference: %v", Norm(&diff, 2)) 259 } 260 } 261 } 262 } 263 264 func TestTriMul(t *testing.T) { 265 method := func(receiver, a, b Matrix) { 266 type MulTrier interface { 267 MulTri(a, b Triangular) 268 } 269 receiver.(MulTrier).MulTri(a.(Triangular), b.(Triangular)) 270 } 271 denseComparison := func(receiver, a, b *Dense) { 272 receiver.Mul(a, b) 273 } 274 legalSizeTriMul := func(ar, ac, br, bc int) bool { 275 // Need both to be square and the sizes to be the same 276 return ar == ac && br == bc && ar == br 277 } 278 279 // The legal types are triangles with the same TriKind. 280 // legalTypesTri returns whether both input arguments are Triangular. 281 legalTypes := func(a, b Matrix) bool { 282 at, ok := a.(Triangular) 283 if !ok { 284 return false 285 } 286 bt, ok := b.(Triangular) 287 if !ok { 288 return false 289 } 290 _, ak := at.Triangle() 291 _, bk := bt.Triangle() 292 return ak == bk 293 } 294 legalTypesLower := func(a, b Matrix) bool { 295 legal := legalTypes(a, b) 296 if !legal { 297 return false 298 } 299 _, kind := a.(Triangular).Triangle() 300 r := kind == matrix.Lower 301 return r 302 } 303 receiver := NewTriDense(3, matrix.Lower, nil) 304 testTwoInput(t, "TriMul", receiver, method, denseComparison, legalTypesLower, legalSizeTriMul, 1e-14) 305 306 legalTypesUpper := func(a, b Matrix) bool { 307 legal := legalTypes(a, b) 308 if !legal { 309 return false 310 } 311 _, kind := a.(Triangular).Triangle() 312 r := kind == matrix.Upper 313 return r 314 } 315 receiver = NewTriDense(3, matrix.Upper, nil) 316 testTwoInput(t, "TriMul", receiver, method, denseComparison, legalTypesUpper, legalSizeTriMul, 1e-14) 317 }