github.com/wzzhu/tensor@v0.9.24/dense_format_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  )
     9  
    10  func TestDense_Format(t *testing.T) {
    11  	// if os.Getenv("TRAVISTEST") == "true" {
    12  	// 	t.Skip("skipping format test; This is being run on TravisCI")
    13  	// }
    14  
    15  	assert := assert.New(t)
    16  	var T *Dense
    17  	var res, expected string
    18  
    19  	// Scalar
    20  	T = New(Of(Float64), FromScalar(3.14))
    21  	res = fmt.Sprintf("%3.3f", T)
    22  	assert.Equal("3.140", res)
    23  
    24  	// Scalar-equiv (vector)
    25  	T = New(WithBacking([]float64{3.14}), WithShape(1))
    26  	res = fmt.Sprintf("%3.3f", T)
    27  	assert.Equal("[3.140]", res)
    28  
    29  	// Scalar-equiv (n-dimensional)
    30  	T = New(WithBacking([]float64{3.14}), WithShape(1, 1, 1, 1))
    31  	res = fmt.Sprintf("%3.3f", T)
    32  	assert.Equal("[[[[3.140]]]]", res)
    33  
    34  	// short vector
    35  	T = New(Of(Float64), WithShape(4))
    36  	res = fmt.Sprintf("%v", T)
    37  	expected = "[0  0  0  0]"
    38  	assert.Equal(expected, res)
    39  	T = New(WithShape(2, 2), WithBacking([]float64{3.141515163242, 20, 5.15, 6.28}))
    40  
    41  	res = fmt.Sprintf("\n%v", T)
    42  	expected = `
    43  ⎡3.141515163242              20⎤
    44  ⎣          5.15            6.28⎦
    45  `
    46  	assert.Equal(expected, res, res)
    47  
    48  	// precision
    49  	res = fmt.Sprintf("\n%0.2v", T)
    50  	expected = `
    51  ⎡3.1   20⎤
    52  ⎣5.2  6.3⎦
    53  `
    54  	assert.Equal(expected, res, res)
    55  
    56  	// with metadata
    57  	res = fmt.Sprintf("\n%+0.2v", T)
    58  	expected = `
    59  Matrix (2, 2) [2 1]
    60  ⎡3.1   20⎤
    61  ⎣5.2  6.3⎦
    62  `
    63  	assert.Equal(expected, res, res)
    64  
    65  	// many columns
    66  	T = New(WithShape(16, 14), WithBacking(Range(Float32, 0, 16*14)))
    67  	res = fmt.Sprintf("\n%v", T)
    68  	expected = `
    69  ⎡  0    1    2    3  ...  10   11   12   13⎤
    70  ⎢ 14   15   16   17  ...  24   25   26   27⎥
    71  ⎢ 28   29   30   31  ...  38   39   40   41⎥
    72  ⎢ 42   43   44   45  ...  52   53   54   55⎥
    73  .
    74  .
    75  .
    76  ⎢168  169  170  171  ... 178  179  180  181⎥
    77  ⎢182  183  184  185  ... 192  193  194  195⎥
    78  ⎢196  197  198  199  ... 206  207  208  209⎥
    79  ⎣210  211  212  213  ... 220  221  222  223⎦
    80  `
    81  	assert.Equal(expected, res, "expected %v. Got %v", expected, res)
    82  
    83  	// many cols, rows, compressed
    84  	T = New(WithShape(16, 14), WithBacking(Range(Float64, 0, 16*14)))
    85  	res = fmt.Sprintf("\n%s", T)
    86  	// this clunky string addition thing is because some editors like to trim whitespace.
    87  	// There should be two spaces after `  ⋮` .
    88  	expected = `
    89  ⎡  0    1  ⋯  12   13⎤
    90  ⎢ 14   15  ⋯  26   27⎥
    91  ` + `  ⋮  ` + `
    92  ` + `⎢196  197  ⋯ 208  209⎥
    93  ⎣210  211  ⋯ 222  223⎦
    94  `
    95  	assert.Equal(expected, res, "expected %v. Got %v", expected, res)
    96  
    97  	// many cols, full
    98  	T = New(WithShape(8, 9), WithBacking(Range(Float64, 0, 8*9)))
    99  	res = fmt.Sprintf("\n%#v", T)
   100  	expected = `
   101  ⎡ 0   1   2   3   4   5   6   7   8⎤
   102  ⎢ 9  10  11  12  13  14  15  16  17⎥
   103  ⎢18  19  20  21  22  23  24  25  26⎥
   104  ⎢27  28  29  30  31  32  33  34  35⎥
   105  ⎢36  37  38  39  40  41  42  43  44⎥
   106  ⎢45  46  47  48  49  50  51  52  53⎥
   107  ⎢54  55  56  57  58  59  60  61  62⎥
   108  ⎣63  64  65  66  67  68  69  70  71⎦
   109  `
   110  	assert.Equal(expected, res, res)
   111  
   112  	// vectors
   113  	T = New(Of(Int), WithShape(3, 1))
   114  	res = fmt.Sprintf("%v", T)
   115  	expected = `C[0  0  0]`
   116  	assert.Equal(expected, res)
   117  
   118  	T = New(Of(Int32), WithShape(1, 3))
   119  	res = fmt.Sprintf("%v", T)
   120  	expected = `R[0  0  0]`
   121  	assert.Equal(expected, res)
   122  
   123  	// 3+ Dimensional Tensors - super janky for now
   124  	T = New(WithShape(2, 3, 2), WithBacking(Range(Float64, 0, 2*3*2)))
   125  	res = fmt.Sprintf("\n%v", T)
   126  	expected = `
   127  ⎡ 0   1⎤
   128  ⎢ 2   3⎥
   129  ⎣ 4   5⎦
   130  
   131  ⎡ 6   7⎤
   132  ⎢ 8   9⎥
   133  ⎣10  11⎦
   134  
   135  `
   136  
   137  	assert.Equal(expected, res, res)
   138  
   139  	// checking metadata + compression
   140  	res = fmt.Sprintf("\n%+s", T)
   141  	expected = `
   142  Tensor-3 (2, 3, 2) [6 2 1]
   143  ⎡ 0   1⎤
   144  ⎢ 2   3⎥
   145  ⎣ 4   5⎦
   146  
   147  ⎡ 6   7⎤
   148  ⎢ 8   9⎥
   149  ⎣10  11⎦
   150  
   151  `
   152  	assert.Equal(expected, res, res)
   153  
   154  	// check flat + compress
   155  	res = fmt.Sprintf("%-s", T)
   156  	expected = `[0 1 2 3 4 ⋯ ]`
   157  	assert.Equal(expected, res, res)
   158  
   159  	// check flat
   160  	res = fmt.Sprintf("%-3.3f", T)
   161  	expected = `[0.000 1.000 2.000 3.000 4.000 5.000 6.000 7.000 8.000 9.000 ... ]`
   162  	assert.Equal(expected, res, res)
   163  
   164  	// check flat + extended
   165  	res = fmt.Sprintf("%-#v", T)
   166  	expected = `[0 1 2 3 4 5 6 7 8 9 10 11]`
   167  	assert.Equal(expected, res, res)
   168  
   169  	/* Test Views and Sliced Tensors */
   170  
   171  	var V Tensor
   172  	var err error
   173  
   174  	V, err = T.Slice(makeRS(1, 2))
   175  	if err != nil {
   176  		t.Error(err)
   177  	}
   178  
   179  	// flat mode for view
   180  	res = fmt.Sprintf("\n%-s", V)
   181  	expected = "\n[6 7 8 9 10 ⋯ ]"
   182  	assert.Equal(expected, res, res)
   183  
   184  	// standard
   185  	res = fmt.Sprintf("\n%+s", V)
   186  	expected = `
   187  Matrix (3, 2) [2 1]
   188  ⎡ 6   7⎤
   189  ⎢ 8   9⎥
   190  ⎣10  11⎦
   191  `
   192  	assert.Equal(expected, res, res)
   193  
   194  	// T[:, 1]
   195  	V, err = T.Slice(nil, ss(1))
   196  	res = fmt.Sprintf("\n%+s", V)
   197  	expected = `
   198  Matrix (2, 2) [6 1]
   199  ⎡2  3⎤
   200  ⎣8  9⎦
   201  `
   202  	assert.Equal(expected, res, res)
   203  
   204  	// transpose a view
   205  	V.T()
   206  	expected = `
   207  Matrix (2, 2) [1 6]
   208  ⎡2  8⎤
   209  ⎣3  9⎦
   210  `
   211  
   212  	res = fmt.Sprintf("\n%+s", V)
   213  	assert.Equal(expected, res, res)
   214  
   215  	// T[1, :, 1]
   216  	V, err = T.Slice(ss(1), nil, ss(1))
   217  	if err != nil {
   218  		t.Error(err)
   219  	}
   220  	expected = `Vector (3) [2]
   221  [7881299347898368p-50  5066549580791808p-49  6192449487634432p-49]`
   222  	res = fmt.Sprintf("%+b", V)
   223  	assert.Equal(expected, res)
   224  
   225  	// T[1, 1, 1] - will result in a scalar
   226  	V, err = T.Slice(ss(1), ss(1), ss(1))
   227  	if err != nil {
   228  		t.Error(err)
   229  	}
   230  	res = fmt.Sprintf("%#3.3E", V)
   231  	expected = `9.000E+00`
   232  	assert.Equal(expected, res)
   233  
   234  	// on regular matrices
   235  	T = New(WithShape(3, 5), WithBacking(Range(Float64, 0, 3*5)))
   236  	V, err = T.Slice(ss(1))
   237  	if err != nil {
   238  		t.Error(err)
   239  	}
   240  	expected = `[5  6  7  8  9]`
   241  	res = fmt.Sprintf("%v", V)
   242  	assert.Equal(expected, res)
   243  }
   244  
   245  var basicFmtTests = []struct {
   246  	a      interface{}
   247  	format string
   248  
   249  	correct string
   250  }{
   251  	{Range(Float64, 0, 4), "%1.1f", "[0.0  1.0  2.0  3.0]"},
   252  	{Range(Float32, 0, 4), "%1.1f", "[0.0  1.0  2.0  3.0]"},
   253  	{Range(Int, 0, 4), "%b", "[ 0   1  10  11]"},
   254  	{Range(Int, 0, 4), "%d", "[0  1  2  3]"},
   255  	{Range(Int, 6, 10), "%o", "[ 6   7  10  11]"},
   256  	{Range(Int, 14, 18), "%x", "[ e   f  10  11]"},
   257  	{Range(Int, 0, 4), "%f", "[0  1  2  3]"},
   258  
   259  	{Range(Int32, 0, 4), "%b", "[ 0   1  10  11]"},
   260  	{Range(Int32, 0, 4), "%d", "[0  1  2  3]"},
   261  	{Range(Int32, 6, 10), "%o", "[ 6   7  10  11]"},
   262  	{Range(Int32, 14, 18), "%x", "[ e   f  10  11]"},
   263  	{Range(Int32, 0, 4), "%f", "[0  1  2  3]"},
   264  
   265  	{Range(Int64, 0, 4), "%b", "[ 0   1  10  11]"},
   266  	{Range(Int64, 0, 4), "%d", "[0  1  2  3]"},
   267  	{Range(Int64, 6, 10), "%o", "[ 6   7  10  11]"},
   268  	{Range(Int64, 14, 18), "%x", "[ e   f  10  11]"},
   269  	{Range(Int64, 0, 4), "%f", "[0  1  2  3]"},
   270  
   271  	{Range(Byte, 0, 4), "%b", "[ 0   1  10  11]"},
   272  	{Range(Byte, 0, 4), "%d", "[0  1  2  3]"},
   273  	{Range(Byte, 6, 10), "%o", "[ 6   7  10  11]"},
   274  	{Range(Byte, 14, 18), "%x", "[ e   f  10  11]"},
   275  	{Range(Byte, 0, 4), "%f", "[0  1  2  3]"},
   276  
   277  	{[]bool{true, false, true, false}, "%f", "[ true  false   true  false]"},
   278  	{[]bool{true, false, true, false}, "%s", "[ true  false   true  false]"},
   279  }
   280  
   281  func TestDense_Format_basics(t *testing.T) {
   282  	for _, v := range basicFmtTests {
   283  		T := New(WithBacking(v.a))
   284  		s := fmt.Sprintf(v.format, T)
   285  
   286  		if s != v.correct {
   287  			t.Errorf("Expected %q. Got %q", v.correct, s)
   288  		}
   289  	}
   290  }
   291  
   292  func TestDense_Format_Masked(t *testing.T) {
   293  	assert := assert.New(t)
   294  	T := New(Of(Int), WithShape(1, 12))
   295  	data := T.Ints()
   296  	for i := 0; i < len(data); i++ {
   297  		data[i] = i
   298  	}
   299  	T.ResetMask(false)
   300  	for i := 0; i < 12; i += 2 {
   301  		T.mask[i] = true
   302  	}
   303  
   304  	s := fmt.Sprintf("%d", T)
   305  	assert.Equal(`R[--   1  --   3  ... --   9  --  11]`, s)
   306  
   307  	T = New(Of(Int), WithShape(2, 4, 16))
   308  	data = T.Ints()
   309  	for i := 0; i < len(data); i++ {
   310  		data[i] = i
   311  	}
   312  	T.ResetMask(false)
   313  	for i := 0; i < len(data); i += 2 {
   314  		T.mask[i] = true
   315  	}
   316  
   317  	s = fmt.Sprintf("%d", T)
   318  	assert.Equal(`⎡ --    1   --    3  ...  --   13   --   15⎤
   319  ⎢ --   17   --   19  ...  --   29   --   31⎥
   320  ⎢ --   33   --   35  ...  --   45   --   47⎥
   321  ⎣ --   49   --   51  ...  --   61   --   63⎦
   322  
   323  ⎡ --   65   --   67  ...  --   77   --   79⎤
   324  ⎢ --   81   --   83  ...  --   93   --   95⎥
   325  ⎢ --   97   --   99  ...  --  109   --  111⎥
   326  ⎣ --  113   --  115  ...  --  125   --  127⎦
   327  
   328  `, s)
   329  
   330  }