go.undefinedlabs.com/scopeagent@v0.4.2/instrumentation/coverage/coverage.go (about)

     1  package coverage
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"os/exec"
    10  	"path"
    11  	"path/filepath"
    12  	"runtime"
    13  	"sort"
    14  	"strings"
    15  	"sync"
    16  	"sync/atomic"
    17  	"testing"
    18  	_ "unsafe"
    19  
    20  	"github.com/google/uuid"
    21  	"go.undefinedlabs.com/scopeagent/instrumentation"
    22  )
    23  
    24  type (
    25  	coverage struct {
    26  		Type    string         `json:"type" msgpack:"type"`
    27  		Version string         `json:"version" msgpack:"version"`
    28  		Uuid    string         `json:"uuid" msgpack:"uuid"`
    29  		Files   []fileCoverage `json:"files" msgpack:"files"`
    30  	}
    31  	fileCoverage struct {
    32  		Filename   string  `json:"filename" msgpack:"filename"`
    33  		Boundaries [][]int `json:"boundaries" msgpack:"boundaries"`
    34  	}
    35  	pkg struct {
    36  		ImportPath string
    37  		Dir        string
    38  		Error      *struct {
    39  			Err string
    40  		}
    41  	}
    42  	blockWithCount struct {
    43  		block *testing.CoverBlock
    44  		count int
    45  	}
    46  )
    47  
    48  //go:linkname cover testing.cover
    49  var (
    50  	cover         testing.Cover
    51  	counters      map[string][]uint32
    52  	countersMutex sync.Mutex
    53  	filePathData  map[string]string
    54  	initOnce      sync.Once
    55  )
    56  
    57  // Initialize coverage
    58  func initCoverage() {
    59  	initOnce.Do(func() {
    60  		var files []string
    61  		for key := range cover.Blocks {
    62  			files = append(files, key)
    63  		}
    64  		pkgData, err := findPkgs(files)
    65  		if err != nil {
    66  			pkgData = map[string]*pkg{}
    67  			instrumentation.Logger().Printf("coverage error: %v", err)
    68  		}
    69  		filePathData = map[string]string{}
    70  		for key := range cover.Blocks {
    71  			filePath, err := findFile(pkgData, key)
    72  			if err != nil {
    73  				instrumentation.Logger().Printf("coverage error: %v", err)
    74  			} else {
    75  				filePathData[key] = filePath
    76  			}
    77  		}
    78  		counters = map[string][]uint32{}
    79  	})
    80  }
    81  
    82  // Clean the counters for a new coverage session
    83  func StartCoverage() {
    84  	countersMutex.Lock()
    85  	defer countersMutex.Unlock()
    86  	if cover.Mode == "" {
    87  		return
    88  	}
    89  	initCoverage()
    90  
    91  	for name, counts := range cover.Counters {
    92  		counters[name] = make([]uint32, len(counts))
    93  		for i := range counts {
    94  			counters[name][i] = atomic.SwapUint32(&counts[i], 0)
    95  		}
    96  	}
    97  }
    98  
    99  // Restore counters
   100  func RestoreCoverageCounters() {
   101  	countersMutex.Lock()
   102  	defer countersMutex.Unlock()
   103  	if cover.Mode == "" {
   104  		return
   105  	}
   106  	for name, counts := range cover.Counters {
   107  		for i := range counts {
   108  			atomic.StoreUint32(&counts[i], counters[name][i]+atomic.LoadUint32(&counts[i]))
   109  		}
   110  	}
   111  }
   112  
   113  // Get the counters values and extract the coverage info
   114  func EndCoverage() *coverage {
   115  	countersMutex.Lock()
   116  	defer countersMutex.Unlock()
   117  	if cover.Mode == "" {
   118  		return nil
   119  	}
   120  
   121  	var covSource = map[string][]*blockWithCount{}
   122  	for name, counts := range cover.Counters {
   123  		if file, ok := filePathData[name]; ok {
   124  			blocks := cover.Blocks[name]
   125  			for i := range counts {
   126  				count := atomic.LoadUint32(&counts[i])
   127  				atomic.StoreUint32(&counts[i], counters[name][i]+count)
   128  				covSource[file] = append(covSource[file], &blockWithCount{
   129  					block: &blocks[i],
   130  					count: int(count),
   131  				})
   132  			}
   133  			sort.SliceStable(covSource[file][:], func(i, j int) bool {
   134  				if covSource[file][i].block.Line0 == covSource[file][j].block.Line0 {
   135  					return covSource[file][i].block.Col0 < covSource[file][j].block.Col0
   136  				}
   137  				return covSource[file][i].block.Line0 < covSource[file][j].block.Line0
   138  			})
   139  		}
   140  	}
   141  
   142  	fileMap := map[string][][]int{}
   143  	for file, blockCount := range covSource {
   144  		blockStack := make([]*blockWithCount, 0)
   145  		for _, curBlock := range blockCount {
   146  			if curBlock.count > 0 {
   147  				var prvBlock *testing.CoverBlock
   148  				blockStackLen := len(blockStack)
   149  				if blockStackLen > 0 {
   150  					prvBlock = blockStack[blockStackLen-1].block
   151  				}
   152  
   153  				if prvBlock == nil {
   154  					fileMap[file] = append(fileMap[file], []int{
   155  						int(curBlock.block.Line0), int(curBlock.block.Col0), curBlock.count,
   156  					})
   157  					blockStack = append(blockStack, curBlock)
   158  				} else if contains(prvBlock, curBlock.block) {
   159  					pBoundCol := int(curBlock.block.Col0)
   160  					cBoundCol := int(curBlock.block.Col0)
   161  					if pBoundCol > 0 {
   162  						pBoundCol--
   163  					} else {
   164  						cBoundCol++
   165  					}
   166  					fileMap[file] = append(fileMap[file], []int{
   167  						int(curBlock.block.Line0), pBoundCol, -1,
   168  					})
   169  					fileMap[file] = append(fileMap[file], []int{
   170  						int(curBlock.block.Line0), cBoundCol, curBlock.count,
   171  					})
   172  					blockStack = append(blockStack, curBlock)
   173  				} else {
   174  					pBoundCol := int(prvBlock.Col1)
   175  					cBoundCol := int(curBlock.block.Col0)
   176  					if prvBlock.Line1 == curBlock.block.Line0 {
   177  						if pBoundCol > 0 {
   178  							pBoundCol--
   179  						} else {
   180  							cBoundCol++
   181  						}
   182  					}
   183  					fileMap[file] = append(fileMap[file], []int{
   184  						int(prvBlock.Line1), pBoundCol, -1,
   185  					})
   186  					fileMap[file] = append(fileMap[file], []int{
   187  						int(curBlock.block.Line0), cBoundCol, curBlock.count,
   188  					})
   189  					blockStack[blockStackLen-1] = curBlock
   190  				}
   191  			}
   192  		}
   193  
   194  		if len(blockStack) > 0 {
   195  			var prvBlock *blockWithCount
   196  			for i := len(blockStack) - 1; i >= 0; i-- {
   197  				cBlock := blockStack[i]
   198  				if prvBlock != nil {
   199  					fileMap[file] = append(fileMap[file], []int{
   200  						int(prvBlock.block.Line1), int(prvBlock.block.Col1) + 1, cBlock.count,
   201  					})
   202  				}
   203  				fileMap[file] = append(fileMap[file], []int{
   204  					int(cBlock.block.Line1), int(cBlock.block.Col1), -1,
   205  				})
   206  				prvBlock = cBlock
   207  			}
   208  		}
   209  	}
   210  	files := make([]fileCoverage, 0)
   211  	for key, value := range fileMap {
   212  		files = append(files, fileCoverage{
   213  			Filename:   key,
   214  			Boundaries: value,
   215  		})
   216  	}
   217  	uuidValue, _ := uuid.NewRandom()
   218  	coverageData := &coverage{
   219  		Type:    "com.undefinedlabs.uccf",
   220  		Version: "0.2.0",
   221  		Uuid:    uuidValue.String(),
   222  		Files:   files,
   223  	}
   224  	return coverageData
   225  }
   226  
   227  func contains(outer, inner *testing.CoverBlock) bool {
   228  	if outer != nil && inner != nil {
   229  		if outer.Line0 > inner.Line0 || (outer.Line0 == inner.Line0 && outer.Col0 > inner.Col0) {
   230  			return false
   231  		}
   232  		if outer.Line1 < inner.Line1 || (outer.Line1 == inner.Line1 && outer.Col1 < inner.Col1) {
   233  			return false
   234  		}
   235  		return true
   236  	}
   237  	return false
   238  }
   239  
   240  // The following functions are to find the absolute path from coverage data.
   241  // There are extracted from the go cover cmd tool: https://github.com/golang/go/blob/master/src/cmd/cover/func.go
   242  
   243  func findPkgs(fileNames []string) (map[string]*pkg, error) {
   244  	// Run go list to find the location of every package we care about.
   245  	pkgs := make(map[string]*pkg)
   246  	var list []string
   247  	for _, filename := range fileNames {
   248  		if strings.HasPrefix(filename, ".") || filepath.IsAbs(filename) {
   249  			// Relative or absolute path.
   250  			continue
   251  		}
   252  		pkg := path.Dir(filename)
   253  		if _, ok := pkgs[pkg]; !ok {
   254  			pkgs[pkg] = nil
   255  			list = append(list, pkg)
   256  		}
   257  	}
   258  
   259  	if len(list) == 0 {
   260  		return pkgs, nil
   261  	}
   262  
   263  	// Note: usually run as "go tool cover" in which case $GOROOT is set,
   264  	// in which case runtime.GOROOT() does exactly what we want.
   265  	goTool := filepath.Join(runtime.GOROOT(), "bin/go")
   266  	cmd := exec.Command(goTool, append([]string{"list", "-e", "-json"}, list...)...)
   267  	var stderr bytes.Buffer
   268  	cmd.Stderr = &stderr
   269  	stdout, err := cmd.Output()
   270  	if err != nil {
   271  		return nil, fmt.Errorf("cannot run go list: %v\n%s", err, stderr.Bytes())
   272  	}
   273  	dec := json.NewDecoder(bytes.NewReader(stdout))
   274  	for {
   275  		var pkg pkg
   276  		err := dec.Decode(&pkg)
   277  		if err == io.EOF {
   278  			break
   279  		}
   280  		if err != nil {
   281  			return nil, fmt.Errorf("decoding go list json: %v", err)
   282  		}
   283  		pkgs[pkg.ImportPath] = &pkg
   284  	}
   285  	return pkgs, nil
   286  }
   287  
   288  // findFile finds the location of the named file in GOROOT, GOPATH etc.
   289  func findFile(pkgs map[string]*pkg, file string) (string, error) {
   290  	if strings.HasPrefix(file, ".") || filepath.IsAbs(file) {
   291  		// Relative or absolute path.
   292  		return file, nil
   293  	}
   294  	pkg := pkgs[path.Dir(file)]
   295  	if pkg != nil {
   296  		if pkg.Dir != "" {
   297  			return filepath.Join(pkg.Dir, path.Base(file)), nil
   298  		}
   299  		if pkg.Error != nil {
   300  			return "", errors.New(pkg.Error.Err)
   301  		}
   302  	}
   303  	return "", fmt.Errorf("did not find package for %s in go list output", file)
   304  }