github.com/yaricom/goNEAT@v0.0.0-20210507221059-e2110b885482/experiments/pole/cart2pole_test.go (about)

     1  package pole
     2  
     3  import (
     4  	"testing"
     5  	"os"
     6  	"github.com/yaricom/goNEAT/neat"
     7  	"fmt"
     8  	"github.com/yaricom/goNEAT/neat/genetics"
     9  	"github.com/yaricom/goNEAT/experiments"
    10  	"math/rand"
    11  )
    12  
    13  // Run double pole-balancing experiment with Markov environment setup
    14  func TestCartDoublePoleGenerationEvaluator_GenerationEvaluateMarkov(t *testing.T) {
    15  	// to make sure we have predictable results
    16  	rand.Seed(3423)
    17  
    18  	out_dir_path, context_path, genome_path := "../../out/pole2_markov_test", "../../data/pole2_markov.neat", "../../data/pole2_markov_startgenes"
    19  
    20  	// Load context configuration
    21  	configFile, err := os.Open(context_path)
    22  	if err != nil {
    23  		t.Error("Failed to load context", err)
    24  		return
    25  	}
    26  	context := neat.LoadContext(configFile)
    27  	neat.LogLevel = neat.LogLevelInfo
    28  
    29  	// Load Genome
    30  	fmt.Println("Loading start genome for POLE2 Markov experiment")
    31  	genomeFile, err := os.Open(genome_path)
    32  	if err != nil {
    33  		t.Error("Failed to open genome file")
    34  		return
    35  	}
    36  	start_genome, err := genetics.ReadGenome(genomeFile, 1)
    37  	if err != nil {
    38  		t.Error("Failed to read start genome")
    39  		return
    40  	}
    41  
    42  	// Check if output dir exists
    43  	if _, err := os.Stat(out_dir_path); err == nil {
    44  		// clear it
    45  		os.RemoveAll(out_dir_path)
    46  	}
    47  	// create output dir
    48  	err = os.MkdirAll(out_dir_path, os.ModePerm)
    49  	if err != nil {
    50  		t.Errorf("Failed to create output directory, reason: %s", err)
    51  		return
    52  	}
    53  
    54  	// The 10 runs POLE2 Markov experiment
    55  	context.NumRuns = 5
    56  	experiment := experiments.Experiment{
    57  		Id:0,
    58  		Trials:make(experiments.Trials, context.NumRuns),
    59  	}
    60  	err = experiment.Execute(context, start_genome, CartDoublePoleGenerationEvaluator{
    61  		OutputPath:out_dir_path,
    62  		Markov:true,
    63  		ActionType:experiments.ContinuousAction,
    64  	})
    65  	if err != nil {
    66  		t.Error("Failed to perform POLE2 Markov experiment:", err)
    67  		return
    68  	}
    69  
    70  	// Find winner statistics
    71  	avg_nodes, avg_genes, avg_evals, _ := experiment.AvgWinner()
    72  
    73  	// check results
    74  	if avg_nodes < 8 {
    75  		t.Error("avg_nodes < 8", avg_nodes)
    76  	} else if avg_nodes > 40 {
    77  		t.Error("avg_nodes > 40", avg_nodes)
    78  	}
    79  
    80  	if avg_genes < 7 {
    81  		t.Error("avg_genes < 7", avg_genes)
    82  	} else if avg_genes > 50 {
    83  		t.Error("avg_genes > 50", avg_genes)
    84  	}
    85  
    86  	max_evals := float64(context.PopSize * context.NumGenerations)
    87  	if avg_evals > max_evals {
    88  		t.Error("avg_evals > max_evals", avg_evals, max_evals)
    89  	}
    90  
    91  	t.Logf("Average nodes: %.1f, genes: %.1f, evals: %.1f\n", avg_nodes, avg_genes, avg_evals)
    92  	mean_complexity, mean_diversity, mean_age := 0.0, 0.0, 0.0
    93  	for _, t := range experiment.Trials {
    94  		mean_complexity += t.BestComplexity().Mean()
    95  		mean_diversity += t.Diversity().Mean()
    96  		mean_age += t.BestAge().Mean()
    97  	}
    98  	count := float64(len(experiment.Trials))
    99  	mean_complexity /= count
   100  	mean_diversity /= count
   101  	mean_age /= count
   102  	t.Logf("Mean best organisms: complexity=%.1f, diversity=%.1f, age=%.1f\n", mean_complexity, mean_diversity, mean_age)
   103  
   104  	solved_trials := 0
   105  	for _, tr := range experiment.Trials {
   106  		if tr.Solved() {
   107  			solved_trials++
   108  		}
   109  	}
   110  
   111  	t.Logf("Trials solved/run: %d/%d", solved_trials, len(experiment.Trials))
   112  
   113  	if solved_trials == 0 {
   114  		t.Error("Failed to solve at least one trial. Need to be checked what was going wrong")
   115  	}
   116  }
   117  
   118  // Run double pole-balancing experiment with Non-Markov environment setup
   119  func TestCartDoublePoleGenerationEvaluator_GenerationEvaluateNonMarkov(t *testing.T) {
   120  	// to make sure we have predictable results
   121  	rand.Seed(423)
   122  
   123  	out_dir_path, context_path, genome_path := "../../out/pole2_non-markov_test", "../../data/pole2_non-markov.neat", "../../data/pole2_non-markov_startgenes"
   124  
   125  	// Load context configuration
   126  	configFile, err := os.Open(context_path)
   127  	if err != nil {
   128  		t.Error("Failed to load context", err)
   129  		return
   130  	}
   131  	context := neat.LoadContext(configFile)
   132  	neat.LogLevel = neat.LogLevelInfo
   133  
   134  	// Load Genome
   135  	fmt.Println("Loading start genome for POLE2 Non-Markov experiment")
   136  	genomeFile, err := os.Open(genome_path)
   137  	if err != nil {
   138  		t.Error("Failed to open genome file")
   139  		return
   140  	}
   141  	start_genome, err := genetics.ReadGenome(genomeFile, 1)
   142  	if err != nil {
   143  		t.Error("Failed to read start genome")
   144  		return
   145  	}
   146  
   147  	// Check if output dir exists
   148  	if _, err := os.Stat(out_dir_path); err == nil {
   149  		// clear it
   150  		os.RemoveAll(out_dir_path)
   151  	}
   152  	// create output dir
   153  	err = os.MkdirAll(out_dir_path, os.ModePerm)
   154  	if err != nil {
   155  		t.Errorf("Failed to create output directory, reason: %s", err)
   156  		return
   157  	}
   158  
   159  	// The 10 runs POLE2 Non-Markov experiment
   160  	context.NumRuns = 5
   161  	experiment := experiments.Experiment{
   162  		Id:0,
   163  		Trials:make(experiments.Trials, context.NumRuns),
   164  	}
   165  	err = experiment.Execute(context, start_genome, CartDoublePoleGenerationEvaluator{
   166  		OutputPath:out_dir_path,
   167  		Markov:false,
   168  		ActionType:experiments.ContinuousAction,
   169  	})
   170  	if err != nil {
   171  		t.Error("Failed to perform POLE2 Non-Markov experiment:", err)
   172  		return
   173  	}
   174  
   175  	// Find winner statistics
   176  	avg_nodes, avg_genes, avg_evals, _ := experiment.AvgWinner()
   177  
   178  	// check results
   179  	if avg_nodes < 5 {
   180  		t.Error("avg_nodes < 5", avg_nodes)
   181  	} else if avg_nodes > 40 {
   182  		t.Error("avg_nodes > 40", avg_nodes)
   183  	}
   184  
   185  	if avg_genes < 5 {
   186  		t.Error("avg_genes < 5", avg_genes)
   187  	} else if avg_genes > 50 {
   188  		t.Error("avg_genes > 50", avg_genes)
   189  	}
   190  
   191  	max_evals := float64(context.PopSize * context.NumGenerations)
   192  	if avg_evals > max_evals {
   193  		t.Error("avg_evals > max_evals", avg_evals, max_evals)
   194  	}
   195  
   196  	t.Logf("Average nodes: %.1f, genes: %.1f, evals: %.1f\n", avg_nodes, avg_genes, avg_evals)
   197  	mean_complexity, mean_diversity, mean_age := 0.0, 0.0, 0.0
   198  	for _, t := range experiment.Trials {
   199  		mean_complexity += t.BestComplexity().Mean()
   200  		mean_diversity += t.Diversity().Mean()
   201  		mean_age += t.BestAge().Mean()
   202  	}
   203  	count := float64(len(experiment.Trials))
   204  	mean_complexity /= count
   205  	mean_diversity /= count
   206  	mean_age /= count
   207  	t.Logf("Mean best organisms: complexity=%.1f, diversity=%.1f, age=%.1f\n", mean_complexity, mean_diversity, mean_age)
   208  
   209  	solved_trials := 0
   210  	for _, tr := range experiment.Trials {
   211  		if tr.Solved() {
   212  			solved_trials++
   213  		}
   214  	}
   215  	t.Logf("Trials solved/run: %d/%d\n", solved_trials, len(experiment.Trials))
   216  
   217  	if solved_trials == 0 {
   218  		t.Error("Failed to solve at least one trial. Need to be checked what was going wrong")
   219  	}
   220  
   221  	best_g_score := 0.0
   222  	for _, tr := range experiment.Trials {
   223  		if org, found := tr.BestOrganism(true); found {
   224  			best_org_score := org.Fitness
   225  			if best_org_score > best_g_score {
   226  				best_g_score = best_org_score
   227  			}
   228  		}
   229  	}
   230  	t.Logf("Best Generalization Score: %.0f\n", best_g_score)
   231  }