gorgonia.org/gorgonia@v0.9.17/op_math_cuda_test.go (about)

     1  // +build cuda
     2  
     3  package gorgonia
     4  
     5  import (
     6  	"log"
     7  	"os"
     8  	"runtime"
     9  	"testing"
    10  
    11  	"github.com/pkg/errors"
    12  	"github.com/stretchr/testify/assert"
    13  	"gorgonia.org/tensor"
    14  )
    15  
    16  func TestCUDACube(t *testing.T) {
    17  	defer runtime.GC()
    18  
    19  	assert := assert.New(t)
    20  	xT := tensor.New(tensor.Of(tensor.Float32), tensor.WithBacking(tensor.Range(Float32, 0, 32)), tensor.WithShape(8, 4))
    21  
    22  	g := NewGraph(WithGraphName("Test"))
    23  	x := NewMatrix(g, tensor.Float32, WithName("x"), WithShape(8, 4), WithValue(xT))
    24  	x3 := Must(Cube(x))
    25  	var x3Val Value
    26  	Read(x3, &x3Val)
    27  
    28  	m := NewTapeMachine(g)
    29  	defer m.Close()
    30  	if err := m.RunAll(); err != nil {
    31  		t.Error(err)
    32  	}
    33  	correct := []float32{0, 1, 8, 27, 64, 125, 216, 343, 512, 729, 1000, 1331, 1728, 2197, 2744, 3375, 4096, 4913, 5832, 6859, 8000, 9261, 10648, 12167, 13824, 15625, 17576, 19683, 21952, 24389, 27000, 29791}
    34  	assert.Equal(correct, x3Val.Data())
    35  
    36  	t.Logf("0x%x", x3Val.Uintptr())
    37  	t.Logf("\n%v", m.cpumem[1])
    38  	t.Logf("0x%x", m.cpumem[1].Uintptr())
    39  
    40  	correct = tensor.Range(tensor.Float32, 0, 32).([]float32)
    41  	assert.Equal(correct, x.Value().Data())
    42  }
    43  
    44  func TestCUDABasicArithmetic(t *testing.T) {
    45  	for i, bot := range binOpTests {
    46  		// if i != 5 {
    47  		// 	continue
    48  		// }
    49  		// log.Printf("Test %d", i)
    50  		if err := testOneCUDABasicArithmetic(t, bot, i); err != nil {
    51  			t.Fatalf("Test %d. Err %+v", i, err)
    52  		}
    53  		runtime.GC()
    54  	}
    55  
    56  	// logger = spare
    57  }
    58  
    59  func testOneCUDABasicArithmetic(t *testing.T, bot binOpTest, i int) error {
    60  	g := NewGraph()
    61  	xV, _ := CloneValue(bot.a)
    62  	yV, _ := CloneValue(bot.b)
    63  	x := NodeFromAny(g, xV, WithName("x"))
    64  	y := NodeFromAny(g, yV, WithName("y"))
    65  
    66  	var ret *Node
    67  	var retVal Value
    68  	var err error
    69  	if ret, err = bot.binOp(x, y); err != nil {
    70  		return err
    71  	}
    72  	Read(ret, &retVal)
    73  
    74  	cost := Must(Sum(ret))
    75  	var grads Nodes
    76  	if grads, err = Grad(cost, x, y); err != nil {
    77  		return err
    78  	}
    79  
    80  	m1 := NewTapeMachine(g)
    81  	defer m1.Close()
    82  	if err = m1.RunAll(); err != nil {
    83  		t.Logf("%v", m1.Prog())
    84  		return err
    85  	}
    86  
    87  	as := newAssertState(assert.New(t))
    88  	as.Equal(bot.correct.Data(), retVal.Data(), "Test %d result", i)
    89  	as.True(bot.correctShape.Eq(ret.Shape()))
    90  	as.Equal(2, len(grads))
    91  	as.Equal(bot.correctDerivA.Data(), grads[0].Value().Data(), "Test %v xgrad", i)
    92  	as.Equal(bot.correctDerivB.Data(), grads[1].Value().Data(), "Test %v ygrad. Expected %v. Got %v", i, bot.correctDerivB, grads[1].Value())
    93  	if !as.cont {
    94  		prog := m1.Prog()
    95  		return errors.Errorf("Failed. Prog %v", prog)
    96  	}
    97  	return nil
    98  
    99  }
   100  
   101  func TestMultiDeviceArithmetic(t *testing.T) {
   102  	g := NewGraph()
   103  	x := NewMatrix(g, Float64, WithName("x"), WithShape(2, 2))
   104  	y := NewMatrix(g, Float64, WithName("y"), WithShape(2, 2))
   105  	z := Must(Add(x, y))
   106  	zpx := Must(Add(x, z)) // z would be on device
   107  	Must(Sum(zpx))
   108  
   109  	xV := tensor.New(tensor.WithBacking([]float64{0, 1, 2, 3}), tensor.WithShape(2, 2))
   110  	yV := tensor.New(tensor.WithBacking([]float64{0, 1, 2, 3}), tensor.WithShape(2, 2))
   111  
   112  	Let(x, xV)
   113  	Let(y, yV)
   114  
   115  	logger := log.New(os.Stderr, "", 0)
   116  	m := NewLispMachine(g, WithLogger(logger), WithWatchlist(), LogBothDir())
   117  	defer m.Close()
   118  	t.Logf("zpx.Device: %v", zpx.Device())
   119  	t.Logf("x.Device: %v", x.Device())
   120  	t.Logf("y.Device: %v", y.Device())
   121  
   122  	if err := m.RunAll(); err != nil {
   123  		t.Errorf("err: %+v", err)
   124  	}
   125  
   126  }