go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/matrix/matrix.go (about)

     1  /*
     2  
     3  Copyright (c) 2024 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package matrix
     9  
    10  import (
    11  	"bytes"
    12  	"errors"
    13  	"fmt"
    14  	"math"
    15  )
    16  
    17  const (
    18  	// DefaultEpsilon represents the minimum precision for matrix math operations.
    19  	DefaultEpsilon = 0.000001
    20  )
    21  
    22  var (
    23  	// ErrDimensionMismatch is a typical error.
    24  	ErrDimensionMismatch = errors.New("dimension mismatch")
    25  
    26  	// ErrSingularValue is a typical error.
    27  	ErrSingularValue = errors.New("singular value")
    28  )
    29  
    30  // New returns a new matrix.
    31  func New(rows, cols int, values ...float64) *Matrix {
    32  	if len(values) == 0 {
    33  		return &Matrix{
    34  			stride:   cols,
    35  			epsilon:  DefaultEpsilon,
    36  			elements: make([]float64, rows*cols),
    37  		}
    38  	}
    39  	elems := make([]float64, rows*cols)
    40  	copy(elems, values)
    41  	return &Matrix{
    42  		stride:   cols,
    43  		epsilon:  DefaultEpsilon,
    44  		elements: elems,
    45  	}
    46  }
    47  
    48  // Identity returns the identity matrix of a given order.
    49  func Identity(order int) *Matrix {
    50  	m := New(order, order)
    51  	for i := 0; i < order; i++ {
    52  		m.Set(i, i, 1)
    53  	}
    54  	return m
    55  }
    56  
    57  // Zero returns a matrix of a given size zeroed.
    58  func Zero(rows, cols int) *Matrix {
    59  	return New(rows, cols)
    60  }
    61  
    62  // Ones returns an matrix of ones.
    63  func Ones(rows, cols int) *Matrix {
    64  	ones := make([]float64, rows*cols)
    65  	for i := 0; i < (rows * cols); i++ {
    66  		ones[i] = 1
    67  	}
    68  
    69  	return &Matrix{
    70  		stride:   cols,
    71  		epsilon:  DefaultEpsilon,
    72  		elements: ones,
    73  	}
    74  }
    75  
    76  // Eye returns the eye matrix.
    77  func Eye(n int) *Matrix {
    78  	m := Zero(n, n)
    79  	for i := 0; i < len(m.elements); i += n + 1 {
    80  		m.elements[i] = 1
    81  	}
    82  	return m
    83  }
    84  
    85  // NewFromArrays creates a matrix from a jagged array set.
    86  func NewFromArrays(a [][]float64) *Matrix {
    87  	rows := len(a)
    88  	if rows == 0 {
    89  		return nil
    90  	}
    91  	cols := len(a[0])
    92  	m := New(rows, cols)
    93  	for row := 0; row < rows; row++ {
    94  		for col := 0; col < cols; col++ {
    95  			m.Set(row, col, a[row][col])
    96  		}
    97  	}
    98  	return m
    99  }
   100  
   101  // Matrix represents a 2d dense array of floats.
   102  type Matrix struct {
   103  	epsilon  float64
   104  	elements []float64
   105  	stride   int
   106  }
   107  
   108  // String returns a string representation of the matrix.
   109  func (m *Matrix) String() string {
   110  	buffer := bytes.NewBuffer(nil)
   111  	rows, cols := m.Size()
   112  
   113  	for row := 0; row < rows; row++ {
   114  		for col := 0; col < cols; col++ {
   115  			buffer.WriteString(f64s(m.Get(row, col)))
   116  			buffer.WriteRune(' ')
   117  		}
   118  		buffer.WriteRune('\n')
   119  	}
   120  	return buffer.String()
   121  }
   122  
   123  // Epsilon returns the maximum precision for math operations.
   124  func (m *Matrix) Epsilon() float64 {
   125  	return m.epsilon
   126  }
   127  
   128  // WithEpsilon sets the epsilon on the matrix and returns a reference to the matrix.
   129  func (m *Matrix) WithEpsilon(epsilon float64) *Matrix {
   130  	m.epsilon = epsilon
   131  	return m
   132  }
   133  
   134  // Each applies the action to each element of the matrix in
   135  // rows => cols order.
   136  func (m *Matrix) Each(action func(row, col int, value float64)) {
   137  	rows, cols := m.Size()
   138  	for row := 0; row < rows; row++ {
   139  		for col := 0; col < cols; col++ {
   140  			action(row, col, m.Get(row, col))
   141  		}
   142  	}
   143  }
   144  
   145  // Round rounds all the values in a matrix to it epsilon,
   146  // returning a reference to the original
   147  func (m *Matrix) Round() *Matrix {
   148  	rows, cols := m.Size()
   149  	for row := 0; row < rows; row++ {
   150  		for col := 0; col < cols; col++ {
   151  			m.Set(row, col, roundToEpsilon(m.Get(row, col), m.epsilon))
   152  		}
   153  	}
   154  	return m
   155  }
   156  
   157  // Arrays returns the matrix as a two dimensional jagged array.
   158  func (m *Matrix) Arrays() [][]float64 {
   159  	rows, cols := m.Size()
   160  	a := make([][]float64, rows)
   161  
   162  	for row := 0; row < rows; row++ {
   163  		a[row] = make([]float64, cols)
   164  
   165  		for col := 0; col < cols; col++ {
   166  			a[row][col] = m.Get(row, col)
   167  		}
   168  	}
   169  	return a
   170  }
   171  
   172  // Size returns the dimensions of the matrix.
   173  func (m *Matrix) Size() (rows, cols int) {
   174  	rows = len(m.elements) / m.stride
   175  	cols = m.stride
   176  	return
   177  }
   178  
   179  // IsSquare returns if the row count is equal to the column count.
   180  func (m *Matrix) IsSquare() bool {
   181  	return m.stride == (len(m.elements) / m.stride)
   182  }
   183  
   184  // IsSymmetric returns if the matrix is symmetric about its diagonal.
   185  func (m *Matrix) IsSymmetric() bool {
   186  	rows, cols := m.Size()
   187  
   188  	if rows != cols {
   189  		return false
   190  	}
   191  
   192  	for i := 0; i < rows; i++ {
   193  		for j := 0; j < i; j++ {
   194  			if m.Get(i, j) != m.Get(j, i) {
   195  				return false
   196  			}
   197  		}
   198  	}
   199  	return true
   200  }
   201  
   202  // Get returns the element at the given row, col.
   203  func (m *Matrix) Get(row, col int) float64 {
   204  	index := (m.stride * row) + col
   205  	return m.elements[index]
   206  }
   207  
   208  // Set sets a value.
   209  func (m *Matrix) Set(row, col int, val float64) {
   210  	index := (m.stride * row) + col
   211  	m.elements[index] = val
   212  }
   213  
   214  // Col returns a column of the matrix as a vector.
   215  func (m *Matrix) Col(col int) Vector {
   216  	rows, _ := m.Size()
   217  	values := make([]float64, rows)
   218  	for row := 0; row < rows; row++ {
   219  		values[row] = m.Get(row, col)
   220  	}
   221  	return Vector(values)
   222  }
   223  
   224  // Row returns a row of the matrix as a vector.
   225  func (m *Matrix) Row(row int) Vector {
   226  	_, cols := m.Size()
   227  	values := make([]float64, cols)
   228  	for col := 0; col < cols; col++ {
   229  		values[col] = m.Get(row, col)
   230  	}
   231  	return Vector(values)
   232  }
   233  
   234  // SubMatrix returns a sub matrix from a given outer matrix.
   235  func (m *Matrix) SubMatrix(i, j, rows, cols int) *Matrix {
   236  	return &Matrix{
   237  		elements: m.elements[i*m.stride+j : i*m.stride+j+(rows-1)*m.stride+cols],
   238  		stride:   m.stride,
   239  		epsilon:  m.epsilon,
   240  	}
   241  }
   242  
   243  // ScaleRow applies a scale to an entire row.
   244  func (m *Matrix) ScaleRow(row int, scale float64) {
   245  	startIndex := row * m.stride
   246  	for i := startIndex; i < m.stride; i++ {
   247  		m.elements[i] = m.elements[i] * scale
   248  	}
   249  }
   250  
   251  func (m *Matrix) scaleAddRow(rd int, rs int, f float64) {
   252  	indexd := rd * m.stride
   253  	indexs := rs * m.stride
   254  	for col := 0; col < m.stride; col++ {
   255  		m.elements[indexd] += f * m.elements[indexs]
   256  		indexd++
   257  		indexs++
   258  	}
   259  }
   260  
   261  // SwapRows swaps a row in the matrix in place.
   262  func (m *Matrix) SwapRows(i, j int) {
   263  	var vi, vj float64
   264  	for col := 0; col < m.stride; col++ {
   265  		vi = m.Get(i, col)
   266  		vj = m.Get(j, col)
   267  		m.Set(i, col, vj)
   268  		m.Set(j, col, vi)
   269  	}
   270  }
   271  
   272  // Augment concatenates two matrices about the horizontal.
   273  func (m *Matrix) Augment(m2 *Matrix) (*Matrix, error) {
   274  	mr, mc := m.Size()
   275  	m2r, m2c := m2.Size()
   276  	if mr != m2r {
   277  		return nil, ErrDimensionMismatch
   278  	}
   279  
   280  	m3 := Zero(mr, mc+m2c)
   281  	for row := 0; row < mr; row++ {
   282  		for col := 0; col < mc; col++ {
   283  			m3.Set(row, col, m.Get(row, col))
   284  		}
   285  		for col := 0; col < m2c; col++ {
   286  			m3.Set(row, mc+col, m2.Get(row, col))
   287  		}
   288  	}
   289  	return m3, nil
   290  }
   291  
   292  // Copy returns a duplicate of a given matrix.
   293  func (m *Matrix) Copy() *Matrix {
   294  	m2 := &Matrix{stride: m.stride, epsilon: m.epsilon, elements: make([]float64, len(m.elements))}
   295  	copy(m2.elements, m.elements)
   296  	return m2
   297  }
   298  
   299  // DiagonalVector returns a vector from the diagonal of a matrix.
   300  func (m *Matrix) DiagonalVector() Vector {
   301  	rows, cols := m.Size()
   302  	rank := minInt(rows, cols)
   303  	values := make([]float64, rank)
   304  
   305  	for index := 0; index < rank; index++ {
   306  		values[index] = m.Get(index, index)
   307  	}
   308  	return Vector(values)
   309  }
   310  
   311  // Diagonal returns a matrix from the diagonal of a matrix.
   312  func (m *Matrix) Diagonal() *Matrix {
   313  	rows, cols := m.Size()
   314  	rank := minInt(rows, cols)
   315  	m2 := New(rank, rank)
   316  
   317  	for index := 0; index < rank; index++ {
   318  		m2.Set(index, index, m.Get(index, index))
   319  	}
   320  	return m2
   321  }
   322  
   323  // Equals returns if a matrix equals another matrix.
   324  func (m *Matrix) Equals(other *Matrix) bool {
   325  	if m != nil && other == nil {
   326  		return false
   327  	} else if other == nil {
   328  		return true
   329  	}
   330  
   331  	if m.stride != other.stride {
   332  		return false
   333  	}
   334  
   335  	msize := len(m.elements)
   336  	m2size := len(other.elements)
   337  
   338  	if msize != m2size {
   339  		return false
   340  	}
   341  
   342  	for i := 0; i < msize; i++ {
   343  		if m.elements[i] != other.elements[i] {
   344  			return false
   345  		}
   346  	}
   347  	return true
   348  }
   349  
   350  // EqualsEpsilon returns if a matrix element-wise in epsilon to another matrix.
   351  func (m *Matrix) EqualsEpsilon(other *Matrix) bool {
   352  	if m != nil && other == nil {
   353  		return false
   354  	} else if other == nil {
   355  		return true
   356  	}
   357  
   358  	if m.stride != other.stride {
   359  		return false
   360  	}
   361  
   362  	msize := len(m.elements)
   363  	m2size := len(other.elements)
   364  
   365  	if msize != m2size {
   366  		return false
   367  	}
   368  
   369  	for i := 0; i < msize; i++ {
   370  		if !inEpsilon(m.elements[i], other.elements[i], m.epsilon) {
   371  			return false
   372  		}
   373  	}
   374  	return true
   375  }
   376  
   377  func inEpsilon(a, b, epsilon float64) bool {
   378  	if a > b {
   379  		return (a - b) < epsilon
   380  	}
   381  	return (b - a) < epsilon
   382  }
   383  
   384  // L returns the matrix with zeros below the diagonal.
   385  func (m *Matrix) L() *Matrix {
   386  	rows, cols := m.Size()
   387  	m2 := New(rows, cols)
   388  	for row := 0; row < rows; row++ {
   389  		for col := row; col < cols; col++ {
   390  			m2.Set(row, col, m.Get(row, col))
   391  		}
   392  	}
   393  	return m2
   394  }
   395  
   396  // U returns the matrix with zeros above the diagonal.
   397  // Does not include the diagonal.
   398  func (m *Matrix) U() *Matrix {
   399  	rows, cols := m.Size()
   400  	m2 := New(rows, cols)
   401  	for row := 0; row < rows; row++ {
   402  		for col := 0; col < row && col < cols; col++ {
   403  			m2.Set(row, col, m.Get(row, col))
   404  		}
   405  	}
   406  	return m2
   407  }
   408  
   409  // math operations
   410  
   411  // Multiply multiplies two matrices.
   412  func (m *Matrix) Multiply(m2 *Matrix) (m3 *Matrix, err error) {
   413  	if m.stride*m2.stride != len(m2.elements) {
   414  		return nil, ErrDimensionMismatch
   415  	}
   416  
   417  	m3 = &Matrix{epsilon: m.epsilon, stride: m2.stride, elements: make([]float64, (len(m.elements)/m.stride)*m2.stride)}
   418  	for m1c0, m3x := 0, 0; m1c0 < len(m.elements); m1c0 += m.stride {
   419  		for m2r0 := 0; m2r0 < m2.stride; m2r0++ {
   420  			for m1x, m2x := m1c0, m2r0; m2x < len(m2.elements); m2x += m2.stride {
   421  				m3.elements[m3x] += m.elements[m1x] * m2.elements[m2x]
   422  				m1x++
   423  			}
   424  			m3x++
   425  		}
   426  	}
   427  	return
   428  }
   429  
   430  // Pivotize does something i'm not sure what.
   431  func (m *Matrix) Pivotize() *Matrix {
   432  	pv := make([]int, m.stride)
   433  
   434  	for i := range pv {
   435  		pv[i] = i
   436  	}
   437  
   438  	for j, dx := 0, 0; j < m.stride; j++ {
   439  		row := j
   440  		max := m.elements[dx]
   441  		for i, ixcj := j, dx; i < m.stride; i++ {
   442  			if m.elements[ixcj] > max {
   443  				max = m.elements[ixcj]
   444  				row = i
   445  			}
   446  			ixcj += m.stride
   447  		}
   448  		if j != row {
   449  			pv[row], pv[j] = pv[j], pv[row]
   450  		}
   451  		dx += m.stride + 1
   452  	}
   453  	p := Zero(m.stride, m.stride)
   454  	for r, c := range pv {
   455  		p.elements[r*m.stride+c] = 1
   456  	}
   457  	return p
   458  }
   459  
   460  // Times returns the product of a matrix and another.
   461  func (m *Matrix) Times(m2 *Matrix) (*Matrix, error) {
   462  	mr, mc := m.Size()
   463  	m2r, m2c := m2.Size()
   464  
   465  	if mc != m2r {
   466  		return nil, fmt.Errorf("cannot multiply (%dx%d) and (%dx%d)", mr, mc, m2r, m2c)
   467  		//return nil, ErrDimensionMismatch
   468  	}
   469  
   470  	c := Zero(mr, m2c)
   471  
   472  	for i := 0; i < mr; i++ {
   473  		sums := c.elements[i*c.stride : (i+1)*c.stride]
   474  		for k, a := range m.elements[i*m.stride : i*m.stride+m.stride] {
   475  			for j, b := range m2.elements[k*m2.stride : k*m2.stride+m2.stride] {
   476  				sums[j] += a * b
   477  			}
   478  		}
   479  	}
   480  
   481  	return c, nil
   482  }
   483  
   484  // Decompositions
   485  
   486  // LU performs the LU decomposition.
   487  func (m *Matrix) LU() (l, u, p *Matrix) {
   488  	l = Zero(m.stride, m.stride)
   489  	u = Zero(m.stride, m.stride)
   490  	p = m.Pivotize()
   491  	m, _ = p.Multiply(m)
   492  	for j, jxc0 := 0, 0; j < m.stride; j++ {
   493  		l.elements[jxc0+j] = 1
   494  		for i, ixc0 := 0, 0; ixc0 <= jxc0; i++ {
   495  			sum := 0.
   496  			for k, kxcj := 0, j; k < i; k++ {
   497  				sum += u.elements[kxcj] * l.elements[ixc0+k]
   498  				kxcj += m.stride
   499  			}
   500  			u.elements[ixc0+j] = m.elements[ixc0+j] - sum
   501  			ixc0 += m.stride
   502  		}
   503  		for ixc0 := jxc0; ixc0 < len(m.elements); ixc0 += m.stride {
   504  			sum := 0.
   505  			for k, kxcj := 0, j; k < j; k++ {
   506  				sum += u.elements[kxcj] * l.elements[ixc0+k]
   507  				kxcj += m.stride
   508  			}
   509  			l.elements[ixc0+j] = (m.elements[ixc0+j] - sum) / u.elements[jxc0+j]
   510  		}
   511  		jxc0 += m.stride
   512  	}
   513  	return
   514  }
   515  
   516  // QR performs the qr decomposition.
   517  func (m *Matrix) QR() (q, r *Matrix) {
   518  	defer func() {
   519  		q = q.Round()
   520  		r = r.Round()
   521  	}()
   522  
   523  	rows, cols := m.Size()
   524  	qr := m.Copy()
   525  	q = New(rows, cols)
   526  	r = New(rows, cols)
   527  
   528  	var i, j, k int
   529  	var norm, s float64
   530  
   531  	for k = 0; k < cols; k++ {
   532  		norm = 0
   533  		for i = k; i < rows; i++ {
   534  			norm = math.Hypot(norm, qr.Get(i, k))
   535  		}
   536  
   537  		if norm != 0 {
   538  			if qr.Get(k, k) < 0 {
   539  				norm = -norm
   540  			}
   541  
   542  			for i = k; i < rows; i++ {
   543  				qr.Set(i, k, qr.Get(i, k)/norm)
   544  			}
   545  			qr.Set(k, k, qr.Get(k, k)+1.0)
   546  
   547  			for j = k + 1; j < cols; j++ {
   548  				s = 0
   549  				for i = k; i < rows; i++ {
   550  					s += qr.Get(i, k) * qr.Get(i, j)
   551  				}
   552  				s = -s / qr.Get(k, k)
   553  				for i = k; i < rows; i++ {
   554  					qr.Set(i, j, qr.Get(i, j)+s*qr.Get(i, k))
   555  
   556  					if i < j {
   557  						r.Set(i, j, qr.Get(i, j))
   558  					}
   559  				}
   560  
   561  			}
   562  		}
   563  
   564  		r.Set(k, k, -norm)
   565  
   566  	}
   567  
   568  	//Q Matrix:
   569  	// this assignment is ineffectual?
   570  	// i, j, k = 0, 0, 0
   571  
   572  	for k = cols - 1; k >= 0; k-- {
   573  		q.Set(k, k, 1.0)
   574  		for j = k; j < cols; j++ {
   575  			if qr.Get(k, k) != 0 {
   576  				s = 0
   577  				for i = k; i < rows; i++ {
   578  					s += qr.Get(i, k) * q.Get(i, j)
   579  				}
   580  				s = -s / qr.Get(k, k)
   581  				for i = k; i < rows; i++ {
   582  					q.Set(i, j, q.Get(i, j)+s*qr.Get(i, k))
   583  				}
   584  			}
   585  		}
   586  	}
   587  
   588  	return
   589  }
   590  
   591  // Transpose flips a matrix about its diagonal, returning a new copy.
   592  func (m *Matrix) Transpose() *Matrix {
   593  	rows, cols := m.Size()
   594  	m2 := Zero(cols, rows)
   595  	for i := 0; i < rows; i++ {
   596  		for j := 0; j < cols; j++ {
   597  			m2.Set(j, i, m.Get(i, j))
   598  		}
   599  	}
   600  	return m2
   601  }
   602  
   603  // Inverse returns a matrix such that M*I==1.
   604  func (m *Matrix) Inverse() (*Matrix, error) {
   605  	if !m.IsSymmetric() {
   606  		return nil, ErrDimensionMismatch
   607  	}
   608  
   609  	rows, cols := m.Size()
   610  
   611  	aug, _ := m.Augment(Eye(rows))
   612  	for i := 0; i < rows; i++ {
   613  		j := i
   614  		for k := i; k < rows; k++ {
   615  			if math.Abs(aug.Get(k, i)) > math.Abs(aug.Get(j, i)) {
   616  				j = k
   617  			}
   618  		}
   619  		if j != i {
   620  			aug.SwapRows(i, j)
   621  		}
   622  		if aug.Get(i, i) == 0 {
   623  			return nil, ErrSingularValue
   624  		}
   625  		aug.ScaleRow(i, 1.0/aug.Get(i, i))
   626  		for k := 0; k < rows; k++ {
   627  			if k == i {
   628  				continue
   629  			}
   630  			aug.scaleAddRow(k, i, -aug.Get(k, i))
   631  		}
   632  	}
   633  	return aug.SubMatrix(0, cols, rows, cols), nil
   634  }