github.com/yaricom/goNEAT@v0.0.0-20210507221059-e2110b885482/experiments/pole/cartpole.go (about)

     1  package pole
     2  
     3  import (
     4  	"github.com/yaricom/goNEAT/neat"
     5  	"github.com/yaricom/goNEAT/neat/genetics"
     6  	"github.com/yaricom/goNEAT/experiments"
     7  	"math"
     8  	"github.com/yaricom/goNEAT/neat/network"
     9  	"math/rand"
    10  	"fmt"
    11  	"os"
    12  )
    13  
    14  const twelve_degrees = 12.0 * math.Pi / 180.0
    15  
    16  // The single pole balancing experiment entry point.
    17  // This experiment performs evolution on single pole balancing task in order to produce appropriate genome.
    18  type CartPoleGenerationEvaluator struct {
    19  	// The output path to store execution results
    20  	OutputPath        string
    21  	// The flag to indicate if cart emulator should be started from random position
    22  	RandomStart       bool
    23  	// The number of emulation steps to be done balancing pole to win
    24  	WinBalancingSteps int
    25  }
    26  
    27  // This method evaluates one epoch for given population and prints results into output directory if any.
    28  func (ex CartPoleGenerationEvaluator) GenerationEvaluate(pop *genetics.Population, epoch *experiments.Generation, context *neat.NeatContext) (err error) {
    29  	// Evaluate each organism on a test
    30  	for _, org := range pop.Organisms {
    31  		res := ex.orgEvaluate(org)
    32  
    33  		if res && (epoch.Best == nil || org.Fitness > epoch.Best.Fitness){
    34  			epoch.Solved = true
    35  			epoch.WinnerNodes = len(org.Genotype.Nodes)
    36  			epoch.WinnerGenes = org.Genotype.Extrons()
    37  			epoch.WinnerEvals = context.PopSize * epoch.Id + org.Genotype.Id
    38  			epoch.Best = org
    39  			if (epoch.WinnerNodes == 7) {
    40  				// You could dump out optimal genomes here if desired
    41  				opt_path := fmt.Sprintf("%s/%s_%d-%d", experiments.OutDirForTrial(ex.OutputPath, epoch.TrialId),
    42  					"pole1_optimal", org.Phenotype.NodeCount(), org.Phenotype.LinkCount())
    43  				file, err := os.Create(opt_path)
    44  				if err != nil {
    45  					neat.ErrorLog(fmt.Sprintf("Failed to dump optimal genome, reason: %s\n", err))
    46  				} else {
    47  					org.Genotype.Write(file)
    48  					neat.InfoLog(fmt.Sprintf("Dumped optimal genome to: %s\n", opt_path))
    49  				}
    50  			}
    51  		}
    52  	}
    53  
    54  	// Fill statistics about current epoch
    55  	epoch.FillPopulationStatistics(pop)
    56  
    57  	// Only print to file every print_every generations
    58  	if epoch.Solved || epoch.Id % context.PrintEvery == 0 {
    59  		pop_path := fmt.Sprintf("%s/gen_%d", experiments.OutDirForTrial(ex.OutputPath, epoch.TrialId), epoch.Id)
    60  		file, err := os.Create(pop_path)
    61  		if err != nil {
    62  			neat.ErrorLog(fmt.Sprintf("Failed to dump population, reason: %s\n", err))
    63  		} else {
    64  			pop.WriteBySpecies(file)
    65  		}
    66  	}
    67  
    68  	if epoch.Solved {
    69  		// print winner organism
    70  		for _, org := range pop.Organisms {
    71  			if org.IsWinner {
    72  				// Prints the winner organism to file!
    73  				org_path := fmt.Sprintf("%s/%s_%d-%d", experiments.OutDirForTrial(ex.OutputPath, epoch.TrialId),
    74  					"pole1_winner", org.Phenotype.NodeCount(), org.Phenotype.LinkCount())
    75  				file, err := os.Create(org_path)
    76  				if err != nil {
    77  					neat.ErrorLog(fmt.Sprintf("Failed to dump winner organism genome, reason: %s\n", err))
    78  				} else {
    79  					org.Genotype.Write(file)
    80  					neat.InfoLog(fmt.Sprintf("Generation #%d winner dumped to: %s\n", epoch.Id, org_path))
    81  				}
    82  				break
    83  			}
    84  		}
    85  	}
    86  
    87  	return err
    88  }
    89  
    90  // This methods evaluates provided organism for cart pole balancing task
    91  func (ex *CartPoleGenerationEvaluator) orgEvaluate(organism *genetics.Organism) bool {
    92  	// Try to balance a pole now
    93  	organism.Fitness = float64(ex.runCart(organism.Phenotype))
    94  
    95  	if neat.LogLevel == neat.LogLevelDebug {
    96  		neat.DebugLog(fmt.Sprintf("Organism #%3d\tfitness: %f", organism.Genotype.Id, organism.Fitness))
    97  	}
    98  
    99  	// Decide if its a winner
   100  	if organism.Fitness >= float64(ex.WinBalancingSteps) {
   101  		organism.IsWinner = true
   102  	}
   103  
   104  	// adjust fitness to be in range [0;1]
   105  	if organism.IsWinner {
   106  		organism.Fitness = 1.0
   107  		organism.Error = 0.0
   108  	} else if organism.Fitness == 0 {
   109  		organism.Error = 1.0
   110  	} else {
   111  		// we use logarithmic scale because most cart runs fail to early within ~100 steps, but
   112  		// we test against 500'000 balancing steps
   113  		logSteps := math.Log(float64(ex.WinBalancingSteps))
   114  		organism.Error = (logSteps - math.Log(organism.Fitness)) / logSteps
   115  		organism.Fitness = 1.0 - organism.Error
   116  	}
   117  
   118  	return organism.IsWinner
   119  }
   120  
   121  // run cart emulation and return number of emulation steps pole was balanced
   122  func (ex *CartPoleGenerationEvaluator) runCart(net *network.Network) (steps int) {
   123  	var x float64           /* cart position, meters */
   124  	var x_dot float64       /* cart velocity */
   125  	var theta float64       /* pole angle, radians */
   126  	var theta_dot float64   /* pole angular velocity */
   127  	if ex.RandomStart {
   128  		/*set up random start state*/
   129  		x = float64(rand.Int31() % 4800) / 1000.0 - 2.4
   130  		x_dot = float64(rand.Int31() % 2000) / 1000.0 - 1
   131  		theta = float64(rand.Int31() % 400) / 1000.0 - .2
   132  		theta_dot = float64(rand.Int31() % 3000) / 1000.0 - 1.5
   133  	}
   134  
   135  	in := make([]float64, 5)
   136  	for steps = 0; steps < ex.WinBalancingSteps; steps++ {
   137  		/*-- setup the input layer based on the four inputs --*/
   138  		in[0] = 1.0  // Bias
   139  		in[1] = (x + 2.4) / 4.8
   140  		in[2] = (x_dot + .75) / 1.5
   141  		in[3] = (theta + twelve_degrees) / .41
   142  		in[4] = (theta_dot + 1.0) / 2.0
   143  		net.LoadSensors(in)
   144  
   145  		/*-- activate the network based on the input --*/
   146  		if res, err := net.Activate(); !res {
   147  			//If it loops, exit returning only fitness of 1 step
   148  			neat.DebugLog(fmt.Sprintf("Failed to activate Network, reason: %s", err))
   149  			return 1
   150  		}
   151  		/*-- decide which way to push via which output unit is greater --*/
   152  		action := 1
   153  		if net.Outputs[0].Activation > net.Outputs[1].Activation {
   154  			action = 0
   155  		}
   156  		/*--- Apply action to the simulated cart-pole ---*/
   157  		x, x_dot, theta, theta_dot = ex.doAction(action, x, x_dot, theta, theta_dot)
   158  
   159  		/*--- Check for failure.  If so, return steps ---*/
   160  		if (x < -2.4 || x > 2.4 || theta < -twelve_degrees || theta > twelve_degrees) {
   161  			return steps
   162  		}
   163  	}
   164  	return steps
   165  }
   166  
   167  // cart_and_pole() was take directly from the pole simulator written by Richard Sutton and Charles Anderson.
   168  // This simulator uses normalized, continuous inputs instead of discretizing the input space.
   169  /*----------------------------------------------------------------------
   170   cart_pole:  Takes an action (0 or 1) and the current values of the
   171   four state variables and updates their values by estimating the state
   172   TAU seconds later.
   173   ----------------------------------------------------------------------*/
   174  func (ex *CartPoleGenerationEvaluator) doAction(action int, x, x_dot, theta, theta_dot float64) (x_ret, x_dot_ret, theta_ret, theta_dot_ret float64) {
   175  	// The cart pole configuration values
   176  	const GRAVITY = 9.8
   177  	const MASSCART = 1.0
   178  	const MASSPOLE = 0.5
   179  	const TOTAL_MASS = (MASSPOLE + MASSCART)
   180  	const LENGTH = 0.5      /* actually half the pole's length */
   181  	const POLEMASS_LENGTH = (MASSPOLE * LENGTH)
   182  	const FORCE_MAG = 10.0
   183  	const TAU = 0.02      /* seconds between state updates */
   184  	const FOURTHIRDS = 1.3333333333333
   185  
   186  	force := -FORCE_MAG
   187  	if action > 0 {
   188  		force = FORCE_MAG
   189  	}
   190  	cos_theta := math.Cos(theta)
   191  	sin_theta := math.Sin(theta)
   192  
   193  	temp := (force + POLEMASS_LENGTH * theta_dot * theta_dot * sin_theta) / TOTAL_MASS
   194  
   195  	theta_acc := (GRAVITY * sin_theta - cos_theta * temp) / (LENGTH * (FOURTHIRDS - MASSPOLE * cos_theta * cos_theta / TOTAL_MASS))
   196  
   197  	x_acc := temp - POLEMASS_LENGTH * theta_acc * cos_theta / TOTAL_MASS
   198  
   199  	/*** Update the four state variables, using Euler's method. ***/
   200  	x_ret = x + TAU * x_dot
   201  	x_dot_ret = x_dot + TAU * x_acc
   202  	theta_ret = theta + TAU * theta_dot
   203  	theta_dot_ret = theta_dot + TAU * theta_acc
   204  
   205  	return x_ret, x_dot_ret, theta_ret, theta_dot_ret
   206  }
   207