github.com/yaricom/goNEAT@v0.0.0-20210507221059-e2110b885482/experiments/pole/cartpole.go (about) 1 package pole 2 3 import ( 4 "github.com/yaricom/goNEAT/neat" 5 "github.com/yaricom/goNEAT/neat/genetics" 6 "github.com/yaricom/goNEAT/experiments" 7 "math" 8 "github.com/yaricom/goNEAT/neat/network" 9 "math/rand" 10 "fmt" 11 "os" 12 ) 13 14 const twelve_degrees = 12.0 * math.Pi / 180.0 15 16 // The single pole balancing experiment entry point. 17 // This experiment performs evolution on single pole balancing task in order to produce appropriate genome. 18 type CartPoleGenerationEvaluator struct { 19 // The output path to store execution results 20 OutputPath string 21 // The flag to indicate if cart emulator should be started from random position 22 RandomStart bool 23 // The number of emulation steps to be done balancing pole to win 24 WinBalancingSteps int 25 } 26 27 // This method evaluates one epoch for given population and prints results into output directory if any. 28 func (ex CartPoleGenerationEvaluator) GenerationEvaluate(pop *genetics.Population, epoch *experiments.Generation, context *neat.NeatContext) (err error) { 29 // Evaluate each organism on a test 30 for _, org := range pop.Organisms { 31 res := ex.orgEvaluate(org) 32 33 if res && (epoch.Best == nil || org.Fitness > epoch.Best.Fitness){ 34 epoch.Solved = true 35 epoch.WinnerNodes = len(org.Genotype.Nodes) 36 epoch.WinnerGenes = org.Genotype.Extrons() 37 epoch.WinnerEvals = context.PopSize * epoch.Id + org.Genotype.Id 38 epoch.Best = org 39 if (epoch.WinnerNodes == 7) { 40 // You could dump out optimal genomes here if desired 41 opt_path := fmt.Sprintf("%s/%s_%d-%d", experiments.OutDirForTrial(ex.OutputPath, epoch.TrialId), 42 "pole1_optimal", org.Phenotype.NodeCount(), org.Phenotype.LinkCount()) 43 file, err := os.Create(opt_path) 44 if err != nil { 45 neat.ErrorLog(fmt.Sprintf("Failed to dump optimal genome, reason: %s\n", err)) 46 } else { 47 org.Genotype.Write(file) 48 neat.InfoLog(fmt.Sprintf("Dumped optimal genome to: %s\n", opt_path)) 49 } 50 } 51 } 52 } 53 54 // Fill statistics about current epoch 55 epoch.FillPopulationStatistics(pop) 56 57 // Only print to file every print_every generations 58 if epoch.Solved || epoch.Id % context.PrintEvery == 0 { 59 pop_path := fmt.Sprintf("%s/gen_%d", experiments.OutDirForTrial(ex.OutputPath, epoch.TrialId), epoch.Id) 60 file, err := os.Create(pop_path) 61 if err != nil { 62 neat.ErrorLog(fmt.Sprintf("Failed to dump population, reason: %s\n", err)) 63 } else { 64 pop.WriteBySpecies(file) 65 } 66 } 67 68 if epoch.Solved { 69 // print winner organism 70 for _, org := range pop.Organisms { 71 if org.IsWinner { 72 // Prints the winner organism to file! 73 org_path := fmt.Sprintf("%s/%s_%d-%d", experiments.OutDirForTrial(ex.OutputPath, epoch.TrialId), 74 "pole1_winner", org.Phenotype.NodeCount(), org.Phenotype.LinkCount()) 75 file, err := os.Create(org_path) 76 if err != nil { 77 neat.ErrorLog(fmt.Sprintf("Failed to dump winner organism genome, reason: %s\n", err)) 78 } else { 79 org.Genotype.Write(file) 80 neat.InfoLog(fmt.Sprintf("Generation #%d winner dumped to: %s\n", epoch.Id, org_path)) 81 } 82 break 83 } 84 } 85 } 86 87 return err 88 } 89 90 // This methods evaluates provided organism for cart pole balancing task 91 func (ex *CartPoleGenerationEvaluator) orgEvaluate(organism *genetics.Organism) bool { 92 // Try to balance a pole now 93 organism.Fitness = float64(ex.runCart(organism.Phenotype)) 94 95 if neat.LogLevel == neat.LogLevelDebug { 96 neat.DebugLog(fmt.Sprintf("Organism #%3d\tfitness: %f", organism.Genotype.Id, organism.Fitness)) 97 } 98 99 // Decide if its a winner 100 if organism.Fitness >= float64(ex.WinBalancingSteps) { 101 organism.IsWinner = true 102 } 103 104 // adjust fitness to be in range [0;1] 105 if organism.IsWinner { 106 organism.Fitness = 1.0 107 organism.Error = 0.0 108 } else if organism.Fitness == 0 { 109 organism.Error = 1.0 110 } else { 111 // we use logarithmic scale because most cart runs fail to early within ~100 steps, but 112 // we test against 500'000 balancing steps 113 logSteps := math.Log(float64(ex.WinBalancingSteps)) 114 organism.Error = (logSteps - math.Log(organism.Fitness)) / logSteps 115 organism.Fitness = 1.0 - organism.Error 116 } 117 118 return organism.IsWinner 119 } 120 121 // run cart emulation and return number of emulation steps pole was balanced 122 func (ex *CartPoleGenerationEvaluator) runCart(net *network.Network) (steps int) { 123 var x float64 /* cart position, meters */ 124 var x_dot float64 /* cart velocity */ 125 var theta float64 /* pole angle, radians */ 126 var theta_dot float64 /* pole angular velocity */ 127 if ex.RandomStart { 128 /*set up random start state*/ 129 x = float64(rand.Int31() % 4800) / 1000.0 - 2.4 130 x_dot = float64(rand.Int31() % 2000) / 1000.0 - 1 131 theta = float64(rand.Int31() % 400) / 1000.0 - .2 132 theta_dot = float64(rand.Int31() % 3000) / 1000.0 - 1.5 133 } 134 135 in := make([]float64, 5) 136 for steps = 0; steps < ex.WinBalancingSteps; steps++ { 137 /*-- setup the input layer based on the four inputs --*/ 138 in[0] = 1.0 // Bias 139 in[1] = (x + 2.4) / 4.8 140 in[2] = (x_dot + .75) / 1.5 141 in[3] = (theta + twelve_degrees) / .41 142 in[4] = (theta_dot + 1.0) / 2.0 143 net.LoadSensors(in) 144 145 /*-- activate the network based on the input --*/ 146 if res, err := net.Activate(); !res { 147 //If it loops, exit returning only fitness of 1 step 148 neat.DebugLog(fmt.Sprintf("Failed to activate Network, reason: %s", err)) 149 return 1 150 } 151 /*-- decide which way to push via which output unit is greater --*/ 152 action := 1 153 if net.Outputs[0].Activation > net.Outputs[1].Activation { 154 action = 0 155 } 156 /*--- Apply action to the simulated cart-pole ---*/ 157 x, x_dot, theta, theta_dot = ex.doAction(action, x, x_dot, theta, theta_dot) 158 159 /*--- Check for failure. If so, return steps ---*/ 160 if (x < -2.4 || x > 2.4 || theta < -twelve_degrees || theta > twelve_degrees) { 161 return steps 162 } 163 } 164 return steps 165 } 166 167 // cart_and_pole() was take directly from the pole simulator written by Richard Sutton and Charles Anderson. 168 // This simulator uses normalized, continuous inputs instead of discretizing the input space. 169 /*---------------------------------------------------------------------- 170 cart_pole: Takes an action (0 or 1) and the current values of the 171 four state variables and updates their values by estimating the state 172 TAU seconds later. 173 ----------------------------------------------------------------------*/ 174 func (ex *CartPoleGenerationEvaluator) doAction(action int, x, x_dot, theta, theta_dot float64) (x_ret, x_dot_ret, theta_ret, theta_dot_ret float64) { 175 // The cart pole configuration values 176 const GRAVITY = 9.8 177 const MASSCART = 1.0 178 const MASSPOLE = 0.5 179 const TOTAL_MASS = (MASSPOLE + MASSCART) 180 const LENGTH = 0.5 /* actually half the pole's length */ 181 const POLEMASS_LENGTH = (MASSPOLE * LENGTH) 182 const FORCE_MAG = 10.0 183 const TAU = 0.02 /* seconds between state updates */ 184 const FOURTHIRDS = 1.3333333333333 185 186 force := -FORCE_MAG 187 if action > 0 { 188 force = FORCE_MAG 189 } 190 cos_theta := math.Cos(theta) 191 sin_theta := math.Sin(theta) 192 193 temp := (force + POLEMASS_LENGTH * theta_dot * theta_dot * sin_theta) / TOTAL_MASS 194 195 theta_acc := (GRAVITY * sin_theta - cos_theta * temp) / (LENGTH * (FOURTHIRDS - MASSPOLE * cos_theta * cos_theta / TOTAL_MASS)) 196 197 x_acc := temp - POLEMASS_LENGTH * theta_acc * cos_theta / TOTAL_MASS 198 199 /*** Update the four state variables, using Euler's method. ***/ 200 x_ret = x + TAU * x_dot 201 x_dot_ret = x_dot + TAU * x_acc 202 theta_ret = theta + TAU * theta_dot 203 theta_dot_ret = theta_dot + TAU * theta_acc 204 205 return x_ret, x_dot_ret, theta_ret, theta_dot_ret 206 } 207