github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/internal/lsp/source/stub.go (about)

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package source
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"fmt"
    11  	"go/ast"
    12  	"go/format"
    13  	"go/parser"
    14  	"go/token"
    15  	"go/types"
    16  	"strings"
    17  
    18  	"github.com/powerman/golang-tools/go/analysis"
    19  	"github.com/powerman/golang-tools/go/ast/astutil"
    20  	"github.com/powerman/golang-tools/internal/lsp/analysis/stubmethods"
    21  	"github.com/powerman/golang-tools/internal/lsp/protocol"
    22  	"github.com/powerman/golang-tools/internal/span"
    23  	"github.com/powerman/golang-tools/internal/typeparams"
    24  )
    25  
    26  func stubSuggestedFixFunc(ctx context.Context, snapshot Snapshot, fh VersionedFileHandle, rng protocol.Range) (*analysis.SuggestedFix, error) {
    27  	pkg, pgf, err := GetParsedFile(ctx, snapshot, fh, NarrowestPackage)
    28  	if err != nil {
    29  		return nil, fmt.Errorf("GetParsedFile: %w", err)
    30  	}
    31  	nodes, pos, err := getStubNodes(pgf, rng)
    32  	if err != nil {
    33  		return nil, fmt.Errorf("getNodes: %w", err)
    34  	}
    35  	si := stubmethods.GetStubInfo(pkg.GetTypesInfo(), nodes, pos)
    36  	if si == nil {
    37  		return nil, fmt.Errorf("nil interface request")
    38  	}
    39  	parsedConcreteFile, concreteFH, err := getStubFile(ctx, si.Concrete.Obj(), snapshot)
    40  	if err != nil {
    41  		return nil, fmt.Errorf("getFile(concrete): %w", err)
    42  	}
    43  	var (
    44  		methodsSrc  []byte
    45  		stubImports []*stubImport // additional imports needed for method stubs
    46  	)
    47  	if si.Interface.Pkg() == nil && si.Interface.Name() == "error" && si.Interface.Parent() == types.Universe {
    48  		methodsSrc = stubErr(ctx, parsedConcreteFile.File, si, snapshot)
    49  	} else {
    50  		methodsSrc, stubImports, err = stubMethods(ctx, parsedConcreteFile.File, si, snapshot)
    51  	}
    52  	if err != nil {
    53  		return nil, fmt.Errorf("stubMethods: %w", err)
    54  	}
    55  	nodes, _ = astutil.PathEnclosingInterval(parsedConcreteFile.File, si.Concrete.Obj().Pos(), si.Concrete.Obj().Pos())
    56  	concreteSrc, err := concreteFH.Read()
    57  	if err != nil {
    58  		return nil, fmt.Errorf("error reading concrete file source: %w", err)
    59  	}
    60  	insertPos := snapshot.FileSet().Position(nodes[1].End()).Offset
    61  	if insertPos >= len(concreteSrc) {
    62  		return nil, fmt.Errorf("insertion position is past the end of the file")
    63  	}
    64  	var buf bytes.Buffer
    65  	buf.Write(concreteSrc[:insertPos])
    66  	buf.WriteByte('\n')
    67  	buf.Write(methodsSrc)
    68  	buf.Write(concreteSrc[insertPos:])
    69  	fset := token.NewFileSet()
    70  	newF, err := parser.ParseFile(fset, parsedConcreteFile.File.Name.Name, buf.Bytes(), parser.ParseComments)
    71  	if err != nil {
    72  		return nil, fmt.Errorf("could not reparse file: %w", err)
    73  	}
    74  	for _, imp := range stubImports {
    75  		astutil.AddNamedImport(fset, newF, imp.Name, imp.Path)
    76  	}
    77  	var source bytes.Buffer
    78  	err = format.Node(&source, fset, newF)
    79  	if err != nil {
    80  		return nil, fmt.Errorf("format.Node: %w", err)
    81  	}
    82  	diffEdits, err := snapshot.View().Options().ComputeEdits(parsedConcreteFile.URI, string(parsedConcreteFile.Src), source.String())
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	var edits []analysis.TextEdit
    87  	for _, edit := range diffEdits {
    88  		rng, err := edit.Span.Range(parsedConcreteFile.Mapper.Converter)
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  		edits = append(edits, analysis.TextEdit{
    93  			Pos:     rng.Start,
    94  			End:     rng.End,
    95  			NewText: []byte(edit.NewText),
    96  		})
    97  	}
    98  	return &analysis.SuggestedFix{
    99  		TextEdits: edits,
   100  	}, nil
   101  }
   102  
   103  // stubMethods returns the Go code of all methods
   104  // that implement the given interface
   105  func stubMethods(ctx context.Context, concreteFile *ast.File, si *stubmethods.StubInfo, snapshot Snapshot) ([]byte, []*stubImport, error) {
   106  	ifacePkg, err := deducePkgFromTypes(ctx, snapshot, si.Interface)
   107  	if err != nil {
   108  		return nil, nil, err
   109  	}
   110  	si.Concrete.Obj().Type()
   111  	concMS := types.NewMethodSet(types.NewPointer(si.Concrete.Obj().Type()))
   112  	missing, err := missingMethods(ctx, snapshot, concMS, si.Concrete.Obj().Pkg(), si.Interface, ifacePkg, map[string]struct{}{})
   113  	if err != nil {
   114  		return nil, nil, fmt.Errorf("missingMethods: %w", err)
   115  	}
   116  	if len(missing) == 0 {
   117  		return nil, nil, fmt.Errorf("no missing methods found")
   118  	}
   119  	var (
   120  		stubImports   []*stubImport
   121  		methodsBuffer bytes.Buffer
   122  	)
   123  	for _, mi := range missing {
   124  		for _, m := range mi.missing {
   125  			// TODO(marwan-at-work): this should share the same logic with source.FormatVarType
   126  			// as it also accounts for type aliases.
   127  			sig := types.TypeString(m.Type(), stubmethods.RelativeToFiles(si.Concrete.Obj().Pkg(), concreteFile, mi.file, func(name, path string) {
   128  				for _, imp := range stubImports {
   129  					if imp.Name == name && imp.Path == path {
   130  						return
   131  					}
   132  				}
   133  				stubImports = append(stubImports, &stubImport{name, path})
   134  			}))
   135  			_, err = methodsBuffer.Write(printStubMethod(methodData{
   136  				Method:    m.Name(),
   137  				Concrete:  getStubReceiver(si),
   138  				Interface: deduceIfaceName(si.Concrete.Obj().Pkg(), si.Interface.Pkg(), si.Interface),
   139  				Signature: strings.TrimPrefix(sig, "func"),
   140  			}))
   141  			if err != nil {
   142  				return nil, nil, fmt.Errorf("error printing method: %w", err)
   143  			}
   144  			methodsBuffer.WriteRune('\n')
   145  		}
   146  	}
   147  	return methodsBuffer.Bytes(), stubImports, nil
   148  }
   149  
   150  // stubErr reurns the Go code implementation
   151  // of an error interface relevant to the
   152  // concrete type
   153  func stubErr(ctx context.Context, concreteFile *ast.File, si *stubmethods.StubInfo, snapshot Snapshot) []byte {
   154  	return printStubMethod(methodData{
   155  		Method:    "Error",
   156  		Interface: "error",
   157  		Concrete:  getStubReceiver(si),
   158  		Signature: "() string",
   159  	})
   160  }
   161  
   162  // getStubReceiver returns the concrete type's name as a method receiver.
   163  // It accounts for type parameters if they exist.
   164  func getStubReceiver(si *stubmethods.StubInfo) string {
   165  	var concrete string
   166  	if si.Pointer {
   167  		concrete += "*"
   168  	}
   169  	concrete += si.Concrete.Obj().Name()
   170  	concrete += FormatTypeParams(typeparams.ForNamed(si.Concrete))
   171  	return concrete
   172  }
   173  
   174  type methodData struct {
   175  	Method    string
   176  	Interface string
   177  	Concrete  string
   178  	Signature string
   179  }
   180  
   181  // printStubMethod takes methodData and returns Go code that represents the given method such as:
   182  // 	// {{ .Method }} implements {{ .Interface }}
   183  // 	func ({{ .Concrete }}) {{ .Method }}{{ .Signature }} {
   184  // 		panic("unimplemented")
   185  // 	}
   186  func printStubMethod(md methodData) []byte {
   187  	var b bytes.Buffer
   188  	fmt.Fprintf(&b, "// %s implements %s\n", md.Method, md.Interface)
   189  	fmt.Fprintf(&b, "func (%s) %s%s {\n\t", md.Concrete, md.Method, md.Signature)
   190  	fmt.Fprintln(&b, `panic("unimplemented")`)
   191  	fmt.Fprintln(&b, "}")
   192  	return b.Bytes()
   193  }
   194  
   195  func deducePkgFromTypes(ctx context.Context, snapshot Snapshot, ifaceObj types.Object) (Package, error) {
   196  	pkgs, err := snapshot.KnownPackages(ctx)
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  	for _, p := range pkgs {
   201  		if p.PkgPath() == ifaceObj.Pkg().Path() {
   202  			return p, nil
   203  		}
   204  	}
   205  	return nil, fmt.Errorf("pkg %q not found", ifaceObj.Pkg().Path())
   206  }
   207  
   208  func deduceIfaceName(concretePkg, ifacePkg *types.Package, ifaceObj types.Object) string {
   209  	if concretePkg.Path() == ifacePkg.Path() {
   210  		return ifaceObj.Name()
   211  	}
   212  	return fmt.Sprintf("%s.%s", ifacePkg.Name(), ifaceObj.Name())
   213  }
   214  
   215  func getStubNodes(pgf *ParsedGoFile, pRng protocol.Range) ([]ast.Node, token.Pos, error) {
   216  	spn, err := pgf.Mapper.RangeSpan(pRng)
   217  	if err != nil {
   218  		return nil, 0, err
   219  	}
   220  	rng, err := spn.Range(pgf.Mapper.Converter)
   221  	if err != nil {
   222  		return nil, 0, err
   223  	}
   224  	nodes, _ := astutil.PathEnclosingInterval(pgf.File, rng.Start, rng.End)
   225  	return nodes, rng.Start, nil
   226  }
   227  
   228  /*
   229  missingMethods takes a concrete type and returns any missing methods for the given interface as well as
   230  any missing interface that might have been embedded to its parent. For example:
   231  
   232  type I interface {
   233  	io.Writer
   234  	Hello()
   235  }
   236  returns []*missingInterface{
   237  	{
   238  		iface: *types.Interface (io.Writer),
   239  		file: *ast.File: io.go,
   240  		missing []*types.Func{Write},
   241  	},
   242  	{
   243  		iface: *types.Interface (I),
   244  		file: *ast.File: myfile.go,
   245  		missing: []*types.Func{Hello}
   246  	},
   247  }
   248  */
   249  func missingMethods(ctx context.Context, snapshot Snapshot, concMS *types.MethodSet, concPkg *types.Package, ifaceObj types.Object, ifacePkg Package, visited map[string]struct{}) ([]*missingInterface, error) {
   250  	iface, ok := ifaceObj.Type().Underlying().(*types.Interface)
   251  	if !ok {
   252  		return nil, fmt.Errorf("expected %v to be an interface but got %T", iface, ifaceObj.Type().Underlying())
   253  	}
   254  	missing := []*missingInterface{}
   255  	for i := 0; i < iface.NumEmbeddeds(); i++ {
   256  		eiface := iface.Embedded(i).Obj()
   257  		depPkg := ifacePkg
   258  		if eiface.Pkg().Path() != ifacePkg.PkgPath() {
   259  			var err error
   260  			depPkg, err = ifacePkg.GetImport(eiface.Pkg().Path())
   261  			if err != nil {
   262  				return nil, err
   263  			}
   264  		}
   265  		em, err := missingMethods(ctx, snapshot, concMS, concPkg, eiface, depPkg, visited)
   266  		if err != nil {
   267  			return nil, err
   268  		}
   269  		missing = append(missing, em...)
   270  	}
   271  	parsedFile, _, err := getStubFile(ctx, ifaceObj, snapshot)
   272  	if err != nil {
   273  		return nil, fmt.Errorf("error getting iface file: %w", err)
   274  	}
   275  	mi := &missingInterface{
   276  		pkg:   ifacePkg,
   277  		iface: iface,
   278  		file:  parsedFile.File,
   279  	}
   280  	if mi.file == nil {
   281  		return nil, fmt.Errorf("could not find ast.File for %v", ifaceObj.Name())
   282  	}
   283  	for i := 0; i < iface.NumExplicitMethods(); i++ {
   284  		method := iface.ExplicitMethod(i)
   285  		// if the concrete type does not have the interface method
   286  		if concMS.Lookup(concPkg, method.Name()) == nil {
   287  			if _, ok := visited[method.Name()]; !ok {
   288  				mi.missing = append(mi.missing, method)
   289  				visited[method.Name()] = struct{}{}
   290  			}
   291  		}
   292  		if sel := concMS.Lookup(concPkg, method.Name()); sel != nil {
   293  			implSig := sel.Type().(*types.Signature)
   294  			ifaceSig := method.Type().(*types.Signature)
   295  			if !types.Identical(ifaceSig, implSig) {
   296  				return nil, fmt.Errorf("mimsatched %q function signatures:\nhave: %s\nwant: %s", method.Name(), implSig, ifaceSig)
   297  			}
   298  		}
   299  	}
   300  	if len(mi.missing) > 0 {
   301  		missing = append(missing, mi)
   302  	}
   303  	return missing, nil
   304  }
   305  
   306  func getStubFile(ctx context.Context, obj types.Object, snapshot Snapshot) (*ParsedGoFile, VersionedFileHandle, error) {
   307  	objPos := snapshot.FileSet().Position(obj.Pos())
   308  	objFile := span.URIFromPath(objPos.Filename)
   309  	objectFH := snapshot.FindFile(objFile)
   310  	_, goFile, err := GetParsedFile(ctx, snapshot, objectFH, WidestPackage)
   311  	if err != nil {
   312  		return nil, nil, fmt.Errorf("GetParsedFile: %w", err)
   313  	}
   314  	return goFile, objectFH, nil
   315  }
   316  
   317  // missingInterface represents an interface
   318  // that has all or some of its methods missing
   319  // from the destination concrete type
   320  type missingInterface struct {
   321  	iface   *types.Interface
   322  	file    *ast.File
   323  	pkg     Package
   324  	missing []*types.Func
   325  }
   326  
   327  // stubImport represents a newly added import
   328  // statement to the concrete type. If name is not
   329  // empty, then that import is required to have that name.
   330  type stubImport struct{ Name, Path string }