github.com/yaricom/goNEAT@v0.0.0-20210507221059-e2110b885482/experiments/pole/cart2pole_test.go (about) 1 package pole 2 3 import ( 4 "testing" 5 "os" 6 "github.com/yaricom/goNEAT/neat" 7 "fmt" 8 "github.com/yaricom/goNEAT/neat/genetics" 9 "github.com/yaricom/goNEAT/experiments" 10 "math/rand" 11 ) 12 13 // Run double pole-balancing experiment with Markov environment setup 14 func TestCartDoublePoleGenerationEvaluator_GenerationEvaluateMarkov(t *testing.T) { 15 // to make sure we have predictable results 16 rand.Seed(3423) 17 18 out_dir_path, context_path, genome_path := "../../out/pole2_markov_test", "../../data/pole2_markov.neat", "../../data/pole2_markov_startgenes" 19 20 // Load context configuration 21 configFile, err := os.Open(context_path) 22 if err != nil { 23 t.Error("Failed to load context", err) 24 return 25 } 26 context := neat.LoadContext(configFile) 27 neat.LogLevel = neat.LogLevelInfo 28 29 // Load Genome 30 fmt.Println("Loading start genome for POLE2 Markov experiment") 31 genomeFile, err := os.Open(genome_path) 32 if err != nil { 33 t.Error("Failed to open genome file") 34 return 35 } 36 start_genome, err := genetics.ReadGenome(genomeFile, 1) 37 if err != nil { 38 t.Error("Failed to read start genome") 39 return 40 } 41 42 // Check if output dir exists 43 if _, err := os.Stat(out_dir_path); err == nil { 44 // clear it 45 os.RemoveAll(out_dir_path) 46 } 47 // create output dir 48 err = os.MkdirAll(out_dir_path, os.ModePerm) 49 if err != nil { 50 t.Errorf("Failed to create output directory, reason: %s", err) 51 return 52 } 53 54 // The 10 runs POLE2 Markov experiment 55 context.NumRuns = 5 56 experiment := experiments.Experiment{ 57 Id:0, 58 Trials:make(experiments.Trials, context.NumRuns), 59 } 60 err = experiment.Execute(context, start_genome, CartDoublePoleGenerationEvaluator{ 61 OutputPath:out_dir_path, 62 Markov:true, 63 ActionType:experiments.ContinuousAction, 64 }) 65 if err != nil { 66 t.Error("Failed to perform POLE2 Markov experiment:", err) 67 return 68 } 69 70 // Find winner statistics 71 avg_nodes, avg_genes, avg_evals, _ := experiment.AvgWinner() 72 73 // check results 74 if avg_nodes < 8 { 75 t.Error("avg_nodes < 8", avg_nodes) 76 } else if avg_nodes > 40 { 77 t.Error("avg_nodes > 40", avg_nodes) 78 } 79 80 if avg_genes < 7 { 81 t.Error("avg_genes < 7", avg_genes) 82 } else if avg_genes > 50 { 83 t.Error("avg_genes > 50", avg_genes) 84 } 85 86 max_evals := float64(context.PopSize * context.NumGenerations) 87 if avg_evals > max_evals { 88 t.Error("avg_evals > max_evals", avg_evals, max_evals) 89 } 90 91 t.Logf("Average nodes: %.1f, genes: %.1f, evals: %.1f\n", avg_nodes, avg_genes, avg_evals) 92 mean_complexity, mean_diversity, mean_age := 0.0, 0.0, 0.0 93 for _, t := range experiment.Trials { 94 mean_complexity += t.BestComplexity().Mean() 95 mean_diversity += t.Diversity().Mean() 96 mean_age += t.BestAge().Mean() 97 } 98 count := float64(len(experiment.Trials)) 99 mean_complexity /= count 100 mean_diversity /= count 101 mean_age /= count 102 t.Logf("Mean best organisms: complexity=%.1f, diversity=%.1f, age=%.1f\n", mean_complexity, mean_diversity, mean_age) 103 104 solved_trials := 0 105 for _, tr := range experiment.Trials { 106 if tr.Solved() { 107 solved_trials++ 108 } 109 } 110 111 t.Logf("Trials solved/run: %d/%d", solved_trials, len(experiment.Trials)) 112 113 if solved_trials == 0 { 114 t.Error("Failed to solve at least one trial. Need to be checked what was going wrong") 115 } 116 } 117 118 // Run double pole-balancing experiment with Non-Markov environment setup 119 func TestCartDoublePoleGenerationEvaluator_GenerationEvaluateNonMarkov(t *testing.T) { 120 // to make sure we have predictable results 121 rand.Seed(423) 122 123 out_dir_path, context_path, genome_path := "../../out/pole2_non-markov_test", "../../data/pole2_non-markov.neat", "../../data/pole2_non-markov_startgenes" 124 125 // Load context configuration 126 configFile, err := os.Open(context_path) 127 if err != nil { 128 t.Error("Failed to load context", err) 129 return 130 } 131 context := neat.LoadContext(configFile) 132 neat.LogLevel = neat.LogLevelInfo 133 134 // Load Genome 135 fmt.Println("Loading start genome for POLE2 Non-Markov experiment") 136 genomeFile, err := os.Open(genome_path) 137 if err != nil { 138 t.Error("Failed to open genome file") 139 return 140 } 141 start_genome, err := genetics.ReadGenome(genomeFile, 1) 142 if err != nil { 143 t.Error("Failed to read start genome") 144 return 145 } 146 147 // Check if output dir exists 148 if _, err := os.Stat(out_dir_path); err == nil { 149 // clear it 150 os.RemoveAll(out_dir_path) 151 } 152 // create output dir 153 err = os.MkdirAll(out_dir_path, os.ModePerm) 154 if err != nil { 155 t.Errorf("Failed to create output directory, reason: %s", err) 156 return 157 } 158 159 // The 10 runs POLE2 Non-Markov experiment 160 context.NumRuns = 5 161 experiment := experiments.Experiment{ 162 Id:0, 163 Trials:make(experiments.Trials, context.NumRuns), 164 } 165 err = experiment.Execute(context, start_genome, CartDoublePoleGenerationEvaluator{ 166 OutputPath:out_dir_path, 167 Markov:false, 168 ActionType:experiments.ContinuousAction, 169 }) 170 if err != nil { 171 t.Error("Failed to perform POLE2 Non-Markov experiment:", err) 172 return 173 } 174 175 // Find winner statistics 176 avg_nodes, avg_genes, avg_evals, _ := experiment.AvgWinner() 177 178 // check results 179 if avg_nodes < 5 { 180 t.Error("avg_nodes < 5", avg_nodes) 181 } else if avg_nodes > 40 { 182 t.Error("avg_nodes > 40", avg_nodes) 183 } 184 185 if avg_genes < 5 { 186 t.Error("avg_genes < 5", avg_genes) 187 } else if avg_genes > 50 { 188 t.Error("avg_genes > 50", avg_genes) 189 } 190 191 max_evals := float64(context.PopSize * context.NumGenerations) 192 if avg_evals > max_evals { 193 t.Error("avg_evals > max_evals", avg_evals, max_evals) 194 } 195 196 t.Logf("Average nodes: %.1f, genes: %.1f, evals: %.1f\n", avg_nodes, avg_genes, avg_evals) 197 mean_complexity, mean_diversity, mean_age := 0.0, 0.0, 0.0 198 for _, t := range experiment.Trials { 199 mean_complexity += t.BestComplexity().Mean() 200 mean_diversity += t.Diversity().Mean() 201 mean_age += t.BestAge().Mean() 202 } 203 count := float64(len(experiment.Trials)) 204 mean_complexity /= count 205 mean_diversity /= count 206 mean_age /= count 207 t.Logf("Mean best organisms: complexity=%.1f, diversity=%.1f, age=%.1f\n", mean_complexity, mean_diversity, mean_age) 208 209 solved_trials := 0 210 for _, tr := range experiment.Trials { 211 if tr.Solved() { 212 solved_trials++ 213 } 214 } 215 t.Logf("Trials solved/run: %d/%d\n", solved_trials, len(experiment.Trials)) 216 217 if solved_trials == 0 { 218 t.Error("Failed to solve at least one trial. Need to be checked what was going wrong") 219 } 220 221 best_g_score := 0.0 222 for _, tr := range experiment.Trials { 223 if org, found := tr.BestOrganism(true); found { 224 best_org_score := org.Fitness 225 if best_org_score > best_g_score { 226 best_g_score = best_org_score 227 } 228 } 229 } 230 t.Logf("Best Generalization Score: %.0f\n", best_g_score) 231 }