github.com/tobgu/qframe@v0.4.0/contrib/gonum/qplot/qplot_test.go (about)

     1  package qplot_test
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"math"
     6  	"os"
     7  	"testing"
     8  	"time"
     9  
    10  	"gonum.org/v1/gonum/stat"
    11  
    12  	"gonum.org/v1/plot"
    13  	"gonum.org/v1/plot/plotter"
    14  	"gonum.org/v1/plot/plotutil"
    15  
    16  	"github.com/tobgu/qframe"
    17  	"github.com/tobgu/qframe/contrib/gonum/qplot"
    18  )
    19  
    20  func panicOnErr(err error) {
    21  	if err != nil {
    22  		panic(err)
    23  	}
    24  }
    25  
    26  // SlidingWindow returns a function that finds
    27  // the average of n time periods.
    28  func SlidingWindow(n int) func(float64) float64 {
    29  	var buf []float64
    30  	return func(value float64) float64 {
    31  		if len(buf) < n {
    32  			buf = append(buf, value)
    33  			return value
    34  		}
    35  		buf = append(buf[1:], value)
    36  		return stat.Mean(buf, nil)
    37  	}
    38  }
    39  
    40  func ExampleQPlot() {
    41  	fp, err := os.Open("testdata/GlobalTemperatures.csv")
    42  	panicOnErr(err)
    43  	defer fp.Close()
    44  
    45  	qf := qframe.ReadCSV(fp)
    46  	// Filter out any missing values
    47  	qf = qf.Filter(qframe.Filter{
    48  		Column:     "LandAndOceanAverageTemperature",
    49  		Comparator: func(f float64) bool { return !math.IsNaN(f) },
    50  	})
    51  	// QFrame does not yet have native support for timeseries
    52  	// data so we convert the timestamp to epoch time.
    53  	qf = qf.Apply(qframe.Instruction{
    54  		Fn: func(ts *string) int {
    55  			tm, err := time.Parse("2006-01-02", *ts)
    56  			if err != nil {
    57  				panic(err)
    58  			}
    59  			return int(tm.Unix())
    60  		},
    61  		SrcCol1: "dt",
    62  		DstCol:  "time",
    63  	})
    64  	// Compute the average of the last 2 years of temperatures.
    65  	window := SlidingWindow(24)
    66  	qf = qf.Apply(qframe.Instruction{
    67  		Fn: func(value float64) float64 {
    68  			return window(value)
    69  		},
    70  		SrcCol1: "LandAndOceanAverageTemperature",
    71  		DstCol:  "SMA",
    72  	})
    73  
    74  	// Create a new configuration
    75  	cfg := qplot.NewConfig(
    76  		// Configure the base Plot
    77  		qplot.PlotConfig(
    78  			func(plt *plot.Plot) {
    79  				plt.Add(plotter.NewGrid())
    80  				plt.Title.Text = "Global Land & Ocean Temperatures"
    81  				plt.X.Label.Text = "Time"
    82  				plt.Y.Label.Text = "Temperature"
    83  			},
    84  		),
    85  		// Plot each recorded temperature as a scatter plot
    86  		qplot.Plotter(
    87  			qplot.ScatterPlotter(
    88  				qplot.MustNewXYer("time", "LandAndOceanAverageTemperature", qf),
    89  				func(plt *plot.Plot, line *plotter.Scatter) {
    90  					plt.Legend.Add("Temperature", line)
    91  					line.Color = plotutil.Color(2)
    92  				},
    93  			)),
    94  		// Plot the SMA as a line
    95  		qplot.Plotter(
    96  			qplot.LinePlotter(
    97  				qplot.MustNewXYer("time", "SMA", qf),
    98  				func(plt *plot.Plot, line *plotter.Line) {
    99  					plt.Legend.Add("SMA", line)
   100  					line.Color = plotutil.Color(1)
   101  				},
   102  			)),
   103  	)
   104  	// Create a new QPlot
   105  	qp := qplot.NewQPlot(cfg)
   106  	// Write the plot to disk
   107  	panicOnErr(os.WriteFile("testdata/GlobalTemperatures.png", qp.MustBytes(), 0644))
   108  }
   109  
   110  func getHash(t *testing.T, path string) [32]byte {
   111  	raw, err := os.ReadFile(path)
   112  	panicOnErr(err)
   113  	return sha256.Sum256(raw)
   114  }
   115  
   116  func TestQPlot(t *testing.T) {
   117  	original := getHash(t, "testdata/GlobalTemperatures.png")
   118  	ExampleQPlot()
   119  	modified := getHash(t, "testdata/GlobalTemperatures.png")
   120  	if original != modified {
   121  		t.Errorf("output image has changed")
   122  	}
   123  }