gonum.org/v1/gonum@v0.14.0/mat/cmatrix.go (about)

     1  // Copyright ©2013 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"math"
     9  	"math/cmplx"
    10  
    11  	"gonum.org/v1/gonum/blas/cblas128"
    12  	"gonum.org/v1/gonum/floats/scalar"
    13  )
    14  
    15  // CMatrix is the basic matrix interface type for complex matrices.
    16  type CMatrix interface {
    17  	// Dims returns the dimensions of a CMatrix.
    18  	Dims() (r, c int)
    19  
    20  	// At returns the value of a matrix element at row i, column j.
    21  	// It will panic if i or j are out of bounds for the matrix.
    22  	At(i, j int) complex128
    23  
    24  	// H returns the conjugate transpose of the CMatrix. Whether H
    25  	// returns a copy of the underlying data is implementation dependent.
    26  	// This method may be implemented using the ConjTranspose type, which
    27  	// provides an implicit matrix conjugate transpose.
    28  	H() CMatrix
    29  
    30  	// T returns the transpose of the CMatrix. Whether T returns a copy of the
    31  	// underlying data is implementation dependent.
    32  	// This method may be implemented using the CTranspose type, which
    33  	// provides an implicit matrix transpose.
    34  	T() CMatrix
    35  }
    36  
    37  // A RawCMatrixer can return a cblas128.General representation of the receiver. Changes to the cblas128.General.Data
    38  // slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not.
    39  type RawCMatrixer interface {
    40  	RawCMatrix() cblas128.General
    41  }
    42  
    43  var (
    44  	_ CMatrix          = ConjTranspose{}
    45  	_ UnConjTransposer = ConjTranspose{}
    46  )
    47  
    48  // ConjTranspose is a type for performing an implicit matrix conjugate transpose.
    49  // It implements the CMatrix interface, returning values from the conjugate
    50  // transpose of the matrix within.
    51  type ConjTranspose struct {
    52  	CMatrix CMatrix
    53  }
    54  
    55  // At returns the value of the element at row i and column j of the conjugate
    56  // transposed matrix, that is, row j and column i of the CMatrix field.
    57  func (t ConjTranspose) At(i, j int) complex128 {
    58  	z := t.CMatrix.At(j, i)
    59  	return cmplx.Conj(z)
    60  }
    61  
    62  // Dims returns the dimensions of the transposed matrix. The number of rows returned
    63  // is the number of columns in the CMatrix field, and the number of columns is
    64  // the number of rows in the CMatrix field.
    65  func (t ConjTranspose) Dims() (r, c int) {
    66  	c, r = t.CMatrix.Dims()
    67  	return r, c
    68  }
    69  
    70  // H performs an implicit conjugate transpose by returning the CMatrix field.
    71  func (t ConjTranspose) H() CMatrix {
    72  	return t.CMatrix
    73  }
    74  
    75  // T performs an implicit transpose by returning the receiver inside a
    76  // CTranspose.
    77  func (t ConjTranspose) T() CMatrix {
    78  	return CTranspose{t}
    79  }
    80  
    81  // UnConjTranspose returns the CMatrix field.
    82  func (t ConjTranspose) UnConjTranspose() CMatrix {
    83  	return t.CMatrix
    84  }
    85  
    86  // CTranspose is a type for performing an implicit matrix conjugate transpose.
    87  // It implements the CMatrix interface, returning values from the conjugate
    88  // transpose of the matrix within.
    89  type CTranspose struct {
    90  	CMatrix CMatrix
    91  }
    92  
    93  // At returns the value of the element at row i and column j of the conjugate
    94  // transposed matrix, that is, row j and column i of the CMatrix field.
    95  func (t CTranspose) At(i, j int) complex128 {
    96  	return t.CMatrix.At(j, i)
    97  }
    98  
    99  // Dims returns the dimensions of the transposed matrix. The number of rows returned
   100  // is the number of columns in the CMatrix field, and the number of columns is
   101  // the number of rows in the CMatrix field.
   102  func (t CTranspose) Dims() (r, c int) {
   103  	c, r = t.CMatrix.Dims()
   104  	return r, c
   105  }
   106  
   107  // H performs an implicit transpose by returning the receiver inside a
   108  // ConjTranspose.
   109  func (t CTranspose) H() CMatrix {
   110  	return ConjTranspose{t}
   111  }
   112  
   113  // T performs an implicit conjugate transpose by returning the CMatrix field.
   114  func (t CTranspose) T() CMatrix {
   115  	return t.CMatrix
   116  }
   117  
   118  // Untranspose returns the CMatrix field.
   119  func (t CTranspose) Untranspose() CMatrix {
   120  	return t.CMatrix
   121  }
   122  
   123  // UnConjTransposer is a type that can undo an implicit conjugate transpose.
   124  type UnConjTransposer interface {
   125  	// UnConjTranspose returns the underlying CMatrix stored for the implicit
   126  	// conjugate transpose.
   127  	UnConjTranspose() CMatrix
   128  
   129  	// Note: This interface is needed to unify all of the Conjugate types. In
   130  	// the cmat128 methods, we need to test if the CMatrix has been implicitly
   131  	// transposed. If this is checked by testing for the specific Conjugate type
   132  	// then the behavior will be different if the user uses H() or HTri() for a
   133  	// triangular matrix.
   134  }
   135  
   136  // CUntransposer is a type that can undo an implicit transpose.
   137  type CUntransposer interface {
   138  	// Untranspose returns the underlying CMatrix stored for the implicit
   139  	// transpose.
   140  	Untranspose() CMatrix
   141  
   142  	// Note: This interface is needed to unify all of the CTranspose types. In
   143  	// the cmat128 methods, we need to test if the CMatrix has been implicitly
   144  	// transposed. If this is checked by testing for the specific CTranspose type
   145  	// then the behavior will be different if the user uses T() or TTri() for a
   146  	// triangular matrix.
   147  }
   148  
   149  // useC returns a complex128 slice with l elements, using c if it
   150  // has the necessary capacity, otherwise creating a new slice.
   151  func useC(c []complex128, l int) []complex128 {
   152  	if l <= cap(c) {
   153  		return c[:l]
   154  	}
   155  	return make([]complex128, l)
   156  }
   157  
   158  // useZeroedC returns a complex128 slice with l elements, using c if it
   159  // has the necessary capacity, otherwise creating a new slice. The
   160  // elements of the returned slice are guaranteed to be zero.
   161  func useZeroedC(c []complex128, l int) []complex128 {
   162  	if l <= cap(c) {
   163  		c = c[:l]
   164  		zeroC(c)
   165  		return c
   166  	}
   167  	return make([]complex128, l)
   168  }
   169  
   170  // zeroC zeros the given slice's elements.
   171  func zeroC(c []complex128) {
   172  	for i := range c {
   173  		c[i] = 0
   174  	}
   175  }
   176  
   177  // untransposeCmplx untransposes a matrix if applicable. If a is an CUntransposer
   178  // or an UnConjTransposer, then untranspose returns the underlying matrix and true for
   179  // the kind of transpose (potentially both).
   180  // If it is not, then it returns the input matrix and false for trans and conj.
   181  func untransposeCmplx(a CMatrix) (u CMatrix, trans, conj bool) {
   182  	switch ut := a.(type) {
   183  	case CUntransposer:
   184  		trans = true
   185  		u := ut.Untranspose()
   186  		if uc, ok := u.(UnConjTransposer); ok {
   187  			return uc.UnConjTranspose(), trans, true
   188  		}
   189  		return u, trans, false
   190  	case UnConjTransposer:
   191  		conj = true
   192  		u := ut.UnConjTranspose()
   193  		if ut, ok := u.(CUntransposer); ok {
   194  			return ut.Untranspose(), true, conj
   195  		}
   196  		return u, false, conj
   197  	default:
   198  		return a, false, false
   199  	}
   200  }
   201  
   202  // untransposeExtractCmplx returns an untransposed matrix in a built-in matrix type.
   203  //
   204  // The untransposed matrix is returned unaltered if it is a built-in matrix type.
   205  // Otherwise, if it implements a Raw method, an appropriate built-in type value
   206  // is returned holding the raw matrix value of the input. If neither of these
   207  // is possible, the untransposed matrix is returned.
   208  func untransposeExtractCmplx(a CMatrix) (u CMatrix, trans, conj bool) {
   209  	ut, trans, conj := untransposeCmplx(a)
   210  	switch m := ut.(type) {
   211  	case *CDense:
   212  		return m, trans, conj
   213  	case RawCMatrixer:
   214  		var d CDense
   215  		d.SetRawCMatrix(m.RawCMatrix())
   216  		return &d, trans, conj
   217  	default:
   218  		return ut, trans, conj
   219  	}
   220  }
   221  
   222  // CEqual returns whether the matrices a and b have the same size
   223  // and are element-wise equal.
   224  func CEqual(a, b CMatrix) bool {
   225  	ar, ac := a.Dims()
   226  	br, bc := b.Dims()
   227  	if ar != br || ac != bc {
   228  		return false
   229  	}
   230  	// TODO(btracey): Add in fast-paths.
   231  	for i := 0; i < ar; i++ {
   232  		for j := 0; j < ac; j++ {
   233  			if a.At(i, j) != b.At(i, j) {
   234  				return false
   235  			}
   236  		}
   237  	}
   238  	return true
   239  }
   240  
   241  // CEqualApprox returns whether the matrices a and b have the same size and contain all equal
   242  // elements with tolerance for element-wise equality specified by epsilon. Matrices
   243  // with non-equal shapes are not equal.
   244  func CEqualApprox(a, b CMatrix, epsilon float64) bool {
   245  	// TODO(btracey):
   246  	ar, ac := a.Dims()
   247  	br, bc := b.Dims()
   248  	if ar != br || ac != bc {
   249  		return false
   250  	}
   251  	for i := 0; i < ar; i++ {
   252  		for j := 0; j < ac; j++ {
   253  			if !cEqualWithinAbsOrRel(a.At(i, j), b.At(i, j), epsilon, epsilon) {
   254  				return false
   255  			}
   256  		}
   257  	}
   258  	return true
   259  }
   260  
   261  // TODO(btracey): Move these into a cmplxs if/when we have one.
   262  
   263  func cEqualWithinAbsOrRel(a, b complex128, absTol, relTol float64) bool {
   264  	if cEqualWithinAbs(a, b, absTol) {
   265  		return true
   266  	}
   267  	return cEqualWithinRel(a, b, relTol)
   268  }
   269  
   270  // cEqualWithinAbs returns true if a and b have an absolute
   271  // difference of less than tol.
   272  func cEqualWithinAbs(a, b complex128, tol float64) bool {
   273  	return a == b || cmplx.Abs(a-b) <= tol
   274  }
   275  
   276  const minNormalFloat64 = 2.2250738585072014e-308
   277  
   278  // cEqualWithinRel returns true if the difference between a and b
   279  // is not greater than tol times the greater value.
   280  func cEqualWithinRel(a, b complex128, tol float64) bool {
   281  	if a == b {
   282  		return true
   283  	}
   284  	if cmplx.IsNaN(a) || cmplx.IsNaN(b) {
   285  		return false
   286  	}
   287  	// Cannot play the same trick as in floats/scalar because there are multiple
   288  	// possible infinities.
   289  	if cmplx.IsInf(a) {
   290  		if !cmplx.IsInf(b) {
   291  			return false
   292  		}
   293  		ra := real(a)
   294  		if math.IsInf(ra, 0) {
   295  			if ra == real(b) {
   296  				return scalar.EqualWithinRel(imag(a), imag(b), tol)
   297  			}
   298  			return false
   299  		}
   300  		if imag(a) == imag(b) {
   301  			return scalar.EqualWithinRel(ra, real(b), tol)
   302  		}
   303  		return false
   304  	}
   305  	if cmplx.IsInf(b) {
   306  		return false
   307  	}
   308  
   309  	delta := cmplx.Abs(a - b)
   310  	if delta <= minNormalFloat64 {
   311  		return delta <= tol*minNormalFloat64
   312  	}
   313  	return delta/math.Max(cmplx.Abs(a), cmplx.Abs(b)) <= tol
   314  }