github.com/wzzhu/tensor@v0.9.24/dense_io_test.go (about) 1 package tensor 2 3 import ( 4 "bytes" 5 "encoding/gob" 6 "io/ioutil" 7 "os" 8 "os/exec" 9 "regexp" 10 "testing" 11 12 "github.com/stretchr/testify/assert" 13 ) 14 15 func TestSaveLoadNumpy(t *testing.T) { 16 if os.Getenv("CI_NO_PYTHON") == "true" { 17 t.Skip("skipping test; This is being run on a CI tool that does not have Python") 18 } 19 20 assert := assert.New(t) 21 T := New(WithShape(2, 2), WithBacking([]float64{1, 5, 10, -1})) 22 // also checks the 1D Vector. 23 T1D := New(WithShape(4), WithBacking([]float64{1, 5, 10, -1})) 24 25 f, _ := os.OpenFile("test.npy", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644) 26 f1D, _ := os.OpenFile("test1D.npy", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644) 27 28 T.WriteNpy(f) 29 f.Close() 30 31 T1D.WriteNpy(f1D) 32 f1D.Close() 33 34 defer func() { 35 // cleanup 36 err := os.Remove("test.npy") 37 if err != nil { 38 t.Error(err) 39 } 40 41 err = os.Remove("test1D.npy") 42 if err != nil { 43 t.Error(err) 44 } 45 }() 46 47 script := "import numpy as np\nx = np.load('test.npy')\nprint(x)\nx = np.load('test1D.npy')\nprint(x)" 48 // Configurable python command, in order to be able to use python or python3 49 pythonCommand := os.Getenv("PYTHON_COMMAND") 50 if pythonCommand == "" { 51 pythonCommand = "python" 52 } 53 54 cmd := exec.Command(pythonCommand) 55 stdin, err := cmd.StdinPipe() 56 if err != nil { 57 t.Error(err) 58 } 59 stderr, err := cmd.StderrPipe() 60 if err != nil { 61 t.Error(err) 62 } 63 64 go func() { 65 defer stdin.Close() 66 stdin.Write([]byte(script)) 67 }() 68 69 buf := new(bytes.Buffer) 70 cmd.Stdout = buf 71 72 if err = cmd.Start(); err != nil { 73 t.Error(err) 74 t.Logf("Do you have a python with numpy installed? You can change the python interpreter by setting the environment variable PYTHON_COMMAND. Current value: PYTHON_COMMAND=%s", pythonCommand) 75 } 76 77 importError := `ImportError: No module named numpy` 78 slurpErr, _ := ioutil.ReadAll(stderr) 79 if ok, _ := regexp.Match(importError, slurpErr); ok { 80 t.Skipf("Skipping numpy test. It would appear that you do not have Numpy installed.") 81 } 82 83 if err := cmd.Wait(); err != nil { 84 t.Errorf("%q", err.Error()) 85 } 86 87 expected := `\[\[\s*1\.\s*5\.\]\n \[\s*10\.\s*-1\.\]\]\n` 88 if ok, _ := regexp.Match(expected, buf.Bytes()); !ok { 89 t.Errorf("Did not successfully read numpy file, \n%q\n%q", buf.String(), expected) 90 } 91 92 // ok now to test if it can read 93 T2 := new(Dense) 94 buf = new(bytes.Buffer) 95 T.WriteNpy(buf) 96 if err = T2.ReadNpy(buf); err != nil { 97 t.Fatal(err) 98 } 99 assert.Equal(T.Shape(), T2.Shape()) 100 assert.Equal(T.Strides(), T2.Strides()) 101 assert.Equal(T.Data(), T2.Data()) 102 103 // ok now to test if it can read 1D 104 T1D2 := new(Dense) 105 buf = new(bytes.Buffer) 106 T1D.WriteNpy(buf) 107 if err = T1D2.ReadNpy(buf); err != nil { 108 t.Fatal(err) 109 } 110 assert.Equal(T1D.Shape(), T1D2.Shape()) 111 assert.Equal(T1D.Strides(), T1D2.Strides()) 112 assert.Equal(T1D.Data(), T1D2.Data()) 113 114 // try with masked array. masked elements should be filled with default value 115 T.ResetMask(false) 116 T.mask[0] = true 117 T3 := new(Dense) 118 buf = new(bytes.Buffer) 119 T.WriteNpy(buf) 120 if err = T3.ReadNpy(buf); err != nil { 121 t.Fatal(err) 122 } 123 assert.Equal(T.Shape(), T3.Shape()) 124 assert.Equal(T.Strides(), T3.Strides()) 125 data := T.Float64s() 126 data[0] = T.FillValue().(float64) 127 assert.Equal(data, T3.Data()) 128 129 // try with 1D masked array. masked elements should be filled with default value 130 T1D.ResetMask(false) 131 T1D.mask[0] = true 132 T1D3 := new(Dense) 133 buf = new(bytes.Buffer) 134 T1D.WriteNpy(buf) 135 if err = T1D3.ReadNpy(buf); err != nil { 136 t.Fatal(err) 137 } 138 assert.Equal(T1D.Shape(), T1D3.Shape()) 139 assert.Equal(T1D.Strides(), T1D3.Strides()) 140 data = T1D.Float64s() 141 data[0] = T1D.FillValue().(float64) 142 assert.Equal(data, T1D3.Data()) 143 } 144 145 func TestSaveLoadCSV(t *testing.T) { 146 assert := assert.New(t) 147 for _, gtd := range serializationTestData { 148 if _, ok := gtd.([]complex64); ok { 149 continue 150 } 151 if _, ok := gtd.([]complex128); ok { 152 continue 153 } 154 155 buf := new(bytes.Buffer) 156 157 T := New(WithShape(2, 2), WithBacking(gtd)) 158 if err := T.WriteCSV(buf); err != nil { 159 t.Error(err) 160 } 161 162 T2 := new(Dense) 163 if err := T2.ReadCSV(buf, As(T.t)); err != nil { 164 t.Error(err) 165 } 166 167 assert.Equal(T.Shape(), T2.Shape(), "Test: %v", gtd) 168 assert.Equal(T.Data(), T2.Data()) 169 170 } 171 172 T := New(WithShape(2, 2), WithBacking([]float64{1, 5, 10, -1})) 173 f, _ := os.OpenFile("test.csv", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644) 174 T.WriteCSV(f) 175 f.Close() 176 177 // cleanup 178 err := os.Remove("test.csv") 179 if err != nil { 180 t.Error(err) 181 } 182 183 // try with masked array. masked elements should be filled with default value 184 T.ResetMask(false) 185 T.mask[0] = true 186 f, _ = os.OpenFile("test.csv", os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0644) 187 T.WriteCSV(f) 188 f.Close() 189 190 // cleanup again 191 err = os.Remove("test.csv") 192 if err != nil { 193 t.Error(err) 194 } 195 196 } 197 198 var serializationTestData = []interface{}{ 199 []int{1, 5, 10, -1}, 200 []int8{1, 5, 10, -1}, 201 []int16{1, 5, 10, -1}, 202 []int32{1, 5, 10, -1}, 203 []int64{1, 5, 10, -1}, 204 []uint{1, 5, 10, 255}, 205 []uint8{1, 5, 10, 255}, 206 []uint16{1, 5, 10, 255}, 207 []uint32{1, 5, 10, 255}, 208 []uint64{1, 5, 10, 255}, 209 []float32{1, 5, 10, -1}, 210 []float64{1, 5, 10, -1}, 211 []complex64{1, 5, 10, -1}, 212 []complex128{1, 5, 10, -1}, 213 []string{"hello", "world", "hello", "世界"}, 214 } 215 216 func TestDense_GobEncodeDecode(t *testing.T) { 217 assert := assert.New(t) 218 var err error 219 for _, gtd := range serializationTestData { 220 buf := new(bytes.Buffer) 221 encoder := gob.NewEncoder(buf) 222 decoder := gob.NewDecoder(buf) 223 224 T := New(WithShape(2, 2), WithBacking(gtd)) 225 if err = encoder.Encode(T); err != nil { 226 t.Errorf("Error while encoding %v: %v", gtd, err) 227 continue 228 } 229 230 T2 := new(Dense) 231 if err = decoder.Decode(T2); err != nil { 232 t.Errorf("Error while decoding %v: %v", gtd, err) 233 continue 234 } 235 236 assert.Equal(T.Shape(), T2.Shape()) 237 assert.Equal(T.Strides(), T2.Strides()) 238 assert.Equal(T.Data(), T2.Data()) 239 240 // try with masked array. masked elements should be filled with default value 241 buf = new(bytes.Buffer) 242 encoder = gob.NewEncoder(buf) 243 decoder = gob.NewDecoder(buf) 244 245 T.ResetMask(false) 246 T.mask[0] = true 247 assert.True(T.IsMasked()) 248 if err = encoder.Encode(T); err != nil { 249 t.Errorf("Error while encoding %v: %v", gtd, err) 250 continue 251 } 252 253 T3 := new(Dense) 254 if err = decoder.Decode(T3); err != nil { 255 t.Errorf("Error while decoding %v: %v", gtd, err) 256 continue 257 } 258 259 assert.Equal(T.Shape(), T3.Shape()) 260 assert.Equal(T.Strides(), T3.Strides()) 261 assert.Equal(T.Data(), T3.Data()) 262 assert.Equal(T.mask, T3.mask) 263 264 } 265 } 266 267 func TestDense_FBEncodeDecode(t *testing.T) { 268 assert := assert.New(t) 269 for _, gtd := range serializationTestData { 270 T := New(WithShape(2, 2), WithBacking(gtd)) 271 buf, err := T.FBEncode() 272 if err != nil { 273 t.Errorf("UNPOSSIBLE!: %v", err) 274 continue 275 } 276 277 T2 := new(Dense) 278 if err = T2.FBDecode(buf); err != nil { 279 t.Errorf("Error while decoding %v: %v", gtd, err) 280 continue 281 } 282 283 assert.Equal(T.Shape(), T2.Shape()) 284 assert.Equal(T.Strides(), T2.Strides()) 285 assert.Equal(T.Data(), T2.Data()) 286 287 // TODO: MASKED ARRAY 288 } 289 } 290 291 func TestDense_PBEncodeDecode(t *testing.T) { 292 assert := assert.New(t) 293 for _, gtd := range serializationTestData { 294 T := New(WithShape(2, 2), WithBacking(gtd)) 295 buf, err := T.PBEncode() 296 if err != nil { 297 t.Errorf("UNPOSSIBLE!: %v", err) 298 continue 299 } 300 301 T2 := new(Dense) 302 if err = T2.PBDecode(buf); err != nil { 303 t.Errorf("Error while decoding %v: %v", gtd, err) 304 continue 305 } 306 307 assert.Equal(T.Shape(), T2.Shape()) 308 assert.Equal(T.Strides(), T2.Strides()) 309 assert.Equal(T.Data(), T2.Data()) 310 311 // TODO: MASKED ARRAY 312 } 313 }