github.com/v2fly/tools@v0.100.0/internal/lsp/source/implementation.go (about) 1 // Copyright 2019 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package source 6 7 import ( 8 "context" 9 "errors" 10 "fmt" 11 "go/ast" 12 "go/token" 13 "go/types" 14 "sort" 15 16 "github.com/v2fly/tools/internal/event" 17 "github.com/v2fly/tools/internal/lsp/protocol" 18 "golang.org/x/xerrors" 19 ) 20 21 func Implementation(ctx context.Context, snapshot Snapshot, f FileHandle, pp protocol.Position) ([]protocol.Location, error) { 22 ctx, done := event.Start(ctx, "source.Implementation") 23 defer done() 24 25 impls, err := implementations(ctx, snapshot, f, pp) 26 if err != nil { 27 return nil, err 28 } 29 var locations []protocol.Location 30 for _, impl := range impls { 31 if impl.pkg == nil || len(impl.pkg.CompiledGoFiles()) == 0 { 32 continue 33 } 34 rng, err := objToMappedRange(snapshot, impl.pkg, impl.obj) 35 if err != nil { 36 return nil, err 37 } 38 pr, err := rng.Range() 39 if err != nil { 40 return nil, err 41 } 42 locations = append(locations, protocol.Location{ 43 URI: protocol.URIFromSpanURI(rng.URI()), 44 Range: pr, 45 }) 46 } 47 sort.Slice(locations, func(i, j int) bool { 48 li, lj := locations[i], locations[j] 49 if li.URI == lj.URI { 50 return protocol.CompareRange(li.Range, lj.Range) < 0 51 } 52 return li.URI < lj.URI 53 }) 54 return locations, nil 55 } 56 57 var ErrNotAType = errors.New("not a type name or method") 58 59 // implementations returns the concrete implementations of the specified 60 // interface, or the interfaces implemented by the specified concrete type. 61 func implementations(ctx context.Context, s Snapshot, f FileHandle, pp protocol.Position) ([]qualifiedObject, error) { 62 var ( 63 impls []qualifiedObject 64 seen = make(map[token.Position]bool) 65 fset = s.FileSet() 66 ) 67 68 qos, err := qualifiedObjsAtProtocolPos(ctx, s, f, pp) 69 if err != nil { 70 return nil, err 71 } 72 for _, qo := range qos { 73 var ( 74 queryType types.Type 75 queryMethod *types.Func 76 ) 77 78 switch obj := qo.obj.(type) { 79 case *types.Func: 80 queryMethod = obj 81 if recv := obj.Type().(*types.Signature).Recv(); recv != nil { 82 queryType = ensurePointer(recv.Type()) 83 } 84 case *types.TypeName: 85 queryType = ensurePointer(obj.Type()) 86 } 87 88 if queryType == nil { 89 return nil, ErrNotAType 90 } 91 92 if types.NewMethodSet(queryType).Len() == 0 { 93 return nil, nil 94 } 95 96 // Find all named types, even local types (which can have methods 97 // due to promotion). 98 var ( 99 allNamed []*types.Named 100 pkgs = make(map[*types.Package]Package) 101 ) 102 knownPkgs, err := s.KnownPackages(ctx) 103 if err != nil { 104 return nil, err 105 } 106 for _, pkg := range knownPkgs { 107 pkgs[pkg.GetTypes()] = pkg 108 info := pkg.GetTypesInfo() 109 for _, obj := range info.Defs { 110 obj, ok := obj.(*types.TypeName) 111 // We ignore aliases 'type M = N' to avoid duplicate reporting 112 // of the Named type N. 113 if !ok || obj.IsAlias() { 114 continue 115 } 116 if named, ok := obj.Type().(*types.Named); ok { 117 allNamed = append(allNamed, named) 118 } 119 } 120 } 121 122 // Find all the named types that match our query. 123 for _, named := range allNamed { 124 var ( 125 candObj types.Object = named.Obj() 126 candType = ensurePointer(named) 127 ) 128 129 if !concreteImplementsIntf(candType, queryType) { 130 continue 131 } 132 133 ms := types.NewMethodSet(candType) 134 if ms.Len() == 0 { 135 // Skip empty interfaces. 136 continue 137 } 138 139 // If client queried a method, look up corresponding candType method. 140 if queryMethod != nil { 141 sel := ms.Lookup(queryMethod.Pkg(), queryMethod.Name()) 142 if sel == nil { 143 continue 144 } 145 candObj = sel.Obj() 146 } 147 148 pos := fset.Position(candObj.Pos()) 149 if candObj == queryMethod || seen[pos] { 150 continue 151 } 152 153 seen[pos] = true 154 155 impls = append(impls, qualifiedObject{ 156 obj: candObj, 157 pkg: pkgs[candObj.Pkg()], 158 }) 159 } 160 } 161 162 return impls, nil 163 } 164 165 // concreteImplementsIntf returns true if a is an interface type implemented by 166 // concrete type b, or vice versa. 167 func concreteImplementsIntf(a, b types.Type) bool { 168 aIsIntf, bIsIntf := IsInterface(a), IsInterface(b) 169 170 // Make sure exactly one is an interface type. 171 if aIsIntf == bIsIntf { 172 return false 173 } 174 175 // Rearrange if needed so "a" is the concrete type. 176 if aIsIntf { 177 a, b = b, a 178 } 179 180 return types.AssignableTo(a, b) 181 } 182 183 // ensurePointer wraps T in a *types.Pointer if T is a named, non-interface 184 // type. This is useful to make sure you consider a named type's full method 185 // set. 186 func ensurePointer(T types.Type) types.Type { 187 if _, ok := T.(*types.Named); ok && !IsInterface(T) { 188 return types.NewPointer(T) 189 } 190 191 return T 192 } 193 194 type qualifiedObject struct { 195 obj types.Object 196 197 // pkg is the Package that contains obj's definition. 198 pkg Package 199 200 // node is the *ast.Ident or *ast.ImportSpec we followed to find obj, if any. 201 node ast.Node 202 203 // sourcePkg is the Package that contains node, if any. 204 sourcePkg Package 205 } 206 207 var ( 208 errBuiltin = errors.New("builtin object") 209 errNoObjectFound = errors.New("no object found") 210 ) 211 212 // qualifiedObjsAtProtocolPos returns info for all the type.Objects 213 // referenced at the given position. An object will be returned for 214 // every package that the file belongs to, in every typechecking mode 215 // applicable. 216 func qualifiedObjsAtProtocolPos(ctx context.Context, s Snapshot, fh FileHandle, pp protocol.Position) ([]qualifiedObject, error) { 217 pkgs, err := s.PackagesForFile(ctx, fh.URI(), TypecheckAll) 218 if err != nil { 219 return nil, err 220 } 221 // Check all the packages that the file belongs to. 222 var qualifiedObjs []qualifiedObject 223 for _, searchpkg := range pkgs { 224 astFile, pos, err := getASTFile(searchpkg, fh, pp) 225 if err != nil { 226 return nil, err 227 } 228 path := pathEnclosingObjNode(astFile, pos) 229 if path == nil { 230 continue 231 } 232 var objs []types.Object 233 switch leaf := path[0].(type) { 234 case *ast.Ident: 235 // If leaf represents an implicit type switch object or the type 236 // switch "assign" variable, expand to all of the type switch's 237 // implicit objects. 238 if implicits, _ := typeSwitchImplicits(searchpkg, path); len(implicits) > 0 { 239 objs = append(objs, implicits...) 240 } else { 241 obj := searchpkg.GetTypesInfo().ObjectOf(leaf) 242 if obj == nil { 243 return nil, xerrors.Errorf("%w for %q", errNoObjectFound, leaf.Name) 244 } 245 objs = append(objs, obj) 246 } 247 case *ast.ImportSpec: 248 // Look up the implicit *types.PkgName. 249 obj := searchpkg.GetTypesInfo().Implicits[leaf] 250 if obj == nil { 251 return nil, xerrors.Errorf("%w for import %q", errNoObjectFound, ImportPath(leaf)) 252 } 253 objs = append(objs, obj) 254 } 255 // Get all of the transitive dependencies of the search package. 256 pkgs := make(map[*types.Package]Package) 257 var addPkg func(pkg Package) 258 addPkg = func(pkg Package) { 259 pkgs[pkg.GetTypes()] = pkg 260 for _, imp := range pkg.Imports() { 261 if _, ok := pkgs[imp.GetTypes()]; !ok { 262 addPkg(imp) 263 } 264 } 265 } 266 addPkg(searchpkg) 267 for _, obj := range objs { 268 if obj.Parent() == types.Universe { 269 return nil, xerrors.Errorf("%q: %w", obj.Name(), errBuiltin) 270 } 271 pkg, ok := pkgs[obj.Pkg()] 272 if !ok { 273 event.Error(ctx, fmt.Sprintf("no package for obj %s: %v", obj, obj.Pkg()), err) 274 continue 275 } 276 qualifiedObjs = append(qualifiedObjs, qualifiedObject{ 277 obj: obj, 278 pkg: pkg, 279 sourcePkg: searchpkg, 280 node: path[0], 281 }) 282 } 283 } 284 // Return an error if no objects were found since callers will assume that 285 // the slice has at least 1 element. 286 if len(qualifiedObjs) == 0 { 287 return nil, errNoObjectFound 288 } 289 return qualifiedObjs, nil 290 } 291 292 func getASTFile(pkg Package, f FileHandle, pos protocol.Position) (*ast.File, token.Pos, error) { 293 pgf, err := pkg.File(f.URI()) 294 if err != nil { 295 return nil, 0, err 296 } 297 spn, err := pgf.Mapper.PointSpan(pos) 298 if err != nil { 299 return nil, 0, err 300 } 301 rng, err := spn.Range(pgf.Mapper.Converter) 302 if err != nil { 303 return nil, 0, err 304 } 305 return pgf.File, rng.Start, nil 306 } 307 308 // pathEnclosingObjNode returns the AST path to the object-defining 309 // node associated with pos. "Object-defining" means either an 310 // *ast.Ident mapped directly to a types.Object or an ast.Node mapped 311 // implicitly to a types.Object. 312 func pathEnclosingObjNode(f *ast.File, pos token.Pos) []ast.Node { 313 var ( 314 path []ast.Node 315 found bool 316 ) 317 318 ast.Inspect(f, func(n ast.Node) bool { 319 if found { 320 return false 321 } 322 323 if n == nil { 324 path = path[:len(path)-1] 325 return false 326 } 327 328 path = append(path, n) 329 330 switch n := n.(type) { 331 case *ast.Ident: 332 // Include the position directly after identifier. This handles 333 // the common case where the cursor is right after the 334 // identifier the user is currently typing. Previously we 335 // handled this by calling astutil.PathEnclosingInterval twice, 336 // once for "pos" and once for "pos-1". 337 found = n.Pos() <= pos && pos <= n.End() 338 case *ast.ImportSpec: 339 if n.Path.Pos() <= pos && pos < n.Path.End() { 340 found = true 341 // If import spec has a name, add name to path even though 342 // position isn't in the name. 343 if n.Name != nil { 344 path = append(path, n.Name) 345 } 346 } 347 case *ast.StarExpr: 348 // Follow star expressions to the inner identifier. 349 if pos == n.Star { 350 pos = n.X.Pos() 351 } 352 case *ast.SelectorExpr: 353 // If pos is on the ".", move it into the selector. 354 if pos == n.X.End() { 355 pos = n.Sel.Pos() 356 } 357 } 358 359 return !found 360 }) 361 362 if len(path) == 0 { 363 return nil 364 } 365 366 // Reverse path so leaf is first element. 367 for i := 0; i < len(path)/2; i++ { 368 path[i], path[len(path)-1-i] = path[len(path)-1-i], path[i] 369 } 370 371 return path 372 }