github.com/operandinc/gqlgen@v0.16.1/plugin/federation/federation.go (about)

     1  package federation
     2  
     3  import (
     4  	"fmt"
     5  	"sort"
     6  	"strings"
     7  
     8  	"github.com/vektah/gqlparser/v2/ast"
     9  
    10  	"github.com/operandinc/gqlgen/codegen"
    11  	"github.com/operandinc/gqlgen/codegen/config"
    12  	"github.com/operandinc/gqlgen/codegen/templates"
    13  	"github.com/operandinc/gqlgen/plugin"
    14  	"github.com/operandinc/gqlgen/plugin/federation/fieldset"
    15  )
    16  
    17  type federation struct {
    18  	Entities []*Entity
    19  }
    20  
    21  // New returns a federation plugin that injects
    22  // federated directives and types into the schema
    23  func New() plugin.Plugin {
    24  	return &federation{}
    25  }
    26  
    27  // Name returns the plugin name
    28  func (f *federation) Name() string {
    29  	return "federation"
    30  }
    31  
    32  // MutateConfig mutates the configuration
    33  func (f *federation) MutateConfig(cfg *config.Config) error {
    34  	builtins := config.TypeMap{
    35  		"_Service": {
    36  			Model: config.StringList{
    37  				"github.com/operandinc/gqlgen/plugin/federation/fedruntime.Service",
    38  			},
    39  		},
    40  		"_Entity": {
    41  			Model: config.StringList{
    42  				"github.com/operandinc/gqlgen/plugin/federation/fedruntime.Entity",
    43  			},
    44  		},
    45  		"Entity": {
    46  			Model: config.StringList{
    47  				"github.com/operandinc/gqlgen/plugin/federation/fedruntime.Entity",
    48  			},
    49  		},
    50  		"_Any": {
    51  			Model: config.StringList{"github.com/operandinc/gqlgen/graphql.Map"},
    52  		},
    53  	}
    54  	for typeName, entry := range builtins {
    55  		if cfg.Models.Exists(typeName) {
    56  			return fmt.Errorf("%v already exists which must be reserved when Federation is enabled", typeName)
    57  		}
    58  		cfg.Models[typeName] = entry
    59  	}
    60  	cfg.Directives["external"] = config.DirectiveConfig{SkipRuntime: true}
    61  	cfg.Directives["requires"] = config.DirectiveConfig{SkipRuntime: true}
    62  	cfg.Directives["provides"] = config.DirectiveConfig{SkipRuntime: true}
    63  	cfg.Directives["key"] = config.DirectiveConfig{SkipRuntime: true}
    64  	cfg.Directives["extends"] = config.DirectiveConfig{SkipRuntime: true}
    65  
    66  	return nil
    67  }
    68  
    69  func (f *federation) InjectSourceEarly() *ast.Source {
    70  	return &ast.Source{
    71  		Name: "federation/directives.graphql",
    72  		Input: `
    73  scalar _Any
    74  scalar _FieldSet
    75  
    76  directive @external on FIELD_DEFINITION
    77  directive @requires(fields: _FieldSet!) on FIELD_DEFINITION
    78  directive @provides(fields: _FieldSet!) on FIELD_DEFINITION
    79  directive @key(fields: _FieldSet!) repeatable on OBJECT | INTERFACE
    80  directive @extends on OBJECT | INTERFACE
    81  `,
    82  		BuiltIn: true,
    83  	}
    84  }
    85  
    86  // InjectSources creates a GraphQL Entity type with all
    87  // the fields that had the @key directive
    88  func (f *federation) InjectSourceLate(schema *ast.Schema) *ast.Source {
    89  	f.setEntities(schema)
    90  
    91  	var entities, resolvers, entityResolverInputDefinitions string
    92  	for i, e := range f.Entities {
    93  		if i != 0 {
    94  			entities += " | "
    95  		}
    96  		entities += e.Name
    97  
    98  		for _, r := range e.Resolvers {
    99  			if e.Multi {
   100  				if entityResolverInputDefinitions != "" {
   101  					entityResolverInputDefinitions += "\n\n"
   102  				}
   103  				entityResolverInputDefinitions += "input " + r.InputType + " {\n"
   104  				for _, keyField := range r.KeyFields {
   105  					entityResolverInputDefinitions += fmt.Sprintf("\t%s: %s\n", keyField.Field.ToGo(), keyField.Definition.Type.String())
   106  				}
   107  				entityResolverInputDefinitions += "}"
   108  				resolvers += fmt.Sprintf("\t%s(reps: [%s!]!): [%s]\n", r.ResolverName, r.InputType, e.Name)
   109  			} else {
   110  				resolverArgs := ""
   111  				for _, keyField := range r.KeyFields {
   112  					resolverArgs += fmt.Sprintf("%s: %s,", keyField.Field.ToGoPrivate(), keyField.Definition.Type.String())
   113  				}
   114  				resolvers += fmt.Sprintf("\t%s(%s): %s!\n", r.ResolverName, resolverArgs, e.Name)
   115  			}
   116  		}
   117  	}
   118  
   119  	var blocks []string
   120  	if entities != "" {
   121  		entities = `# a union of all types that use the @key directive
   122  union _Entity = ` + entities
   123  		blocks = append(blocks, entities)
   124  	}
   125  
   126  	// resolvers can be empty if a service defines only "empty
   127  	// extend" types.  This should be rare.
   128  	if resolvers != "" {
   129  		if entityResolverInputDefinitions != "" {
   130  			blocks = append(blocks, entityResolverInputDefinitions)
   131  		}
   132  		resolvers = `# fake type to build resolver interfaces for users to implement
   133  type Entity {
   134  	` + resolvers + `
   135  }`
   136  		blocks = append(blocks, resolvers)
   137  	}
   138  
   139  	_serviceTypeDef := `type _Service {
   140    sdl: String
   141  }`
   142  	blocks = append(blocks, _serviceTypeDef)
   143  
   144  	var additionalQueryFields string
   145  	// Quote from the Apollo Federation subgraph specification:
   146  	// If no types are annotated with the key directive, then the
   147  	// _Entity union and _entities field should be removed from the schema
   148  	if len(f.Entities) > 0 {
   149  		additionalQueryFields += `  _entities(representations: [_Any!]!): [_Entity]!
   150  `
   151  	}
   152  	// _service field is required in any case
   153  	additionalQueryFields += `  _service: _Service!`
   154  
   155  	extendTypeQueryDef := `extend type ` + schema.Query.Name + ` {
   156  ` + additionalQueryFields + `
   157  }`
   158  	blocks = append(blocks, extendTypeQueryDef)
   159  
   160  	return &ast.Source{
   161  		Name:    "federation/entity.graphql",
   162  		BuiltIn: true,
   163  		Input:   "\n" + strings.Join(blocks, "\n\n") + "\n",
   164  	}
   165  }
   166  
   167  // Entity represents a federated type
   168  // that was declared in the GQL schema.
   169  type Entity struct {
   170  	Name      string // The same name as the type declaration
   171  	Def       *ast.Definition
   172  	Resolvers []*EntityResolver
   173  	Requires  []*Requires
   174  	Multi     bool
   175  }
   176  
   177  type EntityResolver struct {
   178  	ResolverName string      // The resolver name, such as FindUserByID
   179  	KeyFields    []*KeyField // The fields declared in @key.
   180  	InputType    string      // The Go generated input type for multi entity resolvers
   181  }
   182  
   183  type KeyField struct {
   184  	Definition *ast.FieldDefinition
   185  	Field      fieldset.Field        // len > 1 for nested fields
   186  	Type       *config.TypeReference // The Go representation of that field type
   187  }
   188  
   189  // Requires represents an @requires clause
   190  type Requires struct {
   191  	Name  string                // the name of the field
   192  	Field fieldset.Field        // source Field, len > 1 for nested fields
   193  	Type  *config.TypeReference // The Go representation of that field type
   194  }
   195  
   196  func (e *Entity) allFieldsAreExternal() bool {
   197  	for _, field := range e.Def.Fields {
   198  		if field.Directives.ForName("external") == nil {
   199  			return false
   200  		}
   201  	}
   202  	return true
   203  }
   204  
   205  func (f *federation) GenerateCode(data *codegen.Data) error {
   206  	if len(f.Entities) > 0 {
   207  		if data.Objects.ByName("Entity") != nil {
   208  			data.Objects.ByName("Entity").Root = true
   209  		}
   210  		for _, e := range f.Entities {
   211  			obj := data.Objects.ByName(e.Def.Name)
   212  
   213  			for _, r := range e.Resolvers {
   214  				// fill in types for key fields
   215  				//
   216  				for _, keyField := range r.KeyFields {
   217  					if len(keyField.Field) == 0 {
   218  						fmt.Println(
   219  							"skipping @key field " + keyField.Definition.Name + " in " + r.ResolverName + " in " + e.Def.Name,
   220  						)
   221  						continue
   222  					}
   223  					cgField := keyField.Field.TypeReference(obj, data.Objects)
   224  					keyField.Type = cgField.TypeReference
   225  				}
   226  			}
   227  
   228  			// fill in types for requires fields
   229  			//
   230  			for _, reqField := range e.Requires {
   231  				if len(reqField.Field) == 0 {
   232  					fmt.Println("skipping @requires field " + reqField.Name + " in " + e.Def.Name)
   233  					continue
   234  				}
   235  				cgField := reqField.Field.TypeReference(obj, data.Objects)
   236  				reqField.Type = cgField.TypeReference
   237  			}
   238  		}
   239  	}
   240  
   241  	return templates.Render(templates.Options{
   242  		PackageName:     data.Config.Federation.Package,
   243  		Filename:        data.Config.Federation.Filename,
   244  		Data:            f,
   245  		GeneratedHeader: true,
   246  		Packages:        data.Config.Packages,
   247  	})
   248  }
   249  
   250  func (f *federation) setEntities(schema *ast.Schema) {
   251  	for _, schemaType := range schema.Types {
   252  		keys, ok := isFederatedEntity(schemaType)
   253  		if !ok {
   254  			continue
   255  		}
   256  		e := &Entity{
   257  			Name:      schemaType.Name,
   258  			Def:       schemaType,
   259  			Resolvers: nil,
   260  			Requires:  nil,
   261  		}
   262  
   263  		// Let's process custom entity resolver settings.
   264  		dir := schemaType.Directives.ForName("entityResolver")
   265  		if dir != nil {
   266  			if dirArg := dir.Arguments.ForName("multi"); dirArg != nil {
   267  				if dirVal, err := dirArg.Value.Value(nil); err == nil {
   268  					e.Multi = dirVal.(bool)
   269  				}
   270  			}
   271  		}
   272  
   273  		// If our schema has a field with a type defined in
   274  		// another service, then we need to define an "empty
   275  		// extend" of that type in this service, so this service
   276  		// knows what the type is like.  But the graphql-server
   277  		// will never ask us to actually resolve this "empty
   278  		// extend", so we don't require a resolver function for
   279  		// it.  (Well, it will never ask in practice; it's
   280  		// unclear whether the spec guarantees this.  See
   281  		// https://github.com/apollographql/apollo-server/issues/3852
   282  		// ).  Example:
   283  		//    type MyType {
   284  		//       myvar: TypeDefinedInOtherService
   285  		//    }
   286  		//    // Federation needs this type, but
   287  		//    // it doesn't need a resolver for it!
   288  		//    extend TypeDefinedInOtherService @key(fields: "id") {
   289  		//       id: ID @external
   290  		//    }
   291  		if !e.allFieldsAreExternal() {
   292  			for _, dir := range keys {
   293  				if len(dir.Arguments) != 1 || dir.Arguments[0].Name != "fields" {
   294  					panic("Exactly one `fields` argument needed for @key declaration.")
   295  				}
   296  				arg := dir.Arguments[0]
   297  				keyFieldSet := fieldset.New(arg.Value.Raw, nil)
   298  
   299  				keyFields := make([]*KeyField, len(keyFieldSet))
   300  				resolverFields := []string{}
   301  				for i, field := range keyFieldSet {
   302  					def := field.FieldDefinition(schemaType, schema)
   303  
   304  					if def == nil {
   305  						panic(fmt.Sprintf("no field for %v", field))
   306  					}
   307  
   308  					keyFields[i] = &KeyField{Definition: def, Field: field}
   309  					resolverFields = append(resolverFields, keyFields[i].Field.ToGo())
   310  				}
   311  
   312  				resolverFieldsToGo := schemaType.Name + "By" + strings.Join(resolverFields, "And")
   313  				var resolverName string
   314  				if e.Multi {
   315  					resolverFieldsToGo += "s" // Pluralize for better API readability
   316  					resolverName = fmt.Sprintf("findMany%s", resolverFieldsToGo)
   317  				} else {
   318  					resolverName = fmt.Sprintf("find%s", resolverFieldsToGo)
   319  				}
   320  
   321  				e.Resolvers = append(e.Resolvers, &EntityResolver{
   322  					ResolverName: resolverName,
   323  					KeyFields:    keyFields,
   324  					InputType:    resolverFieldsToGo + "Input",
   325  				})
   326  			}
   327  
   328  			e.Requires = []*Requires{}
   329  			for _, f := range schemaType.Fields {
   330  				dir := f.Directives.ForName("requires")
   331  				if dir == nil {
   332  					continue
   333  				}
   334  				if len(dir.Arguments) != 1 || dir.Arguments[0].Name != "fields" {
   335  					panic("Exactly one `fields` argument needed for @requires declaration.")
   336  				}
   337  				requiresFieldSet := fieldset.New(dir.Arguments[0].Value.Raw, nil)
   338  				for _, field := range requiresFieldSet {
   339  					e.Requires = append(e.Requires, &Requires{
   340  						Name:  field.ToGoPrivate(),
   341  						Field: field,
   342  					})
   343  				}
   344  			}
   345  		}
   346  		f.Entities = append(f.Entities, e)
   347  	}
   348  
   349  	// make sure order remains stable across multiple builds
   350  	sort.Slice(f.Entities, func(i, j int) bool {
   351  		return f.Entities[i].Name < f.Entities[j].Name
   352  	})
   353  }
   354  
   355  func isFederatedEntity(schemaType *ast.Definition) ([]*ast.Directive, bool) {
   356  	switch schemaType.Kind {
   357  	case ast.Object:
   358  		keys := schemaType.Directives.ForNames("key")
   359  		if len(keys) > 0 {
   360  			return keys, true
   361  		}
   362  	case ast.Interface:
   363  		// TODO: support @key and @extends for interfaces
   364  		if dir := schemaType.Directives.ForName("key"); dir != nil {
   365  			fmt.Printf("@key directive found on \"interface %s\". Will be ignored.\n", schemaType.Name)
   366  		}
   367  		if dir := schemaType.Directives.ForName("extends"); dir != nil {
   368  			panic(
   369  				fmt.Sprintf(
   370  					"@extends directive is not currently supported for interfaces, use \"extend interface %s\" instead.",
   371  					schemaType.Name,
   372  				))
   373  		}
   374  	default:
   375  		// ignore
   376  	}
   377  	return nil, false
   378  }