github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/lapack/gonum/dlags2.go (about)

     1  // Copyright ©2017 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import "math"
     8  
     9  // Dlags2 computes 2-by-2 orthogonal matrices U, V and Q with the
    10  // triangles of A and B specified by upper.
    11  //
    12  // If upper is true
    13  //
    14  //  Uᵀ*A*Q = Uᵀ*[ a1 a2 ]*Q = [ x  0 ]
    15  //              [ 0  a3 ]     [ x  x ]
    16  // and
    17  //  Vᵀ*B*Q = Vᵀ*[ b1 b2 ]*Q = [ x  0 ]
    18  //              [ 0  b3 ]     [ x  x ]
    19  //
    20  // otherwise
    21  //
    22  //  Uᵀ*A*Q = Uᵀ*[ a1 0  ]*Q = [ x  x ]
    23  //              [ a2 a3 ]     [ 0  x ]
    24  // and
    25  //  Vᵀ*B*Q = Vᵀ*[ b1 0  ]*Q = [ x  x ]
    26  //              [ b2 b3 ]     [ 0  x ].
    27  //
    28  // The rows of the transformed A and B are parallel, where
    29  //
    30  //  U = [  csu  snu ], V = [  csv snv ], Q = [  csq   snq ]
    31  //      [ -snu  csu ]      [ -snv csv ]      [ -snq   csq ]
    32  //
    33  // Dlags2 is an internal routine. It is exported for testing purposes.
    34  func (impl Implementation) Dlags2(upper bool, a1, a2, a3, b1, b2, b3 float64) (csu, snu, csv, snv, csq, snq float64) {
    35  	if upper {
    36  		// Input matrices A and B are upper triangular matrices.
    37  		//
    38  		// Form matrix C = A*adj(B) = [ a b ]
    39  		//                            [ 0 d ]
    40  		a := a1 * b3
    41  		d := a3 * b1
    42  		b := a2*b1 - a1*b2
    43  
    44  		// The SVD of real 2-by-2 triangular C.
    45  		//
    46  		//  [ csl -snl ]*[ a b ]*[  csr  snr ] = [ r 0 ]
    47  		//  [ snl  csl ] [ 0 d ] [ -snr  csr ]   [ 0 t ]
    48  		_, _, snr, csr, snl, csl := impl.Dlasv2(a, b, d)
    49  
    50  		if math.Abs(csl) >= math.Abs(snl) || math.Abs(csr) >= math.Abs(snr) {
    51  			// Compute the [0, 0] and [0, 1] elements of Uᵀ*A and Vᵀ*B,
    52  			// and [0, 1] element of |U|ᵀ*|A| and |V|ᵀ*|B|.
    53  
    54  			ua11r := csl * a1
    55  			ua12 := csl*a2 + snl*a3
    56  
    57  			vb11r := csr * b1
    58  			vb12 := csr*b2 + snr*b3
    59  
    60  			aua12 := math.Abs(csl)*math.Abs(a2) + math.Abs(snl)*math.Abs(a3)
    61  			avb12 := math.Abs(csr)*math.Abs(b2) + math.Abs(snr)*math.Abs(b3)
    62  
    63  			// Zero [0, 1] elements of Uᵀ*A and Vᵀ*B.
    64  			if math.Abs(ua11r)+math.Abs(ua12) != 0 {
    65  				if aua12/(math.Abs(ua11r)+math.Abs(ua12)) <= avb12/(math.Abs(vb11r)+math.Abs(vb12)) {
    66  					csq, snq, _ = impl.Dlartg(-ua11r, ua12)
    67  				} else {
    68  					csq, snq, _ = impl.Dlartg(-vb11r, vb12)
    69  				}
    70  			} else {
    71  				csq, snq, _ = impl.Dlartg(-vb11r, vb12)
    72  			}
    73  
    74  			csu = csl
    75  			snu = -snl
    76  			csv = csr
    77  			snv = -snr
    78  		} else {
    79  			// Compute the [1, 0] and [1, 1] elements of Uᵀ*A and Vᵀ*B,
    80  			// and [1, 1] element of |U|ᵀ*|A| and |V|ᵀ*|B|.
    81  
    82  			ua21 := -snl * a1
    83  			ua22 := -snl*a2 + csl*a3
    84  
    85  			vb21 := -snr * b1
    86  			vb22 := -snr*b2 + csr*b3
    87  
    88  			aua22 := math.Abs(snl)*math.Abs(a2) + math.Abs(csl)*math.Abs(a3)
    89  			avb22 := math.Abs(snr)*math.Abs(b2) + math.Abs(csr)*math.Abs(b3)
    90  
    91  			// Zero [1, 1] elements of Uᵀ*A and Vᵀ*B, and then swap.
    92  			if math.Abs(ua21)+math.Abs(ua22) != 0 {
    93  				if aua22/(math.Abs(ua21)+math.Abs(ua22)) <= avb22/(math.Abs(vb21)+math.Abs(vb22)) {
    94  					csq, snq, _ = impl.Dlartg(-ua21, ua22)
    95  				} else {
    96  					csq, snq, _ = impl.Dlartg(-vb21, vb22)
    97  				}
    98  			} else {
    99  				csq, snq, _ = impl.Dlartg(-vb21, vb22)
   100  			}
   101  
   102  			csu = snl
   103  			snu = csl
   104  			csv = snr
   105  			snv = csr
   106  		}
   107  	} else {
   108  		// Input matrices A and B are lower triangular matrices
   109  		//
   110  		// Form matrix C = A*adj(B) = [ a 0 ]
   111  		//                            [ c d ]
   112  		a := a1 * b3
   113  		d := a3 * b1
   114  		c := a2*b3 - a3*b2
   115  
   116  		// The SVD of real 2-by-2 triangular C
   117  		//
   118  		// [ csl -snl ]*[ a 0 ]*[  csr  snr ] = [ r 0 ]
   119  		// [ snl  csl ] [ c d ] [ -snr  csr ]   [ 0 t ]
   120  		_, _, snr, csr, snl, csl := impl.Dlasv2(a, c, d)
   121  
   122  		if math.Abs(csr) >= math.Abs(snr) || math.Abs(csl) >= math.Abs(snl) {
   123  			// Compute the [1, 0] and [1, 1] elements of Uᵀ*A and Vᵀ*B,
   124  			// and [1, 0] element of |U|ᵀ*|A| and |V|ᵀ*|B|.
   125  
   126  			ua21 := -snr*a1 + csr*a2
   127  			ua22r := csr * a3
   128  
   129  			vb21 := -snl*b1 + csl*b2
   130  			vb22r := csl * b3
   131  
   132  			aua21 := math.Abs(snr)*math.Abs(a1) + math.Abs(csr)*math.Abs(a2)
   133  			avb21 := math.Abs(snl)*math.Abs(b1) + math.Abs(csl)*math.Abs(b2)
   134  
   135  			// Zero [1, 0] elements of Uᵀ*A and Vᵀ*B.
   136  			if (math.Abs(ua21) + math.Abs(ua22r)) != 0 {
   137  				if aua21/(math.Abs(ua21)+math.Abs(ua22r)) <= avb21/(math.Abs(vb21)+math.Abs(vb22r)) {
   138  					csq, snq, _ = impl.Dlartg(ua22r, ua21)
   139  				} else {
   140  					csq, snq, _ = impl.Dlartg(vb22r, vb21)
   141  				}
   142  			} else {
   143  				csq, snq, _ = impl.Dlartg(vb22r, vb21)
   144  			}
   145  
   146  			csu = csr
   147  			snu = -snr
   148  			csv = csl
   149  			snv = -snl
   150  		} else {
   151  			// Compute the [0, 0] and [0, 1] elements of Uᵀ *A and Vᵀ *B,
   152  			// and [0, 0] element of |U|ᵀ*|A| and |V|ᵀ*|B|.
   153  
   154  			ua11 := csr*a1 + snr*a2
   155  			ua12 := snr * a3
   156  
   157  			vb11 := csl*b1 + snl*b2
   158  			vb12 := snl * b3
   159  
   160  			aua11 := math.Abs(csr)*math.Abs(a1) + math.Abs(snr)*math.Abs(a2)
   161  			avb11 := math.Abs(csl)*math.Abs(b1) + math.Abs(snl)*math.Abs(b2)
   162  
   163  			// Zero [0, 0] elements of Uᵀ*A and Vᵀ*B, and then swap.
   164  			if (math.Abs(ua11) + math.Abs(ua12)) != 0 {
   165  				if aua11/(math.Abs(ua11)+math.Abs(ua12)) <= avb11/(math.Abs(vb11)+math.Abs(vb12)) {
   166  					csq, snq, _ = impl.Dlartg(ua12, ua11)
   167  				} else {
   168  					csq, snq, _ = impl.Dlartg(vb12, vb11)
   169  				}
   170  			} else {
   171  				csq, snq, _ = impl.Dlartg(vb12, vb11)
   172  			}
   173  
   174  			csu = snr
   175  			snu = csr
   176  			csv = snl
   177  			snv = csl
   178  		}
   179  	}
   180  
   181  	return csu, snu, csv, snv, csq, snq
   182  }