github.com/djui/moq@v0.3.3/pkg/moq/moq.go (about)

     1  package moq
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"go/token"
     7  	"go/types"
     8  	"io"
     9  	"strings"
    10  
    11  	"github.com/djui/moq/internal/registry"
    12  	"github.com/djui/moq/internal/template"
    13  )
    14  
    15  // Mocker can generate mock structs.
    16  type Mocker struct {
    17  	cfg Config
    18  
    19  	registry *registry.Registry
    20  	tmpl     template.Template
    21  }
    22  
    23  // Config specifies details about how interfaces should be mocked.
    24  // SrcDir is the only field which needs be specified.
    25  type Config struct {
    26  	SrcDir     string
    27  	PkgName    string
    28  	Formatter  string
    29  	StubImpl   bool
    30  	SkipEnsure bool
    31  }
    32  
    33  // New makes a new Mocker for the specified package directory.
    34  func New(cfg Config) (*Mocker, error) {
    35  	reg, err := registry.New(cfg.SrcDir, cfg.PkgName)
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  
    40  	tmpl, err := template.New()
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	return &Mocker{
    46  		cfg:      cfg,
    47  		registry: reg,
    48  		tmpl:     tmpl,
    49  	}, nil
    50  }
    51  
    52  // Mock generates a mock for the specified interface name.
    53  func (m *Mocker) Mock(w io.Writer, namePairs ...string) error {
    54  	if len(namePairs) == 0 {
    55  		return errors.New("must specify one interface")
    56  	}
    57  
    58  	mocks := make([]template.MockData, len(namePairs))
    59  	for i, np := range namePairs {
    60  		name, mockName := parseInterfaceName(np)
    61  		iface, tparams, err := m.registry.LookupInterface(name)
    62  		if err != nil {
    63  			return err
    64  		}
    65  
    66  		methods := make([]template.MethodData, iface.NumMethods())
    67  		for j := 0; j < iface.NumMethods(); j++ {
    68  			methods[j] = m.methodData(iface.Method(j))
    69  		}
    70  
    71  		mocks[i] = template.MockData{
    72  			InterfaceName: name,
    73  			MockName:      mockName,
    74  			Methods:       methods,
    75  			TypeParams:    m.typeParams(tparams),
    76  		}
    77  	}
    78  
    79  	data := template.Data{
    80  		PkgName:    m.mockPkgName(),
    81  		Mocks:      mocks,
    82  		StubImpl:   m.cfg.StubImpl,
    83  		SkipEnsure: m.cfg.SkipEnsure,
    84  	}
    85  
    86  	if data.MocksSomeMethod() {
    87  		m.registry.AddImport(types.NewPackage("sync", "sync"))
    88  	}
    89  	if m.registry.SrcPkgName() != m.mockPkgName() {
    90  		data.SrcPkgQualifier = m.registry.SrcPkgName() + "."
    91  		if !m.cfg.SkipEnsure {
    92  			imprt := m.registry.AddImport(m.registry.SrcPkg())
    93  			data.SrcPkgQualifier = imprt.Qualifier() + "."
    94  		}
    95  	}
    96  
    97  	data.Imports = m.registry.Imports()
    98  
    99  	var buf bytes.Buffer
   100  	if err := m.tmpl.Execute(&buf, data); err != nil {
   101  		return err
   102  	}
   103  
   104  	formatted, err := m.format(buf.Bytes())
   105  	if err != nil {
   106  		return err
   107  	}
   108  
   109  	if _, err := w.Write(formatted); err != nil {
   110  		return err
   111  	}
   112  	return nil
   113  }
   114  
   115  func (m *Mocker) typeParams(tparams *types.TypeParamList) []template.TypeParamData {
   116  	var tpd []template.TypeParamData
   117  	if tparams == nil {
   118  		return tpd
   119  	}
   120  
   121  	tpd = make([]template.TypeParamData, tparams.Len())
   122  
   123  	scope := m.registry.MethodScope()
   124  	for i := 0; i < len(tpd); i++ {
   125  		tp := tparams.At(i)
   126  		typeParam := types.NewParam(token.Pos(i), tp.Obj().Pkg(), tp.Obj().Name(), tp.Constraint())
   127  		tpd[i] = template.TypeParamData{
   128  			ParamData:  template.ParamData{Var: scope.AddVar(typeParam, "")},
   129  			Constraint: explicitConstraintType(typeParam),
   130  		}
   131  	}
   132  
   133  	return tpd
   134  }
   135  
   136  func explicitConstraintType(typeParam *types.Var) (t types.Type) {
   137  	underlying := typeParam.Type().Underlying().(*types.Interface)
   138  	// check if any of the embedded types is either a basic type or a union,
   139  	// because the generic type has to be an alias for one of those types then
   140  	for j := 0; j < underlying.NumEmbeddeds(); j++ {
   141  		t := underlying.EmbeddedType(j)
   142  		switch t := t.(type) {
   143  		case *types.Basic:
   144  			return t
   145  		case *types.Union: // only unions of basic types are allowed, so just take the first one as a valid type constraint
   146  			return t.Term(0).Type()
   147  		}
   148  	}
   149  	return nil
   150  }
   151  
   152  func (m *Mocker) methodData(f *types.Func) template.MethodData {
   153  	sig := f.Type().(*types.Signature)
   154  
   155  	scope := m.registry.MethodScope()
   156  	n := sig.Params().Len()
   157  	params := make([]template.ParamData, n)
   158  	for i := 0; i < n; i++ {
   159  		p := template.ParamData{
   160  			Var: scope.AddVar(sig.Params().At(i), ""),
   161  		}
   162  		p.Variadic = sig.Variadic() && i == n-1 && p.Var.IsSlice() // check for final variadic argument
   163  
   164  		params[i] = p
   165  	}
   166  
   167  	n = sig.Results().Len()
   168  	results := make([]template.ParamData, n)
   169  	for i := 0; i < n; i++ {
   170  		results[i] = template.ParamData{
   171  			Var: scope.AddVar(sig.Results().At(i), "Out"),
   172  		}
   173  	}
   174  
   175  	return template.MethodData{
   176  		Name:    f.Name(),
   177  		Params:  params,
   178  		Returns: results,
   179  	}
   180  }
   181  
   182  func (m *Mocker) mockPkgName() string {
   183  	if m.cfg.PkgName != "" {
   184  		return m.cfg.PkgName
   185  	}
   186  
   187  	return m.registry.SrcPkgName()
   188  }
   189  
   190  func (m *Mocker) format(src []byte) ([]byte, error) {
   191  	switch m.cfg.Formatter {
   192  	case "goimports":
   193  		return goimports(src)
   194  
   195  	case "noop":
   196  		return src, nil
   197  	}
   198  
   199  	return gofmt(src)
   200  }
   201  
   202  func parseInterfaceName(namePair string) (ifaceName, mockName string) {
   203  	parts := strings.SplitN(namePair, ":", 2)
   204  	if len(parts) == 2 {
   205  		return parts[0], parts[1]
   206  	}
   207  
   208  	ifaceName = parts[0]
   209  	return ifaceName, ifaceName + "Mock"
   210  }