golang.org/x/exp@v0.0.0-20240506185415-9bf2ced13842/apidiff/correspondence.go (about)

     1  package apidiff
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"sort"
     7  )
     8  
     9  // Two types are correspond if they are identical except for defined types,
    10  // which must correspond.
    11  //
    12  // Two defined types correspond if they can be interchanged in the old and new APIs,
    13  // possibly after a renaming.
    14  //
    15  // This is not a pure function. If we come across named types while traversing,
    16  // we establish correspondence.
    17  func (d *differ) correspond(old, new types.Type) bool {
    18  	return d.corr(old, new, nil)
    19  }
    20  
    21  // corr determines whether old and new correspond. The argument p is a list of
    22  // known interface identities, to avoid infinite recursion.
    23  //
    24  // corr calls itself recursively as much as possible, to establish more
    25  // correspondences and so check more of the API. E.g. if the new function has more
    26  // parameters than the old, compare all the old ones before returning false.
    27  //
    28  // Compare this to the implementation of go/types.Identical.
    29  func (d *differ) corr(old, new types.Type, p *ifacePair) bool {
    30  	// Structure copied from types.Identical.
    31  	switch old := old.(type) {
    32  	case *types.Basic:
    33  		if new, ok := new.(*types.Basic); ok {
    34  			return old.Kind() == new.Kind()
    35  		}
    36  
    37  	case *types.Array:
    38  		if new, ok := new.(*types.Array); ok {
    39  			return d.corr(old.Elem(), new.Elem(), p) && old.Len() == new.Len()
    40  		}
    41  
    42  	case *types.Slice:
    43  		if new, ok := new.(*types.Slice); ok {
    44  			return d.corr(old.Elem(), new.Elem(), p)
    45  		}
    46  
    47  	case *types.Map:
    48  		if new, ok := new.(*types.Map); ok {
    49  			return d.corr(old.Key(), new.Key(), p) && d.corr(old.Elem(), new.Elem(), p)
    50  		}
    51  
    52  	case *types.Chan:
    53  		if new, ok := new.(*types.Chan); ok {
    54  			return d.corr(old.Elem(), new.Elem(), p) && old.Dir() == new.Dir()
    55  		}
    56  
    57  	case *types.Pointer:
    58  		if new, ok := new.(*types.Pointer); ok {
    59  			return d.corr(old.Elem(), new.Elem(), p)
    60  		}
    61  
    62  	case *types.Signature:
    63  		if new, ok := new.(*types.Signature); ok {
    64  			pe := d.corr(old.Params(), new.Params(), p)
    65  			re := d.corr(old.Results(), new.Results(), p)
    66  			return old.Variadic() == new.Variadic() && pe && re
    67  		}
    68  
    69  	case *types.Tuple:
    70  		if new, ok := new.(*types.Tuple); ok {
    71  			for i := 0; i < old.Len(); i++ {
    72  				if i >= new.Len() || !d.corr(old.At(i).Type(), new.At(i).Type(), p) {
    73  					return false
    74  				}
    75  			}
    76  			return old.Len() == new.Len()
    77  		}
    78  
    79  	case *types.Struct:
    80  		if new, ok := new.(*types.Struct); ok {
    81  			for i := 0; i < old.NumFields(); i++ {
    82  				if i >= new.NumFields() {
    83  					return false
    84  				}
    85  				of := old.Field(i)
    86  				nf := new.Field(i)
    87  				if of.Anonymous() != nf.Anonymous() ||
    88  					old.Tag(i) != new.Tag(i) ||
    89  					!d.corr(of.Type(), nf.Type(), p) ||
    90  					!d.corrFieldNames(of, nf) {
    91  					return false
    92  				}
    93  			}
    94  			return old.NumFields() == new.NumFields()
    95  		}
    96  
    97  	case *types.Interface:
    98  		if new, ok := new.(*types.Interface); ok {
    99  			// Deal with circularity. See the comment in types.Identical.
   100  			q := &ifacePair{old, new, p}
   101  			for p != nil {
   102  				if p.identical(q) {
   103  					return true // same pair was compared before
   104  				}
   105  				p = p.prev
   106  			}
   107  			oldms := d.sortedMethods(old)
   108  			newms := d.sortedMethods(new)
   109  			for i, om := range oldms {
   110  				if i >= len(newms) {
   111  					return false
   112  				}
   113  				nm := newms[i]
   114  				if d.methodID(om) != d.methodID(nm) || !d.corr(om.Type(), nm.Type(), q) {
   115  					return false
   116  				}
   117  			}
   118  			return old.NumMethods() == new.NumMethods()
   119  		}
   120  
   121  	case *types.Named:
   122  		return d.establishCorrespondence(old, new)
   123  
   124  	case *types.TypeParam:
   125  		if new, ok := new.(*types.TypeParam); ok {
   126  			if old.Index() == new.Index() {
   127  				return true
   128  			}
   129  		}
   130  
   131  	default:
   132  		panic(fmt.Sprintf("unknown type kind %T", old))
   133  	}
   134  	return false
   135  }
   136  
   137  // Compare old and new field names. We are determining correspondence across packages,
   138  // so just compare names, not packages. For an unexported, embedded field of named
   139  // type (non-named embedded fields are possible with aliases), we check that the type
   140  // names correspond. We check the types for correspondence before this is called, so
   141  // we've established correspondence.
   142  func (d *differ) corrFieldNames(of, nf *types.Var) bool {
   143  	if of.Anonymous() && nf.Anonymous() && !of.Exported() && !nf.Exported() {
   144  		if on, ok := of.Type().(*types.Named); ok {
   145  			nn := nf.Type().(*types.Named)
   146  			return d.establishCorrespondence(on, nn)
   147  		}
   148  	}
   149  	return of.Name() == nf.Name()
   150  }
   151  
   152  // establishCorrespondence records and validates a correspondence between
   153  // old and new.
   154  //
   155  // If this is the first type corresponding to old, it checks that the type
   156  // declaration is compatible with old and records its correspondence.
   157  // Otherwise, it checks that new is equivalent to the previously recorded
   158  // type corresponding to old.
   159  func (d *differ) establishCorrespondence(old *types.Named, new types.Type) bool {
   160  	oldname := old.Obj()
   161  	// If there already is a corresponding new type for old, check that they
   162  	// are the same.
   163  	if c := d.correspondMap.At(old); c != nil {
   164  		return types.Identical(c.(types.Type), new)
   165  	}
   166  	// Attempt to establish a correspondence.
   167  	// Assume the types don't correspond unless they have the same
   168  	// ID, or are from the old and new packages, respectively.
   169  	//
   170  	// This is too conservative. For instance,
   171  	//    [old] type A = q.B; [new] type A q.C
   172  	// could be OK if in package q, B is an alias for C.
   173  	// Or, using p as the name of the current old/new packages:
   174  	//    [old] type A = q.B; [new] type A int
   175  	// could be OK if in q,
   176  	//    [old] type B int; [new] type B = p.A
   177  	// In this case, p.A and q.B name the same type in both old and new worlds.
   178  	// Note that this case doesn't imply circular package imports: it's possible
   179  	// that in the old world, p imports q, but in the new, q imports p.
   180  	//
   181  	// However, if we didn't do something here, then we'd incorrectly allow cases
   182  	// like the first one above in which q.B is not an alias for q.C
   183  	//
   184  	// What we should do is check that the old type, in the new world's package
   185  	// of the same path, doesn't correspond to something other than the new type.
   186  	// That is a bit hard, because there is no easy way to find a new package
   187  	// matching an old one.
   188  	switch new := new.(type) {
   189  	case *types.Named:
   190  		newn := new
   191  		oobj := old.Obj()
   192  		nobj := newn.Obj()
   193  		if oobj.Pkg() != d.old || nobj.Pkg() != d.new {
   194  			// Compare the fully qualified names of the types.
   195  			//
   196  			// TODO(jba): when comparing modules, we should only look at the
   197  			// paths relative to the module path, because the module paths may differ.
   198  			// See cmd/gorelease/testdata/internalcompat.
   199  			var opath, npath string
   200  			if oobj.Pkg() != nil {
   201  				opath = oobj.Pkg().Path()
   202  			}
   203  			if nobj.Pkg() != nil {
   204  				npath = nobj.Pkg().Path()
   205  			}
   206  			return oobj.Name() == nobj.Name() && opath == npath
   207  		}
   208  		// Two generic named types correspond if their type parameter lists correspond.
   209  		// Since one or the other of those lists will be empty, it doesn't hurt
   210  		// to check both.
   211  		oldOrigin := old.Origin()
   212  		newOrigin := newn.Origin()
   213  		if oldOrigin != old {
   214  			// old is an instantiated type.
   215  			if newOrigin == newn {
   216  				// new is not; they cannot correspond.
   217  				return false
   218  			}
   219  			// Two instantiated types correspond if their origins correspond and
   220  			// their type argument lists correspond.
   221  			if !d.correspond(oldOrigin, newOrigin) {
   222  				return false
   223  			}
   224  			if !d.typeListsCorrespond(old.TypeArgs(), newn.TypeArgs()) {
   225  				return false
   226  			}
   227  		} else {
   228  			if !d.typeParamListsCorrespond(old.TypeParams(), newn.TypeParams()) {
   229  				return false
   230  			}
   231  		}
   232  	case *types.Basic:
   233  		if old.Obj().Pkg() != d.old {
   234  			// A named type from a package other than old never corresponds to a basic type.
   235  			return false
   236  		}
   237  	default:
   238  		// Only named and basic types can correspond.
   239  		return false
   240  	}
   241  	// If there is no correspondence, create one.
   242  	d.correspondMap.Set(old, new)
   243  	// Check that the corresponding types are compatible.
   244  	d.checkCompatibleDefined(oldname, old, new)
   245  	return true
   246  }
   247  
   248  func (d *differ) typeListsCorrespond(tl1, tl2 *types.TypeList) bool {
   249  	if tl1.Len() != tl2.Len() {
   250  		return false
   251  	}
   252  	for i := 0; i < tl1.Len(); i++ {
   253  		if !d.correspond(tl1.At(i), tl2.At(i)) {
   254  			return false
   255  		}
   256  	}
   257  	return true
   258  }
   259  
   260  // Two list of type parameters correspond if they are the same length, and
   261  // the constraints of corresponding type parameters correspond.
   262  func (d *differ) typeParamListsCorrespond(tps1, tps2 *types.TypeParamList) bool {
   263  	if tps1.Len() != tps2.Len() {
   264  		return false
   265  	}
   266  	for i := 0; i < tps1.Len(); i++ {
   267  		if !d.correspond(tps1.At(i).Constraint(), tps2.At(i).Constraint()) {
   268  			return false
   269  		}
   270  	}
   271  	return true
   272  }
   273  
   274  func (d *differ) sortedMethods(iface *types.Interface) []*types.Func {
   275  	ms := make([]*types.Func, iface.NumMethods())
   276  	for i := 0; i < iface.NumMethods(); i++ {
   277  		ms[i] = iface.Method(i)
   278  	}
   279  	sort.Slice(ms, func(i, j int) bool { return d.methodID(ms[i]) < d.methodID(ms[j]) })
   280  	return ms
   281  }
   282  
   283  func (d *differ) methodID(m *types.Func) string {
   284  	// If the method belongs to one of the two packages being compared, use
   285  	// just its name even if it's unexported. That lets us treat unexported names
   286  	// from the old and new packages as equal.
   287  	if m.Pkg() == d.old || m.Pkg() == d.new {
   288  		return m.Name()
   289  	}
   290  	return m.Id()
   291  }
   292  
   293  // Copied from the go/types package:
   294  
   295  // An ifacePair is a node in a stack of interface type pairs compared for identity.
   296  type ifacePair struct {
   297  	x, y *types.Interface
   298  	prev *ifacePair
   299  }
   300  
   301  func (p *ifacePair) identical(q *ifacePair) bool {
   302  	return p.x == q.x && p.y == q.y || p.x == q.y && p.y == q.x
   303  }