github.com/yaricom/goNEAT@v0.0.0-20210507221059-e2110b885482/experiments/xor/XOR.go (about) 1 // The XOR experiment serves to actually check that network topology actually evolves and everything works as expected. 2 // Because XOR is not linearly separable, a neural network requires hidden units to solve it. The two inputs must be 3 // combined at some hidden unit, as opposed to only at the out- put node, because there is no function over a linear 4 // combination of the inputs that can separate the inputs into the proper classes. These structural requirements make 5 // XOR suitable for testing NEAT’s ability to evolve structure. 6 package xor 7 8 import ( 9 "github.com/yaricom/goNEAT/neat" 10 "os" 11 "fmt" 12 "github.com/yaricom/goNEAT/neat/genetics" 13 "math" 14 "github.com/yaricom/goNEAT/experiments" 15 ) 16 17 // The fitness threshold value for successful solver 18 const fitness_threshold = 15.5 19 20 // XOR is very simple and does not make a very interesting scientific experiment; however, it is a good way to 21 // check whether your system works. 22 // Make sure recurrency is disabled for the XOR test. If NEAT is able to add recurrent connections, it may solve XOR by 23 // memorizing the order of the training set. (Which is why you may even want to randomize order to be most safe) All 24 // documented experiments with XOR are without recurrent connections. Interestingly, XOR can be solved by a recurrent 25 // network with no hidden nodes. 26 // 27 // This method performs evolution on XOR for specified number of generations and output results into outDirPath 28 // It also returns number of nodes, genes, and evaluations performed per each run (context.NumRuns) 29 type XORGenerationEvaluator struct { 30 // The output path to store execution results 31 OutputPath string 32 } 33 34 // This method evaluates one epoch for given population and prints results into output directory if any. 35 func (ex XORGenerationEvaluator) GenerationEvaluate(pop *genetics.Population, epoch *experiments.Generation, context *neat.NeatContext) (err error) { 36 // Evaluate each organism on a test 37 for _, org := range pop.Organisms { 38 res, err := ex.org_evaluate(org, context) 39 if err != nil { 40 return err 41 } 42 43 if res && (epoch.Best == nil || org.Fitness > epoch.Best.Fitness){ 44 epoch.Solved = true 45 epoch.WinnerNodes = len(org.Genotype.Nodes) 46 epoch.WinnerGenes = org.Genotype.Extrons() 47 epoch.WinnerEvals = context.PopSize * epoch.Id + org.Genotype.Id 48 epoch.Best = org 49 if (epoch.WinnerNodes == 5) { 50 // You could dump out optimal genomes here if desired 51 opt_path := fmt.Sprintf("%s/%s_%d-%d", experiments.OutDirForTrial(ex.OutputPath, epoch.TrialId), 52 "xor_optimal", org.Phenotype.NodeCount(), org.Phenotype.LinkCount()) 53 file, err := os.Create(opt_path) 54 if err != nil { 55 neat.ErrorLog(fmt.Sprintf("Failed to dump optimal genome, reason: %s\n", err)) 56 } else { 57 org.Genotype.Write(file) 58 neat.InfoLog(fmt.Sprintf("Dumped optimal genome to: %s\n", opt_path)) 59 } 60 } 61 } 62 } 63 64 // Fill statistics about current epoch 65 epoch.FillPopulationStatistics(pop) 66 67 // Only print to file every print_every generations 68 if epoch.Solved || epoch.Id % context.PrintEvery == 0 { 69 pop_path := fmt.Sprintf("%s/gen_%d", experiments.OutDirForTrial(ex.OutputPath, epoch.TrialId), epoch.Id) 70 file, err := os.Create(pop_path) 71 if err != nil { 72 neat.ErrorLog(fmt.Sprintf("Failed to dump population, reason: %s\n", err)) 73 } else { 74 pop.WriteBySpecies(file) 75 } 76 } 77 78 if epoch.Solved { 79 // print winner organism 80 for _, org := range pop.Organisms { 81 if org.IsWinner { 82 // Prints the winner organism to file! 83 org_path := fmt.Sprintf("%s/%s_%d-%d", experiments.OutDirForTrial(ex.OutputPath, epoch.TrialId), 84 "xor_winner", org.Phenotype.NodeCount(), org.Phenotype.LinkCount()) 85 file, err := os.Create(org_path) 86 if err != nil { 87 neat.ErrorLog(fmt.Sprintf("Failed to dump winner organism genome, reason: %s\n", err)) 88 } else { 89 org.Genotype.Write(file) 90 neat.InfoLog(fmt.Sprintf("Generation #%d winner dumped to: %s\n", epoch.Id, org_path)) 91 } 92 break 93 } 94 } 95 } 96 97 return err 98 } 99 100 // This methods evaluates provided organism 101 func (ex *XORGenerationEvaluator) org_evaluate(organism *genetics.Organism, context *neat.NeatContext) (bool, error) { 102 // The four possible input combinations to xor 103 // The first number is for biasing 104 in := [][]float64{ 105 {1.0, 0.0, 0.0}, 106 {1.0, 0.0, 1.0}, 107 {1.0, 1.0, 0.0}, 108 {1.0, 1.0, 1.0}} 109 110 net_depth, err := organism.Phenotype.MaxDepth() // The max depth of the network to be activated 111 if err != nil { 112 neat.WarnLog( 113 fmt.Sprintf("Failed to estimate maximal depth of the network with loop:\n%s\nUsing default dpeth: %d", 114 organism.Genotype, net_depth)) 115 } 116 neat.DebugLog(fmt.Sprintf("Network depth: %d for organism: %d\n", net_depth, organism.Genotype.Id)) 117 if net_depth == 0 { 118 neat.DebugLog(fmt.Sprintf("ALERT: Network depth is ZERO for Genome: %s", organism.Genotype)) 119 } 120 121 success := false // Check for successful activation 122 out := make([]float64, 4) // The four outputs 123 124 // Load and activate the network on each input 125 for count := 0; count < 4; count++ { 126 organism.Phenotype.LoadSensors(in[count]) 127 128 // Relax net and get output 129 success, err = organism.Phenotype.Activate() 130 if err != nil { 131 neat.ErrorLog("Failed to activate network") 132 return false, err 133 } 134 135 // use depth to ensure relaxation 136 for relax := 0; relax <= net_depth; relax++ { 137 success, err = organism.Phenotype.Activate() 138 if err != nil { 139 neat.ErrorLog("Failed to activate network") 140 return false, err 141 } 142 } 143 out[count] = organism.Phenotype.Outputs[0].Activation 144 145 organism.Phenotype.Flush() 146 } 147 148 if (success) { 149 // Mean Squared Error 150 error_sum := math.Abs(out[0]) + math.Abs(1.0 - out[1]) + math.Abs(1.0 - out[2]) + math.Abs(out[3]) // ideal == 0 151 target := 4.0 - error_sum // ideal == 4.0 152 organism.Fitness = math.Pow(4.0 - error_sum, 2.0) 153 organism.Error = math.Pow(4.0 - target, 2.0) 154 } else { 155 // The network is flawed (shouldn't happen) - flag as anomaly 156 organism.Error = 1.0 157 organism.Fitness = 0.0 158 } 159 160 if organism.Fitness > fitness_threshold { 161 organism.IsWinner = true 162 neat.InfoLog(fmt.Sprintf(">>>> Output activations: %e\n", out)) 163 164 } else { 165 organism.IsWinner = false 166 } 167 return organism.IsWinner, nil 168 }