github.com/wzzhu/tensor@v0.9.24/dense_io_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/gob"
     6  	"io/ioutil"
     7  	"os"
     8  	"os/exec"
     9  	"regexp"
    10  	"testing"
    11  
    12  	"github.com/stretchr/testify/assert"
    13  )
    14  
    15  func TestSaveLoadNumpy(t *testing.T) {
    16  	if os.Getenv("CI_NO_PYTHON") == "true" {
    17  		t.Skip("skipping test; This is being run on a CI tool that does not have Python")
    18  	}
    19  
    20  	assert := assert.New(t)
    21  	T := New(WithShape(2, 2), WithBacking([]float64{1, 5, 10, -1}))
    22  	// also checks the 1D Vector.
    23  	T1D := New(WithShape(4), WithBacking([]float64{1, 5, 10, -1}))
    24  
    25  	f, _ := os.OpenFile("test.npy", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644)
    26  	f1D, _ := os.OpenFile("test1D.npy", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644)
    27  
    28  	T.WriteNpy(f)
    29  	f.Close()
    30  
    31  	T1D.WriteNpy(f1D)
    32  	f1D.Close()
    33  
    34  	defer func() {
    35  		// cleanup
    36  		err := os.Remove("test.npy")
    37  		if err != nil {
    38  			t.Error(err)
    39  		}
    40  
    41  		err = os.Remove("test1D.npy")
    42  		if err != nil {
    43  			t.Error(err)
    44  		}
    45  	}()
    46  
    47  	script := "import numpy as np\nx = np.load('test.npy')\nprint(x)\nx = np.load('test1D.npy')\nprint(x)"
    48  	// Configurable python command, in order to be able to use python or python3
    49  	pythonCommand := os.Getenv("PYTHON_COMMAND")
    50  	if pythonCommand == "" {
    51  		pythonCommand = "python"
    52  	}
    53  
    54  	cmd := exec.Command(pythonCommand)
    55  	stdin, err := cmd.StdinPipe()
    56  	if err != nil {
    57  		t.Error(err)
    58  	}
    59  	stderr, err := cmd.StderrPipe()
    60  	if err != nil {
    61  		t.Error(err)
    62  	}
    63  
    64  	go func() {
    65  		defer stdin.Close()
    66  		stdin.Write([]byte(script))
    67  	}()
    68  
    69  	buf := new(bytes.Buffer)
    70  	cmd.Stdout = buf
    71  
    72  	if err = cmd.Start(); err != nil {
    73  		t.Error(err)
    74  		t.Logf("Do you have a python with numpy installed? You can change the python interpreter by setting the environment variable PYTHON_COMMAND. Current value: PYTHON_COMMAND=%s", pythonCommand)
    75  	}
    76  
    77  	importError := `ImportError: No module named numpy`
    78  	slurpErr, _ := ioutil.ReadAll(stderr)
    79  	if ok, _ := regexp.Match(importError, slurpErr); ok {
    80  		t.Skipf("Skipping numpy test. It would appear that you do not have Numpy installed.")
    81  	}
    82  
    83  	if err := cmd.Wait(); err != nil {
    84  		t.Errorf("%q", err.Error())
    85  	}
    86  
    87  	expected := `\[\[\s*1\.\s*5\.\]\n \[\s*10\.\s*-1\.\]\]\n`
    88  	if ok, _ := regexp.Match(expected, buf.Bytes()); !ok {
    89  		t.Errorf("Did not successfully read numpy file, \n%q\n%q", buf.String(), expected)
    90  	}
    91  
    92  	// ok now to test if it can read
    93  	T2 := new(Dense)
    94  	buf = new(bytes.Buffer)
    95  	T.WriteNpy(buf)
    96  	if err = T2.ReadNpy(buf); err != nil {
    97  		t.Fatal(err)
    98  	}
    99  	assert.Equal(T.Shape(), T2.Shape())
   100  	assert.Equal(T.Strides(), T2.Strides())
   101  	assert.Equal(T.Data(), T2.Data())
   102  
   103  	// ok now to test if it can read 1D
   104  	T1D2 := new(Dense)
   105  	buf = new(bytes.Buffer)
   106  	T1D.WriteNpy(buf)
   107  	if err = T1D2.ReadNpy(buf); err != nil {
   108  		t.Fatal(err)
   109  	}
   110  	assert.Equal(T1D.Shape(), T1D2.Shape())
   111  	assert.Equal(T1D.Strides(), T1D2.Strides())
   112  	assert.Equal(T1D.Data(), T1D2.Data())
   113  
   114  	// try with masked array. masked elements should be filled with default value
   115  	T.ResetMask(false)
   116  	T.mask[0] = true
   117  	T3 := new(Dense)
   118  	buf = new(bytes.Buffer)
   119  	T.WriteNpy(buf)
   120  	if err = T3.ReadNpy(buf); err != nil {
   121  		t.Fatal(err)
   122  	}
   123  	assert.Equal(T.Shape(), T3.Shape())
   124  	assert.Equal(T.Strides(), T3.Strides())
   125  	data := T.Float64s()
   126  	data[0] = T.FillValue().(float64)
   127  	assert.Equal(data, T3.Data())
   128  
   129  	// try with 1D masked array. masked elements should be filled with default value
   130  	T1D.ResetMask(false)
   131  	T1D.mask[0] = true
   132  	T1D3 := new(Dense)
   133  	buf = new(bytes.Buffer)
   134  	T1D.WriteNpy(buf)
   135  	if err = T1D3.ReadNpy(buf); err != nil {
   136  		t.Fatal(err)
   137  	}
   138  	assert.Equal(T1D.Shape(), T1D3.Shape())
   139  	assert.Equal(T1D.Strides(), T1D3.Strides())
   140  	data = T1D.Float64s()
   141  	data[0] = T1D.FillValue().(float64)
   142  	assert.Equal(data, T1D3.Data())
   143  }
   144  
   145  func TestSaveLoadCSV(t *testing.T) {
   146  	assert := assert.New(t)
   147  	for _, gtd := range serializationTestData {
   148  		if _, ok := gtd.([]complex64); ok {
   149  			continue
   150  		}
   151  		if _, ok := gtd.([]complex128); ok {
   152  			continue
   153  		}
   154  
   155  		buf := new(bytes.Buffer)
   156  
   157  		T := New(WithShape(2, 2), WithBacking(gtd))
   158  		if err := T.WriteCSV(buf); err != nil {
   159  			t.Error(err)
   160  		}
   161  
   162  		T2 := new(Dense)
   163  		if err := T2.ReadCSV(buf, As(T.t)); err != nil {
   164  			t.Error(err)
   165  		}
   166  
   167  		assert.Equal(T.Shape(), T2.Shape(), "Test: %v", gtd)
   168  		assert.Equal(T.Data(), T2.Data())
   169  
   170  	}
   171  
   172  	T := New(WithShape(2, 2), WithBacking([]float64{1, 5, 10, -1}))
   173  	f, _ := os.OpenFile("test.csv", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644)
   174  	T.WriteCSV(f)
   175  	f.Close()
   176  
   177  	// cleanup
   178  	err := os.Remove("test.csv")
   179  	if err != nil {
   180  		t.Error(err)
   181  	}
   182  
   183  	// try with masked array. masked elements should be filled with default value
   184  	T.ResetMask(false)
   185  	T.mask[0] = true
   186  	f, _ = os.OpenFile("test.csv", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644)
   187  	T.WriteCSV(f)
   188  	f.Close()
   189  
   190  	// cleanup again
   191  	err = os.Remove("test.csv")
   192  	if err != nil {
   193  		t.Error(err)
   194  	}
   195  
   196  }
   197  
   198  var serializationTestData = []interface{}{
   199  	[]int{1, 5, 10, -1},
   200  	[]int8{1, 5, 10, -1},
   201  	[]int16{1, 5, 10, -1},
   202  	[]int32{1, 5, 10, -1},
   203  	[]int64{1, 5, 10, -1},
   204  	[]uint{1, 5, 10, 255},
   205  	[]uint8{1, 5, 10, 255},
   206  	[]uint16{1, 5, 10, 255},
   207  	[]uint32{1, 5, 10, 255},
   208  	[]uint64{1, 5, 10, 255},
   209  	[]float32{1, 5, 10, -1},
   210  	[]float64{1, 5, 10, -1},
   211  	[]complex64{1, 5, 10, -1},
   212  	[]complex128{1, 5, 10, -1},
   213  	[]string{"hello", "world", "hello", "世界"},
   214  }
   215  
   216  func TestDense_GobEncodeDecode(t *testing.T) {
   217  	assert := assert.New(t)
   218  	var err error
   219  	for _, gtd := range serializationTestData {
   220  		buf := new(bytes.Buffer)
   221  		encoder := gob.NewEncoder(buf)
   222  		decoder := gob.NewDecoder(buf)
   223  
   224  		T := New(WithShape(2, 2), WithBacking(gtd))
   225  		if err = encoder.Encode(T); err != nil {
   226  			t.Errorf("Error while encoding %v: %v", gtd, err)
   227  			continue
   228  		}
   229  
   230  		T2 := new(Dense)
   231  		if err = decoder.Decode(T2); err != nil {
   232  			t.Errorf("Error while decoding %v: %v", gtd, err)
   233  			continue
   234  		}
   235  
   236  		assert.Equal(T.Shape(), T2.Shape())
   237  		assert.Equal(T.Strides(), T2.Strides())
   238  		assert.Equal(T.Data(), T2.Data())
   239  
   240  		// try with masked array. masked elements should be filled with default value
   241  		buf = new(bytes.Buffer)
   242  		encoder = gob.NewEncoder(buf)
   243  		decoder = gob.NewDecoder(buf)
   244  
   245  		T.ResetMask(false)
   246  		T.mask[0] = true
   247  		assert.True(T.IsMasked())
   248  		if err = encoder.Encode(T); err != nil {
   249  			t.Errorf("Error while encoding %v: %v", gtd, err)
   250  			continue
   251  		}
   252  
   253  		T3 := new(Dense)
   254  		if err = decoder.Decode(T3); err != nil {
   255  			t.Errorf("Error while decoding %v: %v", gtd, err)
   256  			continue
   257  		}
   258  
   259  		assert.Equal(T.Shape(), T3.Shape())
   260  		assert.Equal(T.Strides(), T3.Strides())
   261  		assert.Equal(T.Data(), T3.Data())
   262  		assert.Equal(T.mask, T3.mask)
   263  
   264  	}
   265  }
   266  
   267  func TestDense_FBEncodeDecode(t *testing.T) {
   268  	assert := assert.New(t)
   269  	for _, gtd := range serializationTestData {
   270  		T := New(WithShape(2, 2), WithBacking(gtd))
   271  		buf, err := T.FBEncode()
   272  		if err != nil {
   273  			t.Errorf("UNPOSSIBLE!: %v", err)
   274  			continue
   275  		}
   276  
   277  		T2 := new(Dense)
   278  		if err = T2.FBDecode(buf); err != nil {
   279  			t.Errorf("Error while decoding %v: %v", gtd, err)
   280  			continue
   281  		}
   282  
   283  		assert.Equal(T.Shape(), T2.Shape())
   284  		assert.Equal(T.Strides(), T2.Strides())
   285  		assert.Equal(T.Data(), T2.Data())
   286  
   287  		// TODO: MASKED ARRAY
   288  	}
   289  }
   290  
   291  func TestDense_PBEncodeDecode(t *testing.T) {
   292  	assert := assert.New(t)
   293  	for _, gtd := range serializationTestData {
   294  		T := New(WithShape(2, 2), WithBacking(gtd))
   295  		buf, err := T.PBEncode()
   296  		if err != nil {
   297  			t.Errorf("UNPOSSIBLE!: %v", err)
   298  			continue
   299  		}
   300  
   301  		T2 := new(Dense)
   302  		if err = T2.PBDecode(buf); err != nil {
   303  			t.Errorf("Error while decoding %v: %v", gtd, err)
   304  			continue
   305  		}
   306  
   307  		assert.Equal(T.Shape(), T2.Shape())
   308  		assert.Equal(T.Strides(), T2.Strides())
   309  		assert.Equal(T.Data(), T2.Data())
   310  
   311  		// TODO: MASKED ARRAY
   312  	}
   313  }