github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/gonum/level3float64.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gonum
     6  
     7  import (
     8  	"github.com/jingcheng-WU/gonum/blas"
     9  	"github.com/jingcheng-WU/gonum/internal/asm/f64"
    10  )
    11  
    12  var _ blas.Float64Level3 = Implementation{}
    13  
    14  // Dtrsm solves one of the matrix equations
    15  //  A * X = alpha * B   if tA == blas.NoTrans and side == blas.Left
    16  //  Aᵀ * X = alpha * B  if tA == blas.Trans or blas.ConjTrans, and side == blas.Left
    17  //  X * A = alpha * B   if tA == blas.NoTrans and side == blas.Right
    18  //  X * Aᵀ = alpha * B  if tA == blas.Trans or blas.ConjTrans, and side == blas.Right
    19  // where A is an n×n or m×m triangular matrix, X and B are m×n matrices, and alpha is a
    20  // scalar.
    21  //
    22  // At entry to the function, X contains the values of B, and the result is
    23  // stored in-place into X.
    24  //
    25  // No check is made that A is invertible.
    26  func (Implementation) Dtrsm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int) {
    27  	if s != blas.Left && s != blas.Right {
    28  		panic(badSide)
    29  	}
    30  	if ul != blas.Lower && ul != blas.Upper {
    31  		panic(badUplo)
    32  	}
    33  	if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
    34  		panic(badTranspose)
    35  	}
    36  	if d != blas.NonUnit && d != blas.Unit {
    37  		panic(badDiag)
    38  	}
    39  	if m < 0 {
    40  		panic(mLT0)
    41  	}
    42  	if n < 0 {
    43  		panic(nLT0)
    44  	}
    45  	k := n
    46  	if s == blas.Left {
    47  		k = m
    48  	}
    49  	if lda < max(1, k) {
    50  		panic(badLdA)
    51  	}
    52  	if ldb < max(1, n) {
    53  		panic(badLdB)
    54  	}
    55  
    56  	// Quick return if possible.
    57  	if m == 0 || n == 0 {
    58  		return
    59  	}
    60  
    61  	// For zero matrix size the following slice length checks are trivially satisfied.
    62  	if len(a) < lda*(k-1)+k {
    63  		panic(shortA)
    64  	}
    65  	if len(b) < ldb*(m-1)+n {
    66  		panic(shortB)
    67  	}
    68  
    69  	if alpha == 0 {
    70  		for i := 0; i < m; i++ {
    71  			btmp := b[i*ldb : i*ldb+n]
    72  			for j := range btmp {
    73  				btmp[j] = 0
    74  			}
    75  		}
    76  		return
    77  	}
    78  	nonUnit := d == blas.NonUnit
    79  	if s == blas.Left {
    80  		if tA == blas.NoTrans {
    81  			if ul == blas.Upper {
    82  				for i := m - 1; i >= 0; i-- {
    83  					btmp := b[i*ldb : i*ldb+n]
    84  					if alpha != 1 {
    85  						f64.ScalUnitary(alpha, btmp)
    86  					}
    87  					for ka, va := range a[i*lda+i+1 : i*lda+m] {
    88  						if va != 0 {
    89  							k := ka + i + 1
    90  							f64.AxpyUnitary(-va, b[k*ldb:k*ldb+n], btmp)
    91  						}
    92  					}
    93  					if nonUnit {
    94  						tmp := 1 / a[i*lda+i]
    95  						f64.ScalUnitary(tmp, btmp)
    96  					}
    97  				}
    98  				return
    99  			}
   100  			for i := 0; i < m; i++ {
   101  				btmp := b[i*ldb : i*ldb+n]
   102  				if alpha != 1 {
   103  					f64.ScalUnitary(alpha, btmp)
   104  				}
   105  				for k, va := range a[i*lda : i*lda+i] {
   106  					if va != 0 {
   107  						f64.AxpyUnitary(-va, b[k*ldb:k*ldb+n], btmp)
   108  					}
   109  				}
   110  				if nonUnit {
   111  					tmp := 1 / a[i*lda+i]
   112  					f64.ScalUnitary(tmp, btmp)
   113  				}
   114  			}
   115  			return
   116  		}
   117  		// Cases where a is transposed
   118  		if ul == blas.Upper {
   119  			for k := 0; k < m; k++ {
   120  				btmpk := b[k*ldb : k*ldb+n]
   121  				if nonUnit {
   122  					tmp := 1 / a[k*lda+k]
   123  					f64.ScalUnitary(tmp, btmpk)
   124  				}
   125  				for ia, va := range a[k*lda+k+1 : k*lda+m] {
   126  					if va != 0 {
   127  						i := ia + k + 1
   128  						f64.AxpyUnitary(-va, btmpk, b[i*ldb:i*ldb+n])
   129  					}
   130  				}
   131  				if alpha != 1 {
   132  					f64.ScalUnitary(alpha, btmpk)
   133  				}
   134  			}
   135  			return
   136  		}
   137  		for k := m - 1; k >= 0; k-- {
   138  			btmpk := b[k*ldb : k*ldb+n]
   139  			if nonUnit {
   140  				tmp := 1 / a[k*lda+k]
   141  				f64.ScalUnitary(tmp, btmpk)
   142  			}
   143  			for i, va := range a[k*lda : k*lda+k] {
   144  				if va != 0 {
   145  					f64.AxpyUnitary(-va, btmpk, b[i*ldb:i*ldb+n])
   146  				}
   147  			}
   148  			if alpha != 1 {
   149  				f64.ScalUnitary(alpha, btmpk)
   150  			}
   151  		}
   152  		return
   153  	}
   154  	// Cases where a is to the right of X.
   155  	if tA == blas.NoTrans {
   156  		if ul == blas.Upper {
   157  			for i := 0; i < m; i++ {
   158  				btmp := b[i*ldb : i*ldb+n]
   159  				if alpha != 1 {
   160  					f64.ScalUnitary(alpha, btmp)
   161  				}
   162  				for k, vb := range btmp {
   163  					if vb == 0 {
   164  						continue
   165  					}
   166  					if nonUnit {
   167  						btmp[k] /= a[k*lda+k]
   168  					}
   169  					f64.AxpyUnitary(-btmp[k], a[k*lda+k+1:k*lda+n], btmp[k+1:n])
   170  				}
   171  			}
   172  			return
   173  		}
   174  		for i := 0; i < m; i++ {
   175  			btmp := b[i*ldb : i*ldb+n]
   176  			if alpha != 1 {
   177  				f64.ScalUnitary(alpha, btmp)
   178  			}
   179  			for k := n - 1; k >= 0; k-- {
   180  				if btmp[k] == 0 {
   181  					continue
   182  				}
   183  				if nonUnit {
   184  					btmp[k] /= a[k*lda+k]
   185  				}
   186  				f64.AxpyUnitary(-btmp[k], a[k*lda:k*lda+k], btmp[:k])
   187  			}
   188  		}
   189  		return
   190  	}
   191  	// Cases where a is transposed.
   192  	if ul == blas.Upper {
   193  		for i := 0; i < m; i++ {
   194  			btmp := b[i*ldb : i*ldb+n]
   195  			for j := n - 1; j >= 0; j-- {
   196  				tmp := alpha*btmp[j] - f64.DotUnitary(a[j*lda+j+1:j*lda+n], btmp[j+1:])
   197  				if nonUnit {
   198  					tmp /= a[j*lda+j]
   199  				}
   200  				btmp[j] = tmp
   201  			}
   202  		}
   203  		return
   204  	}
   205  	for i := 0; i < m; i++ {
   206  		btmp := b[i*ldb : i*ldb+n]
   207  		for j := 0; j < n; j++ {
   208  			tmp := alpha*btmp[j] - f64.DotUnitary(a[j*lda:j*lda+j], btmp[:j])
   209  			if nonUnit {
   210  				tmp /= a[j*lda+j]
   211  			}
   212  			btmp[j] = tmp
   213  		}
   214  	}
   215  }
   216  
   217  // Dsymm performs one of the matrix-matrix operations
   218  //  C = alpha * A * B + beta * C  if side == blas.Left
   219  //  C = alpha * B * A + beta * C  if side == blas.Right
   220  // where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and alpha
   221  // is a scalar.
   222  func (Implementation) Dsymm(s blas.Side, ul blas.Uplo, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
   223  	if s != blas.Right && s != blas.Left {
   224  		panic(badSide)
   225  	}
   226  	if ul != blas.Lower && ul != blas.Upper {
   227  		panic(badUplo)
   228  	}
   229  	if m < 0 {
   230  		panic(mLT0)
   231  	}
   232  	if n < 0 {
   233  		panic(nLT0)
   234  	}
   235  	k := n
   236  	if s == blas.Left {
   237  		k = m
   238  	}
   239  	if lda < max(1, k) {
   240  		panic(badLdA)
   241  	}
   242  	if ldb < max(1, n) {
   243  		panic(badLdB)
   244  	}
   245  	if ldc < max(1, n) {
   246  		panic(badLdC)
   247  	}
   248  
   249  	// Quick return if possible.
   250  	if m == 0 || n == 0 {
   251  		return
   252  	}
   253  
   254  	// For zero matrix size the following slice length checks are trivially satisfied.
   255  	if len(a) < lda*(k-1)+k {
   256  		panic(shortA)
   257  	}
   258  	if len(b) < ldb*(m-1)+n {
   259  		panic(shortB)
   260  	}
   261  	if len(c) < ldc*(m-1)+n {
   262  		panic(shortC)
   263  	}
   264  
   265  	// Quick return if possible.
   266  	if alpha == 0 && beta == 1 {
   267  		return
   268  	}
   269  
   270  	if beta == 0 {
   271  		for i := 0; i < m; i++ {
   272  			ctmp := c[i*ldc : i*ldc+n]
   273  			for j := range ctmp {
   274  				ctmp[j] = 0
   275  			}
   276  		}
   277  	}
   278  
   279  	if alpha == 0 {
   280  		if beta != 0 {
   281  			for i := 0; i < m; i++ {
   282  				ctmp := c[i*ldc : i*ldc+n]
   283  				for j := 0; j < n; j++ {
   284  					ctmp[j] *= beta
   285  				}
   286  			}
   287  		}
   288  		return
   289  	}
   290  
   291  	isUpper := ul == blas.Upper
   292  	if s == blas.Left {
   293  		for i := 0; i < m; i++ {
   294  			atmp := alpha * a[i*lda+i]
   295  			btmp := b[i*ldb : i*ldb+n]
   296  			ctmp := c[i*ldc : i*ldc+n]
   297  			for j, v := range btmp {
   298  				ctmp[j] *= beta
   299  				ctmp[j] += atmp * v
   300  			}
   301  
   302  			for k := 0; k < i; k++ {
   303  				var atmp float64
   304  				if isUpper {
   305  					atmp = a[k*lda+i]
   306  				} else {
   307  					atmp = a[i*lda+k]
   308  				}
   309  				atmp *= alpha
   310  				f64.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ctmp)
   311  			}
   312  			for k := i + 1; k < m; k++ {
   313  				var atmp float64
   314  				if isUpper {
   315  					atmp = a[i*lda+k]
   316  				} else {
   317  					atmp = a[k*lda+i]
   318  				}
   319  				atmp *= alpha
   320  				f64.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ctmp)
   321  			}
   322  		}
   323  		return
   324  	}
   325  	if isUpper {
   326  		for i := 0; i < m; i++ {
   327  			for j := n - 1; j >= 0; j-- {
   328  				tmp := alpha * b[i*ldb+j]
   329  				var tmp2 float64
   330  				atmp := a[j*lda+j+1 : j*lda+n]
   331  				btmp := b[i*ldb+j+1 : i*ldb+n]
   332  				ctmp := c[i*ldc+j+1 : i*ldc+n]
   333  				for k, v := range atmp {
   334  					ctmp[k] += tmp * v
   335  					tmp2 += btmp[k] * v
   336  				}
   337  				c[i*ldc+j] *= beta
   338  				c[i*ldc+j] += tmp*a[j*lda+j] + alpha*tmp2
   339  			}
   340  		}
   341  		return
   342  	}
   343  	for i := 0; i < m; i++ {
   344  		for j := 0; j < n; j++ {
   345  			tmp := alpha * b[i*ldb+j]
   346  			var tmp2 float64
   347  			atmp := a[j*lda : j*lda+j]
   348  			btmp := b[i*ldb : i*ldb+j]
   349  			ctmp := c[i*ldc : i*ldc+j]
   350  			for k, v := range atmp {
   351  				ctmp[k] += tmp * v
   352  				tmp2 += btmp[k] * v
   353  			}
   354  			c[i*ldc+j] *= beta
   355  			c[i*ldc+j] += tmp*a[j*lda+j] + alpha*tmp2
   356  		}
   357  	}
   358  }
   359  
   360  // Dsyrk performs one of the symmetric rank-k operations
   361  //  C = alpha * A * Aᵀ + beta * C  if tA == blas.NoTrans
   362  //  C = alpha * Aᵀ * A + beta * C  if tA == blas.Trans or tA == blas.ConjTrans
   363  // where A is an n×k or k×n matrix, C is an n×n symmetric matrix, and alpha and
   364  // beta are scalars.
   365  func (Implementation) Dsyrk(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float64, a []float64, lda int, beta float64, c []float64, ldc int) {
   366  	if ul != blas.Lower && ul != blas.Upper {
   367  		panic(badUplo)
   368  	}
   369  	if tA != blas.Trans && tA != blas.NoTrans && tA != blas.ConjTrans {
   370  		panic(badTranspose)
   371  	}
   372  	if n < 0 {
   373  		panic(nLT0)
   374  	}
   375  	if k < 0 {
   376  		panic(kLT0)
   377  	}
   378  	row, col := k, n
   379  	if tA == blas.NoTrans {
   380  		row, col = n, k
   381  	}
   382  	if lda < max(1, col) {
   383  		panic(badLdA)
   384  	}
   385  	if ldc < max(1, n) {
   386  		panic(badLdC)
   387  	}
   388  
   389  	// Quick return if possible.
   390  	if n == 0 {
   391  		return
   392  	}
   393  
   394  	// For zero matrix size the following slice length checks are trivially satisfied.
   395  	if len(a) < lda*(row-1)+col {
   396  		panic(shortA)
   397  	}
   398  	if len(c) < ldc*(n-1)+n {
   399  		panic(shortC)
   400  	}
   401  
   402  	if alpha == 0 {
   403  		if beta == 0 {
   404  			if ul == blas.Upper {
   405  				for i := 0; i < n; i++ {
   406  					ctmp := c[i*ldc+i : i*ldc+n]
   407  					for j := range ctmp {
   408  						ctmp[j] = 0
   409  					}
   410  				}
   411  				return
   412  			}
   413  			for i := 0; i < n; i++ {
   414  				ctmp := c[i*ldc : i*ldc+i+1]
   415  				for j := range ctmp {
   416  					ctmp[j] = 0
   417  				}
   418  			}
   419  			return
   420  		}
   421  		if ul == blas.Upper {
   422  			for i := 0; i < n; i++ {
   423  				ctmp := c[i*ldc+i : i*ldc+n]
   424  				for j := range ctmp {
   425  					ctmp[j] *= beta
   426  				}
   427  			}
   428  			return
   429  		}
   430  		for i := 0; i < n; i++ {
   431  			ctmp := c[i*ldc : i*ldc+i+1]
   432  			for j := range ctmp {
   433  				ctmp[j] *= beta
   434  			}
   435  		}
   436  		return
   437  	}
   438  	if tA == blas.NoTrans {
   439  		if ul == blas.Upper {
   440  			for i := 0; i < n; i++ {
   441  				ctmp := c[i*ldc+i : i*ldc+n]
   442  				atmp := a[i*lda : i*lda+k]
   443  				if beta == 0 {
   444  					for jc := range ctmp {
   445  						j := jc + i
   446  						ctmp[jc] = alpha * f64.DotUnitary(atmp, a[j*lda:j*lda+k])
   447  					}
   448  				} else {
   449  					for jc, vc := range ctmp {
   450  						j := jc + i
   451  						ctmp[jc] = vc*beta + alpha*f64.DotUnitary(atmp, a[j*lda:j*lda+k])
   452  					}
   453  				}
   454  			}
   455  			return
   456  		}
   457  		for i := 0; i < n; i++ {
   458  			ctmp := c[i*ldc : i*ldc+i+1]
   459  			atmp := a[i*lda : i*lda+k]
   460  			if beta == 0 {
   461  				for j := range ctmp {
   462  					ctmp[j] = alpha * f64.DotUnitary(a[j*lda:j*lda+k], atmp)
   463  				}
   464  			} else {
   465  				for j, vc := range ctmp {
   466  					ctmp[j] = vc*beta + alpha*f64.DotUnitary(a[j*lda:j*lda+k], atmp)
   467  				}
   468  			}
   469  		}
   470  		return
   471  	}
   472  	// Cases where a is transposed.
   473  	if ul == blas.Upper {
   474  		for i := 0; i < n; i++ {
   475  			ctmp := c[i*ldc+i : i*ldc+n]
   476  			if beta == 0 {
   477  				for j := range ctmp {
   478  					ctmp[j] = 0
   479  				}
   480  			} else if beta != 1 {
   481  				for j := range ctmp {
   482  					ctmp[j] *= beta
   483  				}
   484  			}
   485  			for l := 0; l < k; l++ {
   486  				tmp := alpha * a[l*lda+i]
   487  				if tmp != 0 {
   488  					f64.AxpyUnitary(tmp, a[l*lda+i:l*lda+n], ctmp)
   489  				}
   490  			}
   491  		}
   492  		return
   493  	}
   494  	for i := 0; i < n; i++ {
   495  		ctmp := c[i*ldc : i*ldc+i+1]
   496  		if beta != 1 {
   497  			for j := range ctmp {
   498  				ctmp[j] *= beta
   499  			}
   500  		}
   501  		for l := 0; l < k; l++ {
   502  			tmp := alpha * a[l*lda+i]
   503  			if tmp != 0 {
   504  				f64.AxpyUnitary(tmp, a[l*lda:l*lda+i+1], ctmp)
   505  			}
   506  		}
   507  	}
   508  }
   509  
   510  // Dsyr2k performs one of the symmetric rank 2k operations
   511  //  C = alpha * A * Bᵀ + alpha * B * Aᵀ + beta * C  if tA == blas.NoTrans
   512  //  C = alpha * Aᵀ * B + alpha * Bᵀ * A + beta * C  if tA == blas.Trans or tA == blas.ConjTrans
   513  // where A and B are n×k or k×n matrices, C is an n×n symmetric matrix, and
   514  // alpha and beta are scalars.
   515  func (Implementation) Dsyr2k(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
   516  	if ul != blas.Lower && ul != blas.Upper {
   517  		panic(badUplo)
   518  	}
   519  	if tA != blas.Trans && tA != blas.NoTrans && tA != blas.ConjTrans {
   520  		panic(badTranspose)
   521  	}
   522  	if n < 0 {
   523  		panic(nLT0)
   524  	}
   525  	if k < 0 {
   526  		panic(kLT0)
   527  	}
   528  	row, col := k, n
   529  	if tA == blas.NoTrans {
   530  		row, col = n, k
   531  	}
   532  	if lda < max(1, col) {
   533  		panic(badLdA)
   534  	}
   535  	if ldb < max(1, col) {
   536  		panic(badLdB)
   537  	}
   538  	if ldc < max(1, n) {
   539  		panic(badLdC)
   540  	}
   541  
   542  	// Quick return if possible.
   543  	if n == 0 {
   544  		return
   545  	}
   546  
   547  	// For zero matrix size the following slice length checks are trivially satisfied.
   548  	if len(a) < lda*(row-1)+col {
   549  		panic(shortA)
   550  	}
   551  	if len(b) < ldb*(row-1)+col {
   552  		panic(shortB)
   553  	}
   554  	if len(c) < ldc*(n-1)+n {
   555  		panic(shortC)
   556  	}
   557  
   558  	if alpha == 0 {
   559  		if beta == 0 {
   560  			if ul == blas.Upper {
   561  				for i := 0; i < n; i++ {
   562  					ctmp := c[i*ldc+i : i*ldc+n]
   563  					for j := range ctmp {
   564  						ctmp[j] = 0
   565  					}
   566  				}
   567  				return
   568  			}
   569  			for i := 0; i < n; i++ {
   570  				ctmp := c[i*ldc : i*ldc+i+1]
   571  				for j := range ctmp {
   572  					ctmp[j] = 0
   573  				}
   574  			}
   575  			return
   576  		}
   577  		if ul == blas.Upper {
   578  			for i := 0; i < n; i++ {
   579  				ctmp := c[i*ldc+i : i*ldc+n]
   580  				for j := range ctmp {
   581  					ctmp[j] *= beta
   582  				}
   583  			}
   584  			return
   585  		}
   586  		for i := 0; i < n; i++ {
   587  			ctmp := c[i*ldc : i*ldc+i+1]
   588  			for j := range ctmp {
   589  				ctmp[j] *= beta
   590  			}
   591  		}
   592  		return
   593  	}
   594  	if tA == blas.NoTrans {
   595  		if ul == blas.Upper {
   596  			for i := 0; i < n; i++ {
   597  				atmp := a[i*lda : i*lda+k]
   598  				btmp := b[i*ldb : i*ldb+k]
   599  				ctmp := c[i*ldc+i : i*ldc+n]
   600  				if beta == 0 {
   601  					for jc := range ctmp {
   602  						j := i + jc
   603  						var tmp1, tmp2 float64
   604  						binner := b[j*ldb : j*ldb+k]
   605  						for l, v := range a[j*lda : j*lda+k] {
   606  							tmp1 += v * btmp[l]
   607  							tmp2 += atmp[l] * binner[l]
   608  						}
   609  						ctmp[jc] = alpha * (tmp1 + tmp2)
   610  					}
   611  				} else {
   612  					for jc := range ctmp {
   613  						j := i + jc
   614  						var tmp1, tmp2 float64
   615  						binner := b[j*ldb : j*ldb+k]
   616  						for l, v := range a[j*lda : j*lda+k] {
   617  							tmp1 += v * btmp[l]
   618  							tmp2 += atmp[l] * binner[l]
   619  						}
   620  						ctmp[jc] *= beta
   621  						ctmp[jc] += alpha * (tmp1 + tmp2)
   622  					}
   623  				}
   624  			}
   625  			return
   626  		}
   627  		for i := 0; i < n; i++ {
   628  			atmp := a[i*lda : i*lda+k]
   629  			btmp := b[i*ldb : i*ldb+k]
   630  			ctmp := c[i*ldc : i*ldc+i+1]
   631  			if beta == 0 {
   632  				for j := 0; j <= i; j++ {
   633  					var tmp1, tmp2 float64
   634  					binner := b[j*ldb : j*ldb+k]
   635  					for l, v := range a[j*lda : j*lda+k] {
   636  						tmp1 += v * btmp[l]
   637  						tmp2 += atmp[l] * binner[l]
   638  					}
   639  					ctmp[j] = alpha * (tmp1 + tmp2)
   640  				}
   641  			} else {
   642  				for j := 0; j <= i; j++ {
   643  					var tmp1, tmp2 float64
   644  					binner := b[j*ldb : j*ldb+k]
   645  					for l, v := range a[j*lda : j*lda+k] {
   646  						tmp1 += v * btmp[l]
   647  						tmp2 += atmp[l] * binner[l]
   648  					}
   649  					ctmp[j] *= beta
   650  					ctmp[j] += alpha * (tmp1 + tmp2)
   651  				}
   652  			}
   653  		}
   654  		return
   655  	}
   656  	if ul == blas.Upper {
   657  		for i := 0; i < n; i++ {
   658  			ctmp := c[i*ldc+i : i*ldc+n]
   659  			switch beta {
   660  			case 0:
   661  				for j := range ctmp {
   662  					ctmp[j] = 0
   663  				}
   664  			case 1:
   665  			default:
   666  				for j := range ctmp {
   667  					ctmp[j] *= beta
   668  				}
   669  			}
   670  			for l := 0; l < k; l++ {
   671  				tmp1 := alpha * b[l*ldb+i]
   672  				tmp2 := alpha * a[l*lda+i]
   673  				btmp := b[l*ldb+i : l*ldb+n]
   674  				if tmp1 != 0 || tmp2 != 0 {
   675  					for j, v := range a[l*lda+i : l*lda+n] {
   676  						ctmp[j] += v*tmp1 + btmp[j]*tmp2
   677  					}
   678  				}
   679  			}
   680  		}
   681  		return
   682  	}
   683  	for i := 0; i < n; i++ {
   684  		ctmp := c[i*ldc : i*ldc+i+1]
   685  		switch beta {
   686  		case 0:
   687  			for j := range ctmp {
   688  				ctmp[j] = 0
   689  			}
   690  		case 1:
   691  		default:
   692  			for j := range ctmp {
   693  				ctmp[j] *= beta
   694  			}
   695  		}
   696  		for l := 0; l < k; l++ {
   697  			tmp1 := alpha * b[l*ldb+i]
   698  			tmp2 := alpha * a[l*lda+i]
   699  			btmp := b[l*ldb : l*ldb+i+1]
   700  			if tmp1 != 0 || tmp2 != 0 {
   701  				for j, v := range a[l*lda : l*lda+i+1] {
   702  					ctmp[j] += v*tmp1 + btmp[j]*tmp2
   703  				}
   704  			}
   705  		}
   706  	}
   707  }
   708  
   709  // Dtrmm performs one of the matrix-matrix operations
   710  //  B = alpha * A * B   if tA == blas.NoTrans and side == blas.Left
   711  //  B = alpha * Aᵀ * B  if tA == blas.Trans or blas.ConjTrans, and side == blas.Left
   712  //  B = alpha * B * A   if tA == blas.NoTrans and side == blas.Right
   713  //  B = alpha * B * Aᵀ  if tA == blas.Trans or blas.ConjTrans, and side == blas.Right
   714  // where A is an n×n or m×m triangular matrix, B is an m×n matrix, and alpha is a scalar.
   715  func (Implementation) Dtrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int) {
   716  	if s != blas.Left && s != blas.Right {
   717  		panic(badSide)
   718  	}
   719  	if ul != blas.Lower && ul != blas.Upper {
   720  		panic(badUplo)
   721  	}
   722  	if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
   723  		panic(badTranspose)
   724  	}
   725  	if d != blas.NonUnit && d != blas.Unit {
   726  		panic(badDiag)
   727  	}
   728  	if m < 0 {
   729  		panic(mLT0)
   730  	}
   731  	if n < 0 {
   732  		panic(nLT0)
   733  	}
   734  	k := n
   735  	if s == blas.Left {
   736  		k = m
   737  	}
   738  	if lda < max(1, k) {
   739  		panic(badLdA)
   740  	}
   741  	if ldb < max(1, n) {
   742  		panic(badLdB)
   743  	}
   744  
   745  	// Quick return if possible.
   746  	if m == 0 || n == 0 {
   747  		return
   748  	}
   749  
   750  	// For zero matrix size the following slice length checks are trivially satisfied.
   751  	if len(a) < lda*(k-1)+k {
   752  		panic(shortA)
   753  	}
   754  	if len(b) < ldb*(m-1)+n {
   755  		panic(shortB)
   756  	}
   757  
   758  	if alpha == 0 {
   759  		for i := 0; i < m; i++ {
   760  			btmp := b[i*ldb : i*ldb+n]
   761  			for j := range btmp {
   762  				btmp[j] = 0
   763  			}
   764  		}
   765  		return
   766  	}
   767  
   768  	nonUnit := d == blas.NonUnit
   769  	if s == blas.Left {
   770  		if tA == blas.NoTrans {
   771  			if ul == blas.Upper {
   772  				for i := 0; i < m; i++ {
   773  					tmp := alpha
   774  					if nonUnit {
   775  						tmp *= a[i*lda+i]
   776  					}
   777  					btmp := b[i*ldb : i*ldb+n]
   778  					f64.ScalUnitary(tmp, btmp)
   779  					for ka, va := range a[i*lda+i+1 : i*lda+m] {
   780  						k := ka + i + 1
   781  						if va != 0 {
   782  							f64.AxpyUnitary(alpha*va, b[k*ldb:k*ldb+n], btmp)
   783  						}
   784  					}
   785  				}
   786  				return
   787  			}
   788  			for i := m - 1; i >= 0; i-- {
   789  				tmp := alpha
   790  				if nonUnit {
   791  					tmp *= a[i*lda+i]
   792  				}
   793  				btmp := b[i*ldb : i*ldb+n]
   794  				f64.ScalUnitary(tmp, btmp)
   795  				for k, va := range a[i*lda : i*lda+i] {
   796  					if va != 0 {
   797  						f64.AxpyUnitary(alpha*va, b[k*ldb:k*ldb+n], btmp)
   798  					}
   799  				}
   800  			}
   801  			return
   802  		}
   803  		// Cases where a is transposed.
   804  		if ul == blas.Upper {
   805  			for k := m - 1; k >= 0; k-- {
   806  				btmpk := b[k*ldb : k*ldb+n]
   807  				for ia, va := range a[k*lda+k+1 : k*lda+m] {
   808  					i := ia + k + 1
   809  					btmp := b[i*ldb : i*ldb+n]
   810  					if va != 0 {
   811  						f64.AxpyUnitary(alpha*va, btmpk, btmp)
   812  					}
   813  				}
   814  				tmp := alpha
   815  				if nonUnit {
   816  					tmp *= a[k*lda+k]
   817  				}
   818  				if tmp != 1 {
   819  					f64.ScalUnitary(tmp, btmpk)
   820  				}
   821  			}
   822  			return
   823  		}
   824  		for k := 0; k < m; k++ {
   825  			btmpk := b[k*ldb : k*ldb+n]
   826  			for i, va := range a[k*lda : k*lda+k] {
   827  				btmp := b[i*ldb : i*ldb+n]
   828  				if va != 0 {
   829  					f64.AxpyUnitary(alpha*va, btmpk, btmp)
   830  				}
   831  			}
   832  			tmp := alpha
   833  			if nonUnit {
   834  				tmp *= a[k*lda+k]
   835  			}
   836  			if tmp != 1 {
   837  				f64.ScalUnitary(tmp, btmpk)
   838  			}
   839  		}
   840  		return
   841  	}
   842  	// Cases where a is on the right
   843  	if tA == blas.NoTrans {
   844  		if ul == blas.Upper {
   845  			for i := 0; i < m; i++ {
   846  				btmp := b[i*ldb : i*ldb+n]
   847  				for k := n - 1; k >= 0; k-- {
   848  					tmp := alpha * btmp[k]
   849  					if tmp == 0 {
   850  						continue
   851  					}
   852  					btmp[k] = tmp
   853  					if nonUnit {
   854  						btmp[k] *= a[k*lda+k]
   855  					}
   856  					f64.AxpyUnitary(tmp, a[k*lda+k+1:k*lda+n], btmp[k+1:n])
   857  				}
   858  			}
   859  			return
   860  		}
   861  		for i := 0; i < m; i++ {
   862  			btmp := b[i*ldb : i*ldb+n]
   863  			for k := 0; k < n; k++ {
   864  				tmp := alpha * btmp[k]
   865  				if tmp == 0 {
   866  					continue
   867  				}
   868  				btmp[k] = tmp
   869  				if nonUnit {
   870  					btmp[k] *= a[k*lda+k]
   871  				}
   872  				f64.AxpyUnitary(tmp, a[k*lda:k*lda+k], btmp[:k])
   873  			}
   874  		}
   875  		return
   876  	}
   877  	// Cases where a is transposed.
   878  	if ul == blas.Upper {
   879  		for i := 0; i < m; i++ {
   880  			btmp := b[i*ldb : i*ldb+n]
   881  			for j, vb := range btmp {
   882  				tmp := vb
   883  				if nonUnit {
   884  					tmp *= a[j*lda+j]
   885  				}
   886  				tmp += f64.DotUnitary(a[j*lda+j+1:j*lda+n], btmp[j+1:n])
   887  				btmp[j] = alpha * tmp
   888  			}
   889  		}
   890  		return
   891  	}
   892  	for i := 0; i < m; i++ {
   893  		btmp := b[i*ldb : i*ldb+n]
   894  		for j := n - 1; j >= 0; j-- {
   895  			tmp := btmp[j]
   896  			if nonUnit {
   897  				tmp *= a[j*lda+j]
   898  			}
   899  			tmp += f64.DotUnitary(a[j*lda:j*lda+j], btmp[:j])
   900  			btmp[j] = alpha * tmp
   901  		}
   902  	}
   903  }