github.com/djui/moq@v0.3.3/pkg/moq/moq.go (about) 1 package moq 2 3 import ( 4 "bytes" 5 "errors" 6 "go/token" 7 "go/types" 8 "io" 9 "strings" 10 11 "github.com/djui/moq/internal/registry" 12 "github.com/djui/moq/internal/template" 13 ) 14 15 // Mocker can generate mock structs. 16 type Mocker struct { 17 cfg Config 18 19 registry *registry.Registry 20 tmpl template.Template 21 } 22 23 // Config specifies details about how interfaces should be mocked. 24 // SrcDir is the only field which needs be specified. 25 type Config struct { 26 SrcDir string 27 PkgName string 28 Formatter string 29 StubImpl bool 30 SkipEnsure bool 31 } 32 33 // New makes a new Mocker for the specified package directory. 34 func New(cfg Config) (*Mocker, error) { 35 reg, err := registry.New(cfg.SrcDir, cfg.PkgName) 36 if err != nil { 37 return nil, err 38 } 39 40 tmpl, err := template.New() 41 if err != nil { 42 return nil, err 43 } 44 45 return &Mocker{ 46 cfg: cfg, 47 registry: reg, 48 tmpl: tmpl, 49 }, nil 50 } 51 52 // Mock generates a mock for the specified interface name. 53 func (m *Mocker) Mock(w io.Writer, namePairs ...string) error { 54 if len(namePairs) == 0 { 55 return errors.New("must specify one interface") 56 } 57 58 mocks := make([]template.MockData, len(namePairs)) 59 for i, np := range namePairs { 60 name, mockName := parseInterfaceName(np) 61 iface, tparams, err := m.registry.LookupInterface(name) 62 if err != nil { 63 return err 64 } 65 66 methods := make([]template.MethodData, iface.NumMethods()) 67 for j := 0; j < iface.NumMethods(); j++ { 68 methods[j] = m.methodData(iface.Method(j)) 69 } 70 71 mocks[i] = template.MockData{ 72 InterfaceName: name, 73 MockName: mockName, 74 Methods: methods, 75 TypeParams: m.typeParams(tparams), 76 } 77 } 78 79 data := template.Data{ 80 PkgName: m.mockPkgName(), 81 Mocks: mocks, 82 StubImpl: m.cfg.StubImpl, 83 SkipEnsure: m.cfg.SkipEnsure, 84 } 85 86 if data.MocksSomeMethod() { 87 m.registry.AddImport(types.NewPackage("sync", "sync")) 88 } 89 if m.registry.SrcPkgName() != m.mockPkgName() { 90 data.SrcPkgQualifier = m.registry.SrcPkgName() + "." 91 if !m.cfg.SkipEnsure { 92 imprt := m.registry.AddImport(m.registry.SrcPkg()) 93 data.SrcPkgQualifier = imprt.Qualifier() + "." 94 } 95 } 96 97 data.Imports = m.registry.Imports() 98 99 var buf bytes.Buffer 100 if err := m.tmpl.Execute(&buf, data); err != nil { 101 return err 102 } 103 104 formatted, err := m.format(buf.Bytes()) 105 if err != nil { 106 return err 107 } 108 109 if _, err := w.Write(formatted); err != nil { 110 return err 111 } 112 return nil 113 } 114 115 func (m *Mocker) typeParams(tparams *types.TypeParamList) []template.TypeParamData { 116 var tpd []template.TypeParamData 117 if tparams == nil { 118 return tpd 119 } 120 121 tpd = make([]template.TypeParamData, tparams.Len()) 122 123 scope := m.registry.MethodScope() 124 for i := 0; i < len(tpd); i++ { 125 tp := tparams.At(i) 126 typeParam := types.NewParam(token.Pos(i), tp.Obj().Pkg(), tp.Obj().Name(), tp.Constraint()) 127 tpd[i] = template.TypeParamData{ 128 ParamData: template.ParamData{Var: scope.AddVar(typeParam, "")}, 129 Constraint: explicitConstraintType(typeParam), 130 } 131 } 132 133 return tpd 134 } 135 136 func explicitConstraintType(typeParam *types.Var) (t types.Type) { 137 underlying := typeParam.Type().Underlying().(*types.Interface) 138 // check if any of the embedded types is either a basic type or a union, 139 // because the generic type has to be an alias for one of those types then 140 for j := 0; j < underlying.NumEmbeddeds(); j++ { 141 t := underlying.EmbeddedType(j) 142 switch t := t.(type) { 143 case *types.Basic: 144 return t 145 case *types.Union: // only unions of basic types are allowed, so just take the first one as a valid type constraint 146 return t.Term(0).Type() 147 } 148 } 149 return nil 150 } 151 152 func (m *Mocker) methodData(f *types.Func) template.MethodData { 153 sig := f.Type().(*types.Signature) 154 155 scope := m.registry.MethodScope() 156 n := sig.Params().Len() 157 params := make([]template.ParamData, n) 158 for i := 0; i < n; i++ { 159 p := template.ParamData{ 160 Var: scope.AddVar(sig.Params().At(i), ""), 161 } 162 p.Variadic = sig.Variadic() && i == n-1 && p.Var.IsSlice() // check for final variadic argument 163 164 params[i] = p 165 } 166 167 n = sig.Results().Len() 168 results := make([]template.ParamData, n) 169 for i := 0; i < n; i++ { 170 results[i] = template.ParamData{ 171 Var: scope.AddVar(sig.Results().At(i), "Out"), 172 } 173 } 174 175 return template.MethodData{ 176 Name: f.Name(), 177 Params: params, 178 Returns: results, 179 } 180 } 181 182 func (m *Mocker) mockPkgName() string { 183 if m.cfg.PkgName != "" { 184 return m.cfg.PkgName 185 } 186 187 return m.registry.SrcPkgName() 188 } 189 190 func (m *Mocker) format(src []byte) ([]byte, error) { 191 switch m.cfg.Formatter { 192 case "goimports": 193 return goimports(src) 194 195 case "noop": 196 return src, nil 197 } 198 199 return gofmt(src) 200 } 201 202 func parseInterfaceName(namePair string) (ifaceName, mockName string) { 203 parts := strings.SplitN(namePair, ":", 2) 204 if len(parts) == 2 { 205 return parts[0], parts[1] 206 } 207 208 ifaceName = parts[0] 209 return ifaceName, ifaceName + "Mock" 210 }