github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/blas64/conv_symmetric_test.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package blas64 6 7 import ( 8 "math" 9 "testing" 10 11 "github.com/jingcheng-WU/gonum/blas" 12 ) 13 14 func newSymmetricFrom(a SymmetricCols) Symmetric { 15 t := Symmetric{ 16 N: a.N, 17 Stride: a.N, 18 Data: make([]float64, a.N*a.N), 19 Uplo: a.Uplo, 20 } 21 t.From(a) 22 return t 23 } 24 25 func (m Symmetric) n() int { return m.N } 26 func (m Symmetric) at(i, j int) float64 { 27 if m.Uplo == blas.Lower && i < j && j < m.N { 28 i, j = j, i 29 } 30 if m.Uplo == blas.Upper && i > j { 31 i, j = j, i 32 } 33 return m.Data[i*m.Stride+j] 34 } 35 func (m Symmetric) uplo() blas.Uplo { return m.Uplo } 36 37 func newSymmetricColsFrom(a Symmetric) SymmetricCols { 38 t := SymmetricCols{ 39 N: a.N, 40 Stride: a.N, 41 Data: make([]float64, a.N*a.N), 42 Uplo: a.Uplo, 43 } 44 t.From(a) 45 return t 46 } 47 48 func (m SymmetricCols) n() int { return m.N } 49 func (m SymmetricCols) at(i, j int) float64 { 50 if m.Uplo == blas.Lower && i < j { 51 i, j = j, i 52 } 53 if m.Uplo == blas.Upper && i > j && i < m.N { 54 i, j = j, i 55 } 56 return m.Data[i+j*m.Stride] 57 } 58 func (m SymmetricCols) uplo() blas.Uplo { return m.Uplo } 59 60 type symmetric interface { 61 n() int 62 at(i, j int) float64 63 uplo() blas.Uplo 64 } 65 66 func sameSymmetric(a, b symmetric) bool { 67 an := a.n() 68 bn := b.n() 69 if an != bn { 70 return false 71 } 72 if a.uplo() != b.uplo() { 73 return false 74 } 75 for i := 0; i < an; i++ { 76 for j := 0; j < an; j++ { 77 if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) { 78 return false 79 } 80 } 81 } 82 return true 83 } 84 85 var symmetricTests = []Symmetric{ 86 {N: 3, Stride: 3, Data: []float64{ 87 1, 2, 3, 88 4, 5, 6, 89 7, 8, 9, 90 }}, 91 {N: 3, Stride: 5, Data: []float64{ 92 1, 2, 3, 0, 0, 93 4, 5, 6, 0, 0, 94 7, 8, 9, 0, 0, 95 }}, 96 } 97 98 func TestConvertSymmetric(t *testing.T) { 99 for _, test := range symmetricTests { 100 for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} { 101 test.Uplo = uplo 102 colmajor := newSymmetricColsFrom(test) 103 if !sameSymmetric(colmajor, test) { 104 t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v", 105 colmajor, test) 106 } 107 rowmajor := newSymmetricFrom(colmajor) 108 if !sameSymmetric(rowmajor, test) { 109 t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v", 110 rowmajor, test) 111 } 112 } 113 } 114 } 115 func newSymmetricBandFrom(a SymmetricBandCols) SymmetricBand { 116 t := SymmetricBand{ 117 N: a.N, 118 K: a.K, 119 Stride: a.K + 1, 120 Data: make([]float64, a.N*(a.K+1)), 121 Uplo: a.Uplo, 122 } 123 for i := range t.Data { 124 t.Data[i] = math.NaN() 125 } 126 t.From(a) 127 return t 128 } 129 130 func (m SymmetricBand) n() (n int) { return m.N } 131 func (m SymmetricBand) at(i, j int) float64 { 132 b := Band{ 133 Rows: m.N, Cols: m.N, 134 Stride: m.Stride, 135 Data: m.Data, 136 } 137 switch m.Uplo { 138 default: 139 panic("blas64: bad BLAS uplo") 140 case blas.Upper: 141 b.KU = m.K 142 if i > j { 143 i, j = j, i 144 } 145 case blas.Lower: 146 b.KL = m.K 147 if i < j { 148 i, j = j, i 149 } 150 } 151 return b.at(i, j) 152 } 153 func (m SymmetricBand) bandwidth() (k int) { return m.K } 154 func (m SymmetricBand) uplo() blas.Uplo { return m.Uplo } 155 156 func newSymmetricBandColsFrom(a SymmetricBand) SymmetricBandCols { 157 t := SymmetricBandCols{ 158 N: a.N, 159 K: a.K, 160 Stride: a.K + 1, 161 Data: make([]float64, a.N*(a.K+1)), 162 Uplo: a.Uplo, 163 } 164 for i := range t.Data { 165 t.Data[i] = math.NaN() 166 } 167 t.From(a) 168 return t 169 } 170 171 func (m SymmetricBandCols) n() (n int) { return m.N } 172 func (m SymmetricBandCols) at(i, j int) float64 { 173 b := BandCols{ 174 Rows: m.N, Cols: m.N, 175 Stride: m.Stride, 176 Data: m.Data, 177 } 178 switch m.Uplo { 179 default: 180 panic("blas64: bad BLAS uplo") 181 case blas.Upper: 182 b.KU = m.K 183 if i > j { 184 i, j = j, i 185 } 186 case blas.Lower: 187 b.KL = m.K 188 if i < j { 189 i, j = j, i 190 } 191 } 192 return b.at(i, j) 193 } 194 func (m SymmetricBandCols) bandwidth() (k int) { return m.K } 195 func (m SymmetricBandCols) uplo() blas.Uplo { return m.Uplo } 196 197 type symmetricBand interface { 198 n() (n int) 199 at(i, j int) float64 200 bandwidth() (k int) 201 uplo() blas.Uplo 202 } 203 204 func sameSymmetricBand(a, b symmetricBand) bool { 205 an := a.n() 206 bn := b.n() 207 if an != bn { 208 return false 209 } 210 if a.uplo() != b.uplo() { 211 return false 212 } 213 ak := a.bandwidth() 214 bk := b.bandwidth() 215 if ak != bk { 216 return false 217 } 218 for i := 0; i < an; i++ { 219 for j := 0; j < an; j++ { 220 if a.at(i, j) != b.at(i, j) || math.IsNaN(a.at(i, j)) != math.IsNaN(b.at(i, j)) { 221 return false 222 } 223 } 224 } 225 return true 226 } 227 228 var symmetricBandTests = []SymmetricBand{ 229 {N: 3, K: 0, Stride: 1, Uplo: blas.Upper, Data: []float64{ 230 1, 231 2, 232 3, 233 }}, 234 {N: 3, K: 0, Stride: 1, Uplo: blas.Lower, Data: []float64{ 235 1, 236 2, 237 3, 238 }}, 239 {N: 3, K: 1, Stride: 2, Uplo: blas.Upper, Data: []float64{ 240 1, 2, 241 3, 4, 242 5, -1, 243 }}, 244 {N: 3, K: 1, Stride: 2, Uplo: blas.Lower, Data: []float64{ 245 -1, 1, 246 2, 3, 247 4, 5, 248 }}, 249 {N: 3, K: 2, Stride: 3, Uplo: blas.Upper, Data: []float64{ 250 1, 2, 3, 251 4, 5, -1, 252 6, -2, -3, 253 }}, 254 {N: 3, K: 2, Stride: 3, Uplo: blas.Lower, Data: []float64{ 255 -2, -1, 1, 256 -3, 2, 4, 257 3, 5, 6, 258 }}, 259 260 {N: 3, K: 0, Stride: 5, Uplo: blas.Upper, Data: []float64{ 261 1, 0, 0, 0, 0, 262 2, 0, 0, 0, 0, 263 3, 0, 0, 0, 0, 264 }}, 265 {N: 3, K: 0, Stride: 5, Uplo: blas.Lower, Data: []float64{ 266 1, 0, 0, 0, 0, 267 2, 0, 0, 0, 0, 268 3, 0, 0, 0, 0, 269 }}, 270 {N: 3, K: 1, Stride: 5, Uplo: blas.Upper, Data: []float64{ 271 1, 2, 0, 0, 0, 272 3, 4, 0, 0, 0, 273 5, -1, 0, 0, 0, 274 }}, 275 {N: 3, K: 1, Stride: 5, Uplo: blas.Lower, Data: []float64{ 276 -1, 1, 0, 0, 0, 277 2, 3, 0, 0, 0, 278 4, 5, 0, 0, 0, 279 }}, 280 {N: 3, K: 2, Stride: 5, Uplo: blas.Upper, Data: []float64{ 281 1, 2, 3, 0, 0, 282 4, 5, -1, 0, 0, 283 6, -2, -3, 0, 0, 284 }}, 285 {N: 3, K: 2, Stride: 5, Uplo: blas.Lower, Data: []float64{ 286 -2, -1, 1, 0, 0, 287 -3, 2, 4, 0, 0, 288 3, 5, 6, 0, 0, 289 }}, 290 } 291 292 func TestConvertSymBand(t *testing.T) { 293 for _, test := range symmetricBandTests { 294 colmajor := newSymmetricBandColsFrom(test) 295 if !sameSymmetricBand(colmajor, test) { 296 t.Errorf("unexpected result for row major to col major conversion:\n\tgot: %#v\n\tfrom:%#v", 297 colmajor, test) 298 } 299 rowmajor := newSymmetricBandFrom(colmajor) 300 if !sameSymmetricBand(rowmajor, test) { 301 t.Errorf("unexpected result for col major to row major conversion:\n\tgot: %#v\n\twant:%#v", 302 rowmajor, test) 303 } 304 } 305 }