gorgonia.org/gorgonia@v0.9.17/complex_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"runtime/debug"
     5  	"testing"
     6  )
     7  
     8  func TestWeirdNetwork(t *testing.T) {
     9  	const (
    10  		embeddingDims = 50
    11  		hiddenSize    = 200
    12  
    13  		xs     = 64
    14  		xFeats = 20
    15  
    16  		ps     = 20
    17  		pFeats = 10
    18  
    19  		qs     = 50
    20  		qFeats = 12
    21  
    22  		outSize = 10
    23  	)
    24  	var err error
    25  
    26  	g := NewGraph()
    27  	var x *Node // NewVector(g, Float64, WithShape(xFeats*embeddingDims), WithName("x"), WithInit(Zeroes()))
    28  	var p *Node
    29  	var q *Node
    30  
    31  	eX := NewMatrix(g, Float64, WithShape(xs, embeddingDims), WithName("x embeddings"), WithInit(GlorotU(1)))
    32  	eP := NewMatrix(g, Float64, WithShape(ps, embeddingDims), WithName("p embeddings"), WithInit(GlorotU(1)))
    33  	eQ := NewMatrix(g, Float64, WithShape(qs, embeddingDims), WithName("q embeddings"), WithInit(GlorotU(1)))
    34  	w0X := NewMatrix(g, Float64, WithShape(hiddenSize, xFeats*embeddingDims), WithName("layer0 weights for x"), WithInit(GlorotU(1)))
    35  	w0P := NewMatrix(g, Float64, WithShape(hiddenSize, pFeats*embeddingDims), WithName("layer0 weights for p"), WithInit(GlorotU(1)))
    36  	w0Q := NewMatrix(g, Float64, WithShape(hiddenSize, qFeats*embeddingDims), WithName("layer0 weights for q"), WithInit(GlorotU(1)))
    37  	b := NewVector(g, Float64, WithShape(hiddenSize), WithName("bias"), WithInit(Zeroes()))
    38  	w1 := NewMatrix(g, Float64, WithShape(outSize, hiddenSize), WithName("layer 1"), WithInit(GlorotU(1)))
    39  
    40  	model := Nodes{eX, eP, eQ, w0X, w0P, w0Q, b, w1}
    41  
    42  	/* SET UP NEURAL NETWORK */
    43  
    44  	slicesX := make(Nodes, xFeats)
    45  	slicesP := make(Nodes, pFeats)
    46  	slicesQ := make(Nodes, qFeats)
    47  
    48  	for i := 0; i < xFeats; i++ {
    49  		if slicesX[i], err = Slice(eX, S(i)); err != nil {
    50  			t.Fatal(err)
    51  		}
    52  	}
    53  
    54  	for i := 0; i < pFeats; i++ {
    55  		if slicesP[i], err = Slice(eP, S(i)); err != nil {
    56  			t.Fatal(err)
    57  		}
    58  	}
    59  
    60  	for i := 0; i < qFeats; i++ {
    61  		if slicesQ[i], err = Slice(eQ, S(i)); err != nil {
    62  			t.Fatal(err)
    63  		}
    64  	}
    65  
    66  	if x, err = Concat(0, slicesX...); err != nil {
    67  		t.Fatal(err)
    68  	}
    69  
    70  	if p, err = Concat(0, slicesP...); err != nil {
    71  		t.Fatal(err)
    72  	}
    73  
    74  	if q, err = Concat(0, slicesQ...); err != nil {
    75  		t.Fatal(err)
    76  	}
    77  
    78  	var wx, wp, wq *Node
    79  	if wx, err = Mul(w0X, x); err != nil {
    80  		t.Fatal(err)
    81  	}
    82  
    83  	if wp, err = Mul(w0P, p); err != nil {
    84  		t.Fatal(err)
    85  	}
    86  
    87  	if wq, err = Mul(w0Q, q); err != nil {
    88  		t.Fatal(err)
    89  	}
    90  
    91  	// add all them layers
    92  	var add0, add1, add2 *Node
    93  	if add0, err = Add(wx, wp); err != nil {
    94  		t.Fatal(err)
    95  	}
    96  	if add1, err = Add(add0, wq); err != nil {
    97  		t.Fatal(err)
    98  	}
    99  	if add2, err = Add(add1, b); err != nil {
   100  		t.Fatal(err)
   101  	}
   102  
   103  	// activate
   104  	var act0 *Node
   105  	if act0, err = Cube(add2); err != nil {
   106  		t.Fatal(err)
   107  	}
   108  
   109  	// layer 1
   110  	var layer1 *Node
   111  	if layer1, err = Mul(w1, act0); err != nil {
   112  		t.Fatal(err)
   113  	}
   114  
   115  	// activate
   116  	var logProb *Node
   117  	if logProb, err = SoftMax(layer1); err != nil {
   118  		t.Fatal(err)
   119  	}
   120  
   121  	var cost *Node
   122  	if cost, err = Slice(logProb, S(0)); err != nil { // dummy slice
   123  		t.Fatal(err)
   124  	}
   125  
   126  	// backprop
   127  	if _, err = Grad(cost, model...); err != nil {
   128  		t.Fatal(err)
   129  	}
   130  
   131  	/* SET UP COMPLETE */
   132  
   133  	m := NewTapeMachine(g, BindDualValues(model...))
   134  	defer m.Close()
   135  
   136  	// for debug purposes
   137  	// prog, locMap, err := Compile(g)
   138  	// log.Println(prog)
   139  
   140  	// for i := 0; i < 104729; i++ {
   141  	for i := 0; i < 2; i++ {
   142  		if err = m.RunAll(); err != nil {
   143  			t.Errorf("%d %v", i, err)
   144  			t.Log(string(debug.Stack()))
   145  
   146  			break
   147  		}
   148  
   149  		m.Reset()
   150  	}
   151  
   152  }