github.com/onsi/ginkgo@v1.16.6-0.20211118180735-4e1925ba4c95/ginkgo/unfocus/unfocus_command.go (about)

     1  package unfocus
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/parser"
     8  	"go/token"
     9  	"io"
    10  	"os"
    11  	"path/filepath"
    12  	"strings"
    13  	"sync"
    14  
    15  	"github.com/onsi/ginkgo/ginkgo/command"
    16  )
    17  
    18  func BuildUnfocusCommand() command.Command {
    19  	return command.Command{
    20  		Name:     "unfocus",
    21  		Usage:    "ginkgo unfocus",
    22  		ShortDoc: "Recursively unfocus any focused tests under the current directory",
    23  		DocLink:  "filtering-specs",
    24  		Command: func(_ []string, _ []string) {
    25  			unfocusSpecs()
    26  		},
    27  	}
    28  }
    29  
    30  func unfocusSpecs() {
    31  	fmt.Println("Scanning for focus...")
    32  
    33  	goFiles := make(chan string)
    34  	go func() {
    35  		unfocusDir(goFiles, ".")
    36  		close(goFiles)
    37  	}()
    38  
    39  	const workers = 10
    40  	wg := sync.WaitGroup{}
    41  	wg.Add(workers)
    42  
    43  	for i := 0; i < workers; i++ {
    44  		go func() {
    45  			for path := range goFiles {
    46  				unfocusFile(path)
    47  			}
    48  			wg.Done()
    49  		}()
    50  	}
    51  
    52  	wg.Wait()
    53  }
    54  
    55  func unfocusDir(goFiles chan string, path string) {
    56  	files, err := os.ReadDir(path)
    57  	if err != nil {
    58  		fmt.Println(err.Error())
    59  		return
    60  	}
    61  
    62  	for _, f := range files {
    63  		switch {
    64  		case f.IsDir() && shouldProcessDir(f.Name()):
    65  			unfocusDir(goFiles, filepath.Join(path, f.Name()))
    66  		case !f.IsDir() && shouldProcessFile(f.Name()):
    67  			goFiles <- filepath.Join(path, f.Name())
    68  		}
    69  	}
    70  }
    71  
    72  func shouldProcessDir(basename string) bool {
    73  	return basename != "vendor" && !strings.HasPrefix(basename, ".")
    74  }
    75  
    76  func shouldProcessFile(basename string) bool {
    77  	return strings.HasSuffix(basename, ".go")
    78  }
    79  
    80  func unfocusFile(path string) {
    81  	data, err := os.ReadFile(path)
    82  	if err != nil {
    83  		fmt.Printf("error reading file '%s': %s\n", path, err.Error())
    84  		return
    85  	}
    86  
    87  	ast, err := parser.ParseFile(token.NewFileSet(), path, bytes.NewReader(data), 0)
    88  	if err != nil {
    89  		fmt.Printf("error parsing file '%s': %s\n", path, err.Error())
    90  		return
    91  	}
    92  
    93  	eliminations := scanForFocus(ast)
    94  	if len(eliminations) == 0 {
    95  		return
    96  	}
    97  
    98  	fmt.Printf("...updating %s\n", path)
    99  	backup, err := writeBackup(path, data)
   100  	if err != nil {
   101  		fmt.Printf("error creating backup file: %s\n", err.Error())
   102  		return
   103  	}
   104  
   105  	if err := updateFile(path, data, eliminations); err != nil {
   106  		fmt.Printf("error writing file '%s': %s\n", path, err.Error())
   107  		return
   108  	}
   109  
   110  	os.Remove(backup)
   111  }
   112  
   113  func writeBackup(path string, data []byte) (string, error) {
   114  	t, err := os.CreateTemp(filepath.Dir(path), filepath.Base(path))
   115  
   116  	if err != nil {
   117  		return "", fmt.Errorf("error creating temporary file: %w", err)
   118  	}
   119  	defer t.Close()
   120  
   121  	if _, err := io.Copy(t, bytes.NewReader(data)); err != nil {
   122  		return "", fmt.Errorf("error writing to temporary file: %w", err)
   123  	}
   124  
   125  	return t.Name(), nil
   126  }
   127  
   128  func updateFile(path string, data []byte, eliminations [][]int64) error {
   129  	to, err := os.Create(path)
   130  	if err != nil {
   131  		return fmt.Errorf("error opening file for writing '%s': %w\n", path, err)
   132  	}
   133  	defer to.Close()
   134  
   135  	from := bytes.NewReader(data)
   136  	var cursor int64
   137  	for _, eliminationRange := range eliminations {
   138  		positionToEliminate, lengthToEliminate := eliminationRange[0], eliminationRange[1]
   139  		if _, err := io.CopyN(to, from, positionToEliminate-cursor); err != nil {
   140  			return fmt.Errorf("error copying data: %w", err)
   141  		}
   142  
   143  		cursor = positionToEliminate + lengthToEliminate
   144  
   145  		if _, err := from.Seek(lengthToEliminate, io.SeekCurrent); err != nil {
   146  			return fmt.Errorf("error seeking to position in buffer: %w", err)
   147  		}
   148  	}
   149  
   150  	if _, err := io.Copy(to, from); err != nil {
   151  		return fmt.Errorf("error copying end data: %w", err)
   152  	}
   153  
   154  	return nil
   155  }
   156  
   157  func scanForFocus(file *ast.File) (eliminations [][]int64) {
   158  	ast.Inspect(file, func(n ast.Node) bool {
   159  		if c, ok := n.(*ast.CallExpr); ok {
   160  			if i, ok := c.Fun.(*ast.Ident); ok {
   161  				if isFocus(i.Name) {
   162  					eliminations = append(eliminations, []int64{int64(i.Pos() - file.Pos()), 1})
   163  				}
   164  			}
   165  		}
   166  
   167  		if i, ok := n.(*ast.Ident); ok {
   168  			if i.Name == "Focus" {
   169  				eliminations = append(eliminations, []int64{int64(i.Pos() - file.Pos()), 6})
   170  			}
   171  		}
   172  
   173  		return true
   174  	})
   175  
   176  	return eliminations
   177  }
   178  
   179  func isFocus(name string) bool {
   180  	switch name {
   181  	case "FDescribe", "FContext", "FIt", "FDescribeTable", "FEntry", "FSpecify", "FWhen":
   182  		return true
   183  	default:
   184  		return false
   185  	}
   186  }