github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/internal/code/compare.go (about)

     1  package code
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  )
     7  
     8  // CompatibleTypes isnt a strict comparison, it allows for pointer differences
     9  func CompatibleTypes(expected types.Type, actual types.Type) error {
    10  	// Special case to deal with pointer mismatches
    11  	{
    12  		expectedPtr, expectedIsPtr := expected.(*types.Pointer)
    13  		actualPtr, actualIsPtr := actual.(*types.Pointer)
    14  
    15  		if expectedIsPtr && actualIsPtr {
    16  			return CompatibleTypes(expectedPtr.Elem(), actualPtr.Elem())
    17  		}
    18  		if expectedIsPtr && !actualIsPtr {
    19  			return CompatibleTypes(expectedPtr.Elem(), actual)
    20  		}
    21  		if !expectedIsPtr && actualIsPtr {
    22  			return CompatibleTypes(expected, actualPtr.Elem())
    23  		}
    24  	}
    25  
    26  	switch expected := expected.(type) {
    27  	case *types.Slice:
    28  		if actualSlice, ok := actual.(*types.Slice); ok {
    29  			return CompatibleTypes(expected.Elem(), actualSlice.Elem())
    30  		} else if actual, ok := actual.(*types.Named); ok {
    31  			if underlyingSlice, ok := actual.Underlying().(*types.Slice); ok {
    32  				return CompatibleTypes(expected.Elem(), underlyingSlice.Elem())
    33  			}
    34  		}
    35  
    36  	case *types.Array:
    37  		if actual, ok := actual.(*types.Array); ok {
    38  			if expected.Len() != actual.Len() {
    39  				return fmt.Errorf("array length differs")
    40  			}
    41  
    42  			return CompatibleTypes(expected.Elem(), actual.Elem())
    43  		}
    44  
    45  	case *types.Basic:
    46  		if actualBasic, ok := actual.(*types.Basic); ok {
    47  			if actualBasic.Kind() != expected.Kind() {
    48  				return fmt.Errorf("basic kind differs, %s != %s", expected.Name(), actualBasic.Name())
    49  			}
    50  
    51  			return nil
    52  		} else if actual, ok := actual.(*types.Named); ok {
    53  			if underlyingBasic, ok := actual.Underlying().(*types.Basic); ok {
    54  				return CompatibleTypes(expected, underlyingBasic)
    55  			}
    56  		}
    57  
    58  	case *types.Struct:
    59  		if actual, ok := actual.(*types.Struct); ok {
    60  			if expected.NumFields() != actual.NumFields() {
    61  				return fmt.Errorf("number of struct fields differ")
    62  			}
    63  
    64  			for i := 0; i < expected.NumFields(); i++ {
    65  				if expected.Field(i).Name() != actual.Field(i).Name() {
    66  					return fmt.Errorf("struct field %d name differs, %s != %s", i, expected.Field(i).Name(), actual.Field(i).Name())
    67  				}
    68  				if err := CompatibleTypes(expected.Field(i).Type(), actual.Field(i).Type()); err != nil {
    69  					return err
    70  				}
    71  			}
    72  			return nil
    73  		}
    74  
    75  	case *types.Tuple:
    76  		if actual, ok := actual.(*types.Tuple); ok {
    77  			if expected.Len() != actual.Len() {
    78  				return fmt.Errorf("tuple length differs, %d != %d", expected.Len(), actual.Len())
    79  			}
    80  
    81  			for i := 0; i < expected.Len(); i++ {
    82  				if err := CompatibleTypes(expected.At(i).Type(), actual.At(i).Type()); err != nil {
    83  					return err
    84  				}
    85  			}
    86  
    87  			return nil
    88  		}
    89  
    90  	case *types.Signature:
    91  		if actual, ok := actual.(*types.Signature); ok {
    92  			if err := CompatibleTypes(expected.Params(), actual.Params()); err != nil {
    93  				return err
    94  			}
    95  			if err := CompatibleTypes(expected.Results(), actual.Results()); err != nil {
    96  				return err
    97  			}
    98  
    99  			return nil
   100  		}
   101  	case *types.Interface:
   102  		if actual, ok := actual.(*types.Interface); ok {
   103  			if expected.NumMethods() != actual.NumMethods() {
   104  				return fmt.Errorf("interface method count differs, %d != %d", expected.NumMethods(), actual.NumMethods())
   105  			}
   106  
   107  			for i := 0; i < expected.NumMethods(); i++ {
   108  				if expected.Method(i).Name() != actual.Method(i).Name() {
   109  					return fmt.Errorf("interface method %d name differs, %s != %s", i, expected.Method(i).Name(), actual.Method(i).Name())
   110  				}
   111  				if err := CompatibleTypes(expected.Method(i).Type(), actual.Method(i).Type()); err != nil {
   112  					return err
   113  				}
   114  			}
   115  
   116  			return nil
   117  		}
   118  
   119  	case *types.Map:
   120  		if actualMap, ok := actual.(*types.Map); ok {
   121  			if err := CompatibleTypes(expected.Key(), actualMap.Key()); err != nil {
   122  				return err
   123  			}
   124  
   125  			if err := CompatibleTypes(expected.Elem(), actualMap.Elem()); err != nil {
   126  				return err
   127  			}
   128  
   129  			return nil
   130  		} else if actual, ok := actual.(*types.Named); ok {
   131  			if underlyingBasic, ok := actual.Underlying().(*types.Map); ok {
   132  				return CompatibleTypes(expected.Elem(), underlyingBasic.Elem())
   133  			}
   134  		}
   135  
   136  	case *types.Chan:
   137  		if actual, ok := actual.(*types.Chan); ok {
   138  			return CompatibleTypes(expected.Elem(), actual.Elem())
   139  		}
   140  
   141  	case *types.Named:
   142  		if actual, ok := actual.(*types.Named); ok {
   143  			if NormalizeVendor(expected.Obj().Pkg().Path()) != NormalizeVendor(actual.Obj().Pkg().Path()) {
   144  				return fmt.Errorf(
   145  					"package name of named type differs, %s != %s",
   146  					NormalizeVendor(expected.Obj().Pkg().Path()),
   147  					NormalizeVendor(actual.Obj().Pkg().Path()),
   148  				)
   149  			}
   150  
   151  			if expected.Obj().Name() != actual.Obj().Name() {
   152  				return fmt.Errorf(
   153  					"named type name differs, %s != %s",
   154  					NormalizeVendor(expected.Obj().Name()),
   155  					NormalizeVendor(actual.Obj().Name()),
   156  				)
   157  			}
   158  
   159  			return nil
   160  		}
   161  
   162  		// Before models are generated all missing references will be Invalid Basic references.
   163  		// lets assume these are valid too.
   164  		if actual, ok := actual.(*types.Basic); ok && actual.Kind() == types.Invalid {
   165  			return nil
   166  		}
   167  
   168  	default:
   169  		return fmt.Errorf("missing support for %T", expected)
   170  	}
   171  
   172  	return fmt.Errorf("type mismatch %T != %T", expected, actual)
   173  }