github.com/Capventis/moq@v0.2.6-0.20220316100624-05dd47497214/pkg/moq/moq.go (about)

     1  package moq
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"go/types"
     7  	"io"
     8  	"strings"
     9  
    10  	"github.com/Capventis/moq/internal/registry"
    11  	"github.com/Capventis/moq/internal/template"
    12  )
    13  
    14  // Mocker can generate mock structs.
    15  type Mocker struct {
    16  	cfg Config
    17  
    18  	registry *registry.Registry
    19  	tmpl     template.Template
    20  }
    21  
    22  // Config specifies details about how interfaces should be mocked.
    23  // SrcDir is the only field which needs be specified.
    24  type Config struct {
    25  	SrcDir     string
    26  	PkgName    string
    27  	Formatter  string
    28  	StubImpl   bool
    29  	SkipEnsure bool
    30  }
    31  
    32  // New makes a new Mocker for the specified package directory.
    33  func New(cfg Config) (*Mocker, error) {
    34  	reg, err := registry.New(cfg.SrcDir, cfg.PkgName)
    35  	if err != nil {
    36  		return nil, err
    37  	}
    38  
    39  	tmpl, err := template.New()
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  
    44  	return &Mocker{
    45  		cfg:      cfg,
    46  		registry: reg,
    47  		tmpl:     tmpl,
    48  	}, nil
    49  }
    50  
    51  // Mock generates a mock for the specified interface name.
    52  func (m *Mocker) Mock(w io.Writer, namePairs ...string) error {
    53  	if len(namePairs) == 0 {
    54  		return errors.New("must specify one interface")
    55  	}
    56  
    57  	mocks := make([]template.MockData, len(namePairs))
    58  	for i, np := range namePairs {
    59  		name, mockName := parseInterfaceName(np)
    60  		iface, err := m.registry.LookupInterface(name)
    61  		if err != nil {
    62  			return err
    63  		}
    64  
    65  		methods := make([]template.MethodData, iface.NumMethods())
    66  		for j := 0; j < iface.NumMethods(); j++ {
    67  			methods[j] = m.methodData(iface.Method(j))
    68  		}
    69  
    70  		mocks[i] = template.MockData{
    71  			InterfaceName: name,
    72  			MockName:      mockName,
    73  			Methods:       methods,
    74  		}
    75  	}
    76  
    77  	data := template.Data{
    78  		PkgName:    m.mockPkgName(),
    79  		Mocks:      mocks,
    80  		StubImpl:   m.cfg.StubImpl,
    81  		SkipEnsure: m.cfg.SkipEnsure,
    82  	}
    83  
    84  	if data.MocksSomeMethod() {
    85  		m.registry.AddImport(types.NewPackage("sync", "sync"))
    86  	}
    87  	if m.registry.SrcPkgName() != m.mockPkgName() {
    88  		data.SrcPkgQualifier = m.registry.SrcPkgName() + "."
    89  		if !m.cfg.SkipEnsure {
    90  			imprt := m.registry.AddImport(m.registry.SrcPkg())
    91  			data.SrcPkgQualifier = imprt.Qualifier() + "."
    92  		}
    93  	}
    94  
    95  	data.Imports = m.registry.Imports()
    96  
    97  	var buf bytes.Buffer
    98  	if err := m.tmpl.Execute(&buf, data); err != nil {
    99  		return err
   100  	}
   101  
   102  	formatted, err := m.format(buf.Bytes())
   103  	if err != nil {
   104  		return err
   105  	}
   106  
   107  	if _, err := w.Write(formatted); err != nil {
   108  		return err
   109  	}
   110  	return nil
   111  }
   112  
   113  func (m *Mocker) methodData(f *types.Func) template.MethodData {
   114  	sig := f.Type().(*types.Signature)
   115  
   116  	scope := m.registry.MethodScope()
   117  	n := sig.Params().Len()
   118  	params := make([]template.ParamData, n)
   119  	for i := 0; i < n; i++ {
   120  		p := template.ParamData{
   121  			Var: scope.AddVar(sig.Params().At(i), ""),
   122  		}
   123  		p.Variadic = sig.Variadic() && i == n-1 && p.Var.IsSlice() // check for final variadic argument
   124  
   125  		params[i] = p
   126  	}
   127  
   128  	n = sig.Results().Len()
   129  	results := make([]template.ParamData, n)
   130  	for i := 0; i < n; i++ {
   131  		results[i] = template.ParamData{
   132  			Var: scope.AddVar(sig.Results().At(i), "Out"),
   133  		}
   134  	}
   135  
   136  	return template.MethodData{
   137  		Name:    f.Name(),
   138  		Params:  params,
   139  		Returns: results,
   140  	}
   141  }
   142  
   143  func (m *Mocker) mockPkgName() string {
   144  	if m.cfg.PkgName != "" {
   145  		return m.cfg.PkgName
   146  	}
   147  
   148  	return m.registry.SrcPkgName()
   149  }
   150  
   151  func (m *Mocker) format(src []byte) ([]byte, error) {
   152  	switch m.cfg.Formatter {
   153  	case "goimports":
   154  		return goimports(src)
   155  
   156  	case "noop":
   157  		return src, nil
   158  	}
   159  
   160  	return gofmt(src)
   161  }
   162  
   163  func parseInterfaceName(namePair string) (ifaceName, mockName string) {
   164  	parts := strings.SplitN(namePair, ":", 2)
   165  	if len(parts) == 2 {
   166  		return parts[0], parts[1]
   167  	}
   168  
   169  	ifaceName = parts[0]
   170  	return ifaceName, ifaceName + "Mock"
   171  }