gotest.tools/gotestsum@v1.11.0/cmd/tool/matrix/matrix.go (about)

     1  package matrix
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"math"
     9  	"os"
    10  	"path/filepath"
    11  	"sort"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/dnephin/pflag"
    16  	"gotest.tools/gotestsum/internal/log"
    17  	"gotest.tools/gotestsum/testjson"
    18  )
    19  
    20  func Run(name string, args []string) error {
    21  	flags, opts := setupFlags(name)
    22  	switch err := flags.Parse(args); {
    23  	case err == pflag.ErrHelp:
    24  		return nil
    25  	case err != nil:
    26  		usage(os.Stderr, name, flags)
    27  		return err
    28  	}
    29  	opts.stdin = os.Stdin
    30  	opts.stdout = os.Stdout
    31  	return run(*opts)
    32  }
    33  
    34  type options struct {
    35  	numPartitions      uint
    36  	timingFilesPattern string
    37  	debug              bool
    38  
    39  	// shims for testing
    40  	stdin  io.Reader
    41  	stdout io.Writer
    42  }
    43  
    44  func setupFlags(name string) (*pflag.FlagSet, *options) {
    45  	opts := &options{}
    46  	flags := pflag.NewFlagSet(name, pflag.ContinueOnError)
    47  	flags.SetInterspersed(false)
    48  	flags.Usage = func() {
    49  		usage(os.Stdout, name, flags)
    50  	}
    51  	flags.UintVar(&opts.numPartitions, "partitions", 0,
    52  		"number of parallel partitions to create in the test matrix")
    53  	flags.StringVar(&opts.timingFilesPattern, "timing-files", "",
    54  		"glob pattern to match files that contain test2json events, ex: ./logs/*.log")
    55  	flags.BoolVar(&opts.debug, "debug", false,
    56  		"enable debug logging")
    57  	return flags, opts
    58  }
    59  
    60  func usage(out io.Writer, name string, flags *pflag.FlagSet) {
    61  	fmt.Fprintf(out, `Usage:
    62      %[1]s [flags]
    63  
    64  Read a list of packages from stdin and output a GitHub Actions matrix strategy
    65  that splits the packages by previous run times to minimize overall CI runtime.
    66  
    67      echo -n "matrix=" >> $GITHUB_OUTPUT
    68      go list ./... | %[1]s --timing-files ./*.log --partitions 4 >> $GITHUB_OUTPUT
    69  
    70  The output of the command is a JSON object that can be used as the matrix
    71  strategy for a test job.
    72  
    73  
    74  Flags:
    75  `, name)
    76  	flags.SetOutput(out)
    77  	flags.PrintDefaults()
    78  }
    79  
    80  func run(opts options) error {
    81  	log.SetLevel(log.InfoLevel)
    82  	if opts.debug {
    83  		log.SetLevel(log.DebugLevel)
    84  	}
    85  	if opts.numPartitions < 2 {
    86  		return fmt.Errorf("--partitions must be atleast 2")
    87  	}
    88  	if opts.timingFilesPattern == "" {
    89  		return fmt.Errorf("--timing-files is required")
    90  	}
    91  
    92  	pkgs, err := readPackages(opts.stdin)
    93  	if err != nil {
    94  		return fmt.Errorf("failed to read packages from stdin: %v", err)
    95  	}
    96  
    97  	files, err := readTimingReports(opts)
    98  	if err != nil {
    99  		return fmt.Errorf("failed to read or delete timing files: %v", err)
   100  	}
   101  	defer closeFiles(files)
   102  
   103  	pkgTiming, err := packageTiming(files)
   104  	if err != nil {
   105  		return err
   106  	}
   107  
   108  	buckets := bucketPackages(packagePercentile(pkgTiming), pkgs, opts.numPartitions)
   109  	return writeMatrix(opts.stdout, buckets)
   110  }
   111  
   112  func readPackages(stdin io.Reader) ([]string, error) {
   113  	var packages []string
   114  	scan := bufio.NewScanner(stdin)
   115  	for scan.Scan() {
   116  		packages = append(packages, scan.Text())
   117  	}
   118  	return packages, scan.Err()
   119  }
   120  
   121  func readTimingReports(opts options) ([]*os.File, error) {
   122  	fileNames, err := filepath.Glob(opts.timingFilesPattern)
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  
   127  	files := make([]*os.File, 0, len(fileNames))
   128  	for _, fileName := range fileNames {
   129  		fh, err := os.Open(fileName)
   130  		if err != nil {
   131  			return nil, err
   132  		}
   133  		files = append(files, fh)
   134  	}
   135  
   136  	log.Infof("Found %v timing files in %v", len(files), opts.timingFilesPattern)
   137  	return files, nil
   138  }
   139  
   140  func parseEvent(reader io.Reader) (testjson.TestEvent, error) {
   141  	event := testjson.TestEvent{}
   142  	err := json.NewDecoder(reader).Decode(&event)
   143  	return event, err
   144  }
   145  
   146  func packageTiming(files []*os.File) (map[string][]time.Duration, error) {
   147  	timing := make(map[string][]time.Duration)
   148  	for _, fh := range files {
   149  		exec, err := testjson.ScanTestOutput(testjson.ScanConfig{Stdout: fh})
   150  		if err != nil {
   151  			return nil, fmt.Errorf("failed to read events from %v: %v", fh.Name(), err)
   152  		}
   153  
   154  		for _, pkg := range exec.Packages() {
   155  			timing[pkg] = append(timing[pkg], exec.Package(pkg).Elapsed())
   156  		}
   157  	}
   158  	return timing, nil
   159  }
   160  
   161  func packagePercentile(timing map[string][]time.Duration) map[string]time.Duration {
   162  	result := make(map[string]time.Duration)
   163  	for pkg, times := range timing {
   164  		lenTimes := len(times)
   165  		if lenTimes == 0 {
   166  			result[pkg] = 0
   167  			continue
   168  		}
   169  
   170  		sort.Slice(times, func(i, j int) bool {
   171  			return times[i] < times[j]
   172  		})
   173  
   174  		r := int(math.Ceil(0.85 * float64(lenTimes)))
   175  		if r == 0 {
   176  			result[pkg] = times[0]
   177  			continue
   178  		}
   179  		result[pkg] = times[r-1]
   180  	}
   181  	return result
   182  }
   183  
   184  func closeFiles(files []*os.File) {
   185  	for _, fh := range files {
   186  		_ = fh.Close()
   187  	}
   188  }
   189  
   190  func bucketPackages(timing map[string]time.Duration, packages []string, n uint) []bucket {
   191  	sort.SliceStable(packages, func(i, j int) bool {
   192  		return timing[packages[i]] >= timing[packages[j]]
   193  	})
   194  
   195  	buckets := make([]bucket, n)
   196  	for _, pkg := range packages {
   197  		i := minBucket(buckets)
   198  		buckets[i].Total += timing[pkg]
   199  		buckets[i].Packages = append(buckets[i].Packages, pkg)
   200  		log.Debugf("adding %v (%v) to bucket %v with total %v",
   201  			pkg, timing[pkg], i, buckets[i].Total)
   202  	}
   203  	return buckets
   204  }
   205  
   206  func minBucket(buckets []bucket) int {
   207  	var n int
   208  	var min time.Duration = -1
   209  	for i, b := range buckets {
   210  		switch {
   211  		case min < 0 || b.Total < min:
   212  			min = b.Total
   213  			n = i
   214  		case b.Total == min && len(buckets[i].Packages) < len(buckets[n].Packages):
   215  			n = i
   216  		}
   217  	}
   218  	return n
   219  }
   220  
   221  type bucket struct {
   222  	Total    time.Duration
   223  	Packages []string
   224  }
   225  
   226  type matrix struct {
   227  	Include []Partition `json:"include"`
   228  }
   229  
   230  type Partition struct {
   231  	ID               int    `json:"id"`
   232  	EstimatedRuntime string `json:"estimatedRuntime"`
   233  	Packages         string `json:"packages"`
   234  	Description      string `json:"description"`
   235  }
   236  
   237  func writeMatrix(out io.Writer, buckets []bucket) error {
   238  	m := matrix{Include: make([]Partition, len(buckets))}
   239  	for i, bucket := range buckets {
   240  		p := Partition{
   241  			ID:               i,
   242  			EstimatedRuntime: bucket.Total.String(),
   243  			Packages:         strings.Join(bucket.Packages, " "),
   244  		}
   245  		if len(bucket.Packages) > 0 {
   246  			var extra string
   247  			if len(bucket.Packages) > 1 {
   248  				extra = fmt.Sprintf(" and %d others", len(bucket.Packages)-1)
   249  			}
   250  			p.Description = fmt.Sprintf("%d - %v%v",
   251  				p.ID, testjson.RelativePackagePath(bucket.Packages[0]), extra)
   252  		}
   253  
   254  		m.Include[i] = p
   255  	}
   256  
   257  	log.Debugf("%v\n", debugMatrix(m))
   258  
   259  	err := json.NewEncoder(out).Encode(m)
   260  	if err != nil {
   261  		return fmt.Errorf("failed to json encode output: %v", err)
   262  	}
   263  	return nil
   264  }
   265  
   266  type debugMatrix matrix
   267  
   268  func (d debugMatrix) String() string {
   269  	raw, err := json.MarshalIndent(d, "", "  ")
   270  	if err != nil {
   271  		return fmt.Sprintf("failed to marshal: %v", err.Error())
   272  	}
   273  	return string(raw)
   274  }