gorgonia.org/gorgonia@v0.9.17/cuda_test.go (about)

     1  // +build cuda
     2  
     3  package gorgonia
     4  
     5  import (
     6  	"io/ioutil"
     7  	"log"
     8  	"os"
     9  	"testing"
    10  
    11  	"gorgonia.org/tensor"
    12  )
    13  
    14  func TestDevCUDA(t *testing.T) {
    15  	t.SkipNow()
    16  
    17  	g := NewGraph()
    18  	x := NewMatrix(g, Float64, WithShape(1024, 100), WithName("x"), WithInit(ValuesOf(2.0)))
    19  	y := NewMatrix(g, Float64, WithShape(1024, 100), WithName("y"), WithInit(ValuesOf(8.0)))
    20  	xpy := Must(Add(x, y))
    21  	xmy := Must(Sub(x, y))
    22  	xpy2 := Must(Square(xpy))
    23  	WithName("xpy2")(xpy2)
    24  	xmy2 := Must(Square(xmy))
    25  	xpy2s := Must(Slice(xpy2, S(0)))
    26  	ioutil.WriteFile("fullgraph.dot", []byte(g.ToDot()), 0644)
    27  
    28  	var xpyV, xmyV, xpy2V, xpy2sV, xmy2V Value
    29  	Read(xpy, &xpyV)
    30  	Read(xmy, &xmyV)
    31  	Read(xpy2, &xpy2V)
    32  	Read(xpy2s, &xpy2sV)
    33  	Read(xmy2, &xmy2V)
    34  
    35  	logger := log.New(os.Stderr, "", 0)
    36  	m := NewTapeMachine(g, WithLogger(logger), WithWatchlist(), WithValueFmt("0x%x"))
    37  	defer m.Close()
    38  
    39  	prog, locMap, _ := Compile(g)
    40  	t.Logf("prog:\n%v\n", prog)
    41  	t.Logf("locMap %-v", FmtNodeMap(locMap))
    42  	if err := m.RunAll(); err != nil {
    43  		t.Errorf("%+v", err)
    44  	}
    45  
    46  	t.Logf("x: \n%v", x.Value())
    47  
    48  	t.Logf("y: \n%v", y.Value())
    49  	t.Logf("xpy \n%v", xpyV)
    50  	t.Logf("xmy \n%v", xmyV)
    51  	t.Logf("xpy2: \n%v", xpy2V)
    52  	t.Logf("xpy2s \n%v", xpy2sV)
    53  	t.Logf("xmy2 \n%v", xmy2V)
    54  
    55  	if assertGraphEngine(t, g, stdengType); t.Failed() {
    56  		t.FailNow()
    57  	}
    58  }
    59  
    60  func TestExternMetadata_Transfer(t *testing.T) {
    61  	m := new(ExternMetadata)
    62  	m.init([]int64{1024}) // allocate 1024 bytes
    63  
    64  	v := tensor.New(tensor.Of(Float64), tensor.WithShape(2, 2))
    65  	go func() {
    66  		for s := range m.WorkAvailable() {
    67  			m.DoWork()
    68  			if s {
    69  				m.syncChan <- struct{}{}
    70  			}
    71  		}
    72  	}()
    73  
    74  	//	 transfer from CPU to GPU
    75  	v2, err := m.Transfer(Device(0), CPU, v, true)
    76  	if err != nil {
    77  		t.Error(err)
    78  	}
    79  
    80  	if vt, ok := v2.(*tensor.Dense); (ok && !vt.IsManuallyManaged()) || !ok {
    81  		t.Errorf("Expected manually managed value")
    82  	}
    83  	t.Logf("v2: 0x%x", v2.Uintptr())
    84  
    85  	// transfer from GPU to CPU
    86  	v3, err := m.Transfer(CPU, Device(0), v2, true)
    87  	if err != nil {
    88  		t.Error(err)
    89  	}
    90  	if vt, ok := v3.(*tensor.Dense); (ok && vt.IsManuallyManaged()) || !ok {
    91  		t.Errorf("Expected Go managed value")
    92  	}
    93  	t.Logf("v3: 0x%x", v3.Uintptr())
    94  
    95  	// transfer from CPU to CPU
    96  	v4, err := m.Transfer(CPU, CPU, v3, true)
    97  	if err != nil {
    98  		t.Error(err)
    99  	}
   100  	if v4 != v3 {
   101  		t.Errorf("Expected the values to be returned exactly the same")
   102  	}
   103  }
   104  
   105  func BenchmarkOneMilCUDA(b *testing.B) {
   106  	xT := tensor.New(tensor.WithShape(1000000), tensor.WithBacking(tensor.Random(tensor.Float32, 1000000)))
   107  	g := NewGraph()
   108  	x := NewVector(g, Float32, WithShape(1000000), WithName("x"), WithValue(xT))
   109  	Must(Sigmoid(x))
   110  
   111  	m := NewTapeMachine(g)
   112  	defer m.Close()
   113  
   114  	// runtime.LockOSThread()
   115  	for n := 0; n < b.N; n++ {
   116  		if err := m.RunAll(); err != nil {
   117  			b.Fatalf("Failed at n: %d. Error: %v", n, err)
   118  			break
   119  		}
   120  		m.Reset()
   121  	}
   122  	// runtime.UnlockOSThread()
   123  }
   124  
   125  func BenchmarkOneMil(b *testing.B) {
   126  	xT := tensor.New(tensor.WithShape(1000000), tensor.WithBacking(tensor.Random(tensor.Float32, 1000000)))
   127  	g := NewGraph()
   128  	x := NewVector(g, Float32, WithShape(1000000), WithName("x"), WithValue(xT))
   129  	Must(Sigmoid(x))
   130  
   131  	m := NewTapeMachine(g)
   132  	defer m.Close()
   133  
   134  	for n := 0; n < b.N; n++ {
   135  		if err := m.RunAll(); err != nil {
   136  			b.Fatalf("Failed at n: %d. Error: %v", n, err)
   137  			break
   138  		}
   139  		m.Reset()
   140  	}
   141  }