github.com/spread-ai/gqlgen@v0.0.0-20221124102857-a6c8ef538a1d/plugin/federation/federation.go (about)

     1  package federation
     2  
     3  import (
     4  	_ "embed"
     5  	"fmt"
     6  	"sort"
     7  	"strings"
     8  
     9  	"github.com/vektah/gqlparser/v2/ast"
    10  
    11  	"github.com/spread-ai/gqlgen/codegen"
    12  	"github.com/spread-ai/gqlgen/codegen/config"
    13  	"github.com/spread-ai/gqlgen/codegen/templates"
    14  	"github.com/spread-ai/gqlgen/plugin"
    15  	"github.com/spread-ai/gqlgen/plugin/federation/fieldset"
    16  )
    17  
    18  //go:embed federation.gotpl
    19  var federationTemplate string
    20  
    21  type federation struct {
    22  	Entities []*Entity
    23  	Version  int
    24  }
    25  
    26  // New returns a federation plugin that injects
    27  // federated directives and types into the schema
    28  func New(version int) plugin.Plugin {
    29  	if version == 0 {
    30  		version = 1
    31  	}
    32  
    33  	return &federation{Version: version}
    34  }
    35  
    36  // Name returns the plugin name
    37  func (f *federation) Name() string {
    38  	return "federation"
    39  }
    40  
    41  // MutateConfig mutates the configuration
    42  func (f *federation) MutateConfig(cfg *config.Config) error {
    43  	builtins := config.TypeMap{
    44  		"_Service": {
    45  			Model: config.StringList{
    46  				"github.com/spread-ai/gqlgen/plugin/federation/fedruntime.Service",
    47  			},
    48  		},
    49  		"_Entity": {
    50  			Model: config.StringList{
    51  				"github.com/spread-ai/gqlgen/plugin/federation/fedruntime.Entity",
    52  			},
    53  		},
    54  		"Entity": {
    55  			Model: config.StringList{
    56  				"github.com/spread-ai/gqlgen/plugin/federation/fedruntime.Entity",
    57  			},
    58  		},
    59  		"_Any": {
    60  			Model: config.StringList{"github.com/spread-ai/gqlgen/graphql.Map"},
    61  		},
    62  	}
    63  
    64  	for typeName, entry := range builtins {
    65  		if cfg.Models.Exists(typeName) {
    66  			return fmt.Errorf("%v already exists which must be reserved when Federation is enabled", typeName)
    67  		}
    68  		cfg.Models[typeName] = entry
    69  	}
    70  	cfg.Directives["external"] = config.DirectiveConfig{SkipRuntime: true}
    71  	cfg.Directives["requires"] = config.DirectiveConfig{SkipRuntime: true}
    72  	cfg.Directives["provides"] = config.DirectiveConfig{SkipRuntime: true}
    73  	cfg.Directives["key"] = config.DirectiveConfig{SkipRuntime: true}
    74  	cfg.Directives["extends"] = config.DirectiveConfig{SkipRuntime: true}
    75  
    76  	// Federation 2 specific directives
    77  	if f.Version == 2 {
    78  		cfg.Directives["shareable"] = config.DirectiveConfig{SkipRuntime: true}
    79  		cfg.Directives["link"] = config.DirectiveConfig{SkipRuntime: true}
    80  		cfg.Directives["tag"] = config.DirectiveConfig{SkipRuntime: true}
    81  		cfg.Directives["override"] = config.DirectiveConfig{SkipRuntime: true}
    82  		cfg.Directives["inaccessible"] = config.DirectiveConfig{SkipRuntime: true}
    83  	}
    84  
    85  	return nil
    86  }
    87  
    88  func (f *federation) InjectSourceEarly() *ast.Source {
    89  	input := `
    90  	scalar _Any
    91  	scalar _FieldSet
    92  
    93  	directive @external on FIELD_DEFINITION
    94  	directive @requires(fields: _FieldSet!) on FIELD_DEFINITION
    95  	directive @provides(fields: _FieldSet!) on FIELD_DEFINITION
    96  	directive @extends on OBJECT | INTERFACE
    97  `
    98  	// add version-specific changes on key directive, as well as adding the new directives for federation 2
    99  	if f.Version == 1 {
   100  		input += `
   101  	directive @key(fields: _FieldSet!) repeatable on OBJECT | INTERFACE
   102  `
   103  	} else if f.Version == 2 {
   104  		input += `
   105  	directive @key(fields: _FieldSet!, resolvable: Boolean = true) repeatable on OBJECT | INTERFACE
   106  	directive @link(import: [String!], url: String!) repeatable on SCHEMA
   107  	directive @shareable on OBJECT | FIELD_DEFINITION
   108  	directive @tag(name: String!) repeatable on FIELD_DEFINITION | INTERFACE | OBJECT | UNION | ARGUMENT_DEFINITION | SCALAR | ENUM | ENUM_VALUE | INPUT_OBJECT | INPUT_FIELD_DEFINITION
   109  	directive @override(from: String!) on FIELD_DEFINITION
   110  	directive @inaccessible on SCALAR | OBJECT | FIELD_DEFINITION | ARGUMENT_DEFINITION | INTERFACE | UNION | ENUM | ENUM_VALUE | INPUT_OBJECT | INPUT_FIELD_DEFINITION
   111  `
   112  	}
   113  	return &ast.Source{
   114  		Name:    "federation/directives.graphql",
   115  		Input:   input,
   116  		BuiltIn: true,
   117  	}
   118  }
   119  
   120  // InjectSources creates a GraphQL Entity type with all
   121  // the fields that had the @key directive
   122  func (f *federation) InjectSourceLate(schema *ast.Schema) *ast.Source {
   123  	f.setEntities(schema)
   124  
   125  	var entities, resolvers, entityResolverInputDefinitions string
   126  	for i, e := range f.Entities {
   127  		if i != 0 {
   128  			entities += " | "
   129  		}
   130  		entities += e.Name
   131  
   132  		for _, r := range e.Resolvers {
   133  			if e.Multi {
   134  				if entityResolverInputDefinitions != "" {
   135  					entityResolverInputDefinitions += "\n\n"
   136  				}
   137  				entityResolverInputDefinitions += "input " + r.InputType + " {\n"
   138  				for _, keyField := range r.KeyFields {
   139  					entityResolverInputDefinitions += fmt.Sprintf("\t%s: %s\n", keyField.Field.ToGo(), keyField.Definition.Type.String())
   140  				}
   141  				entityResolverInputDefinitions += "}"
   142  				resolvers += fmt.Sprintf("\t%s(reps: [%s!]!): [%s]\n", r.ResolverName, r.InputType, e.Name)
   143  			} else {
   144  				resolverArgs := ""
   145  				for _, keyField := range r.KeyFields {
   146  					resolverArgs += fmt.Sprintf("%s: %s,", keyField.Field.ToGoPrivate(), keyField.Definition.Type.String())
   147  				}
   148  				resolvers += fmt.Sprintf("\t%s(%s): %s!\n", r.ResolverName, resolverArgs, e.Name)
   149  			}
   150  		}
   151  	}
   152  
   153  	var blocks []string
   154  	if entities != "" {
   155  		entities = `# a union of all types that use the @key directive
   156  union _Entity = ` + entities
   157  		blocks = append(blocks, entities)
   158  	}
   159  
   160  	// resolvers can be empty if a service defines only "empty
   161  	// extend" types.  This should be rare.
   162  	if resolvers != "" {
   163  		if entityResolverInputDefinitions != "" {
   164  			blocks = append(blocks, entityResolverInputDefinitions)
   165  		}
   166  		resolvers = `# fake type to build resolver interfaces for users to implement
   167  type Entity {
   168  	` + resolvers + `
   169  }`
   170  		blocks = append(blocks, resolvers)
   171  	}
   172  
   173  	_serviceTypeDef := `type _Service {
   174    sdl: String
   175  }`
   176  	blocks = append(blocks, _serviceTypeDef)
   177  
   178  	var additionalQueryFields string
   179  	// Quote from the Apollo Federation subgraph specification:
   180  	// If no types are annotated with the key directive, then the
   181  	// _Entity union and _entities field should be removed from the schema
   182  	if len(f.Entities) > 0 {
   183  		additionalQueryFields += `  _entities(representations: [_Any!]!): [_Entity]!
   184  `
   185  	}
   186  	// _service field is required in any case
   187  	additionalQueryFields += `  _service: _Service!`
   188  
   189  	extendTypeQueryDef := `extend type ` + schema.Query.Name + ` {
   190  ` + additionalQueryFields + `
   191  }`
   192  	blocks = append(blocks, extendTypeQueryDef)
   193  
   194  	return &ast.Source{
   195  		Name:    "federation/entity.graphql",
   196  		BuiltIn: true,
   197  		Input:   "\n" + strings.Join(blocks, "\n\n") + "\n",
   198  	}
   199  }
   200  
   201  // Entity represents a federated type
   202  // that was declared in the GQL schema.
   203  type Entity struct {
   204  	Name      string // The same name as the type declaration
   205  	Def       *ast.Definition
   206  	Resolvers []*EntityResolver
   207  	Requires  []*Requires
   208  	Multi     bool
   209  }
   210  
   211  type EntityResolver struct {
   212  	ResolverName string      // The resolver name, such as FindUserByID
   213  	KeyFields    []*KeyField // The fields declared in @key.
   214  	InputType    string      // The Go generated input type for multi entity resolvers
   215  }
   216  
   217  type KeyField struct {
   218  	Definition *ast.FieldDefinition
   219  	Field      fieldset.Field        // len > 1 for nested fields
   220  	Type       *config.TypeReference // The Go representation of that field type
   221  }
   222  
   223  // Requires represents an @requires clause
   224  type Requires struct {
   225  	Name  string                // the name of the field
   226  	Field fieldset.Field        // source Field, len > 1 for nested fields
   227  	Type  *config.TypeReference // The Go representation of that field type
   228  }
   229  
   230  func (e *Entity) allFieldsAreExternal() bool {
   231  	for _, field := range e.Def.Fields {
   232  		if field.Directives.ForName("external") == nil {
   233  			return false
   234  		}
   235  	}
   236  	return true
   237  }
   238  
   239  func (f *federation) GenerateCode(data *codegen.Data) error {
   240  	if len(f.Entities) > 0 {
   241  		if data.Objects.ByName("Entity") != nil {
   242  			data.Objects.ByName("Entity").Root = true
   243  		}
   244  		for _, e := range f.Entities {
   245  			obj := data.Objects.ByName(e.Def.Name)
   246  
   247  			for _, r := range e.Resolvers {
   248  				// fill in types for key fields
   249  				//
   250  				for _, keyField := range r.KeyFields {
   251  					if len(keyField.Field) == 0 {
   252  						fmt.Println(
   253  							"skipping @key field " + keyField.Definition.Name + " in " + r.ResolverName + " in " + e.Def.Name,
   254  						)
   255  						continue
   256  					}
   257  					cgField := keyField.Field.TypeReference(obj, data.Objects)
   258  					keyField.Type = cgField.TypeReference
   259  				}
   260  			}
   261  
   262  			// fill in types for requires fields
   263  			//
   264  			for _, reqField := range e.Requires {
   265  				if len(reqField.Field) == 0 {
   266  					fmt.Println("skipping @requires field " + reqField.Name + " in " + e.Def.Name)
   267  					continue
   268  				}
   269  				cgField := reqField.Field.TypeReference(obj, data.Objects)
   270  				reqField.Type = cgField.TypeReference
   271  			}
   272  		}
   273  	}
   274  
   275  	return templates.Render(templates.Options{
   276  		PackageName:     data.Config.Federation.Package,
   277  		Filename:        data.Config.Federation.Filename,
   278  		Data:            f,
   279  		GeneratedHeader: true,
   280  		Packages:        data.Config.Packages,
   281  		Template:        federationTemplate,
   282  	})
   283  }
   284  
   285  func (f *federation) setEntities(schema *ast.Schema) {
   286  	for _, schemaType := range schema.Types {
   287  		keys, ok := isFederatedEntity(schemaType)
   288  		if !ok {
   289  			continue
   290  		}
   291  		e := &Entity{
   292  			Name:      schemaType.Name,
   293  			Def:       schemaType,
   294  			Resolvers: nil,
   295  			Requires:  nil,
   296  		}
   297  
   298  		// Let's process custom entity resolver settings.
   299  		dir := schemaType.Directives.ForName("entityResolver")
   300  		if dir != nil {
   301  			if dirArg := dir.Arguments.ForName("multi"); dirArg != nil {
   302  				if dirVal, err := dirArg.Value.Value(nil); err == nil {
   303  					e.Multi = dirVal.(bool)
   304  				}
   305  			}
   306  		}
   307  
   308  		// If our schema has a field with a type defined in
   309  		// another service, then we need to define an "empty
   310  		// extend" of that type in this service, so this service
   311  		// knows what the type is like.  But the graphql-server
   312  		// will never ask us to actually resolve this "empty
   313  		// extend", so we don't require a resolver function for
   314  		// it.  (Well, it will never ask in practice; it's
   315  		// unclear whether the spec guarantees this.  See
   316  		// https://github.com/apollographql/apollo-server/issues/3852
   317  		// ).  Example:
   318  		//    type MyType {
   319  		//       myvar: TypeDefinedInOtherService
   320  		//    }
   321  		//    // Federation needs this type, but
   322  		//    // it doesn't need a resolver for it!
   323  		//    extend TypeDefinedInOtherService @key(fields: "id") {
   324  		//       id: ID @external
   325  		//    }
   326  		if !e.allFieldsAreExternal() {
   327  			for _, dir := range keys {
   328  				if len(dir.Arguments) > 2 {
   329  					panic("More than two arguments provided for @key declaration.")
   330  				}
   331  				var arg *ast.Argument
   332  
   333  				// since keys are able to now have multiple arguments, we need to check both possible for a possible @key(fields="" fields="")
   334  				for _, a := range dir.Arguments {
   335  					if a.Name == "fields" {
   336  						if arg != nil {
   337  							panic("More than one `fields` provided for @key declaration.")
   338  						}
   339  						arg = a
   340  					}
   341  				}
   342  
   343  				keyFieldSet := fieldset.New(arg.Value.Raw, nil)
   344  
   345  				keyFields := make([]*KeyField, len(keyFieldSet))
   346  				resolverFields := []string{}
   347  				for i, field := range keyFieldSet {
   348  					def := field.FieldDefinition(schemaType, schema)
   349  
   350  					if def == nil {
   351  						panic(fmt.Sprintf("no field for %v", field))
   352  					}
   353  
   354  					keyFields[i] = &KeyField{Definition: def, Field: field}
   355  					resolverFields = append(resolverFields, keyFields[i].Field.ToGo())
   356  				}
   357  
   358  				resolverFieldsToGo := schemaType.Name + "By" + strings.Join(resolverFields, "And")
   359  				var resolverName string
   360  				if e.Multi {
   361  					resolverFieldsToGo += "s" // Pluralize for better API readability
   362  					resolverName = fmt.Sprintf("findMany%s", resolverFieldsToGo)
   363  				} else {
   364  					resolverName = fmt.Sprintf("find%s", resolverFieldsToGo)
   365  				}
   366  
   367  				e.Resolvers = append(e.Resolvers, &EntityResolver{
   368  					ResolverName: resolverName,
   369  					KeyFields:    keyFields,
   370  					InputType:    resolverFieldsToGo + "Input",
   371  				})
   372  			}
   373  
   374  			e.Requires = []*Requires{}
   375  			for _, f := range schemaType.Fields {
   376  				dir := f.Directives.ForName("requires")
   377  				if dir == nil {
   378  					continue
   379  				}
   380  				if len(dir.Arguments) != 1 || dir.Arguments[0].Name != "fields" {
   381  					panic("Exactly one `fields` argument needed for @requires declaration.")
   382  				}
   383  				requiresFieldSet := fieldset.New(dir.Arguments[0].Value.Raw, nil)
   384  				for _, field := range requiresFieldSet {
   385  					e.Requires = append(e.Requires, &Requires{
   386  						Name:  field.ToGoPrivate(),
   387  						Field: field,
   388  					})
   389  				}
   390  			}
   391  		}
   392  		f.Entities = append(f.Entities, e)
   393  	}
   394  
   395  	// make sure order remains stable across multiple builds
   396  	sort.Slice(f.Entities, func(i, j int) bool {
   397  		return f.Entities[i].Name < f.Entities[j].Name
   398  	})
   399  }
   400  
   401  func isFederatedEntity(schemaType *ast.Definition) ([]*ast.Directive, bool) {
   402  	switch schemaType.Kind {
   403  	case ast.Object:
   404  		keys := schemaType.Directives.ForNames("key")
   405  		if len(keys) > 0 {
   406  			return keys, true
   407  		}
   408  	case ast.Interface:
   409  		// TODO: support @key and @extends for interfaces
   410  		if dir := schemaType.Directives.ForName("key"); dir != nil {
   411  			fmt.Printf("@key directive found on \"interface %s\". Will be ignored.\n", schemaType.Name)
   412  		}
   413  		if dir := schemaType.Directives.ForName("extends"); dir != nil {
   414  			panic(
   415  				fmt.Sprintf(
   416  					"@extends directive is not currently supported for interfaces, use \"extend interface %s\" instead.",
   417  					schemaType.Name,
   418  				))
   419  		}
   420  	default:
   421  		// ignore
   422  	}
   423  	return nil, false
   424  }