github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/mat/cmatrix.go (about) 1 // Copyright ©2013 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import ( 8 "math" 9 "math/cmplx" 10 11 "github.com/jingcheng-WU/gonum/blas/cblas128" 12 "github.com/jingcheng-WU/gonum/floats/scalar" 13 ) 14 15 // CMatrix is the basic matrix interface type for complex matrices. 16 type CMatrix interface { 17 // Dims returns the dimensions of a CMatrix. 18 Dims() (r, c int) 19 20 // At returns the value of a matrix element at row i, column j. 21 // It will panic if i or j are out of bounds for the matrix. 22 At(i, j int) complex128 23 24 // H returns the conjugate transpose of the CMatrix. Whether H 25 // returns a copy of the underlying data is implementation dependent. 26 // This method may be implemented using the ConjTranspose type, which 27 // provides an implicit matrix conjugate transpose. 28 H() CMatrix 29 30 // T returns the transpose of the CMatrix. Whether T returns a copy of the 31 // underlying data is implementation dependent. 32 // This method may be implemented using the CTranspose type, which 33 // provides an implicit matrix transpose. 34 T() CMatrix 35 } 36 37 // A RawCMatrixer can return a cblas128.General representation of the receiver. Changes to the cblas128.General.Data 38 // slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not. 39 type RawCMatrixer interface { 40 RawCMatrix() cblas128.General 41 } 42 43 var ( 44 _ CMatrix = ConjTranspose{} 45 _ UnConjTransposer = ConjTranspose{} 46 ) 47 48 // ConjTranspose is a type for performing an implicit matrix conjugate transpose. 49 // It implements the CMatrix interface, returning values from the conjugate 50 // transpose of the matrix within. 51 type ConjTranspose struct { 52 CMatrix CMatrix 53 } 54 55 // At returns the value of the element at row i and column j of the conjugate 56 // transposed matrix, that is, row j and column i of the CMatrix field. 57 func (t ConjTranspose) At(i, j int) complex128 { 58 z := t.CMatrix.At(j, i) 59 return cmplx.Conj(z) 60 } 61 62 // Dims returns the dimensions of the transposed matrix. The number of rows returned 63 // is the number of columns in the CMatrix field, and the number of columns is 64 // the number of rows in the CMatrix field. 65 func (t ConjTranspose) Dims() (r, c int) { 66 c, r = t.CMatrix.Dims() 67 return r, c 68 } 69 70 // H performs an implicit conjugate transpose by returning the CMatrix field. 71 func (t ConjTranspose) H() CMatrix { 72 return t.CMatrix 73 } 74 75 // T performs an implicit transpose by returning the receiver inside a 76 // CTranspose. 77 func (t ConjTranspose) T() CMatrix { 78 return CTranspose{t} 79 } 80 81 // UnConjTranspose returns the CMatrix field. 82 func (t ConjTranspose) UnConjTranspose() CMatrix { 83 return t.CMatrix 84 } 85 86 // CTranspose is a type for performing an implicit matrix conjugate transpose. 87 // It implements the CMatrix interface, returning values from the conjugate 88 // transpose of the matrix within. 89 type CTranspose struct { 90 CMatrix CMatrix 91 } 92 93 // At returns the value of the element at row i and column j of the conjugate 94 // transposed matrix, that is, row j and column i of the CMatrix field. 95 func (t CTranspose) At(i, j int) complex128 { 96 return t.CMatrix.At(j, i) 97 } 98 99 // Dims returns the dimensions of the transposed matrix. The number of rows returned 100 // is the number of columns in the CMatrix field, and the number of columns is 101 // the number of rows in the CMatrix field. 102 func (t CTranspose) Dims() (r, c int) { 103 c, r = t.CMatrix.Dims() 104 return r, c 105 } 106 107 // H performs an implicit transpose by returning the receiver inside a 108 // ConjTranspose. 109 func (t CTranspose) H() CMatrix { 110 return ConjTranspose{t} 111 } 112 113 // T performs an implicit conjugate transpose by returning the CMatrix field. 114 func (t CTranspose) T() CMatrix { 115 return t.CMatrix 116 } 117 118 // Untranspose returns the CMatrix field. 119 func (t CTranspose) Untranspose() CMatrix { 120 return t.CMatrix 121 } 122 123 // UnConjTransposer is a type that can undo an implicit conjugate transpose. 124 type UnConjTransposer interface { 125 // UnConjTranspose returns the underlying CMatrix stored for the implicit 126 // conjugate transpose. 127 UnConjTranspose() CMatrix 128 129 // Note: This interface is needed to unify all of the Conjugate types. In 130 // the cmat128 methods, we need to test if the CMatrix has been implicitly 131 // transposed. If this is checked by testing for the specific Conjugate type 132 // then the behavior will be different if the user uses H() or HTri() for a 133 // triangular matrix. 134 } 135 136 // CUntransposer is a type that can undo an implicit transpose. 137 type CUntransposer interface { 138 // Untranspose returns the underlying CMatrix stored for the implicit 139 // transpose. 140 Untranspose() CMatrix 141 142 // Note: This interface is needed to unify all of the CTranspose types. In 143 // the cmat128 methods, we need to test if the CMatrix has been implicitly 144 // transposed. If this is checked by testing for the specific CTranspose type 145 // then the behavior will be different if the user uses T() or TTri() for a 146 // triangular matrix. 147 } 148 149 // useC returns a complex128 slice with l elements, using c if it 150 // has the necessary capacity, otherwise creating a new slice. 151 func useC(c []complex128, l int) []complex128 { 152 if l <= cap(c) { 153 return c[:l] 154 } 155 return make([]complex128, l) 156 } 157 158 // useZeroedC returns a complex128 slice with l elements, using c if it 159 // has the necessary capacity, otherwise creating a new slice. The 160 // elements of the returned slice are guaranteed to be zero. 161 func useZeroedC(c []complex128, l int) []complex128 { 162 if l <= cap(c) { 163 c = c[:l] 164 zeroC(c) 165 return c 166 } 167 return make([]complex128, l) 168 } 169 170 // zeroC zeros the given slice's elements. 171 func zeroC(c []complex128) { 172 for i := range c { 173 c[i] = 0 174 } 175 } 176 177 // untransposeCmplx untransposes a matrix if applicable. If a is an CUntransposer 178 // or an UnConjTransposer, then untranspose returns the underlying matrix and true for 179 // the kind of transpose (potentially both). 180 // If it is not, then it returns the input matrix and false for trans and conj. 181 func untransposeCmplx(a CMatrix) (u CMatrix, trans, conj bool) { 182 switch ut := a.(type) { 183 case CUntransposer: 184 trans = true 185 u := ut.Untranspose() 186 if uc, ok := u.(UnConjTransposer); ok { 187 return uc.UnConjTranspose(), trans, true 188 } 189 return u, trans, false 190 case UnConjTransposer: 191 conj = true 192 u := ut.UnConjTranspose() 193 if ut, ok := u.(CUntransposer); ok { 194 return ut.Untranspose(), true, conj 195 } 196 return u, false, conj 197 default: 198 return a, false, false 199 } 200 } 201 202 // untransposeExtractCmplx returns an untransposed matrix in a built-in matrix type. 203 // 204 // The untransposed matrix is returned unaltered if it is a built-in matrix type. 205 // Otherwise, if it implements a Raw method, an appropriate built-in type value 206 // is returned holding the raw matrix value of the input. If neither of these 207 // is possible, the untransposed matrix is returned. 208 func untransposeExtractCmplx(a CMatrix) (u CMatrix, trans, conj bool) { 209 ut, trans, conj := untransposeCmplx(a) 210 switch m := ut.(type) { 211 case *CDense: 212 return m, trans, conj 213 case RawCMatrixer: 214 var d CDense 215 d.SetRawCMatrix(m.RawCMatrix()) 216 return &d, trans, conj 217 default: 218 return ut, trans, conj 219 } 220 } 221 222 // CEqual returns whether the matrices a and b have the same size 223 // and are element-wise equal. 224 func CEqual(a, b CMatrix) bool { 225 ar, ac := a.Dims() 226 br, bc := b.Dims() 227 if ar != br || ac != bc { 228 return false 229 } 230 // TODO(btracey): Add in fast-paths. 231 for i := 0; i < ar; i++ { 232 for j := 0; j < ac; j++ { 233 if a.At(i, j) != b.At(i, j) { 234 return false 235 } 236 } 237 } 238 return true 239 } 240 241 // CEqualApprox returns whether the matrices a and b have the same size and contain all equal 242 // elements with tolerance for element-wise equality specified by epsilon. Matrices 243 // with non-equal shapes are not equal. 244 func CEqualApprox(a, b CMatrix, epsilon float64) bool { 245 // TODO(btracey): 246 ar, ac := a.Dims() 247 br, bc := b.Dims() 248 if ar != br || ac != bc { 249 return false 250 } 251 for i := 0; i < ar; i++ { 252 for j := 0; j < ac; j++ { 253 if !cEqualWithinAbsOrRel(a.At(i, j), b.At(i, j), epsilon, epsilon) { 254 return false 255 } 256 } 257 } 258 return true 259 } 260 261 // TODO(btracey): Move these into a cmplxs if/when we have one. 262 263 func cEqualWithinAbsOrRel(a, b complex128, absTol, relTol float64) bool { 264 if cEqualWithinAbs(a, b, absTol) { 265 return true 266 } 267 return cEqualWithinRel(a, b, relTol) 268 } 269 270 // cEqualWithinAbs returns true if a and b have an absolute 271 // difference of less than tol. 272 func cEqualWithinAbs(a, b complex128, tol float64) bool { 273 return a == b || cmplx.Abs(a-b) <= tol 274 } 275 276 const minNormalFloat64 = 2.2250738585072014e-308 277 278 // cEqualWithinRel returns true if the difference between a and b 279 // is not greater than tol times the greater value. 280 func cEqualWithinRel(a, b complex128, tol float64) bool { 281 if a == b { 282 return true 283 } 284 if cmplx.IsNaN(a) || cmplx.IsNaN(b) { 285 return false 286 } 287 // Cannot play the same trick as in floats/scalar because there are multiple 288 // possible infinities. 289 if cmplx.IsInf(a) { 290 if !cmplx.IsInf(b) { 291 return false 292 } 293 ra := real(a) 294 if math.IsInf(ra, 0) { 295 if ra == real(b) { 296 return scalar.EqualWithinRel(imag(a), imag(b), tol) 297 } 298 return false 299 } 300 if imag(a) == imag(b) { 301 return scalar.EqualWithinRel(ra, real(b), tol) 302 } 303 return false 304 } 305 if cmplx.IsInf(b) { 306 return false 307 } 308 309 delta := cmplx.Abs(a - b) 310 if delta <= minNormalFloat64 { 311 return delta <= tol*minNormalFloat64 312 } 313 return delta/math.Max(cmplx.Abs(a), cmplx.Abs(b)) <= tol 314 }