github.com/niko0xdev/gqlgen@v0.17.55-0.20240120102243-2ecff98c3e37/codegen/interface.go (about)

     1  package codegen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"sort"
     7  
     8  	"github.com/vektah/gqlparser/v2/ast"
     9  
    10  	"github.com/niko0xdev/gqlgen/codegen/config"
    11  )
    12  
    13  type Interface struct {
    14  	*ast.Definition
    15  	Type         types.Type
    16  	Implementors []InterfaceImplementor
    17  	InTypemap    bool
    18  }
    19  
    20  type InterfaceImplementor struct {
    21  	*ast.Definition
    22  
    23  	Type    types.Type
    24  	TakeRef bool
    25  }
    26  
    27  func (b *builder) buildInterface(typ *ast.Definition) (*Interface, error) {
    28  	obj, err := b.Binder.DefaultUserObject(typ.Name)
    29  	if err != nil {
    30  		panic(err)
    31  	}
    32  
    33  	i := &Interface{
    34  		Definition: typ,
    35  		Type:       obj,
    36  		InTypemap:  b.Config.Models.UserDefined(typ.Name),
    37  	}
    38  
    39  	interfaceType, err := findGoInterface(i.Type)
    40  	if interfaceType == nil || err != nil {
    41  		return nil, fmt.Errorf("%s is not an interface", i.Type)
    42  	}
    43  
    44  	// Sort so that more specific types are evaluated first.
    45  	implementors := b.Schema.GetPossibleTypes(typ)
    46  	sort.Slice(implementors, func(i, j int) bool {
    47  		return len(implementors[i].Interfaces) > len(implementors[j].Interfaces)
    48  	})
    49  
    50  	for _, implementor := range implementors {
    51  		obj, err := b.Binder.DefaultUserObject(implementor.Name)
    52  		if err != nil {
    53  			return nil, fmt.Errorf("%s has no backing go type", implementor.Name)
    54  		}
    55  
    56  		implementorType, err := findGoNamedType(obj)
    57  		if err != nil {
    58  			return nil, fmt.Errorf("can not find backing go type %s: %w", obj.String(), err)
    59  		} else if implementorType == nil {
    60  			return nil, fmt.Errorf("can not find backing go type %s", obj.String())
    61  		}
    62  
    63  		anyValid := false
    64  
    65  		// first check if the value receiver can be nil, eg can we type switch on case Thing:
    66  		if types.Implements(implementorType, interfaceType) {
    67  			i.Implementors = append(i.Implementors, InterfaceImplementor{
    68  				Definition: implementor,
    69  				Type:       obj,
    70  				TakeRef:    !types.IsInterface(obj),
    71  			})
    72  			anyValid = true
    73  		}
    74  
    75  		// then check if the pointer receiver can be nil, eg can we type switch on case *Thing:
    76  		if types.Implements(types.NewPointer(implementorType), interfaceType) {
    77  			i.Implementors = append(i.Implementors, InterfaceImplementor{
    78  				Definition: implementor,
    79  				Type:       types.NewPointer(obj),
    80  			})
    81  			anyValid = true
    82  		}
    83  
    84  		if !anyValid {
    85  			return nil, fmt.Errorf("%s does not satisfy the interface %s", implementorType.String(), i.Type.String())
    86  		}
    87  	}
    88  
    89  	return i, nil
    90  }
    91  
    92  func (i *InterfaceImplementor) CanBeNil() bool {
    93  	return config.IsNilable(i.Type)
    94  }