gorgonia.org/gorgonia@v0.9.17/ops/nn/api_cuda_test.go (about)

     1  // +build cuda
     2  
     3  package nnops
     4  
     5  import (
     6  	"io/ioutil"
     7  	"log"
     8  	"os"
     9  	"testing"
    10  
    11  	G "gorgonia.org/gorgonia"
    12  	"gorgonia.org/tensor"
    13  )
    14  
    15  func TestDropout(t *testing.T) {
    16  	g := G.NewGraph()
    17  	x := G.NewMatrix(g, G.Float64, G.WithShape(2, 3), G.WithName("x"))
    18  	do, _ := Dropout(x, 0.5)
    19  	log.Printf("%v", do)
    20  	ioutil.WriteFile("foo.dot", []byte(g.ToDot()), 0644)
    21  
    22  }
    23  
    24  /*
    25  func TestBatchNorm_F64(t *testing.T) {
    26  	g := G.NewGraph()
    27  	x := G.NewTensor(g, G.Float64, 4, G.WithShape(5, 2, 3, 4), G.WithInit(G.Gaussian(0, 1)))
    28  	y, op, err := BatchNorm(x, 0.9, 1e-5, true)
    29  	if err != nil {
    30  		t.Fatal(err)
    31  	}
    32  
    33  	var yVal G.Value
    34  	G.Read(y, &yVal)
    35  
    36  	cost, _ := G.Mean(y)
    37  
    38  	if _, err := G.Grad(cost, x); err != nil {
    39  		t.Fatal(err)
    40  	}
    41  
    42  	m := G.NewTapeMachine(g, G.BindDualValues(x), G.TraceExec())
    43  	if err := m.RunAll(); err != nil {
    44  		t.Fatal(err)
    45  	}
    46  	m.Close()
    47  	ioutil.WriteFile("foo.dot", []byte(g.ToDot()), 0644)
    48  
    49  	shape := x.Shape()
    50  	n, c, h, w := shape[0], shape[1], shape[2], shape[3]
    51  
    52  	yVT := yVal.(*tensor.Dense)
    53  	for j := 0; j < c; j++ {
    54  		var sum, variance float64
    55  		for i := 0; i < n; i++ {
    56  			for k := 0; k < h; k++ {
    57  				for l := 0; l < w; l++ {
    58  					at, err := yVT.At(i, j, k, l)
    59  					if err != nil {
    60  						t.Fatal(err)
    61  					}
    62  					atf := at.(float64)
    63  					sum += atf
    64  					variance += atf * atf
    65  				}
    66  			}
    67  		}
    68  		sum /= float64(h * w * n)
    69  		variance /= float64(h * w * n)
    70  
    71  		if !dawson.ToleranceF64(sum, 0, 0.00001) {
    72  			t.Errorf("channel %d: Expected sum to be near 0. Got %v", j, sum)
    73  		}
    74  
    75  		if !dawson.ToleranceF64(variance, 1, 0.0001) {
    76  			t.Errorf("channel %d: Expected variance to be near 1. Got %v", j, variance)
    77  		}
    78  	}
    79  
    80  	op.SetTesting()
    81  	m = G.NewTapeMachine(g, G.BindDualValues(x))
    82  	if err := m.RunAll(); err != nil {
    83  		t.Fatal(err)
    84  	}
    85  	m.Close()
    86  	yVT = yVal.(*tensor.Dense)
    87  	for j := 0; j < c; j++ {
    88  		var sum, variance float64
    89  		for i := 0; i < n; i++ {
    90  			for k := 0; k < h; k++ {
    91  				for l := 0; l < w; l++ {
    92  					at, err := yVT.At(i, j, k, l)
    93  					if err != nil {
    94  						t.Fatal(err)
    95  					}
    96  					atf := at.(float64)
    97  					sum += atf
    98  					variance += atf * atf
    99  				}
   100  			}
   101  		}
   102  		sum /= float64(h * w * n)
   103  		variance /= float64(h * w * n)
   104  
   105  		if !dawson.ToleranceF64(sum, 0, 0.00001) {
   106  			t.Errorf("channel %d: Expected sum to be near 0. Got %v", j, sum)
   107  		}
   108  
   109  		if !dawson.ToleranceF64(variance, 0.9833, 0.0001) {
   110  			t.Errorf("channel %d: Expected variance to be near 0.98. Got %v", j, variance)
   111  		}
   112  	}
   113  }
   114  */
   115  
   116  func TestDevBN(t *testing.T) {
   117  	g := G.NewGraph()
   118  	x := G.NewTensor(g, G.Float64, 4, G.WithShape(5, 2, 3, 4), G.WithInit(G.Gaussian(0, 1)), G.WithName("x"))
   119  	y, _, _, op, err := BatchNorm(x, nil, nil, 0.9, 1e-5)
   120  	if err != nil {
   121  		t.Fatal(err)
   122  	}
   123  	ioutil.WriteFile("foo.dot", []byte(g.ToDot()), 0644)
   124  
   125  	cost, _ := G.Mean(y)
   126  
   127  	if _, err := G.Grad(cost, x); err != nil {
   128  		t.Fatal(err)
   129  	}
   130  
   131  	log.Printf("%v | %v", y, op)
   132  	ioutil.WriteFile("bar.dot", []byte(g.ToDot()), 0644)
   133  	prog, _, _ := G.Compile(g)
   134  	log.Printf("%v", prog)
   135  	logger := log.New(os.Stderr, "", 0)
   136  	m := G.NewTapeMachine(g, G.BindDualValues(x), G.WithLogger(logger), G.WithWatchlist())
   137  	if err := m.RunAll(); err != nil {
   138  		t.Fatal(err)
   139  	}
   140  	m.Close()
   141  }
   142  
   143  func TestScratch(t *testing.T) {
   144  	g := G.NewGraph()
   145  	ss := &scratchOp{tensor.Shape{1, 2, 3, 4}, tensor.Float64, "testScratch"}
   146  	x := G.NewTensor(g, tensor.Float64, 4, G.WithName("x"))
   147  	y := G.NewTensor(g, tensor.Float64, 4, G.WithOp(ss))
   148  	prog, _, _ := G.Compile(g)
   149  	log.Printf("x: %v y: %v", x, y)
   150  	log.Printf("%v", prog)
   151  }