gorgonia.org/tensor@v0.9.24/consopt_test.go (about)

     1  // +build linux
     2  
     3  package tensor
     4  
     5  import (
     6  	"fmt"
     7  	"io/ioutil"
     8  	"os"
     9  	"syscall"
    10  	"testing"
    11  	"testing/quick"
    12  	"unsafe"
    13  
    14  	"github.com/stretchr/testify/assert"
    15  )
    16  
    17  type F64 float64
    18  
    19  func newF64(f float64) *F64 { r := F64(f); return &r }
    20  
    21  func (f *F64) Uintptr() uintptr { return uintptr(unsafe.Pointer(f)) }
    22  
    23  func (f *F64) MemSize() uintptr { return 8 }
    24  
    25  func (f *F64) Pointer() unsafe.Pointer { return unsafe.Pointer(f) }
    26  
    27  func Test_FromMemory(t *testing.T) {
    28  	fn := func(F float64) bool {
    29  		f := newF64(F)
    30  		T := New(WithShape(), Of(Float64), FromMemory(f.Uintptr(), f.MemSize()))
    31  		data := T.Data().(float64)
    32  
    33  		if data != F {
    34  			return false
    35  		}
    36  		return true
    37  	}
    38  	if err := quick.Check(fn, &quick.Config{MaxCount: 1000000}); err != nil {
    39  		t.Logf("%v", err)
    40  	}
    41  
    42  	f, err := ioutil.TempFile("", "test")
    43  	if err != nil {
    44  		t.Fatal(err)
    45  	}
    46  	// fill in with fake data
    47  	backing := make([]byte, 8*1024*1024) // 1024*1024 matrix of float64
    48  	asFloats := *(*[]float64)(unsafe.Pointer(&backing))
    49  	asFloats = asFloats[: 1024*1024 : 1024*1024]
    50  	asFloats[0] = 3.14
    51  	asFloats[2] = 6.28
    52  	asFloats[1024*1024-1] = 3.14
    53  	asFloats[1024*1024-3] = 6.28
    54  	f.Write(backing)
    55  
    56  	// defer cleanup
    57  	defer os.Remove(f.Name())
    58  
    59  	// do the mmap stuff
    60  	stat, err := f.Stat()
    61  	if err != nil {
    62  		t.Fatal(err)
    63  	}
    64  
    65  	size := int(stat.Size())
    66  	fd := int(f.Fd())
    67  	bs, err := syscall.Mmap(fd, 0, size, syscall.PROT_READ, syscall.MAP_SHARED)
    68  	if err != nil {
    69  		t.Fatal(err)
    70  	}
    71  	defer func() {
    72  		if err := syscall.Munmap(bs); err != nil {
    73  			t.Error(err)
    74  		}
    75  	}()
    76  	T := New(WithShape(1024, 1024), Of(Float64), FromMemory(uintptr(unsafe.Pointer(&bs[0])), uintptr(size)))
    77  
    78  	s := fmt.Sprintf("%v", T)
    79  	expected := `⎡3.14     0  6.28     0  ...    0     0     0     0⎤
    80  ⎢   0     0     0     0  ...    0     0     0     0⎥
    81  ⎢   0     0     0     0  ...    0     0     0     0⎥
    82  ⎢   0     0     0     0  ...    0     0     0     0⎥
    83  .
    84  .
    85  .
    86  ⎢   0     0     0     0  ...    0     0     0     0⎥
    87  ⎢   0     0     0     0  ...    0     0     0     0⎥
    88  ⎢   0     0     0     0  ...    0     0     0     0⎥
    89  ⎣   0     0     0     0  ...    0  6.28     0  3.14⎦
    90  `
    91  	if s != expected {
    92  		t.Errorf("Expected mmap'd tensor to be exactly the same.")
    93  	}
    94  
    95  	assert.True(t, T.IsManuallyManaged())
    96  }