gorgonia.org/gorgonia@v0.9.17/complex_test.go (about) 1 package gorgonia 2 3 import ( 4 "runtime/debug" 5 "testing" 6 ) 7 8 func TestWeirdNetwork(t *testing.T) { 9 const ( 10 embeddingDims = 50 11 hiddenSize = 200 12 13 xs = 64 14 xFeats = 20 15 16 ps = 20 17 pFeats = 10 18 19 qs = 50 20 qFeats = 12 21 22 outSize = 10 23 ) 24 var err error 25 26 g := NewGraph() 27 var x *Node // NewVector(g, Float64, WithShape(xFeats*embeddingDims), WithName("x"), WithInit(Zeroes())) 28 var p *Node 29 var q *Node 30 31 eX := NewMatrix(g, Float64, WithShape(xs, embeddingDims), WithName("x embeddings"), WithInit(GlorotU(1))) 32 eP := NewMatrix(g, Float64, WithShape(ps, embeddingDims), WithName("p embeddings"), WithInit(GlorotU(1))) 33 eQ := NewMatrix(g, Float64, WithShape(qs, embeddingDims), WithName("q embeddings"), WithInit(GlorotU(1))) 34 w0X := NewMatrix(g, Float64, WithShape(hiddenSize, xFeats*embeddingDims), WithName("layer0 weights for x"), WithInit(GlorotU(1))) 35 w0P := NewMatrix(g, Float64, WithShape(hiddenSize, pFeats*embeddingDims), WithName("layer0 weights for p"), WithInit(GlorotU(1))) 36 w0Q := NewMatrix(g, Float64, WithShape(hiddenSize, qFeats*embeddingDims), WithName("layer0 weights for q"), WithInit(GlorotU(1))) 37 b := NewVector(g, Float64, WithShape(hiddenSize), WithName("bias"), WithInit(Zeroes())) 38 w1 := NewMatrix(g, Float64, WithShape(outSize, hiddenSize), WithName("layer 1"), WithInit(GlorotU(1))) 39 40 model := Nodes{eX, eP, eQ, w0X, w0P, w0Q, b, w1} 41 42 /* SET UP NEURAL NETWORK */ 43 44 slicesX := make(Nodes, xFeats) 45 slicesP := make(Nodes, pFeats) 46 slicesQ := make(Nodes, qFeats) 47 48 for i := 0; i < xFeats; i++ { 49 if slicesX[i], err = Slice(eX, S(i)); err != nil { 50 t.Fatal(err) 51 } 52 } 53 54 for i := 0; i < pFeats; i++ { 55 if slicesP[i], err = Slice(eP, S(i)); err != nil { 56 t.Fatal(err) 57 } 58 } 59 60 for i := 0; i < qFeats; i++ { 61 if slicesQ[i], err = Slice(eQ, S(i)); err != nil { 62 t.Fatal(err) 63 } 64 } 65 66 if x, err = Concat(0, slicesX...); err != nil { 67 t.Fatal(err) 68 } 69 70 if p, err = Concat(0, slicesP...); err != nil { 71 t.Fatal(err) 72 } 73 74 if q, err = Concat(0, slicesQ...); err != nil { 75 t.Fatal(err) 76 } 77 78 var wx, wp, wq *Node 79 if wx, err = Mul(w0X, x); err != nil { 80 t.Fatal(err) 81 } 82 83 if wp, err = Mul(w0P, p); err != nil { 84 t.Fatal(err) 85 } 86 87 if wq, err = Mul(w0Q, q); err != nil { 88 t.Fatal(err) 89 } 90 91 // add all them layers 92 var add0, add1, add2 *Node 93 if add0, err = Add(wx, wp); err != nil { 94 t.Fatal(err) 95 } 96 if add1, err = Add(add0, wq); err != nil { 97 t.Fatal(err) 98 } 99 if add2, err = Add(add1, b); err != nil { 100 t.Fatal(err) 101 } 102 103 // activate 104 var act0 *Node 105 if act0, err = Cube(add2); err != nil { 106 t.Fatal(err) 107 } 108 109 // layer 1 110 var layer1 *Node 111 if layer1, err = Mul(w1, act0); err != nil { 112 t.Fatal(err) 113 } 114 115 // activate 116 var logProb *Node 117 if logProb, err = SoftMax(layer1); err != nil { 118 t.Fatal(err) 119 } 120 121 var cost *Node 122 if cost, err = Slice(logProb, S(0)); err != nil { // dummy slice 123 t.Fatal(err) 124 } 125 126 // backprop 127 if _, err = Grad(cost, model...); err != nil { 128 t.Fatal(err) 129 } 130 131 /* SET UP COMPLETE */ 132 133 m := NewTapeMachine(g, BindDualValues(model...)) 134 defer m.Close() 135 136 // for debug purposes 137 // prog, locMap, err := Compile(g) 138 // log.Println(prog) 139 140 // for i := 0; i < 104729; i++ { 141 for i := 0; i < 2; i++ { 142 if err = m.RunAll(); err != nil { 143 t.Errorf("%d %v", i, err) 144 t.Log(string(debug.Stack())) 145 146 break 147 } 148 149 m.Reset() 150 } 151 152 }