github.com/inturn/pre-commit-gobuild@v1.0.12/internal/errchecker/embedded_walker.go (about)

     1  package errchecker
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  )
     7  
     8  // walkThroughEmbeddedInterfaces returns a slice of Interfaces that
     9  // we need to walk through in order to reach the actual definition,
    10  // in an Interface, of the method selected by the given selection.
    11  //
    12  // false will be returned in the second return value if:
    13  //   - the right side of the selection is not a function
    14  //   - the actual definition of the function is not in an Interface
    15  //
    16  // The returned slice will contain all the interface types that need
    17  // to be walked through to reach the actual definition.
    18  //
    19  // For example, say we have:
    20  //
    21  //    type Inner interface {Method()}
    22  //    type Middle interface {Inner}
    23  //    type Outer interface {Middle}
    24  //    type T struct {Outer}
    25  //    type U struct {T}
    26  //    type V struct {U}
    27  //
    28  // And then the selector:
    29  //
    30  //    V.Method
    31  //
    32  // We'll return [Outer, Middle, Inner] by first walking through the embedded structs
    33  // until we reach the Outer interface, then descending through the embedded interfaces
    34  // until we find the one that actually explicitly defines Method.
    35  func walkThroughEmbeddedInterfaces(sel *types.Selection) ([]types.Type, bool) {
    36  	fn, ok := sel.Obj().(*types.Func)
    37  	if !ok {
    38  		return nil, false
    39  	}
    40  
    41  	// Start off at the receiver.
    42  	currentT := sel.Recv()
    43  
    44  	// First, we can walk through any Struct fields provided
    45  	// by the selection Index() method. We ignore the last
    46  	// index because it would give the method itself.
    47  	indexes := sel.Index()
    48  	for _, fieldIndex := range indexes[:len(indexes)-1] {
    49  		currentT = getTypeAtFieldIndex(currentT, fieldIndex)
    50  	}
    51  
    52  	// Now currentT is either a type implementing the actual function,
    53  	// an Invalid type (if the receiver is a package), or an interface.
    54  	//
    55  	// If it's not an Interface, then we're done, as this function
    56  	// only cares about Interface-defined functions.
    57  	//
    58  	// If it is an Interface, we potentially need to continue digging until
    59  	// we find the Interface that actually explicitly defines the function.
    60  	interfaceT, ok := maybeUnname(currentT).(*types.Interface)
    61  	if !ok {
    62  		return nil, false
    63  	}
    64  
    65  	// The first interface we pass through is this one we've found. We return the possibly
    66  	// wrapping types.Named because it is more useful to work with for callers.
    67  	result := []types.Type{currentT}
    68  
    69  	// If this interface itself explicitly defines the given method
    70  	// then we're done digging.
    71  	for !explicitlyDefinesMethod(interfaceT, fn) {
    72  		// Otherwise, we find which of the embedded interfaces _does_
    73  		// define the method, add it to our list, and loop.
    74  		namedInterfaceT, ok := getEmbeddedInterfaceDefiningMethod(interfaceT, fn)
    75  		if !ok {
    76  			// This should be impossible as long as we type-checked: either the
    77  			// interface or one of its embedded ones must implement the method...
    78  			panic(fmt.Sprintf("either %v or one of its embedded interfaces must implement %v", currentT, fn))
    79  		}
    80  		result = append(result, namedInterfaceT)
    81  		interfaceT = namedInterfaceT.Underlying().(*types.Interface)
    82  	}
    83  
    84  	return result, true
    85  }
    86  
    87  func getTypeAtFieldIndex(startingAt types.Type, fieldIndex int) types.Type {
    88  	t := maybeUnname(maybeDereference(startingAt))
    89  	s, ok := t.(*types.Struct)
    90  	if !ok {
    91  		panic(fmt.Sprintf("cannot get Field of a type that is not a struct, got a %T", t))
    92  	}
    93  
    94  	return s.Field(fieldIndex).Type()
    95  }
    96  
    97  // getEmbeddedInterfaceDefiningMethod searches through any embedded interfaces of the
    98  // passed interface searching for one that defines the given function. If found, the
    99  // types.Named wrapping that interface will be returned along with true in the second value.
   100  //
   101  // If no such embedded interface is found, nil and false are returned.
   102  func getEmbeddedInterfaceDefiningMethod(interfaceT *types.Interface, fn *types.Func) (*types.Named, bool) {
   103  	for i := 0; i < interfaceT.NumEmbeddeds(); i++ {
   104  		embedded := interfaceT.Embedded(i)
   105  		if definesMethod(embedded.Underlying().(*types.Interface), fn) {
   106  			return embedded, true
   107  		}
   108  	}
   109  	return nil, false
   110  }
   111  
   112  func explicitlyDefinesMethod(interfaceT *types.Interface, fn *types.Func) bool {
   113  	for i := 0; i < interfaceT.NumExplicitMethods(); i++ {
   114  		if interfaceT.ExplicitMethod(i) == fn {
   115  			return true
   116  		}
   117  	}
   118  	return false
   119  }
   120  
   121  func definesMethod(interfaceT *types.Interface, fn *types.Func) bool {
   122  	for i := 0; i < interfaceT.NumMethods(); i++ {
   123  		if interfaceT.Method(i) == fn {
   124  			return true
   125  		}
   126  	}
   127  	return false
   128  }
   129  
   130  func maybeDereference(t types.Type) types.Type {
   131  	p, ok := t.(*types.Pointer)
   132  	if ok {
   133  		return p.Elem()
   134  	}
   135  	return t
   136  }
   137  
   138  func maybeUnname(t types.Type) types.Type {
   139  	n, ok := t.(*types.Named)
   140  	if ok {
   141  		return n.Underlying()
   142  	}
   143  	return t
   144  }