github.com/gorgonia/agogo@v0.1.1/statistics.go (about)

     1  package agogo
     2  
     3  import (
     4  	"encoding/csv"
     5  	"fmt"
     6  	"os"
     7  	"strconv"
     8  )
     9  
    10  type Statistics struct {
    11  	Creation []string
    12  	Wins     map[string][]float32
    13  	Losses   map[string][]float32
    14  	Draws    map[string][]float32
    15  }
    16  
    17  func makeStatistics() Statistics {
    18  	return Statistics{
    19  		Creation: make([]string, 0, 64),
    20  		Wins:     make(map[string][]float32),
    21  		Losses:   make(map[string][]float32),
    22  		Draws:    make(map[string][]float32),
    23  	}
    24  }
    25  
    26  func (s *Statistics) update(A *Agent) {
    27  	aname := fmt.Sprintf("%p", A.NN)
    28  
    29  	if _, ok := s.Wins[aname]; !ok {
    30  		s.Creation = append(s.Creation, aname)
    31  	}
    32  
    33  	s.Wins[aname] = append(s.Wins[aname], A.Wins)
    34  	s.Losses[aname] = append(s.Losses[aname], A.Loss)
    35  	s.Draws[aname] = append(s.Draws[aname], A.Draw)
    36  }
    37  
    38  // Dump the statistics in filename using a CSV format
    39  func (s *Statistics) Dump(filename string) error {
    40  	f, err := os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
    41  	if err != nil {
    42  		return err
    43  	}
    44  	defer f.Close()
    45  	w := csv.NewWriter(f)
    46  	if err := w.Write(s.Creation); err != nil {
    47  		return err
    48  	}
    49  	var records [][]string
    50  	for i, agent := range s.Creation {
    51  		for j, win := range s.Wins[agent] {
    52  			record := make([]string, len(s.Creation))
    53  			winRate := win / (win + s.Losses[agent][j] + s.Draws[agent][j])
    54  
    55  			record[i] = strconv.FormatFloat(float64(winRate), 'f', 3, 32)
    56  			records = append(records, record)
    57  		}
    58  	}
    59  	if err := w.WriteAll(records); err != nil {
    60  		return err
    61  	}
    62  	w.Flush()
    63  	return nil
    64  }