gonum.org/v1/gonum@v0.14.0/lapack/gonum/dlasy2.go (about)

     1  // Copyright ©2016 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"math"
     9  
    10  	"gonum.org/v1/gonum/blas/blas64"
    11  )
    12  
    13  // Dlasy2 solves the Sylvester matrix equation where the matrices are of order 1
    14  // or 2. It computes the unknown n1×n2 matrix X so that
    15  //
    16  //	TL*X   + sgn*X*TR  = scale*B  if tranl == false and tranr == false,
    17  //	TLᵀ*X + sgn*X*TR   = scale*B  if tranl == true  and tranr == false,
    18  //	TL*X   + sgn*X*TRᵀ = scale*B  if tranl == false and tranr == true,
    19  //	TLᵀ*X + sgn*X*TRᵀ  = scale*B  if tranl == true  and tranr == true,
    20  //
    21  // where TL is n1×n1, TR is n2×n2, B is n1×n2, and 1 <= n1,n2 <= 2.
    22  //
    23  // isgn must be 1 or -1, and n1 and n2 must be 0, 1, or 2, but these conditions
    24  // are not checked.
    25  //
    26  // Dlasy2 returns three values, a scale factor that is chosen less than or equal
    27  // to 1 to prevent the solution overflowing, the infinity norm of the solution,
    28  // and an indicator of success. If ok is false, TL and TR have eigenvalues that
    29  // are too close, so TL or TR is perturbed to get a non-singular equation.
    30  //
    31  // Dlasy2 is an internal routine. It is exported for testing purposes.
    32  func (impl Implementation) Dlasy2(tranl, tranr bool, isgn, n1, n2 int, tl []float64, ldtl int, tr []float64, ldtr int, b []float64, ldb int, x []float64, ldx int) (scale, xnorm float64, ok bool) {
    33  	// TODO(vladimir-ch): Add input validation checks conditionally skipped
    34  	// using the build tag mechanism.
    35  
    36  	ok = true
    37  	// Quick return if possible.
    38  	if n1 == 0 || n2 == 0 {
    39  		return scale, xnorm, ok
    40  	}
    41  
    42  	// Set constants to control overflow.
    43  	eps := dlamchP
    44  	smlnum := dlamchS / eps
    45  	sgn := float64(isgn)
    46  
    47  	if n1 == 1 && n2 == 1 {
    48  		// 1×1 case: TL11*X + sgn*X*TR11 = B11.
    49  		tau1 := tl[0] + sgn*tr[0]
    50  		bet := math.Abs(tau1)
    51  		if bet <= smlnum {
    52  			tau1 = smlnum
    53  			bet = smlnum
    54  			ok = false
    55  		}
    56  		scale = 1
    57  		gam := math.Abs(b[0])
    58  		if smlnum*gam > bet {
    59  			scale = 1 / gam
    60  		}
    61  		x[0] = b[0] * scale / tau1
    62  		xnorm = math.Abs(x[0])
    63  		return scale, xnorm, ok
    64  	}
    65  
    66  	if n1+n2 == 3 {
    67  		// 1×2 or 2×1 case.
    68  		var (
    69  			smin float64
    70  			tmp  [4]float64 // tmp is used as a 2×2 row-major matrix.
    71  			btmp [2]float64
    72  		)
    73  		if n1 == 1 && n2 == 2 {
    74  			// 1×2 case: TL11*[X11 X12] + sgn*[X11 X12]*op[TR11 TR12] = [B11 B12].
    75  			//                                            [TR21 TR22]
    76  			smin = math.Abs(tl[0])
    77  			smin = math.Max(smin, math.Max(math.Abs(tr[0]), math.Abs(tr[1])))
    78  			smin = math.Max(smin, math.Max(math.Abs(tr[ldtr]), math.Abs(tr[ldtr+1])))
    79  			smin = math.Max(eps*smin, smlnum)
    80  			tmp[0] = tl[0] + sgn*tr[0]
    81  			tmp[3] = tl[0] + sgn*tr[ldtr+1]
    82  			if tranr {
    83  				tmp[1] = sgn * tr[1]
    84  				tmp[2] = sgn * tr[ldtr]
    85  			} else {
    86  				tmp[1] = sgn * tr[ldtr]
    87  				tmp[2] = sgn * tr[1]
    88  			}
    89  			btmp[0] = b[0]
    90  			btmp[1] = b[1]
    91  		} else {
    92  			// 2×1 case: op[TL11 TL12]*[X11] + sgn*[X11]*TR11 = [B11].
    93  			//             [TL21 TL22]*[X21]       [X21]        [B21]
    94  			smin = math.Abs(tr[0])
    95  			smin = math.Max(smin, math.Max(math.Abs(tl[0]), math.Abs(tl[1])))
    96  			smin = math.Max(smin, math.Max(math.Abs(tl[ldtl]), math.Abs(tl[ldtl+1])))
    97  			smin = math.Max(eps*smin, smlnum)
    98  			tmp[0] = tl[0] + sgn*tr[0]
    99  			tmp[3] = tl[ldtl+1] + sgn*tr[0]
   100  			if tranl {
   101  				tmp[1] = tl[ldtl]
   102  				tmp[2] = tl[1]
   103  			} else {
   104  				tmp[1] = tl[1]
   105  				tmp[2] = tl[ldtl]
   106  			}
   107  			btmp[0] = b[0]
   108  			btmp[1] = b[ldb]
   109  		}
   110  
   111  		// Solve 2×2 system using complete pivoting.
   112  		// Set pivots less than smin to smin.
   113  
   114  		bi := blas64.Implementation()
   115  		ipiv := bi.Idamax(len(tmp), tmp[:], 1)
   116  		// Compute the upper triangular matrix [u11 u12].
   117  		//                                     [  0 u22]
   118  		u11 := tmp[ipiv]
   119  		if math.Abs(u11) <= smin {
   120  			ok = false
   121  			u11 = smin
   122  		}
   123  		locu12 := [4]int{1, 0, 3, 2} // Index in tmp of the element on the same row as the pivot.
   124  		u12 := tmp[locu12[ipiv]]
   125  		locl21 := [4]int{2, 3, 0, 1} // Index in tmp of the element on the same column as the pivot.
   126  		l21 := tmp[locl21[ipiv]] / u11
   127  		locu22 := [4]int{3, 2, 1, 0} // Index in tmp of the remaining element.
   128  		u22 := tmp[locu22[ipiv]] - l21*u12
   129  		if math.Abs(u22) <= smin {
   130  			ok = false
   131  			u22 = smin
   132  		}
   133  		if ipiv&0x2 != 0 { // true for ipiv equal to 2 and 3.
   134  			// The pivot was in the second row, swap the elements of
   135  			// the right-hand side.
   136  			btmp[0], btmp[1] = btmp[1], btmp[0]-l21*btmp[1]
   137  		} else {
   138  			btmp[1] -= l21 * btmp[0]
   139  		}
   140  		scale = 1
   141  		if 2*smlnum*math.Abs(btmp[1]) > math.Abs(u22) || 2*smlnum*math.Abs(btmp[0]) > math.Abs(u11) {
   142  			scale = 0.5 / math.Max(math.Abs(btmp[0]), math.Abs(btmp[1]))
   143  			btmp[0] *= scale
   144  			btmp[1] *= scale
   145  		}
   146  		// Solve the system [u11 u12] [x21] = [ btmp[0] ].
   147  		//                  [  0 u22] [x22]   [ btmp[1] ]
   148  		x22 := btmp[1] / u22
   149  		x21 := btmp[0]/u11 - (u12/u11)*x22
   150  		if ipiv&0x1 != 0 { // true for ipiv equal to 1 and 3.
   151  			// The pivot was in the second column, swap the elements
   152  			// of the solution.
   153  			x21, x22 = x22, x21
   154  		}
   155  		x[0] = x21
   156  		if n1 == 1 {
   157  			x[1] = x22
   158  			xnorm = math.Abs(x[0]) + math.Abs(x[1])
   159  		} else {
   160  			x[ldx] = x22
   161  			xnorm = math.Max(math.Abs(x[0]), math.Abs(x[ldx]))
   162  		}
   163  		return scale, xnorm, ok
   164  	}
   165  
   166  	// 2×2 case: op[TL11 TL12]*[X11 X12] + SGN*[X11 X12]*op[TR11 TR12] = [B11 B12].
   167  	//             [TL21 TL22] [X21 X22]       [X21 X22]   [TR21 TR22]   [B21 B22]
   168  	//
   169  	// Solve equivalent 4×4 system using complete pivoting.
   170  	// Set pivots less than smin to smin.
   171  
   172  	smin := math.Max(math.Abs(tr[0]), math.Abs(tr[1]))
   173  	smin = math.Max(smin, math.Max(math.Abs(tr[ldtr]), math.Abs(tr[ldtr+1])))
   174  	smin = math.Max(smin, math.Max(math.Abs(tl[0]), math.Abs(tl[1])))
   175  	smin = math.Max(smin, math.Max(math.Abs(tl[ldtl]), math.Abs(tl[ldtl+1])))
   176  	smin = math.Max(eps*smin, smlnum)
   177  
   178  	var t [4][4]float64
   179  	t[0][0] = tl[0] + sgn*tr[0]
   180  	t[1][1] = tl[0] + sgn*tr[ldtr+1]
   181  	t[2][2] = tl[ldtl+1] + sgn*tr[0]
   182  	t[3][3] = tl[ldtl+1] + sgn*tr[ldtr+1]
   183  	if tranl {
   184  		t[0][2] = tl[ldtl]
   185  		t[1][3] = tl[ldtl]
   186  		t[2][0] = tl[1]
   187  		t[3][1] = tl[1]
   188  	} else {
   189  		t[0][2] = tl[1]
   190  		t[1][3] = tl[1]
   191  		t[2][0] = tl[ldtl]
   192  		t[3][1] = tl[ldtl]
   193  	}
   194  	if tranr {
   195  		t[0][1] = sgn * tr[1]
   196  		t[1][0] = sgn * tr[ldtr]
   197  		t[2][3] = sgn * tr[1]
   198  		t[3][2] = sgn * tr[ldtr]
   199  	} else {
   200  		t[0][1] = sgn * tr[ldtr]
   201  		t[1][0] = sgn * tr[1]
   202  		t[2][3] = sgn * tr[ldtr]
   203  		t[3][2] = sgn * tr[1]
   204  	}
   205  
   206  	var btmp [4]float64
   207  	btmp[0] = b[0]
   208  	btmp[1] = b[1]
   209  	btmp[2] = b[ldb]
   210  	btmp[3] = b[ldb+1]
   211  
   212  	// Perform elimination.
   213  	var jpiv [4]int // jpiv records any column swaps for pivoting.
   214  	for i := 0; i < 3; i++ {
   215  		var (
   216  			xmax       float64
   217  			ipsv, jpsv int
   218  		)
   219  		for ip := i; ip < 4; ip++ {
   220  			for jp := i; jp < 4; jp++ {
   221  				if math.Abs(t[ip][jp]) >= xmax {
   222  					xmax = math.Abs(t[ip][jp])
   223  					ipsv = ip
   224  					jpsv = jp
   225  				}
   226  			}
   227  		}
   228  		if ipsv != i {
   229  			// The pivot is not in the top row of the unprocessed
   230  			// block, swap rows ipsv and i of t and btmp.
   231  			t[ipsv], t[i] = t[i], t[ipsv]
   232  			btmp[ipsv], btmp[i] = btmp[i], btmp[ipsv]
   233  		}
   234  		if jpsv != i {
   235  			// The pivot is not in the left column of the
   236  			// unprocessed block, swap columns jpsv and i of t.
   237  			for k := 0; k < 4; k++ {
   238  				t[k][jpsv], t[k][i] = t[k][i], t[k][jpsv]
   239  			}
   240  		}
   241  		jpiv[i] = jpsv
   242  		if math.Abs(t[i][i]) < smin {
   243  			ok = false
   244  			t[i][i] = smin
   245  		}
   246  		for k := i + 1; k < 4; k++ {
   247  			t[k][i] /= t[i][i]
   248  			btmp[k] -= t[k][i] * btmp[i]
   249  			for j := i + 1; j < 4; j++ {
   250  				t[k][j] -= t[k][i] * t[i][j]
   251  			}
   252  		}
   253  	}
   254  	if math.Abs(t[3][3]) < smin {
   255  		ok = false
   256  		t[3][3] = smin
   257  	}
   258  	scale = 1
   259  	if 8*smlnum*math.Abs(btmp[0]) > math.Abs(t[0][0]) ||
   260  		8*smlnum*math.Abs(btmp[1]) > math.Abs(t[1][1]) ||
   261  		8*smlnum*math.Abs(btmp[2]) > math.Abs(t[2][2]) ||
   262  		8*smlnum*math.Abs(btmp[3]) > math.Abs(t[3][3]) {
   263  
   264  		maxbtmp := math.Max(math.Abs(btmp[0]), math.Abs(btmp[1]))
   265  		maxbtmp = math.Max(maxbtmp, math.Max(math.Abs(btmp[2]), math.Abs(btmp[3])))
   266  		scale = (1.0 / 8.0) / maxbtmp
   267  		btmp[0] *= scale
   268  		btmp[1] *= scale
   269  		btmp[2] *= scale
   270  		btmp[3] *= scale
   271  	}
   272  	// Compute the solution of the upper triangular system t * tmp = btmp.
   273  	var tmp [4]float64
   274  	for i := 3; i >= 0; i-- {
   275  		temp := 1 / t[i][i]
   276  		tmp[i] = btmp[i] * temp
   277  		for j := i + 1; j < 4; j++ {
   278  			tmp[i] -= temp * t[i][j] * tmp[j]
   279  		}
   280  	}
   281  	for i := 2; i >= 0; i-- {
   282  		if jpiv[i] != i {
   283  			tmp[i], tmp[jpiv[i]] = tmp[jpiv[i]], tmp[i]
   284  		}
   285  	}
   286  	x[0] = tmp[0]
   287  	x[1] = tmp[1]
   288  	x[ldx] = tmp[2]
   289  	x[ldx+1] = tmp[3]
   290  	xnorm = math.Max(math.Abs(tmp[0])+math.Abs(tmp[1]), math.Abs(tmp[2])+math.Abs(tmp[3]))
   291  	return scale, xnorm, ok
   292  }