gorgonia.org/gorgonia@v0.9.17/op_math_cuda_test.go (about) 1 // +build cuda 2 3 package gorgonia 4 5 import ( 6 "log" 7 "os" 8 "runtime" 9 "testing" 10 11 "github.com/pkg/errors" 12 "github.com/stretchr/testify/assert" 13 "gorgonia.org/tensor" 14 ) 15 16 func TestCUDACube(t *testing.T) { 17 defer runtime.GC() 18 19 assert := assert.New(t) 20 xT := tensor.New(tensor.Of(tensor.Float32), tensor.WithBacking(tensor.Range(Float32, 0, 32)), tensor.WithShape(8, 4)) 21 22 g := NewGraph(WithGraphName("Test")) 23 x := NewMatrix(g, tensor.Float32, WithName("x"), WithShape(8, 4), WithValue(xT)) 24 x3 := Must(Cube(x)) 25 var x3Val Value 26 Read(x3, &x3Val) 27 28 m := NewTapeMachine(g) 29 defer m.Close() 30 if err := m.RunAll(); err != nil { 31 t.Error(err) 32 } 33 correct := []float32{0, 1, 8, 27, 64, 125, 216, 343, 512, 729, 1000, 1331, 1728, 2197, 2744, 3375, 4096, 4913, 5832, 6859, 8000, 9261, 10648, 12167, 13824, 15625, 17576, 19683, 21952, 24389, 27000, 29791} 34 assert.Equal(correct, x3Val.Data()) 35 36 t.Logf("0x%x", x3Val.Uintptr()) 37 t.Logf("\n%v", m.cpumem[1]) 38 t.Logf("0x%x", m.cpumem[1].Uintptr()) 39 40 correct = tensor.Range(tensor.Float32, 0, 32).([]float32) 41 assert.Equal(correct, x.Value().Data()) 42 } 43 44 func TestCUDABasicArithmetic(t *testing.T) { 45 for i, bot := range binOpTests { 46 // if i != 5 { 47 // continue 48 // } 49 // log.Printf("Test %d", i) 50 if err := testOneCUDABasicArithmetic(t, bot, i); err != nil { 51 t.Fatalf("Test %d. Err %+v", i, err) 52 } 53 runtime.GC() 54 } 55 56 // logger = spare 57 } 58 59 func testOneCUDABasicArithmetic(t *testing.T, bot binOpTest, i int) error { 60 g := NewGraph() 61 xV, _ := CloneValue(bot.a) 62 yV, _ := CloneValue(bot.b) 63 x := NodeFromAny(g, xV, WithName("x")) 64 y := NodeFromAny(g, yV, WithName("y")) 65 66 var ret *Node 67 var retVal Value 68 var err error 69 if ret, err = bot.binOp(x, y); err != nil { 70 return err 71 } 72 Read(ret, &retVal) 73 74 cost := Must(Sum(ret)) 75 var grads Nodes 76 if grads, err = Grad(cost, x, y); err != nil { 77 return err 78 } 79 80 m1 := NewTapeMachine(g) 81 defer m1.Close() 82 if err = m1.RunAll(); err != nil { 83 t.Logf("%v", m1.Prog()) 84 return err 85 } 86 87 as := newAssertState(assert.New(t)) 88 as.Equal(bot.correct.Data(), retVal.Data(), "Test %d result", i) 89 as.True(bot.correctShape.Eq(ret.Shape())) 90 as.Equal(2, len(grads)) 91 as.Equal(bot.correctDerivA.Data(), grads[0].Value().Data(), "Test %v xgrad", i) 92 as.Equal(bot.correctDerivB.Data(), grads[1].Value().Data(), "Test %v ygrad. Expected %v. Got %v", i, bot.correctDerivB, grads[1].Value()) 93 if !as.cont { 94 prog := m1.Prog() 95 return errors.Errorf("Failed. Prog %v", prog) 96 } 97 return nil 98 99 } 100 101 func TestMultiDeviceArithmetic(t *testing.T) { 102 g := NewGraph() 103 x := NewMatrix(g, Float64, WithName("x"), WithShape(2, 2)) 104 y := NewMatrix(g, Float64, WithName("y"), WithShape(2, 2)) 105 z := Must(Add(x, y)) 106 zpx := Must(Add(x, z)) // z would be on device 107 Must(Sum(zpx)) 108 109 xV := tensor.New(tensor.WithBacking([]float64{0, 1, 2, 3}), tensor.WithShape(2, 2)) 110 yV := tensor.New(tensor.WithBacking([]float64{0, 1, 2, 3}), tensor.WithShape(2, 2)) 111 112 Let(x, xV) 113 Let(y, yV) 114 115 logger := log.New(os.Stderr, "", 0) 116 m := NewLispMachine(g, WithLogger(logger), WithWatchlist(), LogBothDir()) 117 defer m.Close() 118 t.Logf("zpx.Device: %v", zpx.Device()) 119 t.Logf("x.Device: %v", x.Device()) 120 t.Logf("y.Device: %v", y.Device()) 121 122 if err := m.RunAll(); err != nil { 123 t.Errorf("err: %+v", err) 124 } 125 126 }