go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/matrix/matrix.go (about) 1 /* 2 3 Copyright (c) 2024 - Present. Will Charczuk. All rights reserved. 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository. 5 6 */ 7 8 package matrix 9 10 import ( 11 "bytes" 12 "errors" 13 "fmt" 14 "math" 15 ) 16 17 const ( 18 // DefaultEpsilon represents the minimum precision for matrix math operations. 19 DefaultEpsilon = 0.000001 20 ) 21 22 var ( 23 // ErrDimensionMismatch is a typical error. 24 ErrDimensionMismatch = errors.New("dimension mismatch") 25 26 // ErrSingularValue is a typical error. 27 ErrSingularValue = errors.New("singular value") 28 ) 29 30 // New returns a new matrix. 31 func New(rows, cols int, values ...float64) *Matrix { 32 if len(values) == 0 { 33 return &Matrix{ 34 stride: cols, 35 epsilon: DefaultEpsilon, 36 elements: make([]float64, rows*cols), 37 } 38 } 39 elems := make([]float64, rows*cols) 40 copy(elems, values) 41 return &Matrix{ 42 stride: cols, 43 epsilon: DefaultEpsilon, 44 elements: elems, 45 } 46 } 47 48 // Identity returns the identity matrix of a given order. 49 func Identity(order int) *Matrix { 50 m := New(order, order) 51 for i := 0; i < order; i++ { 52 m.Set(i, i, 1) 53 } 54 return m 55 } 56 57 // Zero returns a matrix of a given size zeroed. 58 func Zero(rows, cols int) *Matrix { 59 return New(rows, cols) 60 } 61 62 // Ones returns an matrix of ones. 63 func Ones(rows, cols int) *Matrix { 64 ones := make([]float64, rows*cols) 65 for i := 0; i < (rows * cols); i++ { 66 ones[i] = 1 67 } 68 69 return &Matrix{ 70 stride: cols, 71 epsilon: DefaultEpsilon, 72 elements: ones, 73 } 74 } 75 76 // Eye returns the eye matrix. 77 func Eye(n int) *Matrix { 78 m := Zero(n, n) 79 for i := 0; i < len(m.elements); i += n + 1 { 80 m.elements[i] = 1 81 } 82 return m 83 } 84 85 // NewFromArrays creates a matrix from a jagged array set. 86 func NewFromArrays(a [][]float64) *Matrix { 87 rows := len(a) 88 if rows == 0 { 89 return nil 90 } 91 cols := len(a[0]) 92 m := New(rows, cols) 93 for row := 0; row < rows; row++ { 94 for col := 0; col < cols; col++ { 95 m.Set(row, col, a[row][col]) 96 } 97 } 98 return m 99 } 100 101 // Matrix represents a 2d dense array of floats. 102 type Matrix struct { 103 epsilon float64 104 elements []float64 105 stride int 106 } 107 108 // String returns a string representation of the matrix. 109 func (m *Matrix) String() string { 110 buffer := bytes.NewBuffer(nil) 111 rows, cols := m.Size() 112 113 for row := 0; row < rows; row++ { 114 for col := 0; col < cols; col++ { 115 buffer.WriteString(f64s(m.Get(row, col))) 116 buffer.WriteRune(' ') 117 } 118 buffer.WriteRune('\n') 119 } 120 return buffer.String() 121 } 122 123 // Epsilon returns the maximum precision for math operations. 124 func (m *Matrix) Epsilon() float64 { 125 return m.epsilon 126 } 127 128 // WithEpsilon sets the epsilon on the matrix and returns a reference to the matrix. 129 func (m *Matrix) WithEpsilon(epsilon float64) *Matrix { 130 m.epsilon = epsilon 131 return m 132 } 133 134 // Each applies the action to each element of the matrix in 135 // rows => cols order. 136 func (m *Matrix) Each(action func(row, col int, value float64)) { 137 rows, cols := m.Size() 138 for row := 0; row < rows; row++ { 139 for col := 0; col < cols; col++ { 140 action(row, col, m.Get(row, col)) 141 } 142 } 143 } 144 145 // Round rounds all the values in a matrix to it epsilon, 146 // returning a reference to the original 147 func (m *Matrix) Round() *Matrix { 148 rows, cols := m.Size() 149 for row := 0; row < rows; row++ { 150 for col := 0; col < cols; col++ { 151 m.Set(row, col, roundToEpsilon(m.Get(row, col), m.epsilon)) 152 } 153 } 154 return m 155 } 156 157 // Arrays returns the matrix as a two dimensional jagged array. 158 func (m *Matrix) Arrays() [][]float64 { 159 rows, cols := m.Size() 160 a := make([][]float64, rows) 161 162 for row := 0; row < rows; row++ { 163 a[row] = make([]float64, cols) 164 165 for col := 0; col < cols; col++ { 166 a[row][col] = m.Get(row, col) 167 } 168 } 169 return a 170 } 171 172 // Size returns the dimensions of the matrix. 173 func (m *Matrix) Size() (rows, cols int) { 174 rows = len(m.elements) / m.stride 175 cols = m.stride 176 return 177 } 178 179 // IsSquare returns if the row count is equal to the column count. 180 func (m *Matrix) IsSquare() bool { 181 return m.stride == (len(m.elements) / m.stride) 182 } 183 184 // IsSymmetric returns if the matrix is symmetric about its diagonal. 185 func (m *Matrix) IsSymmetric() bool { 186 rows, cols := m.Size() 187 188 if rows != cols { 189 return false 190 } 191 192 for i := 0; i < rows; i++ { 193 for j := 0; j < i; j++ { 194 if m.Get(i, j) != m.Get(j, i) { 195 return false 196 } 197 } 198 } 199 return true 200 } 201 202 // Get returns the element at the given row, col. 203 func (m *Matrix) Get(row, col int) float64 { 204 index := (m.stride * row) + col 205 return m.elements[index] 206 } 207 208 // Set sets a value. 209 func (m *Matrix) Set(row, col int, val float64) { 210 index := (m.stride * row) + col 211 m.elements[index] = val 212 } 213 214 // Col returns a column of the matrix as a vector. 215 func (m *Matrix) Col(col int) Vector { 216 rows, _ := m.Size() 217 values := make([]float64, rows) 218 for row := 0; row < rows; row++ { 219 values[row] = m.Get(row, col) 220 } 221 return Vector(values) 222 } 223 224 // Row returns a row of the matrix as a vector. 225 func (m *Matrix) Row(row int) Vector { 226 _, cols := m.Size() 227 values := make([]float64, cols) 228 for col := 0; col < cols; col++ { 229 values[col] = m.Get(row, col) 230 } 231 return Vector(values) 232 } 233 234 // SubMatrix returns a sub matrix from a given outer matrix. 235 func (m *Matrix) SubMatrix(i, j, rows, cols int) *Matrix { 236 return &Matrix{ 237 elements: m.elements[i*m.stride+j : i*m.stride+j+(rows-1)*m.stride+cols], 238 stride: m.stride, 239 epsilon: m.epsilon, 240 } 241 } 242 243 // ScaleRow applies a scale to an entire row. 244 func (m *Matrix) ScaleRow(row int, scale float64) { 245 startIndex := row * m.stride 246 for i := startIndex; i < m.stride; i++ { 247 m.elements[i] = m.elements[i] * scale 248 } 249 } 250 251 func (m *Matrix) scaleAddRow(rd int, rs int, f float64) { 252 indexd := rd * m.stride 253 indexs := rs * m.stride 254 for col := 0; col < m.stride; col++ { 255 m.elements[indexd] += f * m.elements[indexs] 256 indexd++ 257 indexs++ 258 } 259 } 260 261 // SwapRows swaps a row in the matrix in place. 262 func (m *Matrix) SwapRows(i, j int) { 263 var vi, vj float64 264 for col := 0; col < m.stride; col++ { 265 vi = m.Get(i, col) 266 vj = m.Get(j, col) 267 m.Set(i, col, vj) 268 m.Set(j, col, vi) 269 } 270 } 271 272 // Augment concatenates two matrices about the horizontal. 273 func (m *Matrix) Augment(m2 *Matrix) (*Matrix, error) { 274 mr, mc := m.Size() 275 m2r, m2c := m2.Size() 276 if mr != m2r { 277 return nil, ErrDimensionMismatch 278 } 279 280 m3 := Zero(mr, mc+m2c) 281 for row := 0; row < mr; row++ { 282 for col := 0; col < mc; col++ { 283 m3.Set(row, col, m.Get(row, col)) 284 } 285 for col := 0; col < m2c; col++ { 286 m3.Set(row, mc+col, m2.Get(row, col)) 287 } 288 } 289 return m3, nil 290 } 291 292 // Copy returns a duplicate of a given matrix. 293 func (m *Matrix) Copy() *Matrix { 294 m2 := &Matrix{stride: m.stride, epsilon: m.epsilon, elements: make([]float64, len(m.elements))} 295 copy(m2.elements, m.elements) 296 return m2 297 } 298 299 // DiagonalVector returns a vector from the diagonal of a matrix. 300 func (m *Matrix) DiagonalVector() Vector { 301 rows, cols := m.Size() 302 rank := minInt(rows, cols) 303 values := make([]float64, rank) 304 305 for index := 0; index < rank; index++ { 306 values[index] = m.Get(index, index) 307 } 308 return Vector(values) 309 } 310 311 // Diagonal returns a matrix from the diagonal of a matrix. 312 func (m *Matrix) Diagonal() *Matrix { 313 rows, cols := m.Size() 314 rank := minInt(rows, cols) 315 m2 := New(rank, rank) 316 317 for index := 0; index < rank; index++ { 318 m2.Set(index, index, m.Get(index, index)) 319 } 320 return m2 321 } 322 323 // Equals returns if a matrix equals another matrix. 324 func (m *Matrix) Equals(other *Matrix) bool { 325 if m != nil && other == nil { 326 return false 327 } else if other == nil { 328 return true 329 } 330 331 if m.stride != other.stride { 332 return false 333 } 334 335 msize := len(m.elements) 336 m2size := len(other.elements) 337 338 if msize != m2size { 339 return false 340 } 341 342 for i := 0; i < msize; i++ { 343 if m.elements[i] != other.elements[i] { 344 return false 345 } 346 } 347 return true 348 } 349 350 // EqualsEpsilon returns if a matrix element-wise in epsilon to another matrix. 351 func (m *Matrix) EqualsEpsilon(other *Matrix) bool { 352 if m != nil && other == nil { 353 return false 354 } else if other == nil { 355 return true 356 } 357 358 if m.stride != other.stride { 359 return false 360 } 361 362 msize := len(m.elements) 363 m2size := len(other.elements) 364 365 if msize != m2size { 366 return false 367 } 368 369 for i := 0; i < msize; i++ { 370 if !inEpsilon(m.elements[i], other.elements[i], m.epsilon) { 371 return false 372 } 373 } 374 return true 375 } 376 377 func inEpsilon(a, b, epsilon float64) bool { 378 if a > b { 379 return (a - b) < epsilon 380 } 381 return (b - a) < epsilon 382 } 383 384 // L returns the matrix with zeros below the diagonal. 385 func (m *Matrix) L() *Matrix { 386 rows, cols := m.Size() 387 m2 := New(rows, cols) 388 for row := 0; row < rows; row++ { 389 for col := row; col < cols; col++ { 390 m2.Set(row, col, m.Get(row, col)) 391 } 392 } 393 return m2 394 } 395 396 // U returns the matrix with zeros above the diagonal. 397 // Does not include the diagonal. 398 func (m *Matrix) U() *Matrix { 399 rows, cols := m.Size() 400 m2 := New(rows, cols) 401 for row := 0; row < rows; row++ { 402 for col := 0; col < row && col < cols; col++ { 403 m2.Set(row, col, m.Get(row, col)) 404 } 405 } 406 return m2 407 } 408 409 // math operations 410 411 // Multiply multiplies two matrices. 412 func (m *Matrix) Multiply(m2 *Matrix) (m3 *Matrix, err error) { 413 if m.stride*m2.stride != len(m2.elements) { 414 return nil, ErrDimensionMismatch 415 } 416 417 m3 = &Matrix{epsilon: m.epsilon, stride: m2.stride, elements: make([]float64, (len(m.elements)/m.stride)*m2.stride)} 418 for m1c0, m3x := 0, 0; m1c0 < len(m.elements); m1c0 += m.stride { 419 for m2r0 := 0; m2r0 < m2.stride; m2r0++ { 420 for m1x, m2x := m1c0, m2r0; m2x < len(m2.elements); m2x += m2.stride { 421 m3.elements[m3x] += m.elements[m1x] * m2.elements[m2x] 422 m1x++ 423 } 424 m3x++ 425 } 426 } 427 return 428 } 429 430 // Pivotize does something i'm not sure what. 431 func (m *Matrix) Pivotize() *Matrix { 432 pv := make([]int, m.stride) 433 434 for i := range pv { 435 pv[i] = i 436 } 437 438 for j, dx := 0, 0; j < m.stride; j++ { 439 row := j 440 max := m.elements[dx] 441 for i, ixcj := j, dx; i < m.stride; i++ { 442 if m.elements[ixcj] > max { 443 max = m.elements[ixcj] 444 row = i 445 } 446 ixcj += m.stride 447 } 448 if j != row { 449 pv[row], pv[j] = pv[j], pv[row] 450 } 451 dx += m.stride + 1 452 } 453 p := Zero(m.stride, m.stride) 454 for r, c := range pv { 455 p.elements[r*m.stride+c] = 1 456 } 457 return p 458 } 459 460 // Times returns the product of a matrix and another. 461 func (m *Matrix) Times(m2 *Matrix) (*Matrix, error) { 462 mr, mc := m.Size() 463 m2r, m2c := m2.Size() 464 465 if mc != m2r { 466 return nil, fmt.Errorf("cannot multiply (%dx%d) and (%dx%d)", mr, mc, m2r, m2c) 467 //return nil, ErrDimensionMismatch 468 } 469 470 c := Zero(mr, m2c) 471 472 for i := 0; i < mr; i++ { 473 sums := c.elements[i*c.stride : (i+1)*c.stride] 474 for k, a := range m.elements[i*m.stride : i*m.stride+m.stride] { 475 for j, b := range m2.elements[k*m2.stride : k*m2.stride+m2.stride] { 476 sums[j] += a * b 477 } 478 } 479 } 480 481 return c, nil 482 } 483 484 // Decompositions 485 486 // LU performs the LU decomposition. 487 func (m *Matrix) LU() (l, u, p *Matrix) { 488 l = Zero(m.stride, m.stride) 489 u = Zero(m.stride, m.stride) 490 p = m.Pivotize() 491 m, _ = p.Multiply(m) 492 for j, jxc0 := 0, 0; j < m.stride; j++ { 493 l.elements[jxc0+j] = 1 494 for i, ixc0 := 0, 0; ixc0 <= jxc0; i++ { 495 sum := 0. 496 for k, kxcj := 0, j; k < i; k++ { 497 sum += u.elements[kxcj] * l.elements[ixc0+k] 498 kxcj += m.stride 499 } 500 u.elements[ixc0+j] = m.elements[ixc0+j] - sum 501 ixc0 += m.stride 502 } 503 for ixc0 := jxc0; ixc0 < len(m.elements); ixc0 += m.stride { 504 sum := 0. 505 for k, kxcj := 0, j; k < j; k++ { 506 sum += u.elements[kxcj] * l.elements[ixc0+k] 507 kxcj += m.stride 508 } 509 l.elements[ixc0+j] = (m.elements[ixc0+j] - sum) / u.elements[jxc0+j] 510 } 511 jxc0 += m.stride 512 } 513 return 514 } 515 516 // QR performs the qr decomposition. 517 func (m *Matrix) QR() (q, r *Matrix) { 518 defer func() { 519 q = q.Round() 520 r = r.Round() 521 }() 522 523 rows, cols := m.Size() 524 qr := m.Copy() 525 q = New(rows, cols) 526 r = New(rows, cols) 527 528 var i, j, k int 529 var norm, s float64 530 531 for k = 0; k < cols; k++ { 532 norm = 0 533 for i = k; i < rows; i++ { 534 norm = math.Hypot(norm, qr.Get(i, k)) 535 } 536 537 if norm != 0 { 538 if qr.Get(k, k) < 0 { 539 norm = -norm 540 } 541 542 for i = k; i < rows; i++ { 543 qr.Set(i, k, qr.Get(i, k)/norm) 544 } 545 qr.Set(k, k, qr.Get(k, k)+1.0) 546 547 for j = k + 1; j < cols; j++ { 548 s = 0 549 for i = k; i < rows; i++ { 550 s += qr.Get(i, k) * qr.Get(i, j) 551 } 552 s = -s / qr.Get(k, k) 553 for i = k; i < rows; i++ { 554 qr.Set(i, j, qr.Get(i, j)+s*qr.Get(i, k)) 555 556 if i < j { 557 r.Set(i, j, qr.Get(i, j)) 558 } 559 } 560 561 } 562 } 563 564 r.Set(k, k, -norm) 565 566 } 567 568 //Q Matrix: 569 // this assignment is ineffectual? 570 // i, j, k = 0, 0, 0 571 572 for k = cols - 1; k >= 0; k-- { 573 q.Set(k, k, 1.0) 574 for j = k; j < cols; j++ { 575 if qr.Get(k, k) != 0 { 576 s = 0 577 for i = k; i < rows; i++ { 578 s += qr.Get(i, k) * q.Get(i, j) 579 } 580 s = -s / qr.Get(k, k) 581 for i = k; i < rows; i++ { 582 q.Set(i, j, q.Get(i, j)+s*qr.Get(i, k)) 583 } 584 } 585 } 586 } 587 588 return 589 } 590 591 // Transpose flips a matrix about its diagonal, returning a new copy. 592 func (m *Matrix) Transpose() *Matrix { 593 rows, cols := m.Size() 594 m2 := Zero(cols, rows) 595 for i := 0; i < rows; i++ { 596 for j := 0; j < cols; j++ { 597 m2.Set(j, i, m.Get(i, j)) 598 } 599 } 600 return m2 601 } 602 603 // Inverse returns a matrix such that M*I==1. 604 func (m *Matrix) Inverse() (*Matrix, error) { 605 if !m.IsSymmetric() { 606 return nil, ErrDimensionMismatch 607 } 608 609 rows, cols := m.Size() 610 611 aug, _ := m.Augment(Eye(rows)) 612 for i := 0; i < rows; i++ { 613 j := i 614 for k := i; k < rows; k++ { 615 if math.Abs(aug.Get(k, i)) > math.Abs(aug.Get(j, i)) { 616 j = k 617 } 618 } 619 if j != i { 620 aug.SwapRows(i, j) 621 } 622 if aug.Get(i, i) == 0 { 623 return nil, ErrSingularValue 624 } 625 aug.ScaleRow(i, 1.0/aug.Get(i, i)) 626 for k := 0; k < rows; k++ { 627 if k == i { 628 continue 629 } 630 aug.scaleAddRow(k, i, -aug.Get(k, i)) 631 } 632 } 633 return aug.SubMatrix(0, cols, rows, cols), nil 634 }