gorgonia.org/gorgonia@v0.9.17/ops/nn/api_cuda_test.go (about) 1 // +build cuda 2 3 package nnops 4 5 import ( 6 "io/ioutil" 7 "log" 8 "os" 9 "testing" 10 11 G "gorgonia.org/gorgonia" 12 "gorgonia.org/tensor" 13 ) 14 15 func TestDropout(t *testing.T) { 16 g := G.NewGraph() 17 x := G.NewMatrix(g, G.Float64, G.WithShape(2, 3), G.WithName("x")) 18 do, _ := Dropout(x, 0.5) 19 log.Printf("%v", do) 20 ioutil.WriteFile("foo.dot", []byte(g.ToDot()), 0644) 21 22 } 23 24 /* 25 func TestBatchNorm_F64(t *testing.T) { 26 g := G.NewGraph() 27 x := G.NewTensor(g, G.Float64, 4, G.WithShape(5, 2, 3, 4), G.WithInit(G.Gaussian(0, 1))) 28 y, op, err := BatchNorm(x, 0.9, 1e-5, true) 29 if err != nil { 30 t.Fatal(err) 31 } 32 33 var yVal G.Value 34 G.Read(y, &yVal) 35 36 cost, _ := G.Mean(y) 37 38 if _, err := G.Grad(cost, x); err != nil { 39 t.Fatal(err) 40 } 41 42 m := G.NewTapeMachine(g, G.BindDualValues(x), G.TraceExec()) 43 if err := m.RunAll(); err != nil { 44 t.Fatal(err) 45 } 46 m.Close() 47 ioutil.WriteFile("foo.dot", []byte(g.ToDot()), 0644) 48 49 shape := x.Shape() 50 n, c, h, w := shape[0], shape[1], shape[2], shape[3] 51 52 yVT := yVal.(*tensor.Dense) 53 for j := 0; j < c; j++ { 54 var sum, variance float64 55 for i := 0; i < n; i++ { 56 for k := 0; k < h; k++ { 57 for l := 0; l < w; l++ { 58 at, err := yVT.At(i, j, k, l) 59 if err != nil { 60 t.Fatal(err) 61 } 62 atf := at.(float64) 63 sum += atf 64 variance += atf * atf 65 } 66 } 67 } 68 sum /= float64(h * w * n) 69 variance /= float64(h * w * n) 70 71 if !dawson.ToleranceF64(sum, 0, 0.00001) { 72 t.Errorf("channel %d: Expected sum to be near 0. Got %v", j, sum) 73 } 74 75 if !dawson.ToleranceF64(variance, 1, 0.0001) { 76 t.Errorf("channel %d: Expected variance to be near 1. Got %v", j, variance) 77 } 78 } 79 80 op.SetTesting() 81 m = G.NewTapeMachine(g, G.BindDualValues(x)) 82 if err := m.RunAll(); err != nil { 83 t.Fatal(err) 84 } 85 m.Close() 86 yVT = yVal.(*tensor.Dense) 87 for j := 0; j < c; j++ { 88 var sum, variance float64 89 for i := 0; i < n; i++ { 90 for k := 0; k < h; k++ { 91 for l := 0; l < w; l++ { 92 at, err := yVT.At(i, j, k, l) 93 if err != nil { 94 t.Fatal(err) 95 } 96 atf := at.(float64) 97 sum += atf 98 variance += atf * atf 99 } 100 } 101 } 102 sum /= float64(h * w * n) 103 variance /= float64(h * w * n) 104 105 if !dawson.ToleranceF64(sum, 0, 0.00001) { 106 t.Errorf("channel %d: Expected sum to be near 0. Got %v", j, sum) 107 } 108 109 if !dawson.ToleranceF64(variance, 0.9833, 0.0001) { 110 t.Errorf("channel %d: Expected variance to be near 0.98. Got %v", j, variance) 111 } 112 } 113 } 114 */ 115 116 func TestDevBN(t *testing.T) { 117 g := G.NewGraph() 118 x := G.NewTensor(g, G.Float64, 4, G.WithShape(5, 2, 3, 4), G.WithInit(G.Gaussian(0, 1)), G.WithName("x")) 119 y, _, _, op, err := BatchNorm(x, nil, nil, 0.9, 1e-5) 120 if err != nil { 121 t.Fatal(err) 122 } 123 ioutil.WriteFile("foo.dot", []byte(g.ToDot()), 0644) 124 125 cost, _ := G.Mean(y) 126 127 if _, err := G.Grad(cost, x); err != nil { 128 t.Fatal(err) 129 } 130 131 log.Printf("%v | %v", y, op) 132 ioutil.WriteFile("bar.dot", []byte(g.ToDot()), 0644) 133 prog, _, _ := G.Compile(g) 134 log.Printf("%v", prog) 135 logger := log.New(os.Stderr, "", 0) 136 m := G.NewTapeMachine(g, G.BindDualValues(x), G.WithLogger(logger), G.WithWatchlist()) 137 if err := m.RunAll(); err != nil { 138 t.Fatal(err) 139 } 140 m.Close() 141 } 142 143 func TestScratch(t *testing.T) { 144 g := G.NewGraph() 145 ss := &scratchOp{tensor.Shape{1, 2, 3, 4}, tensor.Float64, "testScratch"} 146 x := G.NewTensor(g, tensor.Float64, 4, G.WithName("x")) 147 y := G.NewTensor(g, tensor.Float64, 4, G.WithOp(ss)) 148 prog, _, _ := G.Compile(g) 149 log.Printf("x: %v y: %v", x, y) 150 log.Printf("%v", prog) 151 }