github.com/yaricom/goNEAT@v0.0.0-20210507221059-e2110b885482/neat/network/network_test.go (about) 1 package network 2 3 import ( 4 "github.com/stretchr/testify/assert" 5 "github.com/stretchr/testify/require" 6 "github.com/yaricom/goNEAT/neat/utils" 7 "testing" 8 ) 9 10 func buildNetwork() *Network { 11 allNodes := []*NNode{ 12 NewNNode(1, InputNeuron), 13 NewNNode(2, InputNeuron), 14 NewNNode(3, BiasNeuron), 15 NewNNode(4, HiddenNeuron), 16 NewNNode(5, HiddenNeuron), 17 NewNNode(6, HiddenNeuron), 18 NewNNode(7, OutputNeuron), 19 NewNNode(8, OutputNeuron), 20 } 21 22 // HIDDEN 4 23 allNodes[3].addIncoming(allNodes[0], 15.0) 24 allNodes[3].addIncoming(allNodes[1], 10.0) 25 // HIDDEN 5 26 allNodes[4].addIncoming(allNodes[1], 5.0) 27 allNodes[4].addIncoming(allNodes[2], 1.0) 28 // HIDDEN 6 29 allNodes[5].addIncoming(allNodes[4], 17.0) 30 // OUTPUT 7 31 allNodes[6].addIncoming(allNodes[3], 7.0) 32 allNodes[6].addIncoming(allNodes[5], 4.5) 33 // OUTPUT 8 34 allNodes[7].addIncoming(allNodes[5], 13.0) 35 36 return NewNetwork(allNodes[0:3], allNodes[6:8], allNodes, 0) 37 } 38 39 func buildModularNetwork() *Network { 40 allNodes := []*NNode{ 41 NewNNode(1, InputNeuron), 42 NewNNode(2, InputNeuron), 43 NewNNode(3, BiasNeuron), 44 NewNNode(4, HiddenNeuron), 45 NewNNode(5, HiddenNeuron), 46 NewNNode(7, HiddenNeuron), 47 NewNNode(8, OutputNeuron), 48 NewNNode(9, OutputNeuron), 49 } 50 controlNodes := []*NNode{ 51 NewNNode(6, HiddenNeuron), 52 } 53 // HIDDEN 6 54 controlNodes[0].ActivationType = utils.MultiplyModuleActivation 55 controlNodes[0].addIncoming(allNodes[3], 1.0) 56 controlNodes[0].addIncoming(allNodes[4], 1.0) 57 controlNodes[0].addOutgoing(allNodes[5], 1.0) 58 59 // HIDDEN 4 60 allNodes[3].ActivationType = utils.LinearActivation 61 allNodes[3].addIncoming(allNodes[0], 15.0) 62 allNodes[3].addIncoming(allNodes[2], 10.0) 63 // HIDDEN 5 64 allNodes[4].ActivationType = utils.LinearActivation 65 allNodes[4].addIncoming(allNodes[1], 5.0) 66 allNodes[4].addIncoming(allNodes[2], 1.0) 67 68 // HIDDEN 7 69 allNodes[5].ActivationType = utils.NullActivation 70 71 // OUTPUT 8 72 allNodes[6].addIncoming(allNodes[5], 4.5) 73 allNodes[6].ActivationType = utils.LinearActivation 74 // OUTPUT 9 75 allNodes[7].addIncoming(allNodes[5], 13.0) 76 allNodes[7].ActivationType = utils.LinearActivation 77 78 return NewModularNetwork(allNodes[0:3], allNodes[6:8], allNodes, controlNodes, 0) 79 } 80 81 func TestModularNetwork_Activate(t *testing.T) { 82 net := buildModularNetwork() 83 84 data := []float64{1.0, 2.0, 0.5} 85 err := net.LoadSensors(data) 86 require.NoError(t, err, "failed to load sensors") 87 88 for i := 0; i < 5; i++ { 89 res, err := net.Activate() 90 require.NoError(t, err, "error when do activation at: %d", i) 91 require.True(t, res, "failed to activate at: %d", i) 92 } 93 assert.Equal(t, 945.0, net.Outputs[0].Activation) 94 assert.Equal(t, 2730.0, net.Outputs[1].Activation) 95 } 96 97 // Tests Network MaxDepth 98 func TestNetwork_MaxDepth(t *testing.T) { 99 net := buildNetwork() 100 101 depth, err := net.MaxDepth() 102 assert.NoError(t, err, "failed to calculate max depth") 103 assert.Equal(t, 3, depth) 104 } 105 106 // Tests Network OutputIsOff 107 func TestNetwork_OutputIsOff(t *testing.T) { 108 net := buildNetwork() 109 110 res := net.OutputIsOff() 111 assert.True(t, res) 112 } 113 114 // Tests Network Activate 115 func TestNetwork_Activate(t *testing.T) { 116 net := buildNetwork() 117 118 res, err := net.Activate() 119 require.NoError(t, err, "error when do activation at") 120 require.True(t, res, "failed to activate at") 121 122 // check activation 123 for i, node := range net.AllNodes() { 124 if node.IsNeuron() { 125 require.NotZero(t, node.ActivationsCount, "ActivationsCount not set at: %d", i) 126 require.NotZero(t, node.Activation, "Activation not set at: %d", i) 127 128 // Check activation and time delayed activation 129 require.NotZero(t, node.GetActiveOut(), "GetActiveOut not set at: %d", i) 130 } 131 } 132 } 133 134 // Test Network LoadSensors 135 func TestNetwork_LoadSensors(t *testing.T) { 136 net := buildNetwork() 137 138 sensors := []float64{1.0, 3.4, 5.6} 139 140 err := net.LoadSensors(sensors) 141 require.NoError(t, err, "failed to load sensors") 142 143 counter := 0 144 for i, node := range net.AllNodes() { 145 if node.IsSensor() { 146 assert.Equal(t, sensors[counter], node.Activation, "Sensor value wrong at: %d", i) 147 assert.EqualValues(t, 1, node.ActivationsCount, "Sensor activations count wrong at: %d", i) 148 counter++ 149 } 150 } 151 } 152 153 // Test Network Flush 154 func TestNetwork_Flush(t *testing.T) { 155 net := buildNetwork() 156 157 // activate and check state 158 res, err := net.Activate() 159 require.NoError(t, err, "error when do activation at") 160 require.True(t, res, "failed to activate at") 161 162 // flush and check 163 res, err = net.Flush() 164 require.NoError(t, err, "error while trying to flush") 165 require.True(t, res, "Network flush failed") 166 167 for i, node := range net.AllNodes() { 168 assert.Zero(t, node.ActivationsCount, "at %d", i) 169 assert.Zero(t, node.Activation, "at %d", i) 170 171 // Check activation and time delayed activation 172 assert.Zero(t, node.GetActiveOut(), "at %d", i) 173 assert.Zero(t, node.GetActiveOutTd(), "at %d", i) 174 } 175 } 176 177 // Tests Network NodeCount 178 func TestNetwork_NodeCount(t *testing.T) { 179 net := buildNetwork() 180 181 count := net.NodeCount() 182 assert.Equal(t, 8, count, "Wrong network's node count") 183 } 184 185 // Tests Network LinkCount 186 func TestNetwork_LinkCount(t *testing.T) { 187 net := buildNetwork() 188 189 count := net.LinkCount() 190 assert.Equal(t, 8, count, "Wrong network's link count") 191 } 192 193 // Tests Network IsRecurrent 194 func TestNetwork_IsRecurrent(t *testing.T) { 195 net := buildNetwork() 196 197 nodes := net.AllNodes() 198 visited := 0 // the count of times the node was visited 199 recur := net.IsRecurrent(nodes[0], nodes[7], &visited, 32) 200 assert.False(t, recur, "Network is not recurrent") 201 assert.Equal(t, 1, visited) 202 203 // Introduce recurrence 204 visited = 0 205 nodes[4].addIncoming(nodes[7], 3.0) 206 recur = net.IsRecurrent(nodes[5], nodes[7], &visited, 32) 207 assert.True(t, recur, "Network is actually recurrent now") 208 assert.Equal(t, 5, visited) 209 } 210 211 // test fast network solver generation 212 func TestNetwork_FastNetworkSolver(t *testing.T) { 213 net := buildModularNetwork() 214 215 solver, err := net.FastNetworkSolver() 216 require.NoError(t, err, "failed to create fast network solver") 217 require.NotNil(t, solver) 218 219 // check solver structure 220 assert.Equal(t, net.NodeCount(), solver.NodeCount(), "wrong number of nodes") 221 assert.Equal(t, net.LinkCount(), solver.LinkCount(), "wrong number of links") 222 }