gonum.org/v1/gonum@v0.14.0/mat/cdense_test.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package mat
     6  
     7  import (
     8  	"math/cmplx"
     9  	"testing"
    10  
    11  	"golang.org/x/exp/rand"
    12  )
    13  
    14  func TestCDenseNewAtSet(t *testing.T) {
    15  	t.Parallel()
    16  	for cas, test := range []struct {
    17  		a          []complex128
    18  		rows, cols int
    19  	}{
    20  		{
    21  			a: []complex128{0, 0, 0,
    22  				0, 0, 0,
    23  				0, 0, 0},
    24  			rows: 3,
    25  			cols: 3,
    26  		},
    27  	} {
    28  		aCopy := make([]complex128, len(test.a))
    29  		copy(aCopy, test.a)
    30  		mZero := NewCDense(test.rows, test.cols, nil)
    31  		rows, cols := mZero.Dims()
    32  		if rows != test.rows {
    33  			t.Errorf("unexpected number of rows for test %d: got: %d want: %d", cas, rows, test.rows)
    34  		}
    35  		if cols != test.cols {
    36  			t.Errorf("unexpected number of cols for test %d: got: %d want: %d", cas, cols, test.cols)
    37  		}
    38  		m := NewCDense(test.rows, test.cols, aCopy)
    39  		for i := 0; i < test.rows; i++ {
    40  			for j := 0; j < test.cols; j++ {
    41  				v := m.At(i, j)
    42  				idx := i*test.rows + j
    43  				if v != test.a[idx] {
    44  					t.Errorf("unexpected get value for test %d at i=%d, j=%d: got: %v, want: %v", cas, i, j, v, test.a[idx])
    45  				}
    46  				add := complex(float64(i+1), float64(j+1))
    47  				m.Set(i, j, v+add)
    48  				if m.At(i, j) != test.a[idx]+add {
    49  					t.Errorf("unexpected set value for test %d at i=%d, j=%d: got: %v, want: %v", cas, i, j, v, test.a[idx]+add)
    50  				}
    51  			}
    52  		}
    53  	}
    54  }
    55  
    56  func TestCDenseConjElem(t *testing.T) {
    57  	t.Parallel()
    58  
    59  	rnd := rand.New(rand.NewSource(1))
    60  
    61  	for r := 1; r <= 8; r++ {
    62  		for c := 1; c <= 8; c++ {
    63  			const (
    64  				empty = iota
    65  				fit
    66  				sliced
    67  				self
    68  			)
    69  			for _, dst := range []int{empty, fit, sliced, self} {
    70  				const (
    71  					noTrans = iota
    72  					trans
    73  					conjTrans
    74  					bothHT
    75  					bothTH
    76  				)
    77  				for _, src := range []int{noTrans, trans, conjTrans, bothHT, bothTH} {
    78  					d := NewCDense(r, c, nil)
    79  					for i := 0; i < r; i++ {
    80  						for j := 0; j < c; j++ {
    81  							d.Set(i, j, complex(rnd.NormFloat64(), rnd.NormFloat64()))
    82  						}
    83  					}
    84  
    85  					var (
    86  						a  CMatrix
    87  						op string
    88  					)
    89  					switch src {
    90  					case noTrans:
    91  						a = d
    92  					case trans:
    93  						r, c = c, r
    94  						a = d.T()
    95  						op = ".T"
    96  					case conjTrans:
    97  						r, c = c, r
    98  						a = d.H()
    99  						op = ".H"
   100  					case bothHT:
   101  						a = d.H().T()
   102  						op = ".H.T"
   103  					case bothTH:
   104  						a = d.T().H()
   105  						op = ".T.H"
   106  					default:
   107  						panic("invalid src op")
   108  					}
   109  					aCopy := NewCDense(r, c, nil)
   110  					aCopy.Copy(a)
   111  
   112  					var got *CDense
   113  					switch dst {
   114  					case empty:
   115  						got = &CDense{}
   116  					case fit:
   117  						got = NewCDense(r, c, nil)
   118  					case sliced:
   119  						got = NewCDense(r*2, c*2, nil).Slice(1, r+1, 1, c+1).(*CDense)
   120  					case self:
   121  						if r != c && (src == conjTrans || src == trans) {
   122  							continue
   123  						}
   124  						got = d
   125  					default:
   126  						panic("invalid dst size")
   127  					}
   128  
   129  					got.Conj(a)
   130  
   131  					for i := 0; i < r; i++ {
   132  						for j := 0; j < c; j++ {
   133  							if got.At(i, j) != cmplx.Conj(aCopy.At(i, j)) {
   134  								t.Errorf("unexpected results a%s[%d, %d] for r=%d c=%d %v != %v",
   135  									op, i, j, r, c, got.At(i, j), cmplx.Conj(a.At(i, j)),
   136  								)
   137  							}
   138  						}
   139  					}
   140  				}
   141  			}
   142  		}
   143  	}
   144  }
   145  
   146  func TestCDenseGrow(t *testing.T) {
   147  	t.Parallel()
   148  	m := &CDense{}
   149  	m = m.Grow(10, 10).(*CDense)
   150  	rows, cols := m.Dims()
   151  	capRows, capCols := m.Caps()
   152  	if rows != 10 {
   153  		t.Errorf("unexpected value for rows: got: %d want: 10", rows)
   154  	}
   155  	if cols != 10 {
   156  		t.Errorf("unexpected value for cols: got: %d want: 10", cols)
   157  	}
   158  	if capRows != 10 {
   159  		t.Errorf("unexpected value for capRows: got: %d want: 10", capRows)
   160  	}
   161  	if capCols != 10 {
   162  		t.Errorf("unexpected value for capCols: got: %d want: 10", capCols)
   163  	}
   164  
   165  	// Test grow within caps is in-place.
   166  	m.Set(1, 1, 1)
   167  	v := m.Slice(1, 5, 1, 5).(*CDense)
   168  	if v.At(0, 0) != m.At(1, 1) {
   169  		t.Errorf("unexpected viewed element value: got: %v want: %v", v.At(0, 0), m.At(1, 1))
   170  	}
   171  	v = v.Grow(5, 5).(*CDense)
   172  	if !CEqual(v, m.Slice(1, 10, 1, 10)) {
   173  		t.Error("unexpected view value after grow")
   174  	}
   175  
   176  	// Test grow bigger than caps copies.
   177  	v = v.Grow(5, 5).(*CDense)
   178  	if !CEqual(v.Slice(0, 9, 0, 9), m.Slice(1, 10, 1, 10)) {
   179  		t.Error("unexpected mismatched common view value after grow")
   180  	}
   181  	v.Set(0, 0, 0)
   182  	if CEqual(v.Slice(0, 9, 0, 9), m.Slice(1, 10, 1, 10)) {
   183  		t.Error("unexpected matching view value after grow past capacity")
   184  	}
   185  
   186  	// Test grow uses existing data slice when matrix is zero size.
   187  	v.Reset()
   188  	p, l := &v.mat.Data[:1][0], cap(v.mat.Data)
   189  	*p = 1 // This element is at position (-1, -1) relative to v and so should not be visible.
   190  	v = v.Grow(5, 5).(*CDense)
   191  	if &v.mat.Data[:1][0] != p {
   192  		t.Error("grow unexpectedly copied slice within cap limit")
   193  	}
   194  	if cap(v.mat.Data) != l {
   195  		t.Errorf("unexpected change in data slice capacity: got: %d want: %d", cap(v.mat.Data), l)
   196  	}
   197  	if v.At(0, 0) != 0 {
   198  		t.Errorf("unexpected value for At(0, 0): got: %v want: 0", v.At(0, 0))
   199  	}
   200  }