github.com/jingcheng-WU/gonum@v0.9.1-0.20210323123734-f1a2a11a8f7b/blas/testblas/dger.go (about)

     1  // Copyright ©2014 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package testblas
     6  
     7  import (
     8  	"math"
     9  	"testing"
    10  )
    11  
    12  type Dgerer interface {
    13  	Dger(m, n int, alpha float64, x []float64, incX int, y []float64, incY int, a []float64, lda int)
    14  }
    15  
    16  func DgerTest(t *testing.T, blasser Dgerer) {
    17  	for _, test := range []struct {
    18  		name string
    19  		a    [][]float64
    20  		m    int
    21  		n    int
    22  		x    []float64
    23  		y    []float64
    24  		incX int
    25  		incY int
    26  
    27  		want [][]float64
    28  	}{
    29  		{
    30  			name: "M gt N inc 1",
    31  			m:    5,
    32  			n:    3,
    33  			a: [][]float64{
    34  				{1.3, 2.4, 3.5},
    35  				{2.6, 2.8, 3.3},
    36  				{-1.3, -4.3, -9.7},
    37  				{8, 9, -10},
    38  				{-12, -14, -6},
    39  			},
    40  			x:    []float64{-2, -3, 0, 1, 2},
    41  			y:    []float64{-1.1, 5, 0},
    42  			incX: 1,
    43  			incY: 1,
    44  			want: [][]float64{{3.5, -7.6, 3.5}, {5.9, -12.2, 3.3}, {-1.3, -4.3, -9.7}, {6.9, 14, -10}, {-14.2, -4, -6}},
    45  		},
    46  		{
    47  			name: "M eq N inc 1",
    48  			m:    3,
    49  			n:    3,
    50  			a: [][]float64{
    51  				{1.3, 2.4, 3.5},
    52  				{2.6, 2.8, 3.3},
    53  				{-1.3, -4.3, -9.7},
    54  			},
    55  			x:    []float64{-2, -3, 0},
    56  			y:    []float64{-1.1, 5, 0},
    57  			incX: 1,
    58  			incY: 1,
    59  			want: [][]float64{{3.5, -7.6, 3.5}, {5.9, -12.2, 3.3}, {-1.3, -4.3, -9.7}},
    60  		},
    61  
    62  		{
    63  			name: "M lt N inc 1",
    64  			m:    3,
    65  			n:    6,
    66  			a: [][]float64{
    67  				{1.3, 2.4, 3.5, 4.8, 1.11, -9},
    68  				{2.6, 2.8, 3.3, -3.4, 6.2, -8.7},
    69  				{-1.3, -4.3, -9.7, -3.1, 8.9, 8.9},
    70  			},
    71  			x:    []float64{-2, -3, 0},
    72  			y:    []float64{-1.1, 5, 0, 9, 19, 22},
    73  			incX: 1,
    74  			incY: 1,
    75  			want: [][]float64{{3.5, -7.6, 3.5, -13.2, -36.89, -53}, {5.9, -12.2, 3.3, -30.4, -50.8, -74.7}, {-1.3, -4.3, -9.7, -3.1, 8.9, 8.9}},
    76  		},
    77  		{
    78  			name: "M gt N inc not 1",
    79  			m:    5,
    80  			n:    3,
    81  			a: [][]float64{
    82  				{1.3, 2.4, 3.5},
    83  				{2.6, 2.8, 3.3},
    84  				{-1.3, -4.3, -9.7},
    85  				{8, 9, -10},
    86  				{-12, -14, -6},
    87  			},
    88  			x:    []float64{-2, -3, 0, 1, 2, 6, 0, 9, 7},
    89  			y:    []float64{-1.1, 5, 0, 8, 7, -5, 7},
    90  			incX: 2,
    91  			incY: 3,
    92  			want: [][]float64{{3.5, -13.6, -10.5}, {2.6, 2.8, 3.3}, {-3.5, 11.7, 4.3}, {8, 9, -10}, {-19.700000000000003, 42, 43}},
    93  		},
    94  		{
    95  			name: "M eq N inc not 1",
    96  			m:    3,
    97  			n:    3,
    98  			a: [][]float64{
    99  				{1.3, 2.4, 3.5},
   100  				{2.6, 2.8, 3.3},
   101  				{-1.3, -4.3, -9.7},
   102  			},
   103  			x:    []float64{-2, -3, 0, 8, 7, -9, 7, -6, 12, 6, 6, 6, -11},
   104  			y:    []float64{-1.1, 5, 0, 0, 9, 8, 6},
   105  			incX: 4,
   106  			incY: 3,
   107  			want: [][]float64{{3.5, 2.4, -8.5}, {-5.1, 2.8, 45.3}, {-14.5, -4.3, 62.3}},
   108  		},
   109  		{
   110  			name: "M lt N inc not 1",
   111  			m:    3,
   112  			n:    6,
   113  			a: [][]float64{
   114  				{1.3, 2.4, 3.5, 4.8, 1.11, -9},
   115  				{2.6, 2.8, 3.3, -3.4, 6.2, -8.7},
   116  				{-1.3, -4.3, -9.7, -3.1, 8.9, 8.9},
   117  			},
   118  			x:    []float64{-2, -3, 0, 0, 8, 0, 9, -3},
   119  			y:    []float64{-1.1, 5, 0, 9, 19, 22, 11, -8.11, -9.22, 9.87, 7},
   120  			incX: 3,
   121  			incY: 2,
   122  			want: [][]float64{{3.5, 2.4, -34.5, -17.2, 19.55, -23}, {2.6, 2.8, 3.3, -3.4, 6.2, -8.7}, {-11.2, -4.3, 161.3, 95.9, -74.08, 71.9}},
   123  		},
   124  		{
   125  			name: "Y NaN element",
   126  			m:    1,
   127  			n:    1,
   128  			a:    [][]float64{{1.3}},
   129  			x:    []float64{1.3},
   130  			y:    []float64{math.NaN()},
   131  			incX: 1,
   132  			incY: 1,
   133  			want: [][]float64{{math.NaN()}},
   134  		},
   135  		{
   136  			name: "M eq N large inc 1",
   137  			m:    7,
   138  			n:    7,
   139  			x:    []float64{6.2, -5, 88.68, 43.4, -30.5, -40.2, 19.9},
   140  			y:    []float64{1.5, 21.7, -28.7, -11.9, 18.1, 3.1, 21},
   141  			a: [][]float64{
   142  				{-20.5, 17.1, -8.4, -23.8, 3.9, 7.7, 6.25},
   143  				{2.9, -0.29, 25.6, -9.4, 36.5, 9.7, 2.3},
   144  				{4.1, -34.1, 10.3, 4.5, -42.05, 9.4, 4},
   145  				{19.2, 9.8, -32.7, 4.1, 4.4, -22.5, -7.8},
   146  				{3.6, -24.5, 21.7, 8.6, -13.82, 38.05, -2.29},
   147  				{39.4, -40.5, 7.9, -2.5, -7.7, 18.1, -25.5},
   148  				{-18.5, 43.2, 2.1, 30.1, 3.02, -31.1, -7.6},
   149  			},
   150  			incX: 1,
   151  			incY: 1,
   152  			want: [][]float64{
   153  				{-11.2, 151.64, -186.34, -97.58, 116.12, 26.92, 136.45},
   154  				{-4.6, -108.79, 169.1, 50.1, -54, -5.8, -102.7},
   155  				{137.12, 1890.256, -2534.816, -1050.792, 1563.058, 284.308, 1866.28},
   156  				{84.3, 951.58, -1278.28, -512.36, 789.94, 112.04, 903.6},
   157  				{-42.15, -686.35, 897.05, 371.55, -565.87, -56.5, -642.79},
   158  				{-20.9, -912.84, 1161.64, 475.88, -735.32, -106.52, -869.7},
   159  				{11.35, 475.03, -569.03, -206.71, 363.21, 30.59, 410.3},
   160  			},
   161  		},
   162  		{
   163  			name: "M eq N large inc not 1",
   164  			m:    7,
   165  			n:    7,
   166  			x:    []float64{6.2, 100, 200, -5, 300, 400, 88.68, 100, 200, 43.4, 300, 400, -30.5, 100, 200, -40.2, 300, 400, 19.9},
   167  			y:    []float64{1.5, 100, 200, 300, 21.7, 100, 200, 300, -28.7, 100, 200, 300, -11.9, 100, 200, 300, 18.1, 100, 200, 300, 3.1, 100, 200, 300, 21},
   168  			a: [][]float64{
   169  				{-20.5, 17.1, -8.4, -23.8, 3.9, 7.7, 6.25},
   170  				{2.9, -0.29, 25.6, -9.4, 36.5, 9.7, 2.3},
   171  				{4.1, -34.1, 10.3, 4.5, -42.05, 9.4, 4},
   172  				{19.2, 9.8, -32.7, 4.1, 4.4, -22.5, -7.8},
   173  				{3.6, -24.5, 21.7, 8.6, -13.82, 38.05, -2.29},
   174  				{39.4, -40.5, 7.9, -2.5, -7.7, 18.1, -25.5},
   175  				{-18.5, 43.2, 2.1, 30.1, 3.02, -31.1, -7.6},
   176  			},
   177  			incX: 3,
   178  			incY: 4,
   179  			want: [][]float64{
   180  				{-11.2, 151.64, -186.34, -97.58, 116.12, 26.92, 136.45},
   181  				{-4.6, -108.79, 169.1, 50.1, -54, -5.8, -102.7},
   182  				{137.12, 1890.256, -2534.816, -1050.792, 1563.058, 284.308, 1866.28},
   183  				{84.3, 951.58, -1278.28, -512.36, 789.94, 112.04, 903.6},
   184  				{-42.15, -686.35, 897.05, 371.55, -565.87, -56.5, -642.79},
   185  				{-20.9, -912.84, 1161.64, 475.88, -735.32, -106.52, -869.7},
   186  				{11.35, 475.03, -569.03, -206.71, 363.21, 30.59, 410.3},
   187  			},
   188  		},
   189  	} {
   190  		// TODO: Add tests where a is longer
   191  		// TODO: Add panic tests
   192  		// TODO: Add negative increment tests
   193  
   194  		x := sliceCopy(test.x)
   195  		y := sliceCopy(test.y)
   196  
   197  		a := sliceOfSliceCopy(test.a)
   198  
   199  		// Test with row major
   200  		alpha := 1.0
   201  		aFlat := flatten(a)
   202  		blasser.Dger(test.m, test.n, alpha, x, test.incX, y, test.incY, aFlat, test.n)
   203  		ans := unflatten(aFlat, test.m, test.n)
   204  		dgercomp(t, x, test.x, y, test.y, ans, test.want, test.name+" row maj")
   205  
   206  		// Test with different alpha
   207  		alpha = 4.0
   208  		aFlat = flatten(a)
   209  		blasser.Dger(test.m, test.n, alpha, x, test.incX, y, test.incY, aFlat, test.n)
   210  		ans = unflatten(aFlat, test.m, test.n)
   211  		trueCopy := sliceOfSliceCopy(test.want)
   212  		for i := range trueCopy {
   213  			for j := range trueCopy[i] {
   214  				trueCopy[i][j] = alpha*(trueCopy[i][j]-a[i][j]) + a[i][j]
   215  			}
   216  		}
   217  		dgercomp(t, x, test.x, y, test.y, ans, trueCopy, test.name+" row maj alpha")
   218  	}
   219  }
   220  
   221  func dgercomp(t *testing.T, x, xCopy, y, yCopy []float64, ans [][]float64, trueAns [][]float64, name string) {
   222  	if !dSliceEqual(x, xCopy) {
   223  		t.Errorf("case %v: x modified during call to dger\n%v\n%v", name, x, xCopy)
   224  	}
   225  	if !dSliceEqual(y, yCopy) {
   226  		t.Errorf("case %v: y modified during call to dger\n%v\n%v", name, y, yCopy)
   227  	}
   228  
   229  	for i := range ans {
   230  		if !dSliceTolEqual(ans[i], trueAns[i]) {
   231  			t.Errorf("case %v: answer mismatch at %v.\nExpected %v,\nFound %v", name, i, trueAns, ans)
   232  			break
   233  		}
   234  	}
   235  }