github.com/gonum/lapack@v0.0.0-20181123203213-e4cdc5a0bff9/testlapack/dlaqp2.go (about) 1 // Copyright ©2017 The gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package testlapack 6 7 import ( 8 "fmt" 9 "math" 10 "testing" 11 12 "github.com/gonum/blas" 13 "github.com/gonum/blas/blas64" 14 ) 15 16 type Dlaqp2er interface { 17 Dlapmter 18 Dlaqp2(m, n, offset int, a []float64, lda int, jpvt []int, tau, vn1, vn2, work []float64) 19 } 20 21 func Dlaqp2Test(t *testing.T, impl Dlaqp2er) { 22 for ti, test := range []struct { 23 m, n, offset int 24 }{ 25 {m: 4, n: 3, offset: 0}, 26 {m: 4, n: 3, offset: 2}, 27 {m: 4, n: 3, offset: 4}, 28 {m: 3, n: 4, offset: 0}, 29 {m: 3, n: 4, offset: 1}, 30 {m: 3, n: 4, offset: 2}, 31 {m: 8, n: 3, offset: 0}, 32 {m: 8, n: 3, offset: 4}, 33 {m: 8, n: 3, offset: 8}, 34 {m: 3, n: 8, offset: 0}, 35 {m: 3, n: 8, offset: 1}, 36 {m: 3, n: 8, offset: 2}, 37 {m: 10, n: 10, offset: 0}, 38 {m: 10, n: 10, offset: 5}, 39 {m: 10, n: 10, offset: 10}, 40 } { 41 m := test.m 42 n := test.n 43 jpiv := make([]int, n) 44 45 for _, extra := range []int{0, 11} { 46 a := zeros(m, n, n+extra) 47 c := 1 48 for i := 0; i < m; i++ { 49 for j := 0; j < n; j++ { 50 a.Data[i*a.Stride+j] = float64(c) 51 c++ 52 } 53 } 54 aCopy := cloneGeneral(a) 55 for j := range jpiv { 56 jpiv[j] = j 57 } 58 59 tau := make([]float64, n) 60 vn1 := columnNorms(m, n, a.Data, a.Stride) 61 vn2 := columnNorms(m, n, a.Data, a.Stride) 62 work := make([]float64, n) 63 64 impl.Dlaqp2(m, n, test.offset, a.Data, a.Stride, jpiv, tau, vn1, vn2, work) 65 66 prefix := fmt.Sprintf("Case %v (offset=%t,m=%v,n=%v,extra=%v)", ti, test.offset, m, n, extra) 67 if !generalOutsideAllNaN(a) { 68 t.Errorf("%v: out-of-range write to A", prefix) 69 } 70 71 if test.offset == m { 72 continue 73 } 74 75 mo := m - test.offset 76 q := constructQ("QR", mo, n, a.Data[test.offset*a.Stride:], a.Stride, tau) 77 // Check that q is orthonormal 78 for i := 0; i < mo; i++ { 79 nrm := blas64.Nrm2(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]}) 80 if math.Abs(nrm-1) > 1e-13 { 81 t.Errorf("Case %v, q not normal", ti) 82 } 83 for j := 0; j < i; j++ { 84 dot := blas64.Dot(mo, blas64.Vector{Inc: 1, Data: q.Data[i*mo:]}, blas64.Vector{Inc: 1, Data: q.Data[j*mo:]}) 85 if math.Abs(dot) > 1e-14 { 86 t.Errorf("Case %v, q not orthogonal", ti) 87 } 88 } 89 } 90 91 // Check that A * P = Q * R 92 r := blas64.General{ 93 Rows: mo, 94 Cols: n, 95 Stride: n, 96 Data: make([]float64, mo*n), 97 } 98 for i := 0; i < mo; i++ { 99 for j := i; j < n; j++ { 100 r.Data[i*n+j] = a.Data[(test.offset+i)*a.Stride+j] 101 } 102 } 103 got := nanGeneral(mo, n, n) 104 blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, r, 0, got) 105 106 want := aCopy 107 impl.Dlapmt(true, want.Rows, want.Cols, want.Data, want.Stride, jpiv) 108 want.Rows = mo 109 want.Data = want.Data[test.offset*want.Stride:] 110 if !equalApproxGeneral(got, want, 1e-12) { 111 t.Errorf("Case %v, Q*R != A*P\nQ*R=%v\nA*P=%v", ti, got, want) 112 } 113 } 114 } 115 }