github.com/yaricom/goNEAT@v0.0.0-20210507221059-e2110b885482/neat/network/fast_network_test.go (about)

     1  package network
     2  
     3  import (
     4  	"github.com/stretchr/testify/assert"
     5  	"github.com/stretchr/testify/require"
     6  	"testing"
     7  )
     8  
     9  func TestFastModularNetworkSolver_RecursiveSteps(t *testing.T) {
    10  	net := buildNetwork()
    11  
    12  	// Create network solver
    13  	data := []float64{0.5, 1.1} // BIAS is 1.0 by definition
    14  	fmm, err := net.FastNetworkSolver()
    15  	require.NoError(t, err, "failed to create fast network solver")
    16  	err = fmm.LoadSensors(data)
    17  	require.NoError(t, err, "failed to load sensors")
    18  
    19  	// Activate objective network
    20  	//
    21  	data = append(data, 1.0) // BIAS is a third object
    22  	err = net.LoadSensors(data)
    23  	require.NoError(t, err, "failed to load sensors")
    24  	depth, err := net.MaxDepth()
    25  	require.NoError(t, err, "failed to calculate max depth")
    26  	for i := 0; i < depth; i++ {
    27  		res, err := net.Activate()
    28  		require.NoError(t, err, "error when trying to activate at: %d", i)
    29  		require.True(t, res, "failed to activate at: %d", i)
    30  	}
    31  
    32  	// Do recursive activation of the Fast Network Solver
    33  	//
    34  	res, err := fmm.RecursiveSteps()
    35  	require.NoError(t, err, "error when trying to activate Fast Network Solver")
    36  	require.True(t, res, "recursive activation failed")
    37  
    38  	// Compare activations of objective network and Fast Network Solver
    39  	//
    40  	fmmOutputs := fmm.ReadOutputs()
    41  	require.Equal(t, len(net.Outputs), len(fmmOutputs))
    42  
    43  	for i, out := range fmmOutputs {
    44  		assert.Equal(t, net.Outputs[i].Activation, out, "wrong activation at: %d", i)
    45  	}
    46  }
    47  
    48  func TestFastModularNetworkSolver_ForwardSteps(t *testing.T) {
    49  	net := buildModularNetwork()
    50  
    51  	// create network solver
    52  	data := []float64{1.0, 2.0} // bias inherent
    53  	fmm, err := net.FastNetworkSolver()
    54  	require.NoError(t, err, "failed to create fast network solver")
    55  	err = fmm.LoadSensors(data)
    56  	require.NoError(t, err, "failed to load sensors")
    57  
    58  	steps := 5
    59  
    60  	// activate objective network
    61  	//
    62  	data = append(data, 1.0) // BIAS is third object
    63  	err = net.LoadSensors(data)
    64  	require.NoError(t, err, "failed to load sensors")
    65  	for i := 0; i < steps; i++ {
    66  		res, err := net.Activate()
    67  		require.NoError(t, err, "error when trying to activate at: %d", i)
    68  		require.True(t, res, "failed to activate at: %d", i)
    69  	}
    70  
    71  	// do forward steps through the solver and test results
    72  	//
    73  	res, err := fmm.ForwardSteps(steps)
    74  	require.NoError(t, err, "error while do forward steps")
    75  	require.True(t, res, "forward steps returned false")
    76  
    77  	// check results by comparing activations of objective network and fast network solver
    78  	//
    79  	outputs := fmm.ReadOutputs()
    80  	for i, out := range outputs {
    81  		assert.Equal(t, net.Outputs[i].Activation, out, "wrong activation at: %d", i)
    82  	}
    83  }
    84  
    85  func TestFastModularNetworkSolver_Relax(t *testing.T) {
    86  	net := buildModularNetwork()
    87  
    88  	// create network solver
    89  	data := []float64{1.5, 2.0} // bias inherent
    90  	fmm, err := net.FastNetworkSolver()
    91  	require.NoError(t, err, "failed to create fast network solver")
    92  	err = fmm.LoadSensors(data)
    93  	require.NoError(t, err, "failed to load sensors")
    94  
    95  	steps := 5
    96  
    97  	// activate objective network
    98  	//
    99  	data = append(data, 1.0) // BIAS is third object
   100  	err = net.LoadSensors(data)
   101  	require.NoError(t, err, "failed to load sensors")
   102  	for i := 0; i < steps; i++ {
   103  		res, err := net.Activate()
   104  		require.NoError(t, err, "error when trying to activate at: %d", i)
   105  		require.True(t, res, "failed to activate at: %d", i)
   106  	}
   107  
   108  	// do relaxation of fast network solver
   109  	//
   110  	res, err := fmm.Relax(steps, 1)
   111  	require.NoError(t, err)
   112  	require.True(t, res, "failed to relax within given maximal steps number")
   113  
   114  	// check results by comparing activations of objective network and fast network solver
   115  	//
   116  	outputs := fmm.ReadOutputs()
   117  	for i, out := range outputs {
   118  		assert.Equal(t, net.Outputs[i].Activation, out, "wrong activation at: %d", i)
   119  	}
   120  }
   121  
   122  func TestFastModularNetworkSolver_Flush(t *testing.T) {
   123  	net := buildModularNetwork()
   124  
   125  	// create network solver
   126  	data := []float64{1.5, 2.0} // bias inherent
   127  	fmm, err := net.FastNetworkSolver()
   128  	require.NoError(t, err, "failed to create fast network solver")
   129  	err = fmm.LoadSensors(data)
   130  	require.NoError(t, err, "failed to load sensors")
   131  
   132  	fmmImpl := fmm.(*FastModularNetworkSolver)
   133  	// test that network has active signals
   134  	active := countActiveSignals(fmmImpl)
   135  	assert.NotZero(t, active, "no active signal found")
   136  
   137  	// flush and test
   138  	res, err := fmm.Flush()
   139  	require.NoError(t, err)
   140  	require.True(t, res, "failed to flush network")
   141  
   142  	active = countActiveSignals(fmmImpl)
   143  	assert.Zero(t, active, "after flush the active signal still present")
   144  }
   145  
   146  func TestFastModularNetworkSolver_NodeCount(t *testing.T) {
   147  	net := buildModularNetwork()
   148  
   149  	fmm, err := net.FastNetworkSolver()
   150  	require.NoError(t, err, "failed to create fast network solver")
   151  	assert.Equal(t, 9, fmm.NodeCount())
   152  }
   153  
   154  func TestFastModularNetworkSolver_LinkCount(t *testing.T) {
   155  	net := buildModularNetwork()
   156  
   157  	fmm, err := net.FastNetworkSolver()
   158  	require.NoError(t, err, "failed to create fast network solver")
   159  	assert.Equal(t, 9, fmm.LinkCount())
   160  }
   161  
   162  func countActiveSignals(impl *FastModularNetworkSolver) int {
   163  	active := 0
   164  	for i := impl.biasNeuronCount; i < impl.totalNeuronCount; i++ {
   165  		if impl.neuronSignals[i] != 0.0 {
   166  			active++
   167  		}
   168  	}
   169  	return active
   170  }