github.com/moitias/moq@v0.0.0-20240223074357-5eb0f0ba4054/pkg/moq/moq.go (about)

     1  package moq
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"go/token"
     7  	"go/types"
     8  	"io"
     9  	"log"
    10  	"strings"
    11  
    12  	"github.com/moitias/moq/internal/registry"
    13  	"github.com/moitias/moq/internal/template"
    14  )
    15  
    16  // Mocker can generate mock structs.
    17  type Mocker struct {
    18  	cfg Config
    19  
    20  	registry *registry.Registry
    21  	tmpl     template.Template
    22  }
    23  
    24  // Config specifies details about how interfaces should be mocked.
    25  // SrcDir is the only field which needs be specified.
    26  type Config struct {
    27  	SrcDir     string
    28  	PkgName    string
    29  	Formatter  string
    30  	StubImpl   bool
    31  	SkipEnsure bool
    32  	WithResets bool
    33  }
    34  
    35  // New makes a new Mocker for the specified package directory.
    36  func New(cfg Config) (*Mocker, error) {
    37  	log.Println("CREATE MOCK REG")
    38  	reg, err := registry.New(cfg.SrcDir, cfg.PkgName)
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  	log.Println("CREATE MOCK TMPL")
    43  
    44  	tmpl, err := template.New()
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  
    49  	return &Mocker{
    50  		cfg:      cfg,
    51  		registry: reg,
    52  		tmpl:     tmpl,
    53  	}, nil
    54  }
    55  
    56  // Mock generates a mock for the specified interface name.
    57  func (m *Mocker) Mock(w io.Writer, namePairs ...string) error {
    58  	if len(namePairs) == 0 {
    59  		return errors.New("must specify one interface")
    60  	}
    61  
    62  	mocks := make([]template.MockData, len(namePairs))
    63  	for i, np := range namePairs {
    64  		log.Println("NP", i, np)
    65  		name, mockName := parseInterfaceName(np)
    66  		iface, tparams, err := m.registry.LookupInterface(name)
    67  		if err != nil {
    68  			return err
    69  		}
    70  
    71  		methods := make([]template.MethodData, iface.NumMethods())
    72  		for j := 0; j < iface.NumMethods(); j++ {
    73  			methods[j] = m.methodData(iface.Method(j))
    74  		}
    75  
    76  		mocks[i] = template.MockData{
    77  			InterfaceName: name,
    78  			MockName:      mockName,
    79  			Methods:       methods,
    80  			TypeParams:    m.typeParams(tparams),
    81  		}
    82  	}
    83  
    84  	log.Println("DATA")
    85  	data := template.Data{
    86  		PkgName:    m.mockPkgName(),
    87  		Mocks:      mocks,
    88  		StubImpl:   m.cfg.StubImpl,
    89  		SkipEnsure: m.cfg.SkipEnsure,
    90  		WithResets: m.cfg.WithResets,
    91  	}
    92  
    93  	if data.MocksSomeMethod() {
    94  		m.registry.AddImport(types.NewPackage("sync", "sync"))
    95  	}
    96  	log.Println("PKGNAME")
    97  	if m.registry.SrcPkgName() != m.mockPkgName() {
    98  		data.SrcPkgQualifier = m.registry.SrcPkgName() + "."
    99  		if !m.cfg.SkipEnsure {
   100  			imprt := m.registry.AddImport(m.registry.SrcPkg())
   101  			data.SrcPkgQualifier = imprt.Qualifier() + "."
   102  		}
   103  	}
   104  
   105  	log.Println("IMPORTS")
   106  	data.Imports = m.registry.Imports()
   107  
   108  	log.Println("TMPL")
   109  	var buf bytes.Buffer
   110  	if err := m.tmpl.Execute(&buf, data); err != nil {
   111  		return err
   112  	}
   113  
   114  	log.Println("FORMAT")
   115  
   116  	formatted, err := m.format(buf.Bytes())
   117  	if err != nil {
   118  		return err
   119  	}
   120  
   121  	if _, err := w.Write(formatted); err != nil {
   122  		return err
   123  	}
   124  	return nil
   125  }
   126  
   127  func (m *Mocker) typeParams(tparams *types.TypeParamList) []template.TypeParamData {
   128  	var tpd []template.TypeParamData
   129  	if tparams == nil {
   130  		return tpd
   131  	}
   132  
   133  	tpd = make([]template.TypeParamData, tparams.Len())
   134  
   135  	scope := m.registry.MethodScope()
   136  	for i := 0; i < len(tpd); i++ {
   137  		tp := tparams.At(i)
   138  		typeParam := types.NewParam(token.Pos(i), tp.Obj().Pkg(), tp.Obj().Name(), tp.Constraint())
   139  		tpd[i] = template.TypeParamData{
   140  			ParamData:  template.ParamData{Var: scope.AddVar(typeParam, "")},
   141  			Constraint: explicitConstraintType(typeParam),
   142  		}
   143  	}
   144  
   145  	return tpd
   146  }
   147  
   148  func explicitConstraintType(typeParam *types.Var) (t types.Type) {
   149  	underlying := typeParam.Type().Underlying().(*types.Interface)
   150  	// check if any of the embedded types is either a basic type or a union,
   151  	// because the generic type has to be an alias for one of those types then
   152  	for j := 0; j < underlying.NumEmbeddeds(); j++ {
   153  		t := underlying.EmbeddedType(j)
   154  		switch t := t.(type) {
   155  		case *types.Basic:
   156  			return t
   157  		case *types.Union: // only unions of basic types are allowed, so just take the first one as a valid type constraint
   158  			return t.Term(0).Type()
   159  		}
   160  	}
   161  	return nil
   162  }
   163  
   164  func (m *Mocker) methodData(f *types.Func) template.MethodData {
   165  	sig := f.Type().(*types.Signature)
   166  
   167  	scope := m.registry.MethodScope()
   168  	n := sig.Params().Len()
   169  	params := make([]template.ParamData, n)
   170  	for i := 0; i < n; i++ {
   171  		p := template.ParamData{
   172  			Var: scope.AddVar(sig.Params().At(i), ""),
   173  		}
   174  		p.Variadic = sig.Variadic() && i == n-1 && p.Var.IsSlice() // check for final variadic argument
   175  
   176  		params[i] = p
   177  	}
   178  
   179  	n = sig.Results().Len()
   180  	results := make([]template.ParamData, n)
   181  	for i := 0; i < n; i++ {
   182  		results[i] = template.ParamData{
   183  			Var: scope.AddVar(sig.Results().At(i), "Out"),
   184  		}
   185  	}
   186  
   187  	return template.MethodData{
   188  		Name:    f.Name(),
   189  		Params:  params,
   190  		Returns: results,
   191  	}
   192  }
   193  
   194  func (m *Mocker) mockPkgName() string {
   195  	if m.cfg.PkgName != "" {
   196  		return m.cfg.PkgName
   197  	}
   198  
   199  	return m.registry.SrcPkgName()
   200  }
   201  
   202  func (m *Mocker) format(src []byte) ([]byte, error) {
   203  	switch m.cfg.Formatter {
   204  	case "goimports":
   205  		return goimports(src)
   206  
   207  	case "noop":
   208  		return src, nil
   209  	}
   210  
   211  	return gofmt(src)
   212  }
   213  
   214  func parseInterfaceName(namePair string) (ifaceName, mockName string) {
   215  	parts := strings.SplitN(namePair, ":", 2)
   216  	if len(parts) == 2 {
   217  		return parts[0], parts[1]
   218  	}
   219  
   220  	ifaceName = parts[0]
   221  	return ifaceName, ifaceName + "Mock"
   222  }