github.com/wzzhu/tensor@v0.9.24/api_arith_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"log"
     5  	"math/rand"
     6  	"testing"
     7  	"testing/quick"
     8  	"time"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  )
    12  
    13  // This file contains the tests for API functions that aren't generated by genlib
    14  
    15  func TestMod(t *testing.T) {
    16  	a := New(WithBacking([]float64{1, 2, 3, 4}))
    17  	b := New(WithBacking([]float64{1, 1, 1, 1}))
    18  	var correct interface{} = []float64{0, 0, 0, 0}
    19  
    20  	// vec-vec
    21  	res, err := Mod(a, b)
    22  	if err != nil {
    23  		t.Fatalf("Error: %v", err)
    24  	}
    25  	assert.Equal(t, correct, res.Data())
    26  
    27  	// scalar
    28  	if res, err = Mod(a, 1.0); err != nil {
    29  		t.Fatalf("Error: %v", err)
    30  	}
    31  	assert.Equal(t, correct, res.Data())
    32  }
    33  
    34  func TestFMA(t *testing.T) {
    35  	same := func(q *Dense) bool {
    36  		a := q.Clone().(*Dense)
    37  		x := q.Clone().(*Dense)
    38  		y := New(Of(q.Dtype()), WithShape(q.Shape().Clone()...))
    39  		y.Memset(identityVal(100, q.Dtype()))
    40  		WithEngine(q.Engine())(y)
    41  		y2 := y.Clone().(*Dense)
    42  
    43  		we, willFailEq := willerr(a, numberTypes, nil)
    44  		_, ok1 := q.Engine().(FMAer)
    45  		_, ok2 := q.Engine().(Muler)
    46  		_, ok3 := q.Engine().(Adder)
    47  		we = we || (!ok1 && (!ok2 || !ok3))
    48  
    49  		f, err := FMA(a, x, y)
    50  		if err, retEarly := qcErrCheck(t, "FMA#1", a, x, we, err); retEarly {
    51  			if err != nil {
    52  				log.Printf("q.Engine() %T", q.Engine())
    53  				return false
    54  			}
    55  			return true
    56  		}
    57  
    58  		we, _ = willerr(a, numberTypes, nil)
    59  		_, ok := a.Engine().(Muler)
    60  		we = we || !ok
    61  		wi, err := Mul(a, x, WithIncr(y2))
    62  		if err, retEarly := qcErrCheck(t, "FMA#2", a, x, we, err); retEarly {
    63  			if err != nil {
    64  				return false
    65  			}
    66  			return true
    67  		}
    68  		return qcEqCheck(t, q.Dtype(), willFailEq, wi, f)
    69  	}
    70  	r := rand.New(rand.NewSource(time.Now().UnixNano()))
    71  	if err := quick.Check(same, &quick.Config{Rand: r}); err != nil {
    72  		t.Error(err)
    73  	}
    74  
    75  	// specific engines
    76  	var eng Engine
    77  
    78  	// FLOAT64 ENGINE
    79  
    80  	// vec-vec
    81  	eng = Float64Engine{}
    82  	a := New(WithBacking(Range(Float64, 0, 100)), WithEngine(eng))
    83  	x := New(WithBacking(Range(Float64, 1, 101)), WithEngine(eng))
    84  	y := New(Of(Float64), WithShape(100), WithEngine(eng))
    85  
    86  	f, err := FMA(a, x, y)
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  
    91  	a2 := New(WithBacking(Range(Float64, 0, 100)))
    92  	x2 := New(WithBacking(Range(Float64, 1, 101)))
    93  	y2 := New(Of(Float64), WithShape(100))
    94  	f2, err := Mul(a2, x2, WithIncr(y2))
    95  	if err != nil {
    96  		t.Fatal(err)
    97  	}
    98  
    99  	assert.Equal(t, f.Data(), f2.Data())
   100  
   101  	// vec-scalar
   102  	a = New(WithBacking(Range(Float64, 0, 100)), WithEngine(eng))
   103  	y = New(Of(Float64), WithShape(100))
   104  
   105  	if f, err = FMA(a, 2.0, y); err != nil {
   106  		t.Fatal(err)
   107  	}
   108  
   109  	a2 = New(WithBacking(Range(Float64, 0, 100)))
   110  	y2 = New(Of(Float64), WithShape(100))
   111  	if f2, err = Mul(a2, 2.0, WithIncr(y2)); err != nil {
   112  		t.Fatal(err)
   113  	}
   114  
   115  	assert.Equal(t, f.Data(), f2.Data())
   116  
   117  	// FLOAT32 engine
   118  	eng = Float32Engine{}
   119  	a = New(WithBacking(Range(Float32, 0, 100)), WithEngine(eng))
   120  	x = New(WithBacking(Range(Float32, 1, 101)), WithEngine(eng))
   121  	y = New(Of(Float32), WithShape(100), WithEngine(eng))
   122  
   123  	f, err = FMA(a, x, y)
   124  	if err != nil {
   125  		t.Fatal(err)
   126  	}
   127  
   128  	a2 = New(WithBacking(Range(Float32, 0, 100)))
   129  	x2 = New(WithBacking(Range(Float32, 1, 101)))
   130  	y2 = New(Of(Float32), WithShape(100))
   131  	f2, err = Mul(a2, x2, WithIncr(y2))
   132  	if err != nil {
   133  		t.Fatal(err)
   134  	}
   135  
   136  	assert.Equal(t, f.Data(), f2.Data())
   137  
   138  	// vec-scalar
   139  	a = New(WithBacking(Range(Float32, 0, 100)), WithEngine(eng))
   140  	y = New(Of(Float32), WithShape(100))
   141  
   142  	if f, err = FMA(a, float32(2), y); err != nil {
   143  		t.Fatal(err)
   144  	}
   145  
   146  	a2 = New(WithBacking(Range(Float32, 0, 100)))
   147  	y2 = New(Of(Float32), WithShape(100))
   148  	if f2, err = Mul(a2, float32(2), WithIncr(y2)); err != nil {
   149  		t.Fatal(err)
   150  	}
   151  
   152  	assert.Equal(t, f.Data(), f2.Data())
   153  
   154  }
   155  
   156  func TestMulScalarScalar(t *testing.T) {
   157  	// scalar-scalar
   158  	a := New(WithBacking([]float64{2}))
   159  	b := New(WithBacking([]float64{3}))
   160  	var correct interface{} = 6.0
   161  
   162  	res, err := Mul(a, b)
   163  	if err != nil {
   164  		t.Fatalf("Error: %v", err)
   165  	}
   166  	assert.Equal(t, correct, res.Data())
   167  
   168  	// Test commutativity
   169  	res, err = Mul(b, a)
   170  	if err != nil {
   171  		t.Fatalf("Error: %v", err)
   172  	}
   173  	assert.Equal(t, correct, res.Data())
   174  
   175  	// scalar-tensor
   176  	a = New(WithBacking([]float64{3, 2}))
   177  	b = New(WithBacking([]float64{2}))
   178  	correct = []float64{6, 4}
   179  
   180  	res, err = Mul(a, b)
   181  	if err != nil {
   182  		t.Fatalf("Error: %v", err)
   183  	}
   184  	assert.Equal(t, correct, res.Data())
   185  
   186  	// Test commutativity
   187  	res, err = Mul(b, a)
   188  	if err != nil {
   189  		t.Fatalf("Error: %v", err)
   190  	}
   191  	assert.Equal(t, correct, res.Data())
   192  
   193  	// tensor - tensor
   194  	a = New(WithBacking([]float64{3, 5}))
   195  	b = New(WithBacking([]float64{7, 2}))
   196  	correct = []float64{21, 10}
   197  
   198  	res, err = Mul(a, b)
   199  	if err != nil {
   200  		t.Fatalf("Error: %v", err)
   201  	}
   202  	assert.Equal(t, correct, res.Data())
   203  
   204  	// Test commutativity
   205  	res, err = Mul(b, a)
   206  	if err != nil {
   207  		t.Fatalf("Error: %v", err)
   208  	}
   209  	assert.Equal(t, correct, res.Data())
   210  
   211  	// Interface - tensor
   212  	ai := 2.0
   213  	b = NewDense(Float64, Shape{1, 1}, WithBacking([]float64{3}))
   214  	correct = []float64{6.0}
   215  
   216  	res, err = Mul(ai, b)
   217  	if err != nil {
   218  		t.Fatalf("Error: %v", err)
   219  	}
   220  	assert.Equal(t, correct, res.Data())
   221  
   222  	// Commutativity
   223  	res, err = Mul(b, ai)
   224  	if err != nil {
   225  		t.Fatalf("Error: %v", err)
   226  	}
   227  	assert.Equal(t, correct, res.Data())
   228  }
   229  
   230  func TestDivScalarScalar(t *testing.T) {
   231  	// scalar-scalar
   232  	a := New(WithBacking([]float64{6}))
   233  	b := New(WithBacking([]float64{2}))
   234  	var correct interface{} = 3.0
   235  
   236  	res, err := Div(a, b)
   237  	if err != nil {
   238  		t.Fatalf("Error: %v", err)
   239  	}
   240  	assert.Equal(t, correct, res.Data())
   241  
   242  	// scalar-tensor
   243  	a = New(WithBacking([]float64{6, 4}))
   244  	b = New(WithBacking([]float64{2}))
   245  	correct = []float64{3, 2}
   246  
   247  	res, err = Div(a, b)
   248  	if err != nil {
   249  		t.Fatalf("Error: %v", err)
   250  	}
   251  	assert.Equal(t, correct, res.Data())
   252  
   253  	// tensor-scalar
   254  	a = New(WithBacking([]float64{6}))
   255  	b = New(WithBacking([]float64{3, 2}))
   256  	correct = []float64{2, 3}
   257  
   258  	res, err = Div(a, b)
   259  	if err != nil {
   260  		t.Fatalf("Error: %v", err)
   261  	}
   262  	assert.Equal(t, correct, res.Data())
   263  
   264  	// tensor - tensor
   265  	a = New(WithBacking([]float64{21, 10}))
   266  	b = New(WithBacking([]float64{7, 2}))
   267  	correct = []float64{3, 5}
   268  
   269  	res, err = Div(a, b)
   270  	if err != nil {
   271  		t.Fatalf("Error: %v", err)
   272  	}
   273  	assert.Equal(t, correct, res.Data())
   274  
   275  	// interface-scalar
   276  	ai := 6.0
   277  	b = New(WithBacking([]float64{2}))
   278  	correct = 3.0
   279  
   280  	res, err = Div(ai, b)
   281  	if err != nil {
   282  		t.Fatalf("Error: %v", err)
   283  	}
   284  	assert.Equal(t, correct, res.Data())
   285  
   286  	// scalar-interface
   287  	a = New(WithBacking([]float64{6}))
   288  	bi := 2.0
   289  	correct = 3.0
   290  
   291  	res, err = Div(a, bi)
   292  	if err != nil {
   293  		t.Fatalf("Error: %v", err)
   294  	}
   295  	assert.Equal(t, correct, res.Data())
   296  }
   297  
   298  func TestAddScalarScalar(t *testing.T) {
   299  	// scalar-scalar
   300  	a := New(WithBacking([]float64{2}))
   301  	b := New(WithBacking([]float64{3}))
   302  	var correct interface{} = 5.0
   303  
   304  	res, err := Add(a, b)
   305  	if err != nil {
   306  		t.Fatalf("Error: %v", err)
   307  	}
   308  	assert.Equal(t, correct, res.Data())
   309  
   310  	// Test commutativity
   311  	res, err = Add(b, a)
   312  	if err != nil {
   313  		t.Fatalf("Error: %v", err)
   314  	}
   315  	assert.Equal(t, correct, res.Data())
   316  
   317  	// scalar-tensor
   318  	a = New(WithBacking([]float64{3, 2}))
   319  	b = New(WithBacking([]float64{2}))
   320  	correct = []float64{5, 4}
   321  
   322  	res, err = Add(a, b)
   323  	if err != nil {
   324  		t.Fatalf("Error: %v", err)
   325  	}
   326  	assert.Equal(t, correct, res.Data())
   327  
   328  	// Test commutativity
   329  	res, err = Add(b, a)
   330  	if err != nil {
   331  		t.Fatalf("Error: %v", err)
   332  	}
   333  	assert.Equal(t, correct, res.Data())
   334  
   335  	// tensor - tensor
   336  	a = New(WithBacking([]float64{3, 5}))
   337  	b = New(WithBacking([]float64{7, 2}))
   338  	correct = []float64{10, 7}
   339  
   340  	res, err = Add(a, b)
   341  	if err != nil {
   342  		t.Fatalf("Error: %v", err)
   343  	}
   344  	assert.Equal(t, correct, res.Data())
   345  
   346  	// Test commutativity
   347  	res, err = Add(b, a)
   348  	if err != nil {
   349  		t.Fatalf("Error: %v", err)
   350  	}
   351  	assert.Equal(t, correct, res.Data())
   352  
   353  	// interface-scalar
   354  	ai := 2.0
   355  	b = New(WithBacking([]float64{3}))
   356  	correct = 5.0
   357  
   358  	res, err = Add(ai, b)
   359  	if err != nil {
   360  		t.Fatalf("Error: %v", err)
   361  	}
   362  	assert.Equal(t, correct, res.Data())
   363  
   364  	// Test commutativity
   365  	res, err = Add(b, ai)
   366  	if err != nil {
   367  		t.Fatalf("Error: %v", err)
   368  	}
   369  	assert.Equal(t, correct, res.Data())
   370  }
   371  
   372  func TestSubScalarScalar(t *testing.T) {
   373  	// scalar-scalar
   374  	a := New(WithBacking([]float64{6}))
   375  	b := New(WithBacking([]float64{2}))
   376  	var correct interface{} = 4.0
   377  
   378  	res, err := Sub(a, b)
   379  	if err != nil {
   380  		t.Fatalf("Error: %v", err)
   381  	}
   382  	assert.Equal(t, correct, res.Data())
   383  
   384  	// scalar-tensor
   385  	a = New(WithBacking([]float64{6, 4}))
   386  	b = New(WithBacking([]float64{2}))
   387  	correct = []float64{4, 2}
   388  
   389  	res, err = Sub(a, b)
   390  	if err != nil {
   391  		t.Fatalf("Error: %v", err)
   392  	}
   393  	assert.Equal(t, correct, res.Data())
   394  
   395  	// tensor-scalar
   396  	a = New(WithBacking([]float64{6}))
   397  	b = New(WithBacking([]float64{3, 2}))
   398  	correct = []float64{3, 4}
   399  
   400  	res, err = Sub(a, b)
   401  	if err != nil {
   402  		t.Fatalf("Error: %v", err)
   403  	}
   404  	assert.Equal(t, correct, res.Data())
   405  
   406  	// tensor - tensor
   407  	a = New(WithBacking([]float64{21, 10}))
   408  	b = New(WithBacking([]float64{7, 2}))
   409  	correct = []float64{14, 8}
   410  
   411  	res, err = Sub(a, b)
   412  	if err != nil {
   413  		t.Fatalf("Error: %v", err)
   414  	}
   415  	assert.Equal(t, correct, res.Data())
   416  
   417  	// interface-scalar
   418  	ai := 6.0
   419  	b = New(WithBacking([]float64{2}))
   420  	correct = 4.0
   421  
   422  	res, err = Sub(ai, b)
   423  	if err != nil {
   424  		t.Fatalf("Error: %v", err)
   425  	}
   426  	assert.Equal(t, correct, res.Data())
   427  
   428  	// scalar-interface
   429  	a = New(WithBacking([]float64{6}))
   430  	bi := 2.0
   431  	correct = 4.0
   432  
   433  	res, err = Sub(a, bi)
   434  	if err != nil {
   435  		t.Fatalf("Error: %v", err)
   436  	}
   437  	assert.Equal(t, correct, res.Data())
   438  }
   439  
   440  func TestModScalarScalar(t *testing.T) {
   441  	// scalar-scalar
   442  	a := New(WithBacking([]float64{5}))
   443  	b := New(WithBacking([]float64{2}))
   444  	var correct interface{} = 1.0
   445  
   446  	res, err := Mod(a, b)
   447  	if err != nil {
   448  		t.Fatalf("Error: %v", err)
   449  	}
   450  	assert.Equal(t, correct, res.Data())
   451  
   452  	// scalar-tensor
   453  	a = New(WithBacking([]float64{5, 4}))
   454  	b = New(WithBacking([]float64{2}))
   455  	correct = []float64{1, 0}
   456  
   457  	res, err = Mod(a, b)
   458  	if err != nil {
   459  		t.Fatalf("Error: %v", err)
   460  	}
   461  	assert.Equal(t, correct, res.Data())
   462  
   463  	// tensor-scalar
   464  	a = New(WithBacking([]float64{5}))
   465  	b = New(WithBacking([]float64{3, 2}))
   466  	correct = []float64{2, 1}
   467  
   468  	res, err = Mod(a, b)
   469  	if err != nil {
   470  		t.Fatalf("Error: %v", err)
   471  	}
   472  	assert.Equal(t, correct, res.Data())
   473  
   474  	// tensor - tensor
   475  	a = New(WithBacking([]float64{22, 10}))
   476  	b = New(WithBacking([]float64{7, 2}))
   477  	correct = []float64{1, 0}
   478  
   479  	res, err = Mod(a, b)
   480  	if err != nil {
   481  		t.Fatalf("Error: %v", err)
   482  	}
   483  	assert.Equal(t, correct, res.Data())
   484  
   485  	// interface-scalar
   486  	ai := 5.0
   487  	b = New(WithBacking([]float64{2}))
   488  	correct = 1.0
   489  
   490  	res, err = Mod(ai, b)
   491  	if err != nil {
   492  		t.Fatalf("Error: %v", err)
   493  	}
   494  	assert.Equal(t, correct, res.Data())
   495  
   496  	// scalar-interface
   497  	a = New(WithBacking([]float64{5}))
   498  	bi := 2.0
   499  	correct = 1.0
   500  
   501  	res, err = Mod(a, bi)
   502  	if err != nil {
   503  		t.Fatalf("Error: %v", err)
   504  	}
   505  	assert.Equal(t, correct, res.Data())
   506  }
   507  
   508  func TestPowScalarScalar(t *testing.T) {
   509  	// scalar-scalar
   510  	a := New(WithBacking([]float64{6}))
   511  	b := New(WithBacking([]float64{2}))
   512  	var correct interface{} = 36.0
   513  
   514  	res, err := Pow(a, b)
   515  	if err != nil {
   516  		t.Fatalf("Error: %v", err)
   517  	}
   518  	assert.Equal(t, correct, res.Data())
   519  
   520  	// scalar-tensor
   521  	a = New(WithBacking([]float64{6, 4}))
   522  	b = New(WithBacking([]float64{2}))
   523  	correct = []float64{36, 16}
   524  
   525  	res, err = Pow(a, b)
   526  	if err != nil {
   527  		t.Fatalf("Error: %v", err)
   528  	}
   529  	assert.Equal(t, correct, res.Data())
   530  
   531  	// tensor-scalar
   532  	a = New(WithBacking([]float64{6}))
   533  	b = New(WithBacking([]float64{3, 2}))
   534  	correct = []float64{216, 36}
   535  
   536  	res, err = Pow(a, b)
   537  	if err != nil {
   538  		t.Fatalf("Error: %v", err)
   539  	}
   540  	assert.Equal(t, correct, res.Data())
   541  
   542  	// tensor - tensor
   543  	a = New(WithBacking([]float64{3, 10}))
   544  	b = New(WithBacking([]float64{7, 2}))
   545  	correct = []float64{2187, 100}
   546  
   547  	res, err = Pow(a, b)
   548  	if err != nil {
   549  		t.Fatalf("Error: %v", err)
   550  	}
   551  	assert.Equal(t, correct, res.Data())
   552  
   553  	// interface-scalar
   554  	ai := 6.0
   555  	b = New(WithBacking([]float64{2}))
   556  	correct = 36.0
   557  
   558  	res, err = Pow(ai, b)
   559  	if err != nil {
   560  		t.Fatalf("Error: %v", err)
   561  	}
   562  	assert.Equal(t, correct, res.Data())
   563  
   564  	// scalar-interface
   565  	a = New(WithBacking([]float64{6}))
   566  	bi := 2.0
   567  	correct = 36.0
   568  
   569  	res, err = Pow(a, bi)
   570  	if err != nil {
   571  		t.Fatalf("Error: %v", err)
   572  	}
   573  	assert.Equal(t, correct, res.Data())
   574  }