github.com/wzzhu/tensor@v0.9.24/dense_norms_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"testing"
     7  
     8  	"github.com/pkg/errors"
     9  	"github.com/stretchr/testify/assert"
    10  )
    11  
    12  func testNormVal(T *Dense, ord NormOrder, want float64) error {
    13  	retVal, err := T.Norm(ord)
    14  	if err != nil {
    15  		err = errors.Wrap(err, "testNormVal")
    16  		return err
    17  	}
    18  
    19  	if !retVal.IsScalar() {
    20  		return errors.New("Expected Scalar")
    21  	}
    22  
    23  	got := retVal.ScalarValue().(float64)
    24  	if !closef64(want, got) && !(math.IsNaN(want) && alikef64(want, got)) {
    25  		return errors.New(fmt.Sprintf("Norm %v, Backing %v: Want %f, got %f instead", ord, T.Data(), want, got))
    26  	}
    27  	return nil
    28  }
    29  
    30  func TestTensor_Norm(t *testing.T) {
    31  	var T *Dense
    32  	var err error
    33  	var backing, backing1, backing2 []float64
    34  	var corrects map[NormOrder]float64
    35  	var wrongs []NormOrder
    36  
    37  	// empty
    38  	backing = make([]float64, 0)
    39  	T = New(WithBacking(backing))
    40  	//TODO
    41  
    42  	// vecktor
    43  	backing = []float64{1, 2, 3, 4}
    44  	backing1 = []float64{-1, -2, -3, -4}
    45  	backing2 = []float64{-1, 2, -3, 4}
    46  
    47  	corrects = map[NormOrder]float64{
    48  		UnorderedNorm(): math.Pow(30, 0.5),               // Unordered
    49  		FrobeniusNorm(): math.NaN(),                      // Frobenius
    50  		NuclearNorm():   math.NaN(),                      // Nuclear
    51  		InfNorm():       4,                               // Inf
    52  		NegInfNorm():    1,                               // -Inf
    53  		Norm(0):         4,                               // 0
    54  		Norm(1):         10,                              // 1
    55  		Norm(-1):        12.0 / 25.0,                     // -1
    56  		Norm(2):         math.Pow(30, 0.5),               // 2
    57  		Norm(-2):        math.Pow((205.0 / 144.0), -0.5), // -2
    58  	}
    59  
    60  	backings := [][]float64{backing, backing1, backing2}
    61  	for ord, want := range corrects {
    62  		for _, b := range backings {
    63  			T = New(WithShape(len(backing)), WithBacking(b))
    64  			if err = testNormVal(T, ord, want); err != nil {
    65  				t.Error(errors.Cause(err))
    66  			}
    67  		}
    68  	}
    69  
    70  	// 2x2 mat
    71  	backing = []float64{1, 3, 5, 7}
    72  	corrects = map[NormOrder]float64{
    73  		UnorderedNorm(): math.Pow(84, 0.5),   // Unordered
    74  		FrobeniusNorm(): math.Pow(84, 0.5),   // Frobenius
    75  		NuclearNorm():   10,                  // Nuclear
    76  		InfNorm():       12,                  // Inf
    77  		NegInfNorm():    4,                   // -Inf
    78  		Norm(1):         10,                  // 1
    79  		Norm(-1):        6,                   // -1
    80  		Norm(2):         9.1231056256176615,  // 2
    81  		Norm(-2):        0.87689437438234041, // -2
    82  	}
    83  
    84  	T = New(WithShape(2, 2), WithBacking(backing))
    85  	for ord, want := range corrects {
    86  		if err = testNormVal(T, ord, want); err != nil {
    87  			t.Errorf("ORD %v: %v", ord, err)
    88  		}
    89  	}
    90  
    91  	// impossible values
    92  	wrongs = []NormOrder{
    93  		Norm(-3),
    94  		Norm(0),
    95  	}
    96  	for _, ord := range wrongs {
    97  		if err = testNormVal(T, ord, math.NaN()); err == nil {
    98  			t.Errorf("Expected an error when finding norm of order %v", ord)
    99  		}
   100  	}
   101  
   102  	// 3x3 mat
   103  	// this test is added because the 2x2 example happens to have equal nuclear norm and induced 1-norm.
   104  	// the 1/10 scaling factor accommodates the absolute tolerance used.
   105  	backing = []float64{0.1, 0.2, 0.3, 0.6, 0, 0.5, 0.3, 0.2, 0.1}
   106  	corrects = map[NormOrder]float64{
   107  		FrobeniusNorm(): (1.0 / 10.0) * math.Pow(89, 0.5),
   108  		NuclearNorm():   1.3366836911774836,
   109  		InfNorm():       1.1,
   110  		NegInfNorm():    0.6,
   111  		Norm(1):         1,
   112  		Norm(-1):        0.4,
   113  		Norm(2):         0.88722940323461277,
   114  		Norm(-2):        0.19456584790481812,
   115  	}
   116  
   117  	T = New(WithShape(3, 3), WithBacking(backing))
   118  	for ord, want := range corrects {
   119  		if err = testNormVal(T, ord, want); err != nil {
   120  			t.Error(err)
   121  		}
   122  	}
   123  }
   124  
   125  func TestTensor_Norm_Axis(t *testing.T) {
   126  	assert := assert.New(t)
   127  	var T, s, expected, retVal *Dense
   128  	var sliced Tensor
   129  	var err error
   130  	var backing []float64
   131  	var ords []NormOrder
   132  
   133  	t.Log("Vector Norm Tests: compare the use of axis with computing of each row or column separately")
   134  	ords = []NormOrder{
   135  		UnorderedNorm(),
   136  		InfNorm(),
   137  		NegInfNorm(),
   138  		Norm(-1),
   139  		Norm(0),
   140  		Norm(1),
   141  		Norm(2),
   142  		Norm(3),
   143  	}
   144  
   145  	backing = []float64{1, 2, 3, 4, 5, 6}
   146  	T = New(WithShape(2, 3), WithBacking(backing))
   147  
   148  	for _, ord := range ords {
   149  		var expecteds []*Dense
   150  		for k := 0; k < T.Shape()[1]; k++ {
   151  			sliced, _ = T.Slice(nil, ss(k))
   152  			s = sliced.(View).Materialize().(*Dense)
   153  			expected, _ = s.Norm(ord)
   154  			expecteds = append(expecteds, expected)
   155  		}
   156  
   157  		if retVal, err = T.Norm(ord, 0); err != nil {
   158  			t.Error(err)
   159  			continue
   160  		}
   161  
   162  		assert.Equal(len(expecteds), retVal.Shape()[0])
   163  		for i, e := range expecteds {
   164  			sliced, _ = retVal.Slice(ss(i))
   165  			sliced = sliced.(View).Materialize()
   166  			if !allClose(e.Data(), sliced.Data()) {
   167  				t.Errorf("Axis = 0; Ord = %v; Expected %v. Got %v instead. ret %v, i: %d", ord, e.Data(), sliced.Data(), retVal, i)
   168  			}
   169  		}
   170  
   171  		// reset and do axis = 1
   172  
   173  		expecteds = expecteds[:0]
   174  		for k := 0; k < T.Shape()[0]; k++ {
   175  			sliced, _ = T.Slice(ss(k))
   176  			s = sliced.(*Dense)
   177  			expected, _ = s.Norm(ord)
   178  			expecteds = append(expecteds, expected)
   179  		}
   180  		if retVal, err = T.Norm(ord, 1); err != nil {
   181  			t.Error(err)
   182  			continue
   183  		}
   184  
   185  		assert.Equal(len(expecteds), retVal.Shape()[0])
   186  		for i, e := range expecteds {
   187  			sliced, _ = retVal.Slice(ss(i))
   188  			sliced = sliced.(View).Materialize().(*Dense)
   189  			if !allClose(e.Data(), sliced.Data()) {
   190  				t.Errorf("Axis = 1; Ord = %v; Expected %v. Got %v instead", ord, e.Data(), sliced.Data())
   191  			}
   192  		}
   193  	}
   194  
   195  	t.Log("Matrix Norms")
   196  
   197  	ords = []NormOrder{
   198  		UnorderedNorm(),
   199  		FrobeniusNorm(),
   200  		InfNorm(),
   201  		NegInfNorm(),
   202  		Norm(-2),
   203  		Norm(-1),
   204  		Norm(1),
   205  		Norm(2),
   206  	}
   207  
   208  	axeses := [][]int{
   209  		{0, 0},
   210  		{0, 1},
   211  		{0, 2},
   212  		{1, 0},
   213  		{1, 1},
   214  		{1, 2},
   215  		{2, 0},
   216  		{2, 1},
   217  		{2, 2},
   218  	}
   219  
   220  	backing = Range(Float64, 1, 25).([]float64)
   221  	T = New(WithShape(2, 3, 4), WithBacking(backing))
   222  
   223  	dims := T.Dims()
   224  	for _, ord := range ords {
   225  		for _, axes := range axeses {
   226  			rowAxis := axes[0]
   227  			colAxis := axes[1]
   228  
   229  			if rowAxis < 0 {
   230  				rowAxis += dims
   231  			}
   232  			if colAxis < 0 {
   233  				colAxis += dims
   234  			}
   235  
   236  			if rowAxis == colAxis {
   237  
   238  			} else {
   239  				kthIndex := dims - (rowAxis + colAxis)
   240  				var expecteds []*Dense
   241  
   242  				for k := 0; k < T.Shape()[kthIndex]; k++ {
   243  					var slices []Slice
   244  					for s := 0; s < kthIndex; s++ {
   245  						slices = append(slices, nil)
   246  					}
   247  					slices = append(slices, ss(k))
   248  					sliced, _ = T.Slice(slices...)
   249  					if rowAxis > colAxis {
   250  						sliced.T()
   251  					}
   252  					sliced = sliced.(View).Materialize().(*Dense)
   253  					s = sliced.(*Dense)
   254  					expected, _ = s.Norm(ord)
   255  					expecteds = append(expecteds, expected)
   256  				}
   257  
   258  				if retVal, err = T.Norm(ord, rowAxis, colAxis); err != nil {
   259  					t.Error(err)
   260  					continue
   261  				}
   262  
   263  				for i, e := range expecteds {
   264  					sliced, _ = retVal.Slice(ss(i))
   265  					assert.Equal(e.Data(), sliced.Data(), "ord %v, rowAxis: %v, colAxis %v", ord, rowAxis, colAxis)
   266  				}
   267  			}
   268  		}
   269  	}
   270  
   271  }