github.com/yaricom/goNEAT@v0.0.0-20210507221059-e2110b885482/neat/network/nnode_test.go (about)

     1  package network
     2  
     3  import (
     4  	"github.com/stretchr/testify/assert"
     5  	"github.com/stretchr/testify/require"
     6  	"testing"
     7  )
     8  
     9  // Tests NNode SensorLoad
    10  func TestNNode_SensorLoad(t *testing.T) {
    11  	node := NewNNode(1, InputNeuron)
    12  
    13  	load := 21.0
    14  	res := node.SensorLoad(load)
    15  	require.True(t, res, "Failed to SensorLoad")
    16  	assert.EqualValues(t, 1, node.ActivationsCount)
    17  	assert.Equal(t, load, node.Activation)
    18  	assert.Equal(t, load, node.GetActiveOut())
    19  
    20  	load2 := 36.0
    21  	res = node.SensorLoad(load2)
    22  	require.True(t, res, "Failed to SensorLoad")
    23  	assert.EqualValues(t, 2, node.ActivationsCount)
    24  	assert.Equal(t, load2, node.Activation)
    25  	// Check activation and time delayed activation
    26  	assert.Equal(t, load2, node.GetActiveOut())
    27  	assert.Equal(t, load, node.GetActiveOutTd())
    28  
    29  	// Check loading of incorrect type node
    30  	//
    31  	nodeN := NewNNode(1, HiddenNeuron)
    32  	res = nodeN.SensorLoad(load)
    33  	assert.False(t, res, "Non SENSOR node can not be loaded")
    34  }
    35  
    36  // Tests NNode AddIncoming
    37  func TestNNode_AddIncoming(t *testing.T) {
    38  	node := NewNNode(1, InputNeuron)
    39  	node2 := NewNNode(2, HiddenNeuron)
    40  
    41  	weight := 1.5
    42  	node2.addIncoming(node, weight)
    43  	assert.Len(t, node2.Incoming, 1, "Wrong number of incoming nodes")
    44  
    45  	link := node2.Incoming[0]
    46  	assert.Equal(t, weight, link.Weight, "Wrong incoming link weight")
    47  	assert.Equal(t, node, link.InNode, "Wrong InNode in Link")
    48  	assert.Equal(t, node2, link.OutNode, "Wrong OutNode in Link")
    49  }
    50  
    51  // Tests NNode Depth
    52  func TestNNode_Depth(t *testing.T) {
    53  	node := NewNNode(1, InputNeuron)
    54  	node2 := NewNNode(2, HiddenNeuron)
    55  	node3 := NewNNode(3, OutputNeuron)
    56  
    57  	node2.addIncoming(node, 15.0)
    58  	node3.addIncoming(node2, 20.0)
    59  
    60  	depth, err := node3.Depth(0)
    61  	require.NoError(t, err)
    62  	assert.Equal(t, 2, depth)
    63  }
    64  
    65  func TestNNode_DepthWithLoop(t *testing.T) {
    66  	node := NewNNode(1, InputNeuron)
    67  	node2 := NewNNode(2, HiddenNeuron)
    68  	node3 := NewNNode(3, OutputNeuron)
    69  
    70  	node2.addIncoming(node, 15.0)
    71  	node3.addIncoming(node2, 20.0)
    72  	node2.addIncoming(node3, 10.0)
    73  	depth, err := node3.Depth(0)
    74  	require.NoError(t, err)
    75  	assert.Equal(t, 2, depth)
    76  }
    77  
    78  // Tests NNode Flushback
    79  func TestNNode_Flushback(t *testing.T) {
    80  	node := NewNNode(1, InputNeuron)
    81  	load := 34.0
    82  	load2 := 14.0
    83  	node.SensorLoad(load)
    84  	node.SensorLoad(load2)
    85  
    86  	// check that node state has been updated
    87  	assert.EqualValues(t, 2, node.ActivationsCount)
    88  	assert.Equal(t, 14.0, node.Activation)
    89  
    90  	// Check activation and time delayed activation
    91  	assert.Equal(t, load2, node.GetActiveOut())
    92  	assert.Equal(t, load, node.GetActiveOutTd())
    93  
    94  	// check flush back
    95  	//
    96  	node.Flushback()
    97  
    98  	assert.Zero(t, node.ActivationsCount)
    99  	assert.Zero(t, node.Activation)
   100  
   101  	// Check activation and time delayed activation
   102  	assert.Zero(t, node.GetActiveOut())
   103  	assert.Zero(t, node.GetActiveOutTd())
   104  }