github.com/moitias/moq@v0.0.0-20240223074357-5eb0f0ba4054/main.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"log"
    11  	"os"
    12  	"path/filepath"
    13  
    14  	"github.com/moitias/moq/pkg/moq"
    15  )
    16  
    17  // Version is the command version, injected at build time.
    18  var Version string = "dev"
    19  
    20  type userFlags struct {
    21  	outFile    string
    22  	pkgName    string
    23  	formatter  string
    24  	stubImpl   bool
    25  	skipEnsure bool
    26  	withResets bool
    27  	remove     bool
    28  	args       []string
    29  }
    30  
    31  func main() {
    32  	var flags userFlags
    33  	flag.StringVar(&flags.outFile, "out", "", "output file (default stdout)")
    34  	flag.StringVar(&flags.pkgName, "pkg", "", "package name (default will infer)")
    35  	flag.StringVar(&flags.formatter, "fmt", "", "go pretty-printer: gofmt, goimports or noop (default gofmt)")
    36  	flag.BoolVar(&flags.stubImpl, "stub", false,
    37  		"return zero values when no mock implementation is provided, do not panic")
    38  	printVersion := flag.Bool("version", false, "show the version for moq")
    39  	flag.BoolVar(&flags.skipEnsure, "skip-ensure", false,
    40  		"suppress mock implementation check, avoid import cycle if mocks generated outside of the tested package")
    41  	flag.BoolVar(&flags.remove, "rm", false, "first remove output file, if it exists")
    42  	flag.BoolVar(&flags.withResets, "with-resets", false,
    43  		"generate functions to facilitate resetting calls made to a mock")
    44  
    45  	flag.Usage = func() {
    46  		fmt.Println(`moq [flags] source-dir interface [interface2 [interface3 [...]]]`)
    47  		flag.PrintDefaults()
    48  		fmt.Println(`Specifying an alias for the mock is also supported with the format 'interface:alias'`)
    49  		fmt.Println(`Ex: moq -pkg different . MyInterface:MyMock`)
    50  	}
    51  
    52  	flag.Parse()
    53  	flags.args = flag.Args()
    54  
    55  	if *printVersion {
    56  		fmt.Printf("moq version %s\n", Version)
    57  		os.Exit(0)
    58  	}
    59  
    60  	if err := run(flags); err != nil {
    61  		fmt.Fprintln(os.Stderr, err)
    62  		flag.Usage()
    63  		os.Exit(1)
    64  	}
    65  }
    66  
    67  func run(flags userFlags) error {
    68  	if len(flags.args) < 2 {
    69  		return errors.New("not enough arguments")
    70  	}
    71  
    72  	if flags.remove && flags.outFile != "" {
    73  		if err := os.Remove(flags.outFile); err != nil {
    74  			if !errors.Is(err, os.ErrNotExist) {
    75  				return err
    76  			}
    77  		}
    78  	}
    79  
    80  	var buf bytes.Buffer
    81  	var out io.Writer = os.Stdout
    82  	if flags.outFile != "" {
    83  		out = &buf
    84  	}
    85  
    86  	log.Println("CREATE MOCK")
    87  	srcDir, args := flags.args[0], flags.args[1:]
    88  	m, err := moq.New(moq.Config{
    89  		SrcDir:     srcDir,
    90  		PkgName:    flags.pkgName,
    91  		Formatter:  flags.formatter,
    92  		StubImpl:   flags.stubImpl,
    93  		SkipEnsure: flags.skipEnsure,
    94  		WithResets: flags.withResets,
    95  	})
    96  	if err != nil {
    97  		return err
    98  	}
    99  
   100  	log.Println("DO MOCK")
   101  	if err = m.Mock(out, args...); err != nil {
   102  		return err
   103  	}
   104  
   105  	if flags.outFile == "" {
   106  		return nil
   107  	}
   108  
   109  	// create the file
   110  	err = os.MkdirAll(filepath.Dir(flags.outFile), 0o750)
   111  	if err != nil {
   112  		return err
   113  	}
   114  
   115  	return ioutil.WriteFile(flags.outFile, buf.Bytes(), 0o600)
   116  }