github.com/yaricom/goNEAT@v0.0.0-20210507221059-e2110b885482/experiments/xor/XOR.go (about)

     1  // The XOR experiment serves to actually check that network topology actually evolves and everything works as expected.
     2  // Because XOR is not linearly separable, a neural network requires hidden units to solve it. The two inputs must be
     3  // combined at some hidden unit, as opposed to only at the out- put node, because there is no function over a linear
     4  // combination of the inputs that can separate the inputs into the proper classes. These structural requirements make
     5  // XOR suitable for testing NEAT’s ability to evolve structure.
     6  package xor
     7  
     8  import (
     9  	"github.com/yaricom/goNEAT/neat"
    10  	"os"
    11  	"fmt"
    12  	"github.com/yaricom/goNEAT/neat/genetics"
    13  	"math"
    14  	"github.com/yaricom/goNEAT/experiments"
    15  )
    16  
    17  // The fitness threshold value for successful solver
    18  const fitness_threshold = 15.5
    19  
    20  // XOR is very simple and does not make a very interesting scientific experiment; however, it is a good way to
    21  // check whether your system works.
    22  // Make sure recurrency is disabled for the XOR test. If NEAT is able to add recurrent connections, it may solve XOR by
    23  // memorizing the order of the training set. (Which is why you may even want to randomize order to be most safe) All
    24  // documented experiments with XOR are without recurrent connections. Interestingly, XOR can be solved by a recurrent
    25  // network with no hidden nodes.
    26  //
    27  // This method performs evolution on XOR for specified number of generations and output results into outDirPath
    28  // It also returns number of nodes, genes, and evaluations performed per each run (context.NumRuns)
    29  type XORGenerationEvaluator struct {
    30  	// The output path to store execution results
    31  	OutputPath string
    32  }
    33  
    34  // This method evaluates one epoch for given population and prints results into output directory if any.
    35  func (ex XORGenerationEvaluator) GenerationEvaluate(pop *genetics.Population, epoch *experiments.Generation, context *neat.NeatContext) (err error) {
    36  	// Evaluate each organism on a test
    37  	for _, org := range pop.Organisms {
    38  		res, err := ex.org_evaluate(org, context)
    39  		if err != nil {
    40  			return err
    41  		}
    42  
    43  		if res && (epoch.Best == nil || org.Fitness > epoch.Best.Fitness){
    44  			epoch.Solved = true
    45  			epoch.WinnerNodes = len(org.Genotype.Nodes)
    46  			epoch.WinnerGenes = org.Genotype.Extrons()
    47  			epoch.WinnerEvals = context.PopSize * epoch.Id + org.Genotype.Id
    48  			epoch.Best = org
    49  			if (epoch.WinnerNodes == 5) {
    50  				// You could dump out optimal genomes here if desired
    51  				opt_path := fmt.Sprintf("%s/%s_%d-%d", experiments.OutDirForTrial(ex.OutputPath, epoch.TrialId),
    52  					"xor_optimal", org.Phenotype.NodeCount(), org.Phenotype.LinkCount())
    53  				file, err := os.Create(opt_path)
    54  				if err != nil {
    55  					neat.ErrorLog(fmt.Sprintf("Failed to dump optimal genome, reason: %s\n", err))
    56  				} else {
    57  					org.Genotype.Write(file)
    58  					neat.InfoLog(fmt.Sprintf("Dumped optimal genome to: %s\n", opt_path))
    59  				}
    60  			}
    61  		}
    62  	}
    63  
    64  	// Fill statistics about current epoch
    65  	epoch.FillPopulationStatistics(pop)
    66  
    67  	// Only print to file every print_every generations
    68  	if epoch.Solved || epoch.Id % context.PrintEvery == 0 {
    69  		pop_path := fmt.Sprintf("%s/gen_%d", experiments.OutDirForTrial(ex.OutputPath, epoch.TrialId), epoch.Id)
    70  		file, err := os.Create(pop_path)
    71  		if err != nil {
    72  			neat.ErrorLog(fmt.Sprintf("Failed to dump population, reason: %s\n", err))
    73  		} else {
    74  			pop.WriteBySpecies(file)
    75  		}
    76  	}
    77  
    78  	if epoch.Solved {
    79  		// print winner organism
    80  		for _, org := range pop.Organisms {
    81  			if org.IsWinner {
    82  				// Prints the winner organism to file!
    83  				org_path := fmt.Sprintf("%s/%s_%d-%d", experiments.OutDirForTrial(ex.OutputPath, epoch.TrialId),
    84  					"xor_winner", org.Phenotype.NodeCount(), org.Phenotype.LinkCount())
    85  				file, err := os.Create(org_path)
    86  				if err != nil {
    87  					neat.ErrorLog(fmt.Sprintf("Failed to dump winner organism genome, reason: %s\n", err))
    88  				} else {
    89  					org.Genotype.Write(file)
    90  					neat.InfoLog(fmt.Sprintf("Generation #%d winner dumped to: %s\n", epoch.Id, org_path))
    91  				}
    92  				break
    93  			}
    94  		}
    95  	}
    96  
    97  	return err
    98  }
    99  
   100  // This methods evaluates provided organism
   101  func (ex *XORGenerationEvaluator) org_evaluate(organism *genetics.Organism, context *neat.NeatContext) (bool, error) {
   102  	// The four possible input combinations to xor
   103  	// The first number is for biasing
   104  	in := [][]float64{
   105  		{1.0, 0.0, 0.0},
   106  		{1.0, 0.0, 1.0},
   107  		{1.0, 1.0, 0.0},
   108  		{1.0, 1.0, 1.0}}
   109  
   110  	net_depth, err := organism.Phenotype.MaxDepth() // The max depth of the network to be activated
   111  	if err != nil {
   112  		neat.WarnLog(
   113  			fmt.Sprintf("Failed to estimate maximal depth of the network with loop:\n%s\nUsing default dpeth: %d",
   114  				organism.Genotype, net_depth))
   115  	}
   116  	neat.DebugLog(fmt.Sprintf("Network depth: %d for organism: %d\n", net_depth, organism.Genotype.Id))
   117  	if net_depth == 0 {
   118  		neat.DebugLog(fmt.Sprintf("ALERT: Network depth is ZERO for Genome: %s", organism.Genotype))
   119  	}
   120  
   121  	success := false  // Check for successful activation
   122  	out := make([]float64, 4) // The four outputs
   123  
   124  	// Load and activate the network on each input
   125  	for count := 0; count < 4; count++ {
   126  		organism.Phenotype.LoadSensors(in[count])
   127  
   128  		// Relax net and get output
   129  		success, err = organism.Phenotype.Activate()
   130  		if err != nil {
   131  			neat.ErrorLog("Failed to activate network")
   132  			return false, err
   133  		}
   134  
   135  		// use depth to ensure relaxation
   136  		for relax := 0; relax <= net_depth; relax++ {
   137  			success, err = organism.Phenotype.Activate()
   138  			if err != nil {
   139  				neat.ErrorLog("Failed to activate network")
   140  				return false, err
   141  			}
   142  		}
   143  		out[count] = organism.Phenotype.Outputs[0].Activation
   144  
   145  		organism.Phenotype.Flush()
   146  	}
   147  
   148  	if (success) {
   149  		// Mean Squared Error
   150  		error_sum := math.Abs(out[0]) + math.Abs(1.0 - out[1]) + math.Abs(1.0 - out[2]) + math.Abs(out[3]) // ideal == 0
   151  		target := 4.0 - error_sum // ideal == 4.0
   152  		organism.Fitness = math.Pow(4.0 - error_sum, 2.0)
   153  		organism.Error = math.Pow(4.0 - target, 2.0)
   154  	} else {
   155  		// The network is flawed (shouldn't happen) - flag as anomaly
   156  		organism.Error = 1.0
   157  		organism.Fitness = 0.0
   158  	}
   159  
   160  	if organism.Fitness > fitness_threshold {
   161  		organism.IsWinner = true
   162  		neat.InfoLog(fmt.Sprintf(">>>> Output activations: %e\n", out))
   163  
   164  	} else {
   165  		organism.IsWinner = false
   166  	}
   167  	return organism.IsWinner, nil
   168  }