github.com/moitias/moq@v0.0.0-20240223074357-5eb0f0ba4054/pkg/moq/moq.go (about) 1 package moq 2 3 import ( 4 "bytes" 5 "errors" 6 "go/token" 7 "go/types" 8 "io" 9 "log" 10 "strings" 11 12 "github.com/moitias/moq/internal/registry" 13 "github.com/moitias/moq/internal/template" 14 ) 15 16 // Mocker can generate mock structs. 17 type Mocker struct { 18 cfg Config 19 20 registry *registry.Registry 21 tmpl template.Template 22 } 23 24 // Config specifies details about how interfaces should be mocked. 25 // SrcDir is the only field which needs be specified. 26 type Config struct { 27 SrcDir string 28 PkgName string 29 Formatter string 30 StubImpl bool 31 SkipEnsure bool 32 WithResets bool 33 } 34 35 // New makes a new Mocker for the specified package directory. 36 func New(cfg Config) (*Mocker, error) { 37 log.Println("CREATE MOCK REG") 38 reg, err := registry.New(cfg.SrcDir, cfg.PkgName) 39 if err != nil { 40 return nil, err 41 } 42 log.Println("CREATE MOCK TMPL") 43 44 tmpl, err := template.New() 45 if err != nil { 46 return nil, err 47 } 48 49 return &Mocker{ 50 cfg: cfg, 51 registry: reg, 52 tmpl: tmpl, 53 }, nil 54 } 55 56 // Mock generates a mock for the specified interface name. 57 func (m *Mocker) Mock(w io.Writer, namePairs ...string) error { 58 if len(namePairs) == 0 { 59 return errors.New("must specify one interface") 60 } 61 62 mocks := make([]template.MockData, len(namePairs)) 63 for i, np := range namePairs { 64 log.Println("NP", i, np) 65 name, mockName := parseInterfaceName(np) 66 iface, tparams, err := m.registry.LookupInterface(name) 67 if err != nil { 68 return err 69 } 70 71 methods := make([]template.MethodData, iface.NumMethods()) 72 for j := 0; j < iface.NumMethods(); j++ { 73 methods[j] = m.methodData(iface.Method(j)) 74 } 75 76 mocks[i] = template.MockData{ 77 InterfaceName: name, 78 MockName: mockName, 79 Methods: methods, 80 TypeParams: m.typeParams(tparams), 81 } 82 } 83 84 log.Println("DATA") 85 data := template.Data{ 86 PkgName: m.mockPkgName(), 87 Mocks: mocks, 88 StubImpl: m.cfg.StubImpl, 89 SkipEnsure: m.cfg.SkipEnsure, 90 WithResets: m.cfg.WithResets, 91 } 92 93 if data.MocksSomeMethod() { 94 m.registry.AddImport(types.NewPackage("sync", "sync")) 95 } 96 log.Println("PKGNAME") 97 if m.registry.SrcPkgName() != m.mockPkgName() { 98 data.SrcPkgQualifier = m.registry.SrcPkgName() + "." 99 if !m.cfg.SkipEnsure { 100 imprt := m.registry.AddImport(m.registry.SrcPkg()) 101 data.SrcPkgQualifier = imprt.Qualifier() + "." 102 } 103 } 104 105 log.Println("IMPORTS") 106 data.Imports = m.registry.Imports() 107 108 log.Println("TMPL") 109 var buf bytes.Buffer 110 if err := m.tmpl.Execute(&buf, data); err != nil { 111 return err 112 } 113 114 log.Println("FORMAT") 115 116 formatted, err := m.format(buf.Bytes()) 117 if err != nil { 118 return err 119 } 120 121 if _, err := w.Write(formatted); err != nil { 122 return err 123 } 124 return nil 125 } 126 127 func (m *Mocker) typeParams(tparams *types.TypeParamList) []template.TypeParamData { 128 var tpd []template.TypeParamData 129 if tparams == nil { 130 return tpd 131 } 132 133 tpd = make([]template.TypeParamData, tparams.Len()) 134 135 scope := m.registry.MethodScope() 136 for i := 0; i < len(tpd); i++ { 137 tp := tparams.At(i) 138 typeParam := types.NewParam(token.Pos(i), tp.Obj().Pkg(), tp.Obj().Name(), tp.Constraint()) 139 tpd[i] = template.TypeParamData{ 140 ParamData: template.ParamData{Var: scope.AddVar(typeParam, "")}, 141 Constraint: explicitConstraintType(typeParam), 142 } 143 } 144 145 return tpd 146 } 147 148 func explicitConstraintType(typeParam *types.Var) (t types.Type) { 149 underlying := typeParam.Type().Underlying().(*types.Interface) 150 // check if any of the embedded types is either a basic type or a union, 151 // because the generic type has to be an alias for one of those types then 152 for j := 0; j < underlying.NumEmbeddeds(); j++ { 153 t := underlying.EmbeddedType(j) 154 switch t := t.(type) { 155 case *types.Basic: 156 return t 157 case *types.Union: // only unions of basic types are allowed, so just take the first one as a valid type constraint 158 return t.Term(0).Type() 159 } 160 } 161 return nil 162 } 163 164 func (m *Mocker) methodData(f *types.Func) template.MethodData { 165 sig := f.Type().(*types.Signature) 166 167 scope := m.registry.MethodScope() 168 n := sig.Params().Len() 169 params := make([]template.ParamData, n) 170 for i := 0; i < n; i++ { 171 p := template.ParamData{ 172 Var: scope.AddVar(sig.Params().At(i), ""), 173 } 174 p.Variadic = sig.Variadic() && i == n-1 && p.Var.IsSlice() // check for final variadic argument 175 176 params[i] = p 177 } 178 179 n = sig.Results().Len() 180 results := make([]template.ParamData, n) 181 for i := 0; i < n; i++ { 182 results[i] = template.ParamData{ 183 Var: scope.AddVar(sig.Results().At(i), "Out"), 184 } 185 } 186 187 return template.MethodData{ 188 Name: f.Name(), 189 Params: params, 190 Returns: results, 191 } 192 } 193 194 func (m *Mocker) mockPkgName() string { 195 if m.cfg.PkgName != "" { 196 return m.cfg.PkgName 197 } 198 199 return m.registry.SrcPkgName() 200 } 201 202 func (m *Mocker) format(src []byte) ([]byte, error) { 203 switch m.cfg.Formatter { 204 case "goimports": 205 return goimports(src) 206 207 case "noop": 208 return src, nil 209 } 210 211 return gofmt(src) 212 } 213 214 func parseInterfaceName(namePair string) (ifaceName, mockName string) { 215 parts := strings.SplitN(namePair, ":", 2) 216 if len(parts) == 2 { 217 return parts[0], parts[1] 218 } 219 220 ifaceName = parts[0] 221 return ifaceName, ifaceName + "Mock" 222 }