gonum.org/v1/gonum@v0.14.0/mat/list_test.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //lint:file-ignore U1000 A number of functions are here that may be used in future.
     6  
     7  package mat
     8  
     9  import (
    10  	"fmt"
    11  	"math"
    12  	"reflect"
    13  	"testing"
    14  
    15  	"golang.org/x/exp/rand"
    16  
    17  	"gonum.org/v1/gonum/blas"
    18  	"gonum.org/v1/gonum/blas/blas64"
    19  	"gonum.org/v1/gonum/floats"
    20  	"gonum.org/v1/gonum/floats/scalar"
    21  )
    22  
    23  // legalSizeSameRectangular returns whether the two matrices have the same rectangular shape.
    24  func legalSizeSameRectangular(ar, ac, br, bc int) bool {
    25  	if ar != br {
    26  		return false
    27  	}
    28  	if ac != bc {
    29  		return false
    30  	}
    31  	return true
    32  }
    33  
    34  // legalSizeSameSquare returns whether the two matrices have the same square shape.
    35  func legalSizeSameSquare(ar, ac, br, bc int) bool {
    36  	if ar != br {
    37  		return false
    38  	}
    39  	if ac != bc {
    40  		return false
    41  	}
    42  	if ar != ac {
    43  		return false
    44  	}
    45  	return true
    46  }
    47  
    48  // legalSizeSameHeight returns whether the two matrices have the same number of rows.
    49  func legalSizeSameHeight(ar, _, br, _ int) bool {
    50  	return ar == br
    51  }
    52  
    53  // legalSizeSameWidth returns whether the two matrices have the same number of columns.
    54  func legalSizeSameWidth(_, ac, _, bc int) bool {
    55  	return ac == bc
    56  }
    57  
    58  // legalSizeSolve returns whether the two matrices can be used in a linear solve.
    59  func legalSizeSolve(ar, ac, br, bc int) bool {
    60  	return ar == br
    61  }
    62  
    63  // legalSizeSameVec returns whether the two matrices are column vectors.
    64  func legalSizeVector(_, ac, _, bc int) bool {
    65  	return ac == 1 && bc == 1
    66  }
    67  
    68  // legalSizeSameVec returns whether the two matrices are column vectors of the
    69  // same dimension.
    70  func legalSizeSameVec(ar, ac, br, bc int) bool {
    71  	return ac == 1 && bc == 1 && ar == br
    72  }
    73  
    74  // isAnySize returns true for all matrix sizes.
    75  func isAnySize(ar, ac int) bool {
    76  	return true
    77  }
    78  
    79  // isAnySize2 returns true for all matrix sizes.
    80  func isAnySize2(ar, ac, br, bc int) bool {
    81  	return true
    82  }
    83  
    84  // isAnyColumnVector returns true for any column vector sizes.
    85  func isAnyColumnVector(ar, ac int) bool {
    86  	return ac == 1
    87  }
    88  
    89  // isSquare returns whether the input matrix is square.
    90  func isSquare(r, c int) bool {
    91  	return r == c
    92  }
    93  
    94  // sameAnswerFloat returns whether the two inputs are both NaN or are equal.
    95  func sameAnswerFloat(a, b interface{}) bool {
    96  	if math.IsNaN(a.(float64)) {
    97  		return math.IsNaN(b.(float64))
    98  	}
    99  	return a.(float64) == b.(float64)
   100  }
   101  
   102  // sameAnswerFloatApproxTol returns a function that determines whether its two
   103  // inputs are both NaN or within tol of each other.
   104  func sameAnswerFloatApproxTol(tol float64) func(a, b interface{}) bool {
   105  	return func(a, b interface{}) bool {
   106  		if math.IsNaN(a.(float64)) {
   107  			return math.IsNaN(b.(float64))
   108  		}
   109  		return scalar.EqualWithinAbsOrRel(a.(float64), b.(float64), tol, tol)
   110  	}
   111  }
   112  
   113  func sameAnswerF64SliceOfSlice(a, b interface{}) bool {
   114  	for i, v := range a.([][]float64) {
   115  		if same := floats.Same(v, b.([][]float64)[i]); !same {
   116  			return false
   117  		}
   118  	}
   119  	return true
   120  }
   121  
   122  // sameAnswerBool returns whether the two inputs have the same value.
   123  func sameAnswerBool(a, b interface{}) bool {
   124  	return a.(bool) == b.(bool)
   125  }
   126  
   127  // isAnyType returns true for all Matrix types.
   128  func isAnyType(Matrix) bool {
   129  	return true
   130  }
   131  
   132  // legalTypesAll returns true for all Matrix types.
   133  func legalTypesAll(a, b Matrix) bool {
   134  	return true
   135  }
   136  
   137  // legalTypeSym returns whether a is a Symmetric.
   138  func legalTypeSym(a Matrix) bool {
   139  	_, ok := a.(Symmetric)
   140  	return ok
   141  }
   142  
   143  // legalTypeTri returns whether a is a Triangular.
   144  func legalTypeTri(a Matrix) bool {
   145  	_, ok := a.(Triangular)
   146  	return ok
   147  }
   148  
   149  // legalTypeTriLower returns whether a is a Triangular with kind == Lower.
   150  func legalTypeTriLower(a Matrix) bool {
   151  	t, ok := a.(Triangular)
   152  	if !ok {
   153  		return false
   154  	}
   155  	_, kind := t.Triangle()
   156  	return kind == Lower
   157  }
   158  
   159  // legalTypeTriUpper returns whether a is a Triangular with kind == Upper.
   160  func legalTypeTriUpper(a Matrix) bool {
   161  	t, ok := a.(Triangular)
   162  	if !ok {
   163  		return false
   164  	}
   165  	_, kind := t.Triangle()
   166  	return kind == Upper
   167  }
   168  
   169  // legalTypesSym returns whether both input arguments are Symmetric.
   170  func legalTypesSym(a, b Matrix) bool {
   171  	if _, ok := a.(Symmetric); !ok {
   172  		return false
   173  	}
   174  	if _, ok := b.(Symmetric); !ok {
   175  		return false
   176  	}
   177  	return true
   178  }
   179  
   180  // legalTypeVector returns whether v is a Vector.
   181  func legalTypeVector(v Matrix) bool {
   182  	_, ok := v.(Vector)
   183  	return ok
   184  }
   185  
   186  // legalTypeVec returns whether v is a *VecDense.
   187  func legalTypeVecDense(v Matrix) bool {
   188  	_, ok := v.(*VecDense)
   189  	return ok
   190  }
   191  
   192  // legalTypesVectorVector returns whether both inputs are Vector
   193  func legalTypesVectorVector(a, b Matrix) bool {
   194  	if _, ok := a.(Vector); !ok {
   195  		return false
   196  	}
   197  	if _, ok := b.(Vector); !ok {
   198  		return false
   199  	}
   200  	return true
   201  }
   202  
   203  // legalTypesVecDenseVecDense returns whether both inputs are *VecDense.
   204  func legalTypesVecDenseVecDense(a, b Matrix) bool {
   205  	if _, ok := a.(*VecDense); !ok {
   206  		return false
   207  	}
   208  	if _, ok := b.(*VecDense); !ok {
   209  		return false
   210  	}
   211  	return true
   212  }
   213  
   214  // legalTypesMatrixVector returns whether the first input is an arbitrary Matrix
   215  // and the second input is a Vector.
   216  func legalTypesMatrixVector(a, b Matrix) bool {
   217  	_, ok := b.(Vector)
   218  	return ok
   219  }
   220  
   221  // legalTypesMatrixVecDense returns whether the first input is an arbitrary Matrix
   222  // and the second input is a *VecDense.
   223  func legalTypesMatrixVecDense(a, b Matrix) bool {
   224  	_, ok := b.(*VecDense)
   225  	return ok
   226  }
   227  
   228  // legalDims returns whether {m,n} is a valid dimension of the given matrix type.
   229  func legalDims(a Matrix, m, n int) bool {
   230  	switch t := a.(type) {
   231  	default:
   232  		panic("legal dims type not coded")
   233  	case Untransposer:
   234  		return legalDims(t.Untranspose(), n, m)
   235  	case *Dense, *basicMatrix, *BandDense, *basicBanded:
   236  		if m < 0 || n < 0 {
   237  			return false
   238  		}
   239  		return true
   240  	case *SymDense, *TriDense, *basicSymmetric, *basicTriangular,
   241  		*SymBandDense, *basicSymBanded, *TriBandDense, *basicTriBanded,
   242  		*basicDiagonal, *DiagDense, *Tridiag:
   243  		if m < 0 || n < 0 || m != n {
   244  			return false
   245  		}
   246  		return true
   247  	case *VecDense, *basicVector:
   248  		if m < 0 || n < 0 {
   249  			return false
   250  		}
   251  		return n == 1
   252  	}
   253  }
   254  
   255  // returnAs returns the matrix a with the type of t. Used for making a concrete
   256  // type and changing to the basic form.
   257  func returnAs(a, t Matrix) Matrix {
   258  	switch mat := a.(type) {
   259  	default:
   260  		panic("unknown type for a")
   261  	case *Dense:
   262  		switch t.(type) {
   263  		default:
   264  			panic("bad type")
   265  		case *Dense:
   266  			return mat
   267  		case *basicMatrix:
   268  			return asBasicMatrix(mat)
   269  		}
   270  	case *SymDense:
   271  		switch t.(type) {
   272  		default:
   273  			panic("bad type")
   274  		case *SymDense:
   275  			return mat
   276  		case *basicSymmetric:
   277  			return asBasicSymmetric(mat)
   278  		}
   279  	case *TriDense:
   280  		switch t.(type) {
   281  		default:
   282  			panic("bad type")
   283  		case *TriDense:
   284  			return mat
   285  		case *basicTriangular:
   286  			return asBasicTriangular(mat)
   287  		}
   288  	case *BandDense:
   289  		switch t.(type) {
   290  		default:
   291  			panic("bad type")
   292  		case *BandDense:
   293  			return mat
   294  		case *basicBanded:
   295  			return asBasicBanded(mat)
   296  		}
   297  	case *SymBandDense:
   298  		switch t.(type) {
   299  		default:
   300  			panic("bad type")
   301  		case *SymBandDense:
   302  			return mat
   303  		case *basicSymBanded:
   304  			return asBasicSymBanded(mat)
   305  		}
   306  	case *TriBandDense:
   307  		switch t.(type) {
   308  		default:
   309  			panic("bad type")
   310  		case *TriBandDense:
   311  			return mat
   312  		case *basicTriBanded:
   313  			return asBasicTriBanded(mat)
   314  		}
   315  	case *DiagDense:
   316  		switch t.(type) {
   317  		default:
   318  			panic("bad type")
   319  		case *DiagDense:
   320  			return mat
   321  		case *basicDiagonal:
   322  			return asBasicDiagonal(mat)
   323  		}
   324  	case *Tridiag:
   325  		switch t.(type) {
   326  		default:
   327  			panic("bad type")
   328  		case *Tridiag:
   329  			return mat
   330  		}
   331  	}
   332  }
   333  
   334  // retranspose returns the matrix m inside an Untransposer of the type
   335  // of a.
   336  func retranspose(a, m Matrix) Matrix {
   337  	switch a.(type) {
   338  	case TransposeTriBand:
   339  		return TransposeTriBand{m.(TriBanded)}
   340  	case TransposeBand:
   341  		return TransposeBand{m.(Banded)}
   342  	case TransposeTri:
   343  		return TransposeTri{m.(Triangular)}
   344  	case Transpose:
   345  		return Transpose{m}
   346  	case Untransposer:
   347  		panic("unknown transposer type")
   348  	default:
   349  		panic("a is not an untransposer")
   350  	}
   351  }
   352  
   353  // makeRandOf returns a new randomly filled m×n matrix of the underlying matrix type.
   354  func makeRandOf(a Matrix, m, n int, src rand.Source) Matrix {
   355  	rnd := rand.New(src)
   356  	var rMatrix Matrix
   357  	switch t := a.(type) {
   358  	default:
   359  		panic("unknown type for make rand of")
   360  	case Untransposer:
   361  		rMatrix = retranspose(a, makeRandOf(t.Untranspose(), n, m, src))
   362  	case *Dense, *basicMatrix:
   363  		var mat = &Dense{}
   364  		if m != 0 && n != 0 {
   365  			mat = NewDense(m, n, nil)
   366  		}
   367  		for i := 0; i < m; i++ {
   368  			for j := 0; j < n; j++ {
   369  				mat.Set(i, j, rnd.NormFloat64())
   370  			}
   371  		}
   372  		rMatrix = returnAs(mat, t)
   373  	case *VecDense:
   374  		if m == 0 && n == 0 {
   375  			return &VecDense{}
   376  		}
   377  		if n != 1 {
   378  			panic(fmt.Sprintf("bad vector size: m = %v, n = %v", m, n))
   379  		}
   380  		length := m
   381  		inc := 1
   382  		if t.mat.Inc != 0 {
   383  			inc = t.mat.Inc
   384  		}
   385  		mat := &VecDense{
   386  			mat: blas64.Vector{
   387  				N:    length,
   388  				Inc:  inc,
   389  				Data: make([]float64, inc*(length-1)+1),
   390  			},
   391  		}
   392  		for i := 0; i < length; i++ {
   393  			mat.SetVec(i, rnd.NormFloat64())
   394  		}
   395  		return mat
   396  	case *basicVector:
   397  		if n != 1 {
   398  			panic(fmt.Sprintf("bad vector size: m = %v, n = %v", m, n))
   399  		}
   400  		if m == 0 {
   401  			return &basicVector{}
   402  		}
   403  		mat := NewVecDense(m, nil)
   404  		for i := 0; i < m; i++ {
   405  			mat.SetVec(i, rnd.NormFloat64())
   406  		}
   407  		return asBasicVector(mat)
   408  	case *SymDense, *basicSymmetric:
   409  		if m != n {
   410  			panic("bad size")
   411  		}
   412  		mat := &SymDense{}
   413  		if n != 0 {
   414  			mat = NewSymDense(n, nil)
   415  		}
   416  		for i := 0; i < m; i++ {
   417  			for j := i; j < n; j++ {
   418  				mat.SetSym(i, j, rnd.NormFloat64())
   419  			}
   420  		}
   421  		rMatrix = returnAs(mat, t)
   422  	case *TriDense, *basicTriangular:
   423  		if m != n {
   424  			panic("bad size")
   425  		}
   426  
   427  		// This is necessary because we are making
   428  		// a triangle from the zero value, which
   429  		// always returns upper as true.
   430  		var triKind TriKind
   431  		switch t := t.(type) {
   432  		case *TriDense:
   433  			triKind = t.triKind()
   434  		case *basicTriangular:
   435  			triKind = (*TriDense)(t).triKind()
   436  		}
   437  
   438  		if n == 0 {
   439  			uplo := blas.Upper
   440  			if triKind == Lower {
   441  				uplo = blas.Lower
   442  			}
   443  			return returnAs(&TriDense{mat: blas64.Triangular{Uplo: uplo}}, t)
   444  		}
   445  
   446  		mat := NewTriDense(n, triKind, nil)
   447  		if triKind == Upper {
   448  			for i := 0; i < m; i++ {
   449  				for j := i; j < n; j++ {
   450  					mat.SetTri(i, j, rnd.NormFloat64())
   451  				}
   452  			}
   453  		} else {
   454  			for i := 0; i < m; i++ {
   455  				for j := 0; j <= i; j++ {
   456  					mat.SetTri(i, j, rnd.NormFloat64())
   457  				}
   458  			}
   459  		}
   460  		rMatrix = returnAs(mat, t)
   461  	case *BandDense, *basicBanded:
   462  		var kl, ku int
   463  		switch t := t.(type) {
   464  		case *BandDense:
   465  			kl = t.mat.KL
   466  			ku = t.mat.KU
   467  		case *basicBanded:
   468  			ku = (*BandDense)(t).mat.KU
   469  			kl = (*BandDense)(t).mat.KL
   470  		}
   471  		ku = min(ku, n-1)
   472  		kl = min(kl, m-1)
   473  		data := make([]float64, min(m, n+kl)*(kl+ku+1))
   474  		for i := range data {
   475  			data[i] = rnd.NormFloat64()
   476  		}
   477  		mat := NewBandDense(m, n, kl, ku, data)
   478  		rMatrix = returnAs(mat, t)
   479  	case *SymBandDense, *basicSymBanded:
   480  		if m != n {
   481  			panic("bad size")
   482  		}
   483  		var k int
   484  		switch t := t.(type) {
   485  		case *SymBandDense:
   486  			k = t.mat.K
   487  		case *basicSymBanded:
   488  			k = (*SymBandDense)(t).mat.K
   489  		}
   490  		k = min(k, m-1) // Special case for small sizes.
   491  		data := make([]float64, m*(k+1))
   492  		for i := range data {
   493  			data[i] = rnd.NormFloat64()
   494  		}
   495  		mat := NewSymBandDense(n, k, data)
   496  		rMatrix = returnAs(mat, t)
   497  	case *TriBandDense, *basicTriBanded:
   498  		if m != n {
   499  			panic("bad size")
   500  		}
   501  		var k int
   502  		var triKind TriKind
   503  		switch t := t.(type) {
   504  		case *TriBandDense:
   505  			k = t.mat.K
   506  			triKind = t.triKind()
   507  		case *basicTriBanded:
   508  			k = (*TriBandDense)(t).mat.K
   509  			triKind = (*TriBandDense)(t).triKind()
   510  		}
   511  		k = min(k, m-1) // Special case for small sizes.
   512  		data := make([]float64, m*(k+1))
   513  		for i := range data {
   514  			data[i] = rnd.NormFloat64()
   515  		}
   516  		mat := NewTriBandDense(n, k, triKind, data)
   517  		rMatrix = returnAs(mat, t)
   518  	case *DiagDense, *basicDiagonal:
   519  		if m != n {
   520  			panic("bad size")
   521  		}
   522  		var inc int
   523  		switch t := t.(type) {
   524  		case *DiagDense:
   525  			inc = t.mat.Inc
   526  		case *basicDiagonal:
   527  			inc = (*DiagDense)(t).mat.Inc
   528  		}
   529  		if inc == 0 {
   530  			inc = 1
   531  		}
   532  		mat := &DiagDense{
   533  			mat: blas64.Vector{
   534  				N:    n,
   535  				Inc:  inc,
   536  				Data: make([]float64, inc*(n-1)+1),
   537  			},
   538  		}
   539  		for i := 0; i < n; i++ {
   540  			mat.SetDiag(i, rnd.Float64())
   541  		}
   542  		rMatrix = returnAs(mat, t)
   543  	case *Tridiag:
   544  		if m != n {
   545  			panic("bad size")
   546  		}
   547  		mat := NewTridiag(n, nil, nil, nil)
   548  		for i := 0; i < n; i++ {
   549  			for j := max(0, i-1); j <= min(i+1, n-1); j++ {
   550  				mat.SetBand(i, j, rnd.NormFloat64())
   551  			}
   552  		}
   553  		rMatrix = returnAs(mat, t)
   554  	}
   555  	if mr, mc := rMatrix.Dims(); mr != m || mc != n {
   556  		panic(fmt.Sprintf("makeRandOf for %T returns wrong size: %d×%d != %d×%d", a, m, n, mr, mc))
   557  	}
   558  	return rMatrix
   559  }
   560  
   561  // makeNaNOf returns a new m×n matrix of the underlying matrix type filled with NaN values.
   562  func makeNaNOf(a Matrix, m, n int) Matrix {
   563  	var rMatrix Matrix
   564  	switch t := a.(type) {
   565  	default:
   566  		panic("unknown type for makeNaNOf")
   567  	case Untransposer:
   568  		rMatrix = retranspose(a, makeNaNOf(t.Untranspose(), n, m))
   569  	case *Dense, *basicMatrix:
   570  		var mat = &Dense{}
   571  		if m != 0 && n != 0 {
   572  			mat = NewDense(m, n, nil)
   573  		}
   574  		for i := 0; i < m; i++ {
   575  			for j := 0; j < n; j++ {
   576  				mat.Set(i, j, math.NaN())
   577  			}
   578  		}
   579  		rMatrix = returnAs(mat, t)
   580  	case *VecDense:
   581  		if m == 0 && n == 0 {
   582  			return &VecDense{}
   583  		}
   584  		if n != 1 {
   585  			panic(fmt.Sprintf("bad vector size: m = %v, n = %v", m, n))
   586  		}
   587  		length := m
   588  		inc := 1
   589  		if t.mat.Inc != 0 {
   590  			inc = t.mat.Inc
   591  		}
   592  		mat := &VecDense{
   593  			mat: blas64.Vector{
   594  				N:    length,
   595  				Inc:  inc,
   596  				Data: make([]float64, inc*(length-1)+1),
   597  			},
   598  		}
   599  		for i := 0; i < length; i++ {
   600  			mat.SetVec(i, math.NaN())
   601  		}
   602  		return mat
   603  	case *basicVector:
   604  		if n != 1 {
   605  			panic(fmt.Sprintf("bad vector size: m = %v, n = %v", m, n))
   606  		}
   607  		if m == 0 {
   608  			return &basicVector{}
   609  		}
   610  		mat := NewVecDense(m, nil)
   611  		for i := 0; i < m; i++ {
   612  			mat.SetVec(i, math.NaN())
   613  		}
   614  		return asBasicVector(mat)
   615  	case *SymDense, *basicSymmetric:
   616  		if m != n {
   617  			panic("bad size")
   618  		}
   619  		mat := &SymDense{}
   620  		if n != 0 {
   621  			mat = NewSymDense(n, nil)
   622  		}
   623  		for i := 0; i < m; i++ {
   624  			for j := i; j < n; j++ {
   625  				mat.SetSym(i, j, math.NaN())
   626  			}
   627  		}
   628  		rMatrix = returnAs(mat, t)
   629  	case *TriDense, *basicTriangular:
   630  		if m != n {
   631  			panic("bad size")
   632  		}
   633  
   634  		// This is necessary because we are making
   635  		// a triangle from the zero value, which
   636  		// always returns upper as true.
   637  		var triKind TriKind
   638  		switch t := t.(type) {
   639  		case *TriDense:
   640  			triKind = t.triKind()
   641  		case *basicTriangular:
   642  			triKind = (*TriDense)(t).triKind()
   643  		}
   644  
   645  		if n == 0 {
   646  			uplo := blas.Upper
   647  			if triKind == Lower {
   648  				uplo = blas.Lower
   649  			}
   650  			return returnAs(&TriDense{mat: blas64.Triangular{Uplo: uplo}}, t)
   651  		}
   652  
   653  		mat := NewTriDense(n, triKind, nil)
   654  		if triKind == Upper {
   655  			for i := 0; i < m; i++ {
   656  				for j := i; j < n; j++ {
   657  					mat.SetTri(i, j, math.NaN())
   658  				}
   659  			}
   660  		} else {
   661  			for i := 0; i < m; i++ {
   662  				for j := 0; j <= i; j++ {
   663  					mat.SetTri(i, j, math.NaN())
   664  				}
   665  			}
   666  		}
   667  		rMatrix = returnAs(mat, t)
   668  	case *BandDense, *basicBanded:
   669  		var kl, ku int
   670  		switch t := t.(type) {
   671  		case *BandDense:
   672  			kl = t.mat.KL
   673  			ku = t.mat.KU
   674  		case *basicBanded:
   675  			ku = (*BandDense)(t).mat.KU
   676  			kl = (*BandDense)(t).mat.KL
   677  		}
   678  		ku = min(ku, n-1)
   679  		kl = min(kl, m-1)
   680  		data := make([]float64, min(m, n+kl)*(kl+ku+1))
   681  		for i := range data {
   682  			data[i] = math.NaN()
   683  		}
   684  		mat := NewBandDense(m, n, kl, ku, data)
   685  		rMatrix = returnAs(mat, t)
   686  	case *SymBandDense, *basicSymBanded:
   687  		if m != n {
   688  			panic("bad size")
   689  		}
   690  		var k int
   691  		switch t := t.(type) {
   692  		case *SymBandDense:
   693  			k = t.mat.K
   694  		case *basicSymBanded:
   695  			k = (*SymBandDense)(t).mat.K
   696  		}
   697  		k = min(k, m-1) // Special case for small sizes.
   698  		data := make([]float64, m*(k+1))
   699  		for i := range data {
   700  			data[i] = math.NaN()
   701  		}
   702  		mat := NewSymBandDense(n, k, data)
   703  		rMatrix = returnAs(mat, t)
   704  	case *TriBandDense, *basicTriBanded:
   705  		if m != n {
   706  			panic("bad size")
   707  		}
   708  		var k int
   709  		var triKind TriKind
   710  		switch t := t.(type) {
   711  		case *TriBandDense:
   712  			k = t.mat.K
   713  			triKind = t.triKind()
   714  		case *basicTriBanded:
   715  			k = (*TriBandDense)(t).mat.K
   716  			triKind = (*TriBandDense)(t).triKind()
   717  		}
   718  		k = min(k, m-1) // Special case for small sizes.
   719  		data := make([]float64, m*(k+1))
   720  		for i := range data {
   721  			data[i] = math.NaN()
   722  		}
   723  		mat := NewTriBandDense(n, k, triKind, data)
   724  		rMatrix = returnAs(mat, t)
   725  	case *DiagDense, *basicDiagonal:
   726  		if m != n {
   727  			panic("bad size")
   728  		}
   729  		var inc int
   730  		switch t := t.(type) {
   731  		case *DiagDense:
   732  			inc = t.mat.Inc
   733  		case *basicDiagonal:
   734  			inc = (*DiagDense)(t).mat.Inc
   735  		}
   736  		if inc == 0 {
   737  			inc = 1
   738  		}
   739  		mat := &DiagDense{
   740  			mat: blas64.Vector{
   741  				N:    n,
   742  				Inc:  inc,
   743  				Data: make([]float64, inc*(n-1)+1),
   744  			},
   745  		}
   746  		for i := 0; i < n; i++ {
   747  			mat.SetDiag(i, math.NaN())
   748  		}
   749  		rMatrix = returnAs(mat, t)
   750  	}
   751  	if mr, mc := rMatrix.Dims(); mr != m || mc != n {
   752  		panic(fmt.Sprintf("makeNaNOf for %T returns wrong size: %d×%d != %d×%d", a, m, n, mr, mc))
   753  	}
   754  	return rMatrix
   755  }
   756  
   757  // makeCopyOf returns a copy of the matrix.
   758  func makeCopyOf(a Matrix) Matrix {
   759  	switch t := a.(type) {
   760  	default:
   761  		panic("unknown type in makeCopyOf")
   762  	case Untransposer:
   763  		return retranspose(a, makeCopyOf(t.Untranspose()))
   764  	case *Dense, *basicMatrix:
   765  		var m Dense
   766  		m.CloneFrom(a)
   767  		return returnAs(&m, t)
   768  	case *SymDense, *basicSymmetric:
   769  		n := t.(Symmetric).SymmetricDim()
   770  		m := NewSymDense(n, nil)
   771  		m.CopySym(t.(Symmetric))
   772  		return returnAs(m, t)
   773  	case *TriDense, *basicTriangular:
   774  		n, upper := t.(Triangular).Triangle()
   775  		m := NewTriDense(n, upper, nil)
   776  		if upper {
   777  			for i := 0; i < n; i++ {
   778  				for j := i; j < n; j++ {
   779  					m.SetTri(i, j, t.At(i, j))
   780  				}
   781  			}
   782  		} else {
   783  			for i := 0; i < n; i++ {
   784  				for j := 0; j <= i; j++ {
   785  					m.SetTri(i, j, t.At(i, j))
   786  				}
   787  			}
   788  		}
   789  		return returnAs(m, t)
   790  	case *BandDense, *basicBanded:
   791  		var band *BandDense
   792  		switch s := t.(type) {
   793  		case *BandDense:
   794  			band = s
   795  		case *basicBanded:
   796  			band = (*BandDense)(s)
   797  		}
   798  		m := &BandDense{
   799  			mat: blas64.Band{
   800  				Rows:   band.mat.Rows,
   801  				Cols:   band.mat.Cols,
   802  				KL:     band.mat.KL,
   803  				KU:     band.mat.KU,
   804  				Data:   make([]float64, len(band.mat.Data)),
   805  				Stride: band.mat.Stride,
   806  			},
   807  		}
   808  		copy(m.mat.Data, band.mat.Data)
   809  		return returnAs(m, t)
   810  	case *SymBandDense, *basicSymBanded:
   811  		var sym *SymBandDense
   812  		switch s := t.(type) {
   813  		case *SymBandDense:
   814  			sym = s
   815  		case *basicSymBanded:
   816  			sym = (*SymBandDense)(s)
   817  		}
   818  		m := &SymBandDense{
   819  			mat: blas64.SymmetricBand{
   820  				Uplo:   blas.Upper,
   821  				N:      sym.mat.N,
   822  				K:      sym.mat.K,
   823  				Data:   make([]float64, len(sym.mat.Data)),
   824  				Stride: sym.mat.Stride,
   825  			},
   826  		}
   827  		copy(m.mat.Data, sym.mat.Data)
   828  		return returnAs(m, t)
   829  	case *TriBandDense, *basicTriBanded:
   830  		var tri *TriBandDense
   831  		switch s := t.(type) {
   832  		case *TriBandDense:
   833  			tri = s
   834  		case *basicTriBanded:
   835  			tri = (*TriBandDense)(s)
   836  		}
   837  		m := &TriBandDense{
   838  			mat: blas64.TriangularBand{
   839  				Uplo:   tri.mat.Uplo,
   840  				Diag:   tri.mat.Diag,
   841  				N:      tri.mat.N,
   842  				K:      tri.mat.K,
   843  				Data:   make([]float64, len(tri.mat.Data)),
   844  				Stride: tri.mat.Stride,
   845  			},
   846  		}
   847  		copy(m.mat.Data, tri.mat.Data)
   848  		return returnAs(m, t)
   849  	case *VecDense:
   850  		var m VecDense
   851  		m.CloneFromVec(t)
   852  		return &m
   853  	case *basicVector:
   854  		var m VecDense
   855  		m.CloneFromVec(t)
   856  		return asBasicVector(&m)
   857  	case *DiagDense, *basicDiagonal:
   858  		var diag *DiagDense
   859  		switch s := t.(type) {
   860  		case *DiagDense:
   861  			diag = s
   862  		case *basicDiagonal:
   863  			diag = (*DiagDense)(s)
   864  		}
   865  		d := &DiagDense{
   866  			mat: blas64.Vector{N: diag.mat.N, Inc: diag.mat.Inc, Data: make([]float64, len(diag.mat.Data))},
   867  		}
   868  		copy(d.mat.Data, diag.mat.Data)
   869  		return returnAs(d, t)
   870  	case *Tridiag:
   871  		var m Tridiag
   872  		m.CloneFromTridiag(a.(*Tridiag))
   873  		return returnAs(&m, t)
   874  	}
   875  }
   876  
   877  // sameType returns true if a and b have the same underlying type.
   878  func sameType(a, b Matrix) bool {
   879  	return reflect.ValueOf(a).Type() == reflect.ValueOf(b).Type()
   880  }
   881  
   882  // maybeSame returns true if the two matrices could be represented by the same
   883  // pointer.
   884  func maybeSame(receiver, a Matrix) bool {
   885  	rr, rc := receiver.Dims()
   886  	u, trans := a.(Untransposer)
   887  	if trans {
   888  		a = u.Untranspose()
   889  	}
   890  	if !sameType(receiver, a) {
   891  		return false
   892  	}
   893  	ar, ac := a.Dims()
   894  	if rr != ar || rc != ac {
   895  		return false
   896  	}
   897  	if _, ok := a.(Triangular); ok {
   898  		// They are both triangular types. The TriType needs to match
   899  		_, aKind := a.(Triangular).Triangle()
   900  		_, rKind := receiver.(Triangular).Triangle()
   901  		if aKind != rKind {
   902  			return false
   903  		}
   904  	}
   905  	return true
   906  }
   907  
   908  // equalApprox returns whether the elements of a and b are the same to within
   909  // the tolerance. If ignoreNaN is true the test is relaxed such that NaN == NaN.
   910  func equalApprox(a, b Matrix, tol float64, ignoreNaN bool) bool {
   911  	ar, ac := a.Dims()
   912  	br, bc := b.Dims()
   913  	if ar != br {
   914  		return false
   915  	}
   916  	if ac != bc {
   917  		return false
   918  	}
   919  	for i := 0; i < ar; i++ {
   920  		for j := 0; j < ac; j++ {
   921  			if !scalar.EqualWithinAbsOrRel(a.At(i, j), b.At(i, j), tol, tol) {
   922  				if ignoreNaN && math.IsNaN(a.At(i, j)) && math.IsNaN(b.At(i, j)) {
   923  					continue
   924  				}
   925  				return false
   926  			}
   927  		}
   928  	}
   929  	return true
   930  }
   931  
   932  // equal returns true if the matrices have equal entries.
   933  func equal(a, b Matrix) bool {
   934  	ar, ac := a.Dims()
   935  	br, bc := b.Dims()
   936  	if ar != br {
   937  		return false
   938  	}
   939  	if ac != bc {
   940  		return false
   941  	}
   942  	for i := 0; i < ar; i++ {
   943  		for j := 0; j < ac; j++ {
   944  			if a.At(i, j) != b.At(i, j) {
   945  				return false
   946  			}
   947  		}
   948  	}
   949  	return true
   950  }
   951  
   952  // isDiagonal returns whether a is a diagonal matrix.
   953  func isDiagonal(a Matrix) bool {
   954  	r, c := a.Dims()
   955  	for i := 0; i < r; i++ {
   956  		for j := 0; j < c; j++ {
   957  			if a.At(i, j) != 0 && i != j {
   958  				return false
   959  			}
   960  		}
   961  	}
   962  	return true
   963  }
   964  
   965  // equalDiagonal returns whether a and b are equal on the diagonal.
   966  func equalDiagonal(a, b Matrix) bool {
   967  	ar, ac := a.Dims()
   968  	br, bc := a.Dims()
   969  	if min(ar, ac) != min(br, bc) {
   970  		return false
   971  	}
   972  	for i := 0; i < min(ar, ac); i++ {
   973  		if a.At(i, i) != b.At(i, i) {
   974  			return false
   975  		}
   976  	}
   977  	return true
   978  }
   979  
   980  // underlyingData extracts the underlying data of the matrix a.
   981  func underlyingData(a Matrix) []float64 {
   982  	switch t := a.(type) {
   983  	default:
   984  		panic("matrix type not implemented for extracting underlying data")
   985  	case Untransposer:
   986  		return underlyingData(t.Untranspose())
   987  	case *Dense:
   988  		return t.mat.Data
   989  	case *SymDense:
   990  		return t.mat.Data
   991  	case *TriDense:
   992  		return t.mat.Data
   993  	case *VecDense:
   994  		return t.mat.Data
   995  	}
   996  }
   997  
   998  // testMatrices is a list of matrix types to test.
   999  // This test relies on the fact that the implementations of Triangle do not
  1000  // corrupt the value of Uplo when they are empty. This test will fail
  1001  // if that changes (and some mechanism will need to be used to force the
  1002  // correct TriKind to be read).
  1003  var testMatrices = []Matrix{
  1004  	&Dense{},
  1005  	&basicMatrix{},
  1006  	Transpose{&Dense{}},
  1007  
  1008  	&VecDense{mat: blas64.Vector{Inc: 1}},
  1009  	&VecDense{mat: blas64.Vector{Inc: 10}},
  1010  	&basicVector{},
  1011  	Transpose{&VecDense{mat: blas64.Vector{Inc: 1}}},
  1012  	Transpose{&VecDense{mat: blas64.Vector{Inc: 10}}},
  1013  	Transpose{&basicVector{}},
  1014  
  1015  	&BandDense{mat: blas64.Band{KL: 2, KU: 1}},
  1016  	&BandDense{mat: blas64.Band{KL: 1, KU: 2}},
  1017  	Transpose{&BandDense{mat: blas64.Band{KL: 2, KU: 1}}},
  1018  	Transpose{&BandDense{mat: blas64.Band{KL: 1, KU: 2}}},
  1019  	TransposeBand{&BandDense{mat: blas64.Band{KL: 2, KU: 1}}},
  1020  	TransposeBand{&BandDense{mat: blas64.Band{KL: 1, KU: 2}}},
  1021  
  1022  	&SymDense{},
  1023  	&basicSymmetric{},
  1024  	Transpose{&basicSymmetric{}},
  1025  
  1026  	&TriDense{mat: blas64.Triangular{Uplo: blas.Upper}},
  1027  	&TriDense{mat: blas64.Triangular{Uplo: blas.Lower}},
  1028  	&basicTriangular{mat: blas64.Triangular{Uplo: blas.Upper}},
  1029  	&basicTriangular{mat: blas64.Triangular{Uplo: blas.Lower}},
  1030  	Transpose{&TriDense{mat: blas64.Triangular{Uplo: blas.Upper}}},
  1031  	Transpose{&TriDense{mat: blas64.Triangular{Uplo: blas.Lower}}},
  1032  	TransposeTri{&TriDense{mat: blas64.Triangular{Uplo: blas.Upper}}},
  1033  	TransposeTri{&TriDense{mat: blas64.Triangular{Uplo: blas.Lower}}},
  1034  	Transpose{&basicTriangular{mat: blas64.Triangular{Uplo: blas.Upper}}},
  1035  	Transpose{&basicTriangular{mat: blas64.Triangular{Uplo: blas.Lower}}},
  1036  	TransposeTri{&basicTriangular{mat: blas64.Triangular{Uplo: blas.Upper}}},
  1037  	TransposeTri{&basicTriangular{mat: blas64.Triangular{Uplo: blas.Lower}}},
  1038  
  1039  	&SymBandDense{},
  1040  	&basicSymBanded{},
  1041  	Transpose{&basicSymBanded{}},
  1042  
  1043  	&SymBandDense{mat: blas64.SymmetricBand{K: 2}},
  1044  	&basicSymBanded{mat: blas64.SymmetricBand{K: 2}},
  1045  	Transpose{&basicSymBanded{mat: blas64.SymmetricBand{K: 2}}},
  1046  	TransposeBand{&basicSymBanded{mat: blas64.SymmetricBand{K: 2}}},
  1047  
  1048  	&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}},
  1049  	&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}},
  1050  	&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}},
  1051  	&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}},
  1052  	Transpose{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}},
  1053  	Transpose{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}},
  1054  	Transpose{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}},
  1055  	Transpose{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}},
  1056  	TransposeTri{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}},
  1057  	TransposeTri{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}},
  1058  	TransposeTri{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}},
  1059  	TransposeTri{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}},
  1060  	TransposeBand{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}},
  1061  	TransposeBand{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}},
  1062  	TransposeBand{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}},
  1063  	TransposeBand{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}},
  1064  	TransposeTriBand{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}},
  1065  	TransposeTriBand{&TriBandDense{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}},
  1066  	TransposeTriBand{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Upper}}},
  1067  	TransposeTriBand{&basicTriBanded{mat: blas64.TriangularBand{K: 2, Uplo: blas.Lower}}},
  1068  
  1069  	&DiagDense{},
  1070  	&DiagDense{mat: blas64.Vector{Inc: 10}},
  1071  	Transpose{&DiagDense{}},
  1072  	Transpose{&DiagDense{mat: blas64.Vector{Inc: 10}}},
  1073  	TransposeTri{&DiagDense{}},
  1074  	TransposeTri{&DiagDense{mat: blas64.Vector{Inc: 10}}},
  1075  	TransposeBand{&DiagDense{}},
  1076  	TransposeBand{&DiagDense{mat: blas64.Vector{Inc: 10}}},
  1077  	TransposeTriBand{&DiagDense{}},
  1078  	TransposeTriBand{&DiagDense{mat: blas64.Vector{Inc: 10}}},
  1079  	&basicDiagonal{},
  1080  	Transpose{&basicDiagonal{}},
  1081  	TransposeTri{&basicDiagonal{}},
  1082  	TransposeBand{&basicDiagonal{}},
  1083  	TransposeTriBand{&basicDiagonal{}},
  1084  
  1085  	&Tridiag{},
  1086  	Transpose{&Tridiag{}},
  1087  	TransposeBand{&Tridiag{}},
  1088  }
  1089  
  1090  var sizes = []struct {
  1091  	ar, ac int
  1092  }{
  1093  	{1, 1},
  1094  	{1, 3},
  1095  	{3, 1},
  1096  
  1097  	{6, 6},
  1098  	{6, 11},
  1099  	{11, 6},
  1100  }
  1101  
  1102  func testOneInputFunc(t *testing.T,
  1103  	// name is the name of the function being tested.
  1104  	name string,
  1105  
  1106  	// f is the function being tested.
  1107  	f func(a Matrix) interface{},
  1108  
  1109  	// denseComparison performs the same operation, but using Dense matrices for
  1110  	// comparison.
  1111  	denseComparison func(a *Dense) interface{},
  1112  
  1113  	// sameAnswer compares the result from two different evaluations of the function
  1114  	// and returns true if they are the same. The specific function being tested
  1115  	// determines the definition of "same". It may mean identical or it may mean
  1116  	// approximately equal.
  1117  	sameAnswer func(a, b interface{}) bool,
  1118  
  1119  	// legalType returns true if the type of the input is a legal type for the
  1120  	// input of the function.
  1121  	legalType func(a Matrix) bool,
  1122  
  1123  	// legalSize returns true if the size is valid for the function.
  1124  	legalSize func(r, c int) bool,
  1125  ) {
  1126  	src := rand.NewSource(1)
  1127  	for _, aMat := range testMatrices {
  1128  		for _, test := range sizes {
  1129  			// Skip the test if the argument would not be assignable to the
  1130  			// method's corresponding input parameter or it is not possible
  1131  			// to construct an argument of the requested size.
  1132  			if !legalType(aMat) {
  1133  				continue
  1134  			}
  1135  			if !legalDims(aMat, test.ar, test.ac) {
  1136  				continue
  1137  			}
  1138  			a := makeRandOf(aMat, test.ar, test.ac, src)
  1139  
  1140  			// Compute the true answer if the sizes are legal.
  1141  			dimsOK := legalSize(test.ar, test.ac)
  1142  			var want interface{}
  1143  			if dimsOK {
  1144  				var aDense Dense
  1145  				aDense.CloneFrom(a)
  1146  				want = denseComparison(&aDense)
  1147  			}
  1148  			aCopy := makeCopyOf(a)
  1149  			// Test the method for a zero-value of the receiver.
  1150  			aType, aTrans := untranspose(a)
  1151  			errStr := fmt.Sprintf("%v(%T), size: %#v, atrans %t", name, aType, test, aTrans)
  1152  			var got interface{}
  1153  			panicked, err := panics(func() { got = f(a) })
  1154  			if !dimsOK && !panicked {
  1155  				t.Errorf("Did not panic with illegal size: %s", errStr)
  1156  				continue
  1157  			}
  1158  			if dimsOK && panicked {
  1159  				t.Errorf("Panicked with legal size: %s: %v", errStr, err)
  1160  				continue
  1161  			}
  1162  			if !equal(a, aCopy) {
  1163  				t.Errorf("First input argument changed in call: %s", errStr)
  1164  			}
  1165  			if !dimsOK {
  1166  				continue
  1167  			}
  1168  			if !sameAnswer(want, got) {
  1169  				t.Errorf("Answer mismatch: %s; got %v, want %v", errStr, got, want)
  1170  			}
  1171  		}
  1172  	}
  1173  }
  1174  
  1175  var sizePairs = []struct {
  1176  	ar, ac, br, bc int
  1177  }{
  1178  	{1, 1, 1, 1},
  1179  	{6, 6, 6, 6},
  1180  	{7, 7, 7, 7},
  1181  
  1182  	{1, 1, 1, 5},
  1183  	{1, 1, 5, 1},
  1184  	{1, 5, 1, 1},
  1185  	{5, 1, 1, 1},
  1186  
  1187  	{5, 5, 5, 1},
  1188  	{5, 5, 1, 5},
  1189  	{5, 1, 5, 5},
  1190  	{1, 5, 5, 5},
  1191  
  1192  	{6, 6, 6, 11},
  1193  	{6, 6, 11, 6},
  1194  	{6, 11, 6, 6},
  1195  	{11, 6, 6, 6},
  1196  	{11, 11, 11, 6},
  1197  	{11, 11, 6, 11},
  1198  	{11, 6, 11, 11},
  1199  	{6, 11, 11, 11},
  1200  
  1201  	{1, 1, 5, 5},
  1202  	{1, 5, 1, 5},
  1203  	{1, 5, 5, 1},
  1204  	{5, 1, 1, 5},
  1205  	{5, 1, 5, 1},
  1206  	{5, 5, 1, 1},
  1207  	{6, 6, 11, 11},
  1208  	{6, 11, 6, 11},
  1209  	{6, 11, 11, 6},
  1210  	{11, 6, 6, 11},
  1211  	{11, 6, 11, 6},
  1212  	{11, 11, 6, 6},
  1213  
  1214  	{1, 1, 17, 11},
  1215  	{1, 1, 11, 17},
  1216  	{1, 11, 1, 17},
  1217  	{1, 17, 1, 11},
  1218  	{1, 11, 17, 1},
  1219  	{1, 17, 11, 1},
  1220  	{11, 1, 1, 17},
  1221  	{17, 1, 1, 11},
  1222  	{11, 1, 17, 1},
  1223  	{17, 1, 11, 1},
  1224  	{11, 17, 1, 1},
  1225  	{17, 11, 1, 1},
  1226  
  1227  	{6, 6, 1, 11},
  1228  	{6, 6, 11, 1},
  1229  	{6, 11, 6, 1},
  1230  	{6, 1, 6, 11},
  1231  	{6, 11, 1, 6},
  1232  	{6, 1, 11, 6},
  1233  	{11, 6, 6, 1},
  1234  	{1, 6, 6, 11},
  1235  	{11, 6, 1, 6},
  1236  	{1, 6, 11, 6},
  1237  	{11, 1, 6, 6},
  1238  	{1, 11, 6, 6},
  1239  
  1240  	{6, 6, 17, 1},
  1241  	{6, 6, 1, 17},
  1242  	{6, 1, 6, 17},
  1243  	{6, 17, 6, 1},
  1244  	{6, 1, 17, 6},
  1245  	{6, 17, 1, 6},
  1246  	{1, 6, 6, 17},
  1247  	{17, 6, 6, 1},
  1248  	{1, 6, 17, 6},
  1249  	{17, 6, 1, 6},
  1250  	{1, 17, 6, 6},
  1251  	{17, 1, 6, 6},
  1252  
  1253  	{6, 6, 17, 11},
  1254  	{6, 6, 11, 17},
  1255  	{6, 11, 6, 17},
  1256  	{6, 17, 6, 11},
  1257  	{6, 11, 17, 6},
  1258  	{6, 17, 11, 6},
  1259  	{11, 6, 6, 17},
  1260  	{17, 6, 6, 11},
  1261  	{11, 6, 17, 6},
  1262  	{17, 6, 11, 6},
  1263  	{11, 17, 6, 6},
  1264  	{17, 11, 6, 6},
  1265  }
  1266  
  1267  func testTwoInputFunc(t *testing.T,
  1268  	// name is the name of the function being tested.
  1269  	name string,
  1270  
  1271  	// f is the function being tested.
  1272  	f func(a, b Matrix) interface{},
  1273  
  1274  	// denseComparison performs the same operation, but using Dense matrices for
  1275  	// comparison.
  1276  	denseComparison func(a, b *Dense) interface{},
  1277  
  1278  	// sameAnswer compares the result from two different evaluations of the function
  1279  	// and returns true if they are the same. The specific function being tested
  1280  	// determines the definition of "same". It may mean identical or it may mean
  1281  	// approximately equal.
  1282  	sameAnswer func(a, b interface{}) bool,
  1283  
  1284  	// legalType returns true if the types of the inputs are legal for the
  1285  	// input of the function.
  1286  	legalType func(a, b Matrix) bool,
  1287  
  1288  	// legalSize returns true if the sizes are valid for the function.
  1289  	legalSize func(ar, ac, br, bc int) bool,
  1290  ) {
  1291  	src := rand.NewSource(1)
  1292  	for _, aMat := range testMatrices {
  1293  		for _, bMat := range testMatrices {
  1294  			// Loop over all of the size combinations (bigger, smaller, etc.).
  1295  			for _, test := range sizePairs {
  1296  				// Skip the test if the argument would not be assignable to the
  1297  				// method's corresponding input parameter or it is not possible
  1298  				// to construct an argument of the requested size.
  1299  				if !legalType(aMat, bMat) {
  1300  					continue
  1301  				}
  1302  				if !legalDims(aMat, test.ar, test.ac) {
  1303  					continue
  1304  				}
  1305  				if !legalDims(bMat, test.br, test.bc) {
  1306  					continue
  1307  				}
  1308  				a := makeRandOf(aMat, test.ar, test.ac, src)
  1309  				b := makeRandOf(bMat, test.br, test.bc, src)
  1310  
  1311  				// Compute the true answer if the sizes are legal.
  1312  				dimsOK := legalSize(test.ar, test.ac, test.br, test.bc)
  1313  				var want interface{}
  1314  				if dimsOK {
  1315  					var aDense, bDense Dense
  1316  					aDense.CloneFrom(a)
  1317  					bDense.CloneFrom(b)
  1318  					want = denseComparison(&aDense, &bDense)
  1319  				}
  1320  				aCopy := makeCopyOf(a)
  1321  				bCopy := makeCopyOf(b)
  1322  				// Test the method for a zero-value of the receiver.
  1323  				aType, aTrans := untranspose(a)
  1324  				bType, bTrans := untranspose(b)
  1325  				errStr := fmt.Sprintf("%v(%T, %T), size: %#v, atrans %t, btrans %t", name, aType, bType, test, aTrans, bTrans)
  1326  				var got interface{}
  1327  				panicked, err := panics(func() { got = f(a, b) })
  1328  				if !dimsOK && !panicked {
  1329  					t.Errorf("Did not panic with illegal size: %s", errStr)
  1330  					continue
  1331  				}
  1332  				if dimsOK && panicked {
  1333  					t.Errorf("Panicked with legal size: %s: %v", errStr, err)
  1334  					continue
  1335  				}
  1336  				if !equal(a, aCopy) {
  1337  					t.Errorf("First input argument changed in call: %s", errStr)
  1338  				}
  1339  				if !equal(b, bCopy) {
  1340  					t.Errorf("First input argument changed in call: %s", errStr)
  1341  				}
  1342  				if !dimsOK {
  1343  					continue
  1344  				}
  1345  				if !sameAnswer(want, got) {
  1346  					t.Errorf("Answer mismatch: %s", errStr)
  1347  				}
  1348  			}
  1349  		}
  1350  	}
  1351  }
  1352  
  1353  // testOneInput tests a method that has one matrix input argument
  1354  func testOneInput(t *testing.T,
  1355  	// name is the name of the method being tested.
  1356  	name string,
  1357  
  1358  	// receiver is a value of the receiver type.
  1359  	receiver Matrix,
  1360  
  1361  	// method is the generalized receiver.Method(a).
  1362  	method func(receiver, a Matrix),
  1363  
  1364  	// denseComparison performs the same operation as method, but with dense
  1365  	// matrices for comparison with the result.
  1366  	denseComparison func(receiver, a *Dense),
  1367  
  1368  	// legalTypes returns whether the concrete types in Matrix are valid for
  1369  	// the method.
  1370  	legalType func(a Matrix) bool,
  1371  
  1372  	// legalSize returns whether the matrix sizes are valid for the method.
  1373  	legalSize func(ar, ac int) bool,
  1374  
  1375  	// tol is the tolerance for equality when comparing method results.
  1376  	tol float64,
  1377  ) {
  1378  	src := rand.NewSource(1)
  1379  	for _, aMat := range testMatrices {
  1380  		for _, test := range sizes {
  1381  			// Skip the test if the argument would not be assignable to the
  1382  			// method's corresponding input parameter or it is not possible
  1383  			// to construct an argument of the requested size.
  1384  			if !legalType(aMat) {
  1385  				continue
  1386  			}
  1387  			if !legalDims(aMat, test.ar, test.ac) {
  1388  				continue
  1389  			}
  1390  			a := makeRandOf(aMat, test.ar, test.ac, src)
  1391  
  1392  			// Compute the true answer if the sizes are legal.
  1393  			dimsOK := legalSize(test.ar, test.ac)
  1394  			var want Dense
  1395  			if dimsOK {
  1396  				var aDense Dense
  1397  				aDense.CloneFrom(a)
  1398  				denseComparison(&want, &aDense)
  1399  			}
  1400  			aCopy := makeCopyOf(a)
  1401  
  1402  			// Test the method for a zero-value of the receiver.
  1403  			aType, aTrans := untranspose(a)
  1404  			errStr := fmt.Sprintf("%T.%s(%T), size: %#v, atrans %v", receiver, name, aType, test, aTrans)
  1405  			empty := makeRandOf(receiver, 0, 0, src)
  1406  			panicked, err := panics(func() { method(empty, a) })
  1407  			if !dimsOK && !panicked {
  1408  				t.Errorf("Did not panic with illegal size: %s", errStr)
  1409  				continue
  1410  			}
  1411  			if dimsOK && panicked {
  1412  				t.Errorf("Panicked with legal size: %s: %v", errStr, err)
  1413  				continue
  1414  			}
  1415  			if !equal(a, aCopy) {
  1416  				t.Errorf("First input argument changed in call: %s", errStr)
  1417  			}
  1418  			if !dimsOK {
  1419  				continue
  1420  			}
  1421  			if !equalApprox(empty, &want, tol, false) {
  1422  				t.Errorf("Answer mismatch with empty receiver: %s.\nGot:\n% v\nWant:\n% v\n", errStr, Formatted(empty), Formatted(&want))
  1423  				continue
  1424  			}
  1425  
  1426  			// Test the method with a non-empty-value of the receiver.
  1427  			// The receiver has been overwritten in place so use its size
  1428  			// to construct a new random matrix.
  1429  			rr, rc := empty.Dims()
  1430  			neverEmpty := makeRandOf(receiver, rr, rc, src)
  1431  			panicked, message := panics(func() { method(neverEmpty, a) })
  1432  			if panicked {
  1433  				t.Errorf("Panicked with non-empty receiver: %s: %s", errStr, message)
  1434  			}
  1435  			if !equalApprox(neverEmpty, &want, tol, false) {
  1436  				t.Errorf("Answer mismatch non-empty receiver: %s", errStr)
  1437  			}
  1438  
  1439  			// Test the method with a NaN-filled-value of the receiver.
  1440  			// The receiver has been overwritten in place so use its size
  1441  			// to construct a new NaN matrix.
  1442  			nanMatrix := makeNaNOf(receiver, rr, rc)
  1443  			panicked, message = panics(func() { method(nanMatrix, a) })
  1444  			if panicked {
  1445  				t.Errorf("Panicked with NaN-filled receiver: %s: %s", errStr, message)
  1446  			}
  1447  			if !equalApprox(nanMatrix, &want, tol, false) {
  1448  				t.Errorf("Answer mismatch NaN-filled receiver: %s", errStr)
  1449  			}
  1450  
  1451  			// Test with an incorrectly sized matrix.
  1452  			switch receiver.(type) {
  1453  			default:
  1454  				panic("matrix type not coded for incorrect receiver size")
  1455  			case *Dense:
  1456  				wrongSize := makeRandOf(receiver, rr+1, rc, src)
  1457  				panicked, _ = panics(func() { method(wrongSize, a) })
  1458  				if !panicked {
  1459  					t.Errorf("Did not panic with wrong number of rows: %s", errStr)
  1460  				}
  1461  				wrongSize = makeRandOf(receiver, rr, rc+1, src)
  1462  				panicked, _ = panics(func() { method(wrongSize, a) })
  1463  				if !panicked {
  1464  					t.Errorf("Did not panic with wrong number of columns: %s", errStr)
  1465  				}
  1466  			case *TriDense, *SymDense:
  1467  				// Add to the square size.
  1468  				wrongSize := makeRandOf(receiver, rr+1, rc+1, src)
  1469  				panicked, _ = panics(func() { method(wrongSize, a) })
  1470  				if !panicked {
  1471  					t.Errorf("Did not panic with wrong size: %s", errStr)
  1472  				}
  1473  			case *VecDense:
  1474  				// Add to the column length.
  1475  				wrongSize := makeRandOf(receiver, rr+1, rc, src)
  1476  				panicked, _ = panics(func() { method(wrongSize, a) })
  1477  				if !panicked {
  1478  					t.Errorf("Did not panic with wrong number of rows: %s", errStr)
  1479  				}
  1480  			}
  1481  
  1482  			// The receiver and the input may share a matrix pointer
  1483  			// if the type and size of the receiver and one of the
  1484  			// arguments match. Test the method works properly
  1485  			// when this is the case.
  1486  			aMaybeSame := maybeSame(neverEmpty, a)
  1487  			if aMaybeSame {
  1488  				aSame := makeCopyOf(a)
  1489  				receiver = aSame
  1490  				u, ok := aSame.(Untransposer)
  1491  				if ok {
  1492  					receiver = u.Untranspose()
  1493  				}
  1494  				preData := underlyingData(receiver)
  1495  				panicked, err = panics(func() { method(receiver, aSame) })
  1496  				if panicked {
  1497  					t.Errorf("Panics when a maybeSame: %s: %v", errStr, err)
  1498  				} else {
  1499  					if !equalApprox(receiver, &want, tol, false) {
  1500  						t.Errorf("Wrong answer when a maybeSame: %s", errStr)
  1501  					}
  1502  					postData := underlyingData(receiver)
  1503  					if !floats.Equal(preData, postData) {
  1504  						t.Errorf("Original data slice not modified when a maybeSame: %s", errStr)
  1505  					}
  1506  				}
  1507  			}
  1508  		}
  1509  	}
  1510  }
  1511  
  1512  // testTwoInput tests a method that has two input arguments.
  1513  func testTwoInput(t *testing.T,
  1514  	// name is the name of the method being tested.
  1515  	name string,
  1516  
  1517  	// receiver is a value of the receiver type.
  1518  	receiver Matrix,
  1519  
  1520  	// method is the generalized receiver.Method(a, b).
  1521  	method func(receiver, a, b Matrix),
  1522  
  1523  	// denseComparison performs the same operation as method, but with dense
  1524  	// matrices for comparison with the result.
  1525  	denseComparison func(receiver, a, b *Dense),
  1526  
  1527  	// legalTypes returns whether the concrete types in Matrix are valid for
  1528  	// the method.
  1529  	legalTypes func(a, b Matrix) bool,
  1530  
  1531  	// legalSize returns whether the matrix sizes are valid for the method.
  1532  	legalSize func(ar, ac, br, bc int) bool,
  1533  
  1534  	// tol is the tolerance for equality when comparing method results.
  1535  	tol float64,
  1536  ) {
  1537  	src := rand.NewSource(1)
  1538  	for _, aMat := range testMatrices {
  1539  		for _, bMat := range testMatrices {
  1540  			// Loop over all of the size combinations (bigger, smaller, etc.).
  1541  			for _, test := range sizePairs {
  1542  				// Skip the test if any argument would not be assignable to the
  1543  				// method's corresponding input parameter or it is not possible
  1544  				// to construct an argument of the requested size.
  1545  				if !legalTypes(aMat, bMat) {
  1546  					continue
  1547  				}
  1548  				if !legalDims(aMat, test.ar, test.ac) {
  1549  					continue
  1550  				}
  1551  				if !legalDims(bMat, test.br, test.bc) {
  1552  					continue
  1553  				}
  1554  				a := makeRandOf(aMat, test.ar, test.ac, src)
  1555  				b := makeRandOf(bMat, test.br, test.bc, src)
  1556  
  1557  				// Compute the true answer if the sizes are legal.
  1558  				dimsOK := legalSize(test.ar, test.ac, test.br, test.bc)
  1559  				var want Dense
  1560  				if dimsOK {
  1561  					var aDense, bDense Dense
  1562  					aDense.CloneFrom(a)
  1563  					bDense.CloneFrom(b)
  1564  					denseComparison(&want, &aDense, &bDense)
  1565  				}
  1566  				aCopy := makeCopyOf(a)
  1567  				bCopy := makeCopyOf(b)
  1568  
  1569  				// Test the method for a empty-value of the receiver.
  1570  				aType, aTrans := untranspose(a)
  1571  				bType, bTrans := untranspose(b)
  1572  				errStr := fmt.Sprintf("%T.%s(%T, %T), sizes: %#v, atrans %v, btrans %v", receiver, name, aType, bType, test, aTrans, bTrans)
  1573  				empty := makeRandOf(receiver, 0, 0, src)
  1574  				panicked, err := panics(func() { method(empty, a, b) })
  1575  				if !dimsOK && !panicked {
  1576  					t.Errorf("Did not panic with illegal size: %s", errStr)
  1577  					continue
  1578  				}
  1579  				if dimsOK && panicked {
  1580  					t.Errorf("Panicked with legal size: %s: %v", errStr, err)
  1581  					continue
  1582  				}
  1583  				if !equal(a, aCopy) {
  1584  					t.Errorf("First input argument changed in call: %s", errStr)
  1585  				}
  1586  				if !equal(b, bCopy) {
  1587  					t.Errorf("Second input argument changed in call: %s", errStr)
  1588  				}
  1589  				if !dimsOK {
  1590  					continue
  1591  				}
  1592  				wasEmpty, empty := empty, nil // Nil-out empty so we detect illegal use.
  1593  				// NaN equality is allowed because of 0/0 in DivElem test.
  1594  				if !equalApprox(wasEmpty, &want, tol, true) {
  1595  					t.Errorf("Answer mismatch with empty receiver: %s", errStr)
  1596  					continue
  1597  				}
  1598  
  1599  				// Test the method with a non-empty-value of the receiver.
  1600  				// The receiver has been overwritten in place so use its size
  1601  				// to construct a new random matrix.
  1602  				rr, rc := wasEmpty.Dims()
  1603  				neverEmpty := makeRandOf(receiver, rr, rc, src)
  1604  				panicked, message := panics(func() { method(neverEmpty, a, b) })
  1605  				if panicked {
  1606  					t.Errorf("Panicked with non-empty receiver: %s: %s", errStr, message)
  1607  				}
  1608  				// NaN equality is allowed because of 0/0 in DivElem test.
  1609  				if !equalApprox(neverEmpty, &want, tol, true) {
  1610  					t.Errorf("Answer mismatch non-empty receiver: %s", errStr)
  1611  				}
  1612  
  1613  				// Test the method with a NaN-filled value of the receiver.
  1614  				// The receiver has been overwritten in place so use its size
  1615  				// to construct a new NaN matrix.
  1616  				nanMatrix := makeNaNOf(receiver, rr, rc)
  1617  				panicked, message = panics(func() { method(nanMatrix, a, b) })
  1618  				if panicked {
  1619  					t.Errorf("Panicked with NaN-filled receiver: %s: %s", errStr, message)
  1620  				}
  1621  				// NaN equality is allowed because of 0/0 in DivElem test.
  1622  				if !equalApprox(nanMatrix, &want, tol, true) {
  1623  					t.Errorf("Answer mismatch NaN-filled receiver: %s", errStr)
  1624  				}
  1625  
  1626  				// Test with an incorrectly sized matrix.
  1627  				switch receiver.(type) {
  1628  				default:
  1629  					panic("matrix type not coded for incorrect receiver size")
  1630  				case *Dense:
  1631  					wrongSize := makeRandOf(receiver, rr+1, rc, src)
  1632  					panicked, _ = panics(func() { method(wrongSize, a, b) })
  1633  					if !panicked {
  1634  						t.Errorf("Did not panic with wrong number of rows: %s", errStr)
  1635  					}
  1636  					wrongSize = makeRandOf(receiver, rr, rc+1, src)
  1637  					panicked, _ = panics(func() { method(wrongSize, a, b) })
  1638  					if !panicked {
  1639  						t.Errorf("Did not panic with wrong number of columns: %s", errStr)
  1640  					}
  1641  				case *TriDense, *SymDense:
  1642  					// Add to the square size.
  1643  					wrongSize := makeRandOf(receiver, rr+1, rc+1, src)
  1644  					panicked, _ = panics(func() { method(wrongSize, a, b) })
  1645  					if !panicked {
  1646  						t.Errorf("Did not panic with wrong size: %s", errStr)
  1647  					}
  1648  				case *VecDense:
  1649  					// Add to the column length.
  1650  					wrongSize := makeRandOf(receiver, rr+1, rc, src)
  1651  					panicked, _ = panics(func() { method(wrongSize, a, b) })
  1652  					if !panicked {
  1653  						t.Errorf("Did not panic with wrong number of rows: %s", errStr)
  1654  					}
  1655  				}
  1656  
  1657  				// The receiver and an input may share a matrix pointer
  1658  				// if the type and size of the receiver and one of the
  1659  				// arguments match. Test the method works properly
  1660  				// when this is the case.
  1661  				aMaybeSame := maybeSame(neverEmpty, a)
  1662  				bMaybeSame := maybeSame(neverEmpty, b)
  1663  				if aMaybeSame {
  1664  					aSame := makeCopyOf(a)
  1665  					receiver = aSame
  1666  					u, ok := aSame.(Untransposer)
  1667  					if ok {
  1668  						receiver = u.Untranspose()
  1669  					}
  1670  					preData := underlyingData(receiver)
  1671  					panicked, err = panics(func() { method(receiver, aSame, b) })
  1672  					if panicked {
  1673  						t.Errorf("Panics when a maybeSame: %s: %v", errStr, err)
  1674  					} else {
  1675  						if !equalApprox(receiver, &want, tol, false) {
  1676  							t.Errorf("Wrong answer when a maybeSame: %s", errStr)
  1677  						}
  1678  						postData := underlyingData(receiver)
  1679  						if !floats.Equal(preData, postData) {
  1680  							t.Errorf("Original data slice not modified when a maybeSame: %s", errStr)
  1681  						}
  1682  					}
  1683  				}
  1684  				if bMaybeSame {
  1685  					bSame := makeCopyOf(b)
  1686  					receiver = bSame
  1687  					u, ok := bSame.(Untransposer)
  1688  					if ok {
  1689  						receiver = u.Untranspose()
  1690  					}
  1691  					preData := underlyingData(receiver)
  1692  					panicked, err = panics(func() { method(receiver, a, bSame) })
  1693  					if panicked {
  1694  						t.Errorf("Panics when b maybeSame: %s: %v", errStr, err)
  1695  					} else {
  1696  						if !equalApprox(receiver, &want, tol, false) {
  1697  							t.Errorf("Wrong answer when b maybeSame: %s", errStr)
  1698  						}
  1699  						postData := underlyingData(receiver)
  1700  						if !floats.Equal(preData, postData) {
  1701  							t.Errorf("Original data slice not modified when b maybeSame: %s", errStr)
  1702  						}
  1703  					}
  1704  				}
  1705  				if aMaybeSame && bMaybeSame {
  1706  					aSame := makeCopyOf(a)
  1707  					receiver = aSame
  1708  					u, ok := aSame.(Untransposer)
  1709  					if ok {
  1710  						receiver = u.Untranspose()
  1711  					}
  1712  					// Ensure that b is the correct transpose type if applicable.
  1713  					// The receiver is always a concrete type so use it.
  1714  					bSame := receiver
  1715  					_, ok = b.(Untransposer)
  1716  					if ok {
  1717  						bSame = retranspose(b, receiver)
  1718  					}
  1719  					// Compute the real answer for this case. It is different
  1720  					// from the initial answer since now a and b have the
  1721  					// same data.
  1722  					empty = makeRandOf(wasEmpty, 0, 0, src)
  1723  					method(empty, aSame, bSame)
  1724  					wasEmpty, empty = empty, nil // Nil-out empty so we detect illegal use.
  1725  					preData := underlyingData(receiver)
  1726  					panicked, err = panics(func() { method(receiver, aSame, bSame) })
  1727  					if panicked {
  1728  						t.Errorf("Panics when both maybeSame: %s: %v", errStr, err)
  1729  					} else {
  1730  						if !equalApprox(receiver, wasEmpty, tol, false) {
  1731  							t.Errorf("Wrong answer when both maybeSame: %s", errStr)
  1732  						}
  1733  						postData := underlyingData(receiver)
  1734  						if !floats.Equal(preData, postData) {
  1735  							t.Errorf("Original data slice not modified when both maybeSame: %s", errStr)
  1736  						}
  1737  					}
  1738  				}
  1739  			}
  1740  		}
  1741  	}
  1742  }