github.com/relnod/pegomock@v2.0.1+incompatible/modelgen/gomock/parse.go (about)

     1  // Copyright 2012 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gomock
    16  
    17  // This file contains the model construction by parsing source files.
    18  
    19  import (
    20  	"flag"
    21  	"fmt"
    22  	"go/ast"
    23  	"go/parser"
    24  	"go/token"
    25  	"log"
    26  	"path"
    27  	"strconv"
    28  	"strings"
    29  
    30  	"github.com/petergtz/pegomock/model"
    31  )
    32  
    33  var (
    34  	imports  = flag.String("imports", "", "(source mode) Comma-separated name=path pairs of explicit imports to use.")
    35  	auxFiles = flag.String("aux_files", "", "(source mode) Comma-separated pkg=path pairs of auxiliary Go source files.")
    36  )
    37  
    38  // TODO: simplify error reporting
    39  
    40  func ParseFile(source string) (*model.Package, error) {
    41  	fs := token.NewFileSet()
    42  	file, err := parser.ParseFile(fs, source, nil, 0)
    43  	if err != nil {
    44  		return nil, fmt.Errorf("failed parsing source file %v: %v", source, err)
    45  	}
    46  
    47  	p := &fileParser{
    48  		fileSet:       fs,
    49  		imports:       make(map[string]string),
    50  		auxInterfaces: make(map[string]map[string]*ast.InterfaceType),
    51  	}
    52  
    53  	// Handle -imports.
    54  	dotImports := make(map[string]bool)
    55  	if *imports != "" {
    56  		for _, kv := range strings.Split(*imports, ",") {
    57  			eq := strings.Index(kv, "=")
    58  			k, v := kv[:eq], kv[eq+1:]
    59  			if k == "." {
    60  				// TODO: Catch dupes?
    61  				dotImports[v] = true
    62  			} else {
    63  				// TODO: Catch dupes?
    64  				p.imports[k] = v
    65  			}
    66  		}
    67  	}
    68  
    69  	// Handle -aux_files.
    70  	if err := p.parseAuxFiles(*auxFiles); err != nil {
    71  		return nil, err
    72  	}
    73  	p.addAuxInterfacesFromFile("", file) // this file
    74  
    75  	pkg, err := p.parseFile(file)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	pkg.DotImports = make([]string, 0, len(dotImports))
    80  	for path := range dotImports {
    81  		pkg.DotImports = append(pkg.DotImports, path)
    82  	}
    83  	return pkg, nil
    84  }
    85  
    86  type fileParser struct {
    87  	fileSet *token.FileSet
    88  	imports map[string]string // package name => import path
    89  
    90  	auxFiles      []*ast.File
    91  	auxInterfaces map[string]map[string]*ast.InterfaceType // package (or "") => name => interface
    92  }
    93  
    94  func (p *fileParser) errorf(pos token.Pos, format string, args ...interface{}) error {
    95  	ps := p.fileSet.Position(pos)
    96  	format = "%s:%d:%d: " + format
    97  	args = append([]interface{}{ps.Filename, ps.Line, ps.Column}, args...)
    98  	return fmt.Errorf(format, args...)
    99  }
   100  
   101  func (p *fileParser) parseAuxFiles(auxFiles string) error {
   102  	auxFiles = strings.TrimSpace(auxFiles)
   103  	if auxFiles == "" {
   104  		return nil
   105  	}
   106  	for _, kv := range strings.Split(auxFiles, ",") {
   107  		parts := strings.SplitN(kv, "=", 2)
   108  		if len(parts) != 2 {
   109  			return fmt.Errorf("bad aux file spec: %v", kv)
   110  		}
   111  		file, err := parser.ParseFile(p.fileSet, parts[1], nil, 0)
   112  		if err != nil {
   113  			return err
   114  		}
   115  		p.auxFiles = append(p.auxFiles, file)
   116  		p.addAuxInterfacesFromFile(parts[0], file)
   117  	}
   118  	return nil
   119  }
   120  
   121  func (p *fileParser) addAuxInterfacesFromFile(pkg string, file *ast.File) {
   122  	if _, ok := p.auxInterfaces[pkg]; !ok {
   123  		p.auxInterfaces[pkg] = make(map[string]*ast.InterfaceType)
   124  	}
   125  	for ni := range iterInterfaces(file) {
   126  		p.auxInterfaces[pkg][ni.name.Name] = ni.it
   127  	}
   128  }
   129  
   130  func (p *fileParser) parseFile(file *ast.File) (*model.Package, error) {
   131  	allImports := importsOfFile(file)
   132  	// Don't stomp imports provided by -imports. Those should take precedence.
   133  	for pkg, path := range allImports {
   134  		if _, ok := p.imports[pkg]; !ok {
   135  			p.imports[pkg] = path
   136  		}
   137  	}
   138  	// Add imports from auxiliary files, which might be needed for embedded interfaces.
   139  	// Don't stomp any other imports.
   140  	for _, f := range p.auxFiles {
   141  		for pkg, path := range importsOfFile(f) {
   142  			if _, ok := p.imports[pkg]; !ok {
   143  				p.imports[pkg] = path
   144  			}
   145  		}
   146  	}
   147  
   148  	var is []*model.Interface
   149  	for ni := range iterInterfaces(file) {
   150  		i, err := p.parseInterface(ni.name.String(), "", ni.it)
   151  		if err != nil {
   152  			return nil, err
   153  		}
   154  		is = append(is, i)
   155  	}
   156  	return &model.Package{
   157  		Name:       file.Name.String(),
   158  		Interfaces: is,
   159  	}, nil
   160  }
   161  
   162  func (p *fileParser) parseInterface(name, pkg string, it *ast.InterfaceType) (*model.Interface, error) {
   163  	intf := &model.Interface{Name: name}
   164  	for _, field := range it.Methods.List {
   165  		switch v := field.Type.(type) {
   166  		case *ast.FuncType:
   167  			if nn := len(field.Names); nn != 1 {
   168  				return nil, fmt.Errorf("expected one name for interface %v, got %d", intf.Name, nn)
   169  			}
   170  			m := &model.Method{
   171  				Name: field.Names[0].String(),
   172  			}
   173  			var err error
   174  			m.In, m.Variadic, m.Out, err = p.parseFunc(pkg, v)
   175  			if err != nil {
   176  				return nil, err
   177  			}
   178  			intf.Methods = append(intf.Methods, m)
   179  		case *ast.Ident:
   180  			// Embedded interface in this package.
   181  			ei := p.auxInterfaces[""][v.String()]
   182  			if ei == nil {
   183  				return nil, p.errorf(v.Pos(), "unknown embedded interface %s", v.String())
   184  			}
   185  			eintf, err := p.parseInterface(v.String(), pkg, ei)
   186  			if err != nil {
   187  				return nil, err
   188  			}
   189  			// Copy the methods.
   190  			// TODO: apply shadowing rules.
   191  			for _, m := range eintf.Methods {
   192  				intf.Methods = append(intf.Methods, m)
   193  			}
   194  		case *ast.SelectorExpr:
   195  			// Embedded interface in another package.
   196  			fpkg, sel := v.X.(*ast.Ident).String(), v.Sel.String()
   197  			ei := p.auxInterfaces[fpkg][sel]
   198  			if ei == nil {
   199  				return nil, p.errorf(v.Pos(), "unknown embedded interface %s.%s", fpkg, sel)
   200  			}
   201  			epkg, ok := p.imports[fpkg]
   202  			if !ok {
   203  				return nil, p.errorf(v.X.Pos(), "unknown package %s", fpkg)
   204  			}
   205  			eintf, err := p.parseInterface(sel, epkg, ei)
   206  			if err != nil {
   207  				return nil, err
   208  			}
   209  			// Copy the methods.
   210  			// TODO: apply shadowing rules.
   211  			for _, m := range eintf.Methods {
   212  				intf.Methods = append(intf.Methods, m)
   213  			}
   214  		default:
   215  			return nil, fmt.Errorf("don't know how to mock method of type %T", field.Type)
   216  		}
   217  	}
   218  	return intf, nil
   219  }
   220  
   221  func (p *fileParser) parseFunc(pkg string, f *ast.FuncType) (in []*model.Parameter, variadic *model.Parameter, out []*model.Parameter, err error) {
   222  	if f.Params != nil {
   223  		regParams := f.Params.List
   224  		if isVariadic(f) {
   225  			n := len(regParams)
   226  			varParams := regParams[n-1:]
   227  			regParams = regParams[:n-1]
   228  			vp, err := p.parseFieldList(pkg, varParams)
   229  			if err != nil {
   230  				return nil, nil, nil, p.errorf(varParams[0].Pos(), "failed parsing variadic argument: %v", err)
   231  			}
   232  			variadic = vp[0]
   233  		}
   234  		in, err = p.parseFieldList(pkg, regParams)
   235  		if err != nil {
   236  			return nil, nil, nil, p.errorf(f.Pos(), "failed parsing arguments: %v", err)
   237  		}
   238  	}
   239  	if f.Results != nil {
   240  		out, err = p.parseFieldList(pkg, f.Results.List)
   241  		if err != nil {
   242  			return nil, nil, nil, p.errorf(f.Pos(), "failed parsing returns: %v", err)
   243  		}
   244  	}
   245  	return
   246  }
   247  
   248  func (p *fileParser) parseFieldList(pkg string, fields []*ast.Field) ([]*model.Parameter, error) {
   249  	nf := 0
   250  	for _, f := range fields {
   251  		nn := len(f.Names)
   252  		if nn == 0 {
   253  			nn = 1 // anonymous parameter
   254  		}
   255  		nf += nn
   256  	}
   257  	if nf == 0 {
   258  		return nil, nil
   259  	}
   260  	ps := make([]*model.Parameter, nf)
   261  	i := 0 // destination index
   262  	for _, f := range fields {
   263  		t, err := p.parseType(pkg, f.Type)
   264  		if err != nil {
   265  			return nil, err
   266  		}
   267  
   268  		if len(f.Names) == 0 {
   269  			// anonymous arg
   270  			ps[i] = &model.Parameter{Type: t}
   271  			i++
   272  			continue
   273  		}
   274  		for _, name := range f.Names {
   275  			ps[i] = &model.Parameter{Name: name.Name, Type: t}
   276  			i++
   277  		}
   278  	}
   279  	return ps, nil
   280  }
   281  
   282  func (p *fileParser) parseType(pkg string, typ ast.Expr) (model.Type, error) {
   283  	switch v := typ.(type) {
   284  	case *ast.ArrayType:
   285  		ln := -1
   286  		if v.Len != nil {
   287  			x, err := strconv.Atoi(v.Len.(*ast.BasicLit).Value)
   288  			if err != nil {
   289  				return nil, p.errorf(v.Len.Pos(), "bad array size: %v", err)
   290  			}
   291  			ln = x
   292  		}
   293  		t, err := p.parseType(pkg, v.Elt)
   294  		if err != nil {
   295  			return nil, err
   296  		}
   297  		return &model.ArrayType{Len: ln, Type: t}, nil
   298  	case *ast.ChanType:
   299  		t, err := p.parseType(pkg, v.Value)
   300  		if err != nil {
   301  			return nil, err
   302  		}
   303  		var dir model.ChanDir
   304  		if v.Dir == ast.SEND {
   305  			dir = model.SendDir
   306  		}
   307  		if v.Dir == ast.RECV {
   308  			dir = model.RecvDir
   309  		}
   310  		return &model.ChanType{Dir: dir, Type: t}, nil
   311  	case *ast.Ellipsis:
   312  		// assume we're parsing a variadic argument
   313  		return p.parseType(pkg, v.Elt)
   314  	case *ast.FuncType:
   315  		in, variadic, out, err := p.parseFunc(pkg, v)
   316  		if err != nil {
   317  			return nil, err
   318  		}
   319  		return &model.FuncType{In: in, Out: out, Variadic: variadic}, nil
   320  	case *ast.Ident:
   321  		if v.IsExported() {
   322  			// assume type in this package
   323  			return &model.NamedType{Package: pkg, Type: v.Name}, nil
   324  		} else {
   325  			// assume predeclared type
   326  			return model.PredeclaredType(v.Name), nil
   327  		}
   328  	case *ast.InterfaceType:
   329  		if v.Methods != nil && len(v.Methods.List) > 0 {
   330  			return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed interface types")
   331  		}
   332  		return model.PredeclaredType("interface{}"), nil
   333  	case *ast.MapType:
   334  		key, err := p.parseType(pkg, v.Key)
   335  		if err != nil {
   336  			return nil, err
   337  		}
   338  		value, err := p.parseType(pkg, v.Value)
   339  		if err != nil {
   340  			return nil, err
   341  		}
   342  		return &model.MapType{Key: key, Value: value}, nil
   343  	case *ast.SelectorExpr:
   344  		pkgName := v.X.(*ast.Ident).String()
   345  		pkg, ok := p.imports[pkgName]
   346  		if !ok {
   347  			return nil, p.errorf(v.Pos(), "unknown package %q", pkgName)
   348  		}
   349  		return &model.NamedType{Package: pkg, Type: v.Sel.String()}, nil
   350  	case *ast.StarExpr:
   351  		t, err := p.parseType(pkg, v.X)
   352  		if err != nil {
   353  			return nil, err
   354  		}
   355  		return &model.PointerType{Type: t}, nil
   356  	case *ast.StructType:
   357  		if v.Fields != nil && len(v.Fields.List) > 0 {
   358  			return nil, p.errorf(v.Pos(), "can't handle non-empty unnamed struct types")
   359  		}
   360  		return model.PredeclaredType("struct{}"), nil
   361  	}
   362  
   363  	return nil, fmt.Errorf("don't know how to parse type %T", typ)
   364  }
   365  
   366  // importsOfFile returns a map of package name to import path
   367  // of the imports in file.
   368  func importsOfFile(file *ast.File) map[string]string {
   369  	/* We have to make guesses about some imports, because imports are not required
   370  	 * to have names. Named imports are always certain. Unnamed imports are guessed
   371  	 * to have a name of the last path component; if the last path component has dots,
   372  	 * the first dot-delimited field is used as the name.
   373  	 */
   374  
   375  	m := make(map[string]string)
   376  	for _, decl := range file.Decls {
   377  		gd, ok := decl.(*ast.GenDecl)
   378  		if !ok || gd.Tok != token.IMPORT {
   379  			continue
   380  		}
   381  		for _, spec := range gd.Specs {
   382  			is, ok := spec.(*ast.ImportSpec)
   383  			if !ok {
   384  				continue
   385  			}
   386  			pkg, importPath := "", string(is.Path.Value)
   387  			importPath = importPath[1 : len(importPath)-1] // remove quotes
   388  
   389  			if is.Name != nil {
   390  				if is.Name.Name == "_" {
   391  					continue
   392  				}
   393  				pkg = removeDot(is.Name.Name)
   394  			} else {
   395  				_, last := path.Split(importPath)
   396  				pkg = strings.SplitN(last, ".", 2)[0]
   397  			}
   398  			if _, ok := m[pkg]; ok {
   399  				log.Fatalf("imported package collision: %q imported twice", pkg)
   400  			}
   401  			m[pkg] = importPath
   402  		}
   403  	}
   404  	return m
   405  }
   406  
   407  func removeDot(s string) string {
   408  	if len(s) > 0 && s[len(s)-1] == '.' {
   409  		return s[0 : len(s)-1]
   410  	}
   411  	return s
   412  }
   413  
   414  type namedInterface struct {
   415  	name *ast.Ident
   416  	it   *ast.InterfaceType
   417  }
   418  
   419  // Create an iterator over all interfaces in file.
   420  func iterInterfaces(file *ast.File) <-chan namedInterface {
   421  	ch := make(chan namedInterface)
   422  	go func() {
   423  		for _, decl := range file.Decls {
   424  			gd, ok := decl.(*ast.GenDecl)
   425  			if !ok || gd.Tok != token.TYPE {
   426  				continue
   427  			}
   428  			for _, spec := range gd.Specs {
   429  				ts, ok := spec.(*ast.TypeSpec)
   430  				if !ok {
   431  					continue
   432  				}
   433  				it, ok := ts.Type.(*ast.InterfaceType)
   434  				if !ok {
   435  					continue
   436  				}
   437  
   438  				ch <- namedInterface{ts.Name, it}
   439  			}
   440  		}
   441  		close(ch)
   442  	}()
   443  	return ch
   444  }
   445  
   446  // isVariadic returns whether the function is variadic.
   447  func isVariadic(f *ast.FuncType) bool {
   448  	nargs := len(f.Params.List)
   449  	if nargs == 0 {
   450  		return false
   451  	}
   452  	_, ok := f.Params.List[nargs-1].Type.(*ast.Ellipsis)
   453  	return ok
   454  }