github.com/Capventis/moq@v0.2.6-0.20220316100624-05dd47497214/pkg/moq/moq.go (about) 1 package moq 2 3 import ( 4 "bytes" 5 "errors" 6 "go/types" 7 "io" 8 "strings" 9 10 "github.com/Capventis/moq/internal/registry" 11 "github.com/Capventis/moq/internal/template" 12 ) 13 14 // Mocker can generate mock structs. 15 type Mocker struct { 16 cfg Config 17 18 registry *registry.Registry 19 tmpl template.Template 20 } 21 22 // Config specifies details about how interfaces should be mocked. 23 // SrcDir is the only field which needs be specified. 24 type Config struct { 25 SrcDir string 26 PkgName string 27 Formatter string 28 StubImpl bool 29 SkipEnsure bool 30 } 31 32 // New makes a new Mocker for the specified package directory. 33 func New(cfg Config) (*Mocker, error) { 34 reg, err := registry.New(cfg.SrcDir, cfg.PkgName) 35 if err != nil { 36 return nil, err 37 } 38 39 tmpl, err := template.New() 40 if err != nil { 41 return nil, err 42 } 43 44 return &Mocker{ 45 cfg: cfg, 46 registry: reg, 47 tmpl: tmpl, 48 }, nil 49 } 50 51 // Mock generates a mock for the specified interface name. 52 func (m *Mocker) Mock(w io.Writer, namePairs ...string) error { 53 if len(namePairs) == 0 { 54 return errors.New("must specify one interface") 55 } 56 57 mocks := make([]template.MockData, len(namePairs)) 58 for i, np := range namePairs { 59 name, mockName := parseInterfaceName(np) 60 iface, err := m.registry.LookupInterface(name) 61 if err != nil { 62 return err 63 } 64 65 methods := make([]template.MethodData, iface.NumMethods()) 66 for j := 0; j < iface.NumMethods(); j++ { 67 methods[j] = m.methodData(iface.Method(j)) 68 } 69 70 mocks[i] = template.MockData{ 71 InterfaceName: name, 72 MockName: mockName, 73 Methods: methods, 74 } 75 } 76 77 data := template.Data{ 78 PkgName: m.mockPkgName(), 79 Mocks: mocks, 80 StubImpl: m.cfg.StubImpl, 81 SkipEnsure: m.cfg.SkipEnsure, 82 } 83 84 if data.MocksSomeMethod() { 85 m.registry.AddImport(types.NewPackage("sync", "sync")) 86 } 87 if m.registry.SrcPkgName() != m.mockPkgName() { 88 data.SrcPkgQualifier = m.registry.SrcPkgName() + "." 89 if !m.cfg.SkipEnsure { 90 imprt := m.registry.AddImport(m.registry.SrcPkg()) 91 data.SrcPkgQualifier = imprt.Qualifier() + "." 92 } 93 } 94 95 data.Imports = m.registry.Imports() 96 97 var buf bytes.Buffer 98 if err := m.tmpl.Execute(&buf, data); err != nil { 99 return err 100 } 101 102 formatted, err := m.format(buf.Bytes()) 103 if err != nil { 104 return err 105 } 106 107 if _, err := w.Write(formatted); err != nil { 108 return err 109 } 110 return nil 111 } 112 113 func (m *Mocker) methodData(f *types.Func) template.MethodData { 114 sig := f.Type().(*types.Signature) 115 116 scope := m.registry.MethodScope() 117 n := sig.Params().Len() 118 params := make([]template.ParamData, n) 119 for i := 0; i < n; i++ { 120 p := template.ParamData{ 121 Var: scope.AddVar(sig.Params().At(i), ""), 122 } 123 p.Variadic = sig.Variadic() && i == n-1 && p.Var.IsSlice() // check for final variadic argument 124 125 params[i] = p 126 } 127 128 n = sig.Results().Len() 129 results := make([]template.ParamData, n) 130 for i := 0; i < n; i++ { 131 results[i] = template.ParamData{ 132 Var: scope.AddVar(sig.Results().At(i), "Out"), 133 } 134 } 135 136 return template.MethodData{ 137 Name: f.Name(), 138 Params: params, 139 Returns: results, 140 } 141 } 142 143 func (m *Mocker) mockPkgName() string { 144 if m.cfg.PkgName != "" { 145 return m.cfg.PkgName 146 } 147 148 return m.registry.SrcPkgName() 149 } 150 151 func (m *Mocker) format(src []byte) ([]byte, error) { 152 switch m.cfg.Formatter { 153 case "goimports": 154 return goimports(src) 155 156 case "noop": 157 return src, nil 158 } 159 160 return gofmt(src) 161 } 162 163 func parseInterfaceName(namePair string) (ifaceName, mockName string) { 164 parts := strings.SplitN(namePair, ":", 2) 165 if len(parts) == 2 { 166 return parts[0], parts[1] 167 } 168 169 ifaceName = parts[0] 170 return ifaceName, ifaceName + "Mock" 171 }