github.com/spread-ai/gqlgen@v0.0.0-20221124102857-a6c8ef538a1d/internal/code/compare.go (about)

     1  package code
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  )
     7  
     8  // CompatibleTypes isnt a strict comparison, it allows for pointer differences
     9  func CompatibleTypes(expected types.Type, actual types.Type) error {
    10  	// Special case to deal with pointer mismatches
    11  	{
    12  		expectedPtr, expectedIsPtr := expected.(*types.Pointer)
    13  		actualPtr, actualIsPtr := actual.(*types.Pointer)
    14  
    15  		if expectedIsPtr && actualIsPtr {
    16  			return CompatibleTypes(expectedPtr.Elem(), actualPtr.Elem())
    17  		}
    18  		if expectedIsPtr && !actualIsPtr {
    19  			return CompatibleTypes(expectedPtr.Elem(), actual)
    20  		}
    21  		if !expectedIsPtr && actualIsPtr {
    22  			return CompatibleTypes(expected, actualPtr.Elem())
    23  		}
    24  	}
    25  
    26  	switch expected := expected.(type) {
    27  	case *types.Slice:
    28  		if actual, ok := actual.(*types.Slice); ok {
    29  			return CompatibleTypes(expected.Elem(), actual.Elem())
    30  		}
    31  
    32  	case *types.Array:
    33  		if actual, ok := actual.(*types.Array); ok {
    34  			if expected.Len() != actual.Len() {
    35  				return fmt.Errorf("array length differs")
    36  			}
    37  
    38  			return CompatibleTypes(expected.Elem(), actual.Elem())
    39  		}
    40  
    41  	case *types.Basic:
    42  		if actual, ok := actual.(*types.Basic); ok {
    43  			if actual.Kind() != expected.Kind() {
    44  				return fmt.Errorf("basic kind differs, %s != %s", expected.Name(), actual.Name())
    45  			}
    46  
    47  			return nil
    48  		}
    49  
    50  	case *types.Struct:
    51  		if actual, ok := actual.(*types.Struct); ok {
    52  			if expected.NumFields() != actual.NumFields() {
    53  				return fmt.Errorf("number of struct fields differ")
    54  			}
    55  
    56  			for i := 0; i < expected.NumFields(); i++ {
    57  				if expected.Field(i).Name() != actual.Field(i).Name() {
    58  					return fmt.Errorf("struct field %d name differs, %s != %s", i, expected.Field(i).Name(), actual.Field(i).Name())
    59  				}
    60  				if err := CompatibleTypes(expected.Field(i).Type(), actual.Field(i).Type()); err != nil {
    61  					return err
    62  				}
    63  			}
    64  			return nil
    65  		}
    66  
    67  	case *types.Tuple:
    68  		if actual, ok := actual.(*types.Tuple); ok {
    69  			if expected.Len() != actual.Len() {
    70  				return fmt.Errorf("tuple length differs, %d != %d", expected.Len(), actual.Len())
    71  			}
    72  
    73  			for i := 0; i < expected.Len(); i++ {
    74  				if err := CompatibleTypes(expected.At(i).Type(), actual.At(i).Type()); err != nil {
    75  					return err
    76  				}
    77  			}
    78  
    79  			return nil
    80  		}
    81  
    82  	case *types.Signature:
    83  		if actual, ok := actual.(*types.Signature); ok {
    84  			if err := CompatibleTypes(expected.Params(), actual.Params()); err != nil {
    85  				return err
    86  			}
    87  			if err := CompatibleTypes(expected.Results(), actual.Results()); err != nil {
    88  				return err
    89  			}
    90  
    91  			return nil
    92  		}
    93  	case *types.Interface:
    94  		if actual, ok := actual.(*types.Interface); ok {
    95  			if expected.NumMethods() != actual.NumMethods() {
    96  				return fmt.Errorf("interface method count differs, %d != %d", expected.NumMethods(), actual.NumMethods())
    97  			}
    98  
    99  			for i := 0; i < expected.NumMethods(); i++ {
   100  				if expected.Method(i).Name() != actual.Method(i).Name() {
   101  					return fmt.Errorf("interface method %d name differs, %s != %s", i, expected.Method(i).Name(), actual.Method(i).Name())
   102  				}
   103  				if err := CompatibleTypes(expected.Method(i).Type(), actual.Method(i).Type()); err != nil {
   104  					return err
   105  				}
   106  			}
   107  
   108  			return nil
   109  		}
   110  
   111  	case *types.Map:
   112  		if actual, ok := actual.(*types.Map); ok {
   113  			if err := CompatibleTypes(expected.Key(), actual.Key()); err != nil {
   114  				return err
   115  			}
   116  
   117  			if err := CompatibleTypes(expected.Elem(), actual.Elem()); err != nil {
   118  				return err
   119  			}
   120  
   121  			return nil
   122  		}
   123  
   124  	case *types.Chan:
   125  		if actual, ok := actual.(*types.Chan); ok {
   126  			return CompatibleTypes(expected.Elem(), actual.Elem())
   127  		}
   128  
   129  	case *types.Named:
   130  		if actual, ok := actual.(*types.Named); ok {
   131  			if NormalizeVendor(expected.Obj().Pkg().Path()) != NormalizeVendor(actual.Obj().Pkg().Path()) {
   132  				return fmt.Errorf(
   133  					"package name of named type differs, %s != %s",
   134  					NormalizeVendor(expected.Obj().Pkg().Path()),
   135  					NormalizeVendor(actual.Obj().Pkg().Path()),
   136  				)
   137  			}
   138  
   139  			if expected.Obj().Name() != actual.Obj().Name() {
   140  				return fmt.Errorf(
   141  					"named type name differs, %s != %s",
   142  					NormalizeVendor(expected.Obj().Name()),
   143  					NormalizeVendor(actual.Obj().Name()),
   144  				)
   145  			}
   146  
   147  			return nil
   148  		}
   149  
   150  		// Before models are generated all missing references will be Invalid Basic references.
   151  		// lets assume these are valid too.
   152  		if actual, ok := actual.(*types.Basic); ok && actual.Kind() == types.Invalid {
   153  			return nil
   154  		}
   155  
   156  	default:
   157  		return fmt.Errorf("missing support for %T", expected)
   158  	}
   159  
   160  	return fmt.Errorf("type mismatch %T != %T", expected, actual)
   161  }