github.com/yaricom/goNEAT@v0.0.0-20210507221059-e2110b885482/neat/network/network_test.go (about)

     1  package network
     2  
     3  import (
     4  	"github.com/stretchr/testify/assert"
     5  	"github.com/stretchr/testify/require"
     6  	"github.com/yaricom/goNEAT/neat/utils"
     7  	"testing"
     8  )
     9  
    10  func buildNetwork() *Network {
    11  	allNodes := []*NNode{
    12  		NewNNode(1, InputNeuron),
    13  		NewNNode(2, InputNeuron),
    14  		NewNNode(3, BiasNeuron),
    15  		NewNNode(4, HiddenNeuron),
    16  		NewNNode(5, HiddenNeuron),
    17  		NewNNode(6, HiddenNeuron),
    18  		NewNNode(7, OutputNeuron),
    19  		NewNNode(8, OutputNeuron),
    20  	}
    21  
    22  	// HIDDEN 4
    23  	allNodes[3].addIncoming(allNodes[0], 15.0)
    24  	allNodes[3].addIncoming(allNodes[1], 10.0)
    25  	// HIDDEN 5
    26  	allNodes[4].addIncoming(allNodes[1], 5.0)
    27  	allNodes[4].addIncoming(allNodes[2], 1.0)
    28  	// HIDDEN 6
    29  	allNodes[5].addIncoming(allNodes[4], 17.0)
    30  	// OUTPUT 7
    31  	allNodes[6].addIncoming(allNodes[3], 7.0)
    32  	allNodes[6].addIncoming(allNodes[5], 4.5)
    33  	// OUTPUT 8
    34  	allNodes[7].addIncoming(allNodes[5], 13.0)
    35  
    36  	return NewNetwork(allNodes[0:3], allNodes[6:8], allNodes, 0)
    37  }
    38  
    39  func buildModularNetwork() *Network {
    40  	allNodes := []*NNode{
    41  		NewNNode(1, InputNeuron),
    42  		NewNNode(2, InputNeuron),
    43  		NewNNode(3, BiasNeuron),
    44  		NewNNode(4, HiddenNeuron),
    45  		NewNNode(5, HiddenNeuron),
    46  		NewNNode(7, HiddenNeuron),
    47  		NewNNode(8, OutputNeuron),
    48  		NewNNode(9, OutputNeuron),
    49  	}
    50  	controlNodes := []*NNode{
    51  		NewNNode(6, HiddenNeuron),
    52  	}
    53  	// HIDDEN 6
    54  	controlNodes[0].ActivationType = utils.MultiplyModuleActivation
    55  	controlNodes[0].addIncoming(allNodes[3], 1.0)
    56  	controlNodes[0].addIncoming(allNodes[4], 1.0)
    57  	controlNodes[0].addOutgoing(allNodes[5], 1.0)
    58  
    59  	// HIDDEN 4
    60  	allNodes[3].ActivationType = utils.LinearActivation
    61  	allNodes[3].addIncoming(allNodes[0], 15.0)
    62  	allNodes[3].addIncoming(allNodes[2], 10.0)
    63  	// HIDDEN 5
    64  	allNodes[4].ActivationType = utils.LinearActivation
    65  	allNodes[4].addIncoming(allNodes[1], 5.0)
    66  	allNodes[4].addIncoming(allNodes[2], 1.0)
    67  
    68  	// HIDDEN 7
    69  	allNodes[5].ActivationType = utils.NullActivation
    70  
    71  	// OUTPUT 8
    72  	allNodes[6].addIncoming(allNodes[5], 4.5)
    73  	allNodes[6].ActivationType = utils.LinearActivation
    74  	// OUTPUT 9
    75  	allNodes[7].addIncoming(allNodes[5], 13.0)
    76  	allNodes[7].ActivationType = utils.LinearActivation
    77  
    78  	return NewModularNetwork(allNodes[0:3], allNodes[6:8], allNodes, controlNodes, 0)
    79  }
    80  
    81  func TestModularNetwork_Activate(t *testing.T) {
    82  	net := buildModularNetwork()
    83  
    84  	data := []float64{1.0, 2.0, 0.5}
    85  	err := net.LoadSensors(data)
    86  	require.NoError(t, err, "failed to load sensors")
    87  
    88  	for i := 0; i < 5; i++ {
    89  		res, err := net.Activate()
    90  		require.NoError(t, err, "error when do activation at: %d", i)
    91  		require.True(t, res, "failed to activate at: %d", i)
    92  	}
    93  	assert.Equal(t, 945.0, net.Outputs[0].Activation)
    94  	assert.Equal(t, 2730.0, net.Outputs[1].Activation)
    95  }
    96  
    97  // Tests Network MaxDepth
    98  func TestNetwork_MaxDepth(t *testing.T) {
    99  	net := buildNetwork()
   100  
   101  	depth, err := net.MaxDepth()
   102  	assert.NoError(t, err, "failed to calculate max depth")
   103  	assert.Equal(t, 3, depth)
   104  }
   105  
   106  // Tests Network OutputIsOff
   107  func TestNetwork_OutputIsOff(t *testing.T) {
   108  	net := buildNetwork()
   109  
   110  	res := net.OutputIsOff()
   111  	assert.True(t, res)
   112  }
   113  
   114  // Tests Network Activate
   115  func TestNetwork_Activate(t *testing.T) {
   116  	net := buildNetwork()
   117  
   118  	res, err := net.Activate()
   119  	require.NoError(t, err, "error when do activation at")
   120  	require.True(t, res, "failed to activate at")
   121  
   122  	// check activation
   123  	for i, node := range net.AllNodes() {
   124  		if node.IsNeuron() {
   125  			require.NotZero(t, node.ActivationsCount, "ActivationsCount not set at: %d", i)
   126  			require.NotZero(t, node.Activation, "Activation not set at: %d", i)
   127  
   128  			// Check activation and time delayed activation
   129  			require.NotZero(t, node.GetActiveOut(), "GetActiveOut not set at: %d", i)
   130  		}
   131  	}
   132  }
   133  
   134  // Test Network LoadSensors
   135  func TestNetwork_LoadSensors(t *testing.T) {
   136  	net := buildNetwork()
   137  
   138  	sensors := []float64{1.0, 3.4, 5.6}
   139  
   140  	err := net.LoadSensors(sensors)
   141  	require.NoError(t, err, "failed to load sensors")
   142  
   143  	counter := 0
   144  	for i, node := range net.AllNodes() {
   145  		if node.IsSensor() {
   146  			assert.Equal(t, sensors[counter], node.Activation, "Sensor value wrong at: %d", i)
   147  			assert.EqualValues(t, 1, node.ActivationsCount, "Sensor activations count wrong at: %d", i)
   148  			counter++
   149  		}
   150  	}
   151  }
   152  
   153  // Test Network Flush
   154  func TestNetwork_Flush(t *testing.T) {
   155  	net := buildNetwork()
   156  
   157  	// activate and check state
   158  	res, err := net.Activate()
   159  	require.NoError(t, err, "error when do activation at")
   160  	require.True(t, res, "failed to activate at")
   161  
   162  	// flush and check
   163  	res, err = net.Flush()
   164  	require.NoError(t, err, "error while trying to flush")
   165  	require.True(t, res, "Network flush failed")
   166  
   167  	for i, node := range net.AllNodes() {
   168  		assert.Zero(t, node.ActivationsCount, "at %d", i)
   169  		assert.Zero(t, node.Activation, "at %d", i)
   170  
   171  		// Check activation and time delayed activation
   172  		assert.Zero(t, node.GetActiveOut(), "at %d", i)
   173  		assert.Zero(t, node.GetActiveOutTd(), "at %d", i)
   174  	}
   175  }
   176  
   177  // Tests Network NodeCount
   178  func TestNetwork_NodeCount(t *testing.T) {
   179  	net := buildNetwork()
   180  
   181  	count := net.NodeCount()
   182  	assert.Equal(t, 8, count, "Wrong network's node count")
   183  }
   184  
   185  // Tests Network LinkCount
   186  func TestNetwork_LinkCount(t *testing.T) {
   187  	net := buildNetwork()
   188  
   189  	count := net.LinkCount()
   190  	assert.Equal(t, 8, count, "Wrong network's link count")
   191  }
   192  
   193  // Tests Network IsRecurrent
   194  func TestNetwork_IsRecurrent(t *testing.T) {
   195  	net := buildNetwork()
   196  
   197  	nodes := net.AllNodes()
   198  	visited := 0 // the count of times the node was visited
   199  	recur := net.IsRecurrent(nodes[0], nodes[7], &visited, 32)
   200  	assert.False(t, recur, "Network is not recurrent")
   201  	assert.Equal(t, 1, visited)
   202  
   203  	// Introduce recurrence
   204  	visited = 0
   205  	nodes[4].addIncoming(nodes[7], 3.0)
   206  	recur = net.IsRecurrent(nodes[5], nodes[7], &visited, 32)
   207  	assert.True(t, recur, "Network is actually recurrent now")
   208  	assert.Equal(t, 5, visited)
   209  }
   210  
   211  // test fast network solver generation
   212  func TestNetwork_FastNetworkSolver(t *testing.T) {
   213  	net := buildModularNetwork()
   214  
   215  	solver, err := net.FastNetworkSolver()
   216  	require.NoError(t, err, "failed to create fast network solver")
   217  	require.NotNil(t, solver)
   218  
   219  	// check solver structure
   220  	assert.Equal(t, net.NodeCount(), solver.NodeCount(), "wrong number of nodes")
   221  	assert.Equal(t, net.LinkCount(), solver.LinkCount(), "wrong number of links")
   222  }