github.com/golang/mock@v1.6.0/mockgen/model/model.go (about)

     1  // Copyright 2012 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package model contains the data model necessary for generating mock implementations.
    16  package model
    17  
    18  import (
    19  	"encoding/gob"
    20  	"fmt"
    21  	"io"
    22  	"reflect"
    23  	"strings"
    24  )
    25  
    26  // pkgPath is the importable path for package model
    27  const pkgPath = "github.com/golang/mock/mockgen/model"
    28  
    29  // Package is a Go package. It may be a subset.
    30  type Package struct {
    31  	Name       string
    32  	PkgPath    string
    33  	Interfaces []*Interface
    34  	DotImports []string
    35  }
    36  
    37  // Print writes the package name and its exported interfaces.
    38  func (pkg *Package) Print(w io.Writer) {
    39  	_, _ = fmt.Fprintf(w, "package %s\n", pkg.Name)
    40  	for _, intf := range pkg.Interfaces {
    41  		intf.Print(w)
    42  	}
    43  }
    44  
    45  // Imports returns the imports needed by the Package as a set of import paths.
    46  func (pkg *Package) Imports() map[string]bool {
    47  	im := make(map[string]bool)
    48  	for _, intf := range pkg.Interfaces {
    49  		intf.addImports(im)
    50  	}
    51  	return im
    52  }
    53  
    54  // Interface is a Go interface.
    55  type Interface struct {
    56  	Name    string
    57  	Methods []*Method
    58  }
    59  
    60  // Print writes the interface name and its methods.
    61  func (intf *Interface) Print(w io.Writer) {
    62  	_, _ = fmt.Fprintf(w, "interface %s\n", intf.Name)
    63  	for _, m := range intf.Methods {
    64  		m.Print(w)
    65  	}
    66  }
    67  
    68  func (intf *Interface) addImports(im map[string]bool) {
    69  	for _, m := range intf.Methods {
    70  		m.addImports(im)
    71  	}
    72  }
    73  
    74  // AddMethod adds a new method, de-duplicating by method name.
    75  func (intf *Interface) AddMethod(m *Method) {
    76  	for _, me := range intf.Methods {
    77  		if me.Name == m.Name {
    78  			return
    79  		}
    80  	}
    81  	intf.Methods = append(intf.Methods, m)
    82  }
    83  
    84  // Method is a single method of an interface.
    85  type Method struct {
    86  	Name     string
    87  	In, Out  []*Parameter
    88  	Variadic *Parameter // may be nil
    89  }
    90  
    91  // Print writes the method name and its signature.
    92  func (m *Method) Print(w io.Writer) {
    93  	_, _ = fmt.Fprintf(w, "  - method %s\n", m.Name)
    94  	if len(m.In) > 0 {
    95  		_, _ = fmt.Fprintf(w, "    in:\n")
    96  		for _, p := range m.In {
    97  			p.Print(w)
    98  		}
    99  	}
   100  	if m.Variadic != nil {
   101  		_, _ = fmt.Fprintf(w, "    ...:\n")
   102  		m.Variadic.Print(w)
   103  	}
   104  	if len(m.Out) > 0 {
   105  		_, _ = fmt.Fprintf(w, "    out:\n")
   106  		for _, p := range m.Out {
   107  			p.Print(w)
   108  		}
   109  	}
   110  }
   111  
   112  func (m *Method) addImports(im map[string]bool) {
   113  	for _, p := range m.In {
   114  		p.Type.addImports(im)
   115  	}
   116  	if m.Variadic != nil {
   117  		m.Variadic.Type.addImports(im)
   118  	}
   119  	for _, p := range m.Out {
   120  		p.Type.addImports(im)
   121  	}
   122  }
   123  
   124  // Parameter is an argument or return parameter of a method.
   125  type Parameter struct {
   126  	Name string // may be empty
   127  	Type Type
   128  }
   129  
   130  // Print writes a method parameter.
   131  func (p *Parameter) Print(w io.Writer) {
   132  	n := p.Name
   133  	if n == "" {
   134  		n = `""`
   135  	}
   136  	_, _ = fmt.Fprintf(w, "    - %v: %v\n", n, p.Type.String(nil, ""))
   137  }
   138  
   139  // Type is a Go type.
   140  type Type interface {
   141  	String(pm map[string]string, pkgOverride string) string
   142  	addImports(im map[string]bool)
   143  }
   144  
   145  func init() {
   146  	gob.Register(&ArrayType{})
   147  	gob.Register(&ChanType{})
   148  	gob.Register(&FuncType{})
   149  	gob.Register(&MapType{})
   150  	gob.Register(&NamedType{})
   151  	gob.Register(&PointerType{})
   152  
   153  	// Call gob.RegisterName to make sure it has the consistent name registered
   154  	// for both gob decoder and encoder.
   155  	//
   156  	// For a non-pointer type, gob.Register will try to get package full path by
   157  	// calling rt.PkgPath() for a name to register. If your project has vendor
   158  	// directory, it is possible that PkgPath will get a path like this:
   159  	//     ../../../vendor/github.com/golang/mock/mockgen/model
   160  	gob.RegisterName(pkgPath+".PredeclaredType", PredeclaredType(""))
   161  }
   162  
   163  // ArrayType is an array or slice type.
   164  type ArrayType struct {
   165  	Len  int // -1 for slices, >= 0 for arrays
   166  	Type Type
   167  }
   168  
   169  func (at *ArrayType) String(pm map[string]string, pkgOverride string) string {
   170  	s := "[]"
   171  	if at.Len > -1 {
   172  		s = fmt.Sprintf("[%d]", at.Len)
   173  	}
   174  	return s + at.Type.String(pm, pkgOverride)
   175  }
   176  
   177  func (at *ArrayType) addImports(im map[string]bool) { at.Type.addImports(im) }
   178  
   179  // ChanType is a channel type.
   180  type ChanType struct {
   181  	Dir  ChanDir // 0, 1 or 2
   182  	Type Type
   183  }
   184  
   185  func (ct *ChanType) String(pm map[string]string, pkgOverride string) string {
   186  	s := ct.Type.String(pm, pkgOverride)
   187  	if ct.Dir == RecvDir {
   188  		return "<-chan " + s
   189  	}
   190  	if ct.Dir == SendDir {
   191  		return "chan<- " + s
   192  	}
   193  	return "chan " + s
   194  }
   195  
   196  func (ct *ChanType) addImports(im map[string]bool) { ct.Type.addImports(im) }
   197  
   198  // ChanDir is a channel direction.
   199  type ChanDir int
   200  
   201  // Constants for channel directions.
   202  const (
   203  	RecvDir ChanDir = 1
   204  	SendDir ChanDir = 2
   205  )
   206  
   207  // FuncType is a function type.
   208  type FuncType struct {
   209  	In, Out  []*Parameter
   210  	Variadic *Parameter // may be nil
   211  }
   212  
   213  func (ft *FuncType) String(pm map[string]string, pkgOverride string) string {
   214  	args := make([]string, len(ft.In))
   215  	for i, p := range ft.In {
   216  		args[i] = p.Type.String(pm, pkgOverride)
   217  	}
   218  	if ft.Variadic != nil {
   219  		args = append(args, "..."+ft.Variadic.Type.String(pm, pkgOverride))
   220  	}
   221  	rets := make([]string, len(ft.Out))
   222  	for i, p := range ft.Out {
   223  		rets[i] = p.Type.String(pm, pkgOverride)
   224  	}
   225  	retString := strings.Join(rets, ", ")
   226  	if nOut := len(ft.Out); nOut == 1 {
   227  		retString = " " + retString
   228  	} else if nOut > 1 {
   229  		retString = " (" + retString + ")"
   230  	}
   231  	return "func(" + strings.Join(args, ", ") + ")" + retString
   232  }
   233  
   234  func (ft *FuncType) addImports(im map[string]bool) {
   235  	for _, p := range ft.In {
   236  		p.Type.addImports(im)
   237  	}
   238  	if ft.Variadic != nil {
   239  		ft.Variadic.Type.addImports(im)
   240  	}
   241  	for _, p := range ft.Out {
   242  		p.Type.addImports(im)
   243  	}
   244  }
   245  
   246  // MapType is a map type.
   247  type MapType struct {
   248  	Key, Value Type
   249  }
   250  
   251  func (mt *MapType) String(pm map[string]string, pkgOverride string) string {
   252  	return "map[" + mt.Key.String(pm, pkgOverride) + "]" + mt.Value.String(pm, pkgOverride)
   253  }
   254  
   255  func (mt *MapType) addImports(im map[string]bool) {
   256  	mt.Key.addImports(im)
   257  	mt.Value.addImports(im)
   258  }
   259  
   260  // NamedType is an exported type in a package.
   261  type NamedType struct {
   262  	Package string // may be empty
   263  	Type    string
   264  }
   265  
   266  func (nt *NamedType) String(pm map[string]string, pkgOverride string) string {
   267  	if pkgOverride == nt.Package {
   268  		return nt.Type
   269  	}
   270  	prefix := pm[nt.Package]
   271  	if prefix != "" {
   272  		return prefix + "." + nt.Type
   273  	}
   274  
   275  	return nt.Type
   276  }
   277  
   278  func (nt *NamedType) addImports(im map[string]bool) {
   279  	if nt.Package != "" {
   280  		im[nt.Package] = true
   281  	}
   282  }
   283  
   284  // PointerType is a pointer to another type.
   285  type PointerType struct {
   286  	Type Type
   287  }
   288  
   289  func (pt *PointerType) String(pm map[string]string, pkgOverride string) string {
   290  	return "*" + pt.Type.String(pm, pkgOverride)
   291  }
   292  func (pt *PointerType) addImports(im map[string]bool) { pt.Type.addImports(im) }
   293  
   294  // PredeclaredType is a predeclared type such as "int".
   295  type PredeclaredType string
   296  
   297  func (pt PredeclaredType) String(map[string]string, string) string { return string(pt) }
   298  func (pt PredeclaredType) addImports(map[string]bool)              {}
   299  
   300  // The following code is intended to be called by the program generated by ../reflect.go.
   301  
   302  // InterfaceFromInterfaceType returns a pointer to an interface for the
   303  // given reflection interface type.
   304  func InterfaceFromInterfaceType(it reflect.Type) (*Interface, error) {
   305  	if it.Kind() != reflect.Interface {
   306  		return nil, fmt.Errorf("%v is not an interface", it)
   307  	}
   308  	intf := &Interface{}
   309  
   310  	for i := 0; i < it.NumMethod(); i++ {
   311  		mt := it.Method(i)
   312  		// TODO: need to skip unexported methods? or just raise an error?
   313  		m := &Method{
   314  			Name: mt.Name,
   315  		}
   316  
   317  		var err error
   318  		m.In, m.Variadic, m.Out, err = funcArgsFromType(mt.Type)
   319  		if err != nil {
   320  			return nil, err
   321  		}
   322  
   323  		intf.AddMethod(m)
   324  	}
   325  
   326  	return intf, nil
   327  }
   328  
   329  // t's Kind must be a reflect.Func.
   330  func funcArgsFromType(t reflect.Type) (in []*Parameter, variadic *Parameter, out []*Parameter, err error) {
   331  	nin := t.NumIn()
   332  	if t.IsVariadic() {
   333  		nin--
   334  	}
   335  	var p *Parameter
   336  	for i := 0; i < nin; i++ {
   337  		p, err = parameterFromType(t.In(i))
   338  		if err != nil {
   339  			return
   340  		}
   341  		in = append(in, p)
   342  	}
   343  	if t.IsVariadic() {
   344  		p, err = parameterFromType(t.In(nin).Elem())
   345  		if err != nil {
   346  			return
   347  		}
   348  		variadic = p
   349  	}
   350  	for i := 0; i < t.NumOut(); i++ {
   351  		p, err = parameterFromType(t.Out(i))
   352  		if err != nil {
   353  			return
   354  		}
   355  		out = append(out, p)
   356  	}
   357  	return
   358  }
   359  
   360  func parameterFromType(t reflect.Type) (*Parameter, error) {
   361  	tt, err := typeFromType(t)
   362  	if err != nil {
   363  		return nil, err
   364  	}
   365  	return &Parameter{Type: tt}, nil
   366  }
   367  
   368  var errorType = reflect.TypeOf((*error)(nil)).Elem()
   369  
   370  var byteType = reflect.TypeOf(byte(0))
   371  
   372  func typeFromType(t reflect.Type) (Type, error) {
   373  	// Hack workaround for https://golang.org/issue/3853.
   374  	// This explicit check should not be necessary.
   375  	if t == byteType {
   376  		return PredeclaredType("byte"), nil
   377  	}
   378  
   379  	if imp := t.PkgPath(); imp != "" {
   380  		return &NamedType{
   381  			Package: impPath(imp),
   382  			Type:    t.Name(),
   383  		}, nil
   384  	}
   385  
   386  	// only unnamed or predeclared types after here
   387  
   388  	// Lots of types have element types. Let's do the parsing and error checking for all of them.
   389  	var elemType Type
   390  	switch t.Kind() {
   391  	case reflect.Array, reflect.Chan, reflect.Map, reflect.Ptr, reflect.Slice:
   392  		var err error
   393  		elemType, err = typeFromType(t.Elem())
   394  		if err != nil {
   395  			return nil, err
   396  		}
   397  	}
   398  
   399  	switch t.Kind() {
   400  	case reflect.Array:
   401  		return &ArrayType{
   402  			Len:  t.Len(),
   403  			Type: elemType,
   404  		}, nil
   405  	case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   406  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr,
   407  		reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128, reflect.String:
   408  		return PredeclaredType(t.Kind().String()), nil
   409  	case reflect.Chan:
   410  		var dir ChanDir
   411  		switch t.ChanDir() {
   412  		case reflect.RecvDir:
   413  			dir = RecvDir
   414  		case reflect.SendDir:
   415  			dir = SendDir
   416  		}
   417  		return &ChanType{
   418  			Dir:  dir,
   419  			Type: elemType,
   420  		}, nil
   421  	case reflect.Func:
   422  		in, variadic, out, err := funcArgsFromType(t)
   423  		if err != nil {
   424  			return nil, err
   425  		}
   426  		return &FuncType{
   427  			In:       in,
   428  			Out:      out,
   429  			Variadic: variadic,
   430  		}, nil
   431  	case reflect.Interface:
   432  		// Two special interfaces.
   433  		if t.NumMethod() == 0 {
   434  			return PredeclaredType("interface{}"), nil
   435  		}
   436  		if t == errorType {
   437  			return PredeclaredType("error"), nil
   438  		}
   439  	case reflect.Map:
   440  		kt, err := typeFromType(t.Key())
   441  		if err != nil {
   442  			return nil, err
   443  		}
   444  		return &MapType{
   445  			Key:   kt,
   446  			Value: elemType,
   447  		}, nil
   448  	case reflect.Ptr:
   449  		return &PointerType{
   450  			Type: elemType,
   451  		}, nil
   452  	case reflect.Slice:
   453  		return &ArrayType{
   454  			Len:  -1,
   455  			Type: elemType,
   456  		}, nil
   457  	case reflect.Struct:
   458  		if t.NumField() == 0 {
   459  			return PredeclaredType("struct{}"), nil
   460  		}
   461  	}
   462  
   463  	// TODO: Struct, UnsafePointer
   464  	return nil, fmt.Errorf("can't yet turn %v (%v) into a model.Type", t, t.Kind())
   465  }
   466  
   467  // impPath sanitizes the package path returned by `PkgPath` method of a reflect Type so that
   468  // it is importable. PkgPath might return a path that includes "vendor". These paths do not
   469  // compile, so we need to remove everything up to and including "/vendor/".
   470  // See https://github.com/golang/go/issues/12019.
   471  func impPath(imp string) string {
   472  	if strings.HasPrefix(imp, "vendor/") {
   473  		imp = "/" + imp
   474  	}
   475  	if i := strings.LastIndex(imp, "/vendor/"); i != -1 {
   476  		imp = imp[i+len("/vendor/"):]
   477  	}
   478  	return imp
   479  }
   480  
   481  // ErrorInterface represent built-in error interface.
   482  var ErrorInterface = Interface{
   483  	Name: "error",
   484  	Methods: []*Method{
   485  		{
   486  			Name: "Error",
   487  			Out: []*Parameter{
   488  				{
   489  					Name: "",
   490  					Type: PredeclaredType("string"),
   491  				},
   492  			},
   493  		},
   494  	},
   495  }