github.com/mstephano/gqlgen-schemagen@v0.0.0-20230113041936-dd2cd4ea46aa/plugin/federation/federation.go (about)

     1  package federation
     2  
     3  import (
     4  	_ "embed"
     5  	"fmt"
     6  	"sort"
     7  	"strings"
     8  
     9  	"github.com/vektah/gqlparser/v2/ast"
    10  
    11  	"github.com/mstephano/gqlgen-schemagen/codegen"
    12  	"github.com/mstephano/gqlgen-schemagen/codegen/config"
    13  	"github.com/mstephano/gqlgen-schemagen/codegen/templates"
    14  	"github.com/mstephano/gqlgen-schemagen/plugin"
    15  	"github.com/mstephano/gqlgen-schemagen/plugin/federation/fieldset"
    16  )
    17  
    18  //go:embed federation.gotpl
    19  var federationTemplate string
    20  
    21  type federation struct {
    22  	Entities []*Entity
    23  	Version  int
    24  }
    25  
    26  // New returns a federation plugin that injects
    27  // federated directives and types into the schema
    28  func New(version int) plugin.Plugin {
    29  	if version == 0 {
    30  		version = 1
    31  	}
    32  
    33  	return &federation{Version: version}
    34  }
    35  
    36  // Name returns the plugin name
    37  func (f *federation) Name() string {
    38  	return "federation"
    39  }
    40  
    41  // MutateConfig mutates the configuration
    42  func (f *federation) MutateConfig(cfg *config.Config) error {
    43  	builtins := config.TypeMap{
    44  		"_Service": {
    45  			Model: config.StringList{
    46  				"github.com/mstephano/gqlgen-schemagen/plugin/federation/fedruntime.Service",
    47  			},
    48  		},
    49  		"_Entity": {
    50  			Model: config.StringList{
    51  				"github.com/mstephano/gqlgen-schemagen/plugin/federation/fedruntime.Entity",
    52  			},
    53  		},
    54  		"Entity": {
    55  			Model: config.StringList{
    56  				"github.com/mstephano/gqlgen-schemagen/plugin/federation/fedruntime.Entity",
    57  			},
    58  		},
    59  		"_Any": {
    60  			Model: config.StringList{"github.com/mstephano/gqlgen-schemagen/graphql.Map"},
    61  		},
    62  	}
    63  
    64  	for typeName, entry := range builtins {
    65  		if cfg.Models.Exists(typeName) {
    66  			return fmt.Errorf("%v already exists which must be reserved when Federation is enabled", typeName)
    67  		}
    68  		cfg.Models[typeName] = entry
    69  	}
    70  	cfg.Directives["external"] = config.DirectiveConfig{SkipRuntime: true}
    71  	cfg.Directives["requires"] = config.DirectiveConfig{SkipRuntime: true}
    72  	cfg.Directives["provides"] = config.DirectiveConfig{SkipRuntime: true}
    73  	cfg.Directives["key"] = config.DirectiveConfig{SkipRuntime: true}
    74  	cfg.Directives["extends"] = config.DirectiveConfig{SkipRuntime: true}
    75  
    76  	// Federation 2 specific directives
    77  	if f.Version == 2 {
    78  		cfg.Directives["shareable"] = config.DirectiveConfig{SkipRuntime: true}
    79  		cfg.Directives["link"] = config.DirectiveConfig{SkipRuntime: true}
    80  		cfg.Directives["tag"] = config.DirectiveConfig{SkipRuntime: true}
    81  		cfg.Directives["override"] = config.DirectiveConfig{SkipRuntime: true}
    82  		cfg.Directives["inaccessible"] = config.DirectiveConfig{SkipRuntime: true}
    83  	}
    84  
    85  	return nil
    86  }
    87  
    88  func (f *federation) InjectSourceEarly() *ast.Source {
    89  	input := `
    90  	scalar _Any
    91  	scalar _FieldSet
    92  
    93  	directive @external on FIELD_DEFINITION
    94  	directive @requires(fields: _FieldSet!) on FIELD_DEFINITION
    95  	directive @provides(fields: _FieldSet!) on FIELD_DEFINITION
    96  	directive @extends on OBJECT | INTERFACE
    97  `
    98  	// add version-specific changes on key directive, as well as adding the new directives for federation 2
    99  	if f.Version == 1 {
   100  		input += `
   101  	directive @key(fields: _FieldSet!) repeatable on OBJECT | INTERFACE
   102  `
   103  	} else if f.Version == 2 {
   104  		input += `
   105  	directive @key(fields: _FieldSet!, resolvable: Boolean = true) repeatable on OBJECT | INTERFACE
   106  	directive @link(import: [String!], url: String!) repeatable on SCHEMA
   107  	directive @shareable on OBJECT | FIELD_DEFINITION
   108  	directive @tag(name: String!) repeatable on FIELD_DEFINITION | INTERFACE | OBJECT | UNION | ARGUMENT_DEFINITION | SCALAR | ENUM | ENUM_VALUE | INPUT_OBJECT | INPUT_FIELD_DEFINITION
   109  	directive @override(from: String!) on FIELD_DEFINITION
   110  	directive @inaccessible on SCALAR | OBJECT | FIELD_DEFINITION | ARGUMENT_DEFINITION | INTERFACE | UNION | ENUM | ENUM_VALUE | INPUT_OBJECT | INPUT_FIELD_DEFINITION
   111  `
   112  	}
   113  	return &ast.Source{
   114  		Name:    "federation/directives.graphql",
   115  		Input:   input,
   116  		BuiltIn: true,
   117  	}
   118  }
   119  
   120  // InjectSources creates a GraphQL Entity type with all
   121  // the fields that had the @key directive
   122  func (f *federation) InjectSourceLate(schema *ast.Schema) *ast.Source {
   123  	f.setEntities(schema)
   124  
   125  	var entities, resolvers, entityResolverInputDefinitions string
   126  	for i, e := range f.Entities {
   127  		if i != 0 {
   128  			entities += " | "
   129  		}
   130  		entities += e.Name
   131  
   132  		for _, r := range e.Resolvers {
   133  			if e.Multi {
   134  				if entityResolverInputDefinitions != "" {
   135  					entityResolverInputDefinitions += "\n\n"
   136  				}
   137  				entityResolverInputDefinitions += "input " + r.InputType + " {\n"
   138  				for _, keyField := range r.KeyFields {
   139  					entityResolverInputDefinitions += fmt.Sprintf("\t%s: %s\n", keyField.Field.ToGo(), keyField.Definition.Type.String())
   140  				}
   141  				entityResolverInputDefinitions += "}"
   142  				resolvers += fmt.Sprintf("\t%s(reps: [%s!]!): [%s]\n", r.ResolverName, r.InputType, e.Name)
   143  			} else {
   144  				resolverArgs := ""
   145  				for _, keyField := range r.KeyFields {
   146  					resolverArgs += fmt.Sprintf("%s: %s,", keyField.Field.ToGoPrivate(), keyField.Definition.Type.String())
   147  				}
   148  				resolvers += fmt.Sprintf("\t%s(%s): %s!\n", r.ResolverName, resolverArgs, e.Name)
   149  			}
   150  		}
   151  	}
   152  
   153  	var blocks []string
   154  	if entities != "" {
   155  		entities = `# a union of all types that use the @key directive
   156  union _Entity = ` + entities
   157  		blocks = append(blocks, entities)
   158  	}
   159  
   160  	// resolvers can be empty if a service defines only "empty
   161  	// extend" types.  This should be rare.
   162  	if resolvers != "" {
   163  		if entityResolverInputDefinitions != "" {
   164  			blocks = append(blocks, entityResolverInputDefinitions)
   165  		}
   166  		resolvers = `# fake type to build resolver interfaces for users to implement
   167  type Entity {
   168  	` + resolvers + `
   169  }`
   170  		blocks = append(blocks, resolvers)
   171  	}
   172  
   173  	_serviceTypeDef := `type _Service {
   174    sdl: String
   175  }`
   176  	blocks = append(blocks, _serviceTypeDef)
   177  
   178  	var additionalQueryFields string
   179  	// Quote from the Apollo Federation subgraph specification:
   180  	// If no types are annotated with the key directive, then the
   181  	// _Entity union and _entities field should be removed from the schema
   182  	if len(f.Entities) > 0 {
   183  		additionalQueryFields += `  _entities(representations: [_Any!]!): [_Entity]!
   184  `
   185  	}
   186  	// _service field is required in any case
   187  	additionalQueryFields += `  _service: _Service!`
   188  
   189  	extendTypeQueryDef := `extend type ` + schema.Query.Name + ` {
   190  ` + additionalQueryFields + `
   191  }`
   192  	blocks = append(blocks, extendTypeQueryDef)
   193  
   194  	return &ast.Source{
   195  		Name:    "federation/entity.graphql",
   196  		BuiltIn: true,
   197  		Input:   "\n" + strings.Join(blocks, "\n\n") + "\n",
   198  	}
   199  }
   200  
   201  func (f *federation) GenerateCode(data *codegen.Data) error {
   202  	if len(f.Entities) > 0 {
   203  		if data.Objects.ByName("Entity") != nil {
   204  			data.Objects.ByName("Entity").Root = true
   205  		}
   206  		for _, e := range f.Entities {
   207  			obj := data.Objects.ByName(e.Def.Name)
   208  
   209  			for _, r := range e.Resolvers {
   210  				// fill in types for key fields
   211  				//
   212  				for _, keyField := range r.KeyFields {
   213  					if len(keyField.Field) == 0 {
   214  						fmt.Println(
   215  							"skipping @key field " + keyField.Definition.Name + " in " + r.ResolverName + " in " + e.Def.Name,
   216  						)
   217  						continue
   218  					}
   219  					cgField := keyField.Field.TypeReference(obj, data.Objects)
   220  					keyField.Type = cgField.TypeReference
   221  				}
   222  			}
   223  
   224  			// fill in types for requires fields
   225  			//
   226  			for _, reqField := range e.Requires {
   227  				if len(reqField.Field) == 0 {
   228  					fmt.Println("skipping @requires field " + reqField.Name + " in " + e.Def.Name)
   229  					continue
   230  				}
   231  				cgField := reqField.Field.TypeReference(obj, data.Objects)
   232  				reqField.Type = cgField.TypeReference
   233  			}
   234  		}
   235  	}
   236  
   237  	return templates.Render(templates.Options{
   238  		PackageName:     data.Config.Federation.Package,
   239  		Filename:        data.Config.Federation.Filename,
   240  		Data:            f,
   241  		GeneratedHeader: true,
   242  		Packages:        data.Config.Packages,
   243  		Template:        federationTemplate,
   244  	})
   245  }
   246  
   247  func (f *federation) setEntities(schema *ast.Schema) {
   248  	for _, schemaType := range schema.Types {
   249  		keys, ok := isFederatedEntity(schemaType)
   250  		if !ok {
   251  			continue
   252  		}
   253  		e := &Entity{
   254  			Name:      schemaType.Name,
   255  			Def:       schemaType,
   256  			Resolvers: nil,
   257  			Requires:  nil,
   258  		}
   259  
   260  		// Let's process custom entity resolver settings.
   261  		dir := schemaType.Directives.ForName("entityResolver")
   262  		if dir != nil {
   263  			if dirArg := dir.Arguments.ForName("multi"); dirArg != nil {
   264  				if dirVal, err := dirArg.Value.Value(nil); err == nil {
   265  					e.Multi = dirVal.(bool)
   266  				}
   267  			}
   268  		}
   269  
   270  		// If our schema has a field with a type defined in
   271  		// another service, then we need to define an "empty
   272  		// extend" of that type in this service, so this service
   273  		// knows what the type is like.  But the graphql-server
   274  		// will never ask us to actually resolve this "empty
   275  		// extend", so we don't require a resolver function for
   276  		// it.  (Well, it will never ask in practice; it's
   277  		// unclear whether the spec guarantees this.  See
   278  		// https://github.com/apollographql/apollo-server/issues/3852
   279  		// ).  Example:
   280  		//    type MyType {
   281  		//       myvar: TypeDefinedInOtherService
   282  		//    }
   283  		//    // Federation needs this type, but
   284  		//    // it doesn't need a resolver for it!
   285  		//    extend TypeDefinedInOtherService @key(fields: "id") {
   286  		//       id: ID @external
   287  		//    }
   288  		if !e.allFieldsAreExternal(f.Version) {
   289  			for _, dir := range keys {
   290  				if len(dir.Arguments) > 2 {
   291  					panic("More than two arguments provided for @key declaration.")
   292  				}
   293  				var arg *ast.Argument
   294  
   295  				// since keys are able to now have multiple arguments, we need to check both possible for a possible @key(fields="" fields="")
   296  				for _, a := range dir.Arguments {
   297  					if a.Name == "fields" {
   298  						if arg != nil {
   299  							panic("More than one `fields` provided for @key declaration.")
   300  						}
   301  						arg = a
   302  					}
   303  				}
   304  
   305  				keyFieldSet := fieldset.New(arg.Value.Raw, nil)
   306  
   307  				keyFields := make([]*KeyField, len(keyFieldSet))
   308  				resolverFields := []string{}
   309  				for i, field := range keyFieldSet {
   310  					def := field.FieldDefinition(schemaType, schema)
   311  
   312  					if def == nil {
   313  						panic(fmt.Sprintf("no field for %v", field))
   314  					}
   315  
   316  					keyFields[i] = &KeyField{Definition: def, Field: field}
   317  					resolverFields = append(resolverFields, keyFields[i].Field.ToGo())
   318  				}
   319  
   320  				resolverFieldsToGo := schemaType.Name + "By" + strings.Join(resolverFields, "And")
   321  				var resolverName string
   322  				if e.Multi {
   323  					resolverFieldsToGo += "s" // Pluralize for better API readability
   324  					resolverName = fmt.Sprintf("findMany%s", resolverFieldsToGo)
   325  				} else {
   326  					resolverName = fmt.Sprintf("find%s", resolverFieldsToGo)
   327  				}
   328  
   329  				e.Resolvers = append(e.Resolvers, &EntityResolver{
   330  					ResolverName: resolverName,
   331  					KeyFields:    keyFields,
   332  					InputType:    resolverFieldsToGo + "Input",
   333  				})
   334  			}
   335  
   336  			e.Requires = []*Requires{}
   337  			for _, f := range schemaType.Fields {
   338  				dir := f.Directives.ForName("requires")
   339  				if dir == nil {
   340  					continue
   341  				}
   342  				if len(dir.Arguments) != 1 || dir.Arguments[0].Name != "fields" {
   343  					panic("Exactly one `fields` argument needed for @requires declaration.")
   344  				}
   345  				requiresFieldSet := fieldset.New(dir.Arguments[0].Value.Raw, nil)
   346  				for _, field := range requiresFieldSet {
   347  					e.Requires = append(e.Requires, &Requires{
   348  						Name:  field.ToGoPrivate(),
   349  						Field: field,
   350  					})
   351  				}
   352  			}
   353  		}
   354  		f.Entities = append(f.Entities, e)
   355  	}
   356  
   357  	// make sure order remains stable across multiple builds
   358  	sort.Slice(f.Entities, func(i, j int) bool {
   359  		return f.Entities[i].Name < f.Entities[j].Name
   360  	})
   361  }
   362  
   363  func isFederatedEntity(schemaType *ast.Definition) ([]*ast.Directive, bool) {
   364  	switch schemaType.Kind {
   365  	case ast.Object:
   366  		keys := schemaType.Directives.ForNames("key")
   367  		if len(keys) > 0 {
   368  			return keys, true
   369  		}
   370  	case ast.Interface:
   371  		// TODO: support @key and @extends for interfaces
   372  		if dir := schemaType.Directives.ForName("key"); dir != nil {
   373  			fmt.Printf("@key directive found on \"interface %s\". Will be ignored.\n", schemaType.Name)
   374  		}
   375  		if dir := schemaType.Directives.ForName("extends"); dir != nil {
   376  			panic(
   377  				fmt.Sprintf(
   378  					"@extends directive is not currently supported for interfaces, use \"extend interface %s\" instead.",
   379  					schemaType.Name,
   380  				))
   381  		}
   382  	default:
   383  		// ignore
   384  	}
   385  	return nil, false
   386  }