github.com/kerryoscer/gqlgen@v0.17.29/plugin/federation/fieldset/fieldset.go (about)

     1  package fieldset
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"github.com/kerryoscer/gqlgen/codegen"
     8  	"github.com/kerryoscer/gqlgen/codegen/templates"
     9  	"github.com/vektah/gqlparser/v2/ast"
    10  )
    11  
    12  // Set represents a FieldSet that is used in federation directives @key and @requires.
    13  // Would be happier to reuse FieldSet parsing from gqlparser, but this suits for now.
    14  type Set []Field
    15  
    16  // Field represents a single field in a FieldSet
    17  type Field []string
    18  
    19  // New parses a FieldSet string into a TinyFieldSet.
    20  func New(raw string, prefix []string) Set {
    21  	if !strings.Contains(raw, "{") {
    22  		return parseUnnestedKeyFieldSet(raw, prefix)
    23  	}
    24  
    25  	var (
    26  		ret       = Set{}
    27  		subPrefix = prefix
    28  	)
    29  	before, during, after := extractSubs(raw)
    30  
    31  	if before != "" {
    32  		befores := New(before, prefix)
    33  		if len(befores) > 0 {
    34  			subPrefix = befores[len(befores)-1]
    35  			ret = append(ret, befores[:len(befores)-1]...)
    36  		}
    37  	}
    38  	if during != "" {
    39  		ret = append(ret, New(during, subPrefix)...)
    40  	}
    41  	if after != "" {
    42  		ret = append(ret, New(after, prefix)...)
    43  	}
    44  	return ret
    45  }
    46  
    47  // FieldDefinition looks up a field in the type.
    48  func (f Field) FieldDefinition(schemaType *ast.Definition, schema *ast.Schema) *ast.FieldDefinition {
    49  	objType := schemaType
    50  	def := objType.Fields.ForName(f[0])
    51  
    52  	for _, part := range f[1:] {
    53  		if objType.Kind != ast.Object {
    54  			panic(fmt.Sprintf(`invalid sub-field reference "%s" in %v: `, objType.Name, f))
    55  		}
    56  		x := def.Type.Name()
    57  		objType = schema.Types[x]
    58  		if objType == nil {
    59  			panic("invalid schema type: " + x)
    60  		}
    61  		def = objType.Fields.ForName(part)
    62  	}
    63  	if def == nil {
    64  		return nil
    65  	}
    66  	ret := *def // shallow copy
    67  	ret.Name = f.ToGoPrivate()
    68  
    69  	return &ret
    70  }
    71  
    72  // TypeReference looks up the type of a field.
    73  func (f Field) TypeReference(obj *codegen.Object, objects codegen.Objects) *codegen.Field {
    74  	var def *codegen.Field
    75  
    76  	for _, part := range f {
    77  		def = fieldByName(obj, part)
    78  		if def == nil {
    79  			panic("unable to find field " + f[0])
    80  		}
    81  		obj = objects.ByName(def.TypeReference.Definition.Name)
    82  	}
    83  	return def
    84  }
    85  
    86  // ToGo converts a (possibly nested) field into a proper public Go name.
    87  func (f Field) ToGo() string {
    88  	var ret string
    89  
    90  	for _, field := range f {
    91  		ret += templates.ToGo(field)
    92  	}
    93  	return ret
    94  }
    95  
    96  // ToGoPrivate converts a (possibly nested) field into a proper private Go name.
    97  func (f Field) ToGoPrivate() string {
    98  	var ret string
    99  
   100  	for i, field := range f {
   101  		if i == 0 {
   102  			ret += templates.ToGoPrivate(field)
   103  			continue
   104  		}
   105  		ret += templates.ToGo(field)
   106  	}
   107  	return ret
   108  }
   109  
   110  // Join concatenates the field parts with a string separator between. Useful in templates.
   111  func (f Field) Join(str string) string {
   112  	return strings.Join(f, str)
   113  }
   114  
   115  // JoinGo concatenates the Go name of field parts with a string separator between. Useful in templates.
   116  func (f Field) JoinGo(str string) string {
   117  	strs := []string{}
   118  
   119  	for _, s := range f {
   120  		strs = append(strs, templates.ToGo(s))
   121  	}
   122  	return strings.Join(strs, str)
   123  }
   124  
   125  func (f Field) LastIndex() int {
   126  	return len(f) - 1
   127  }
   128  
   129  // local functions
   130  
   131  // parseUnnestedKeyFieldSet // handles simple case where none of the fields are nested.
   132  func parseUnnestedKeyFieldSet(raw string, prefix []string) Set {
   133  	ret := Set{}
   134  
   135  	for _, s := range strings.Fields(raw) {
   136  		next := append(prefix[:], s) //nolint:gocritic // slicing out on purpose
   137  		ret = append(ret, next)
   138  	}
   139  	return ret
   140  }
   141  
   142  // extractSubs splits out and trims sub-expressions from before, inside, and after "{}".
   143  func extractSubs(str string) (string, string, string) {
   144  	start := strings.Index(str, "{")
   145  	end := matchingBracketIndex(str, start)
   146  
   147  	if start < 0 || end < 0 {
   148  		panic("invalid key fieldSet: " + str)
   149  	}
   150  	return strings.TrimSpace(str[:start]), strings.TrimSpace(str[start+1 : end]), strings.TrimSpace(str[end+1:])
   151  }
   152  
   153  // matchingBracketIndex returns the index of the closing bracket, assuming an open bracket at start.
   154  func matchingBracketIndex(str string, start int) int {
   155  	if start < 0 || len(str) <= start+1 {
   156  		return -1
   157  	}
   158  	var depth int
   159  
   160  	for i, c := range str[start+1:] {
   161  		switch c {
   162  		case '{':
   163  			depth++
   164  		case '}':
   165  			if depth == 0 {
   166  				return start + 1 + i
   167  			}
   168  			depth--
   169  		}
   170  	}
   171  	return -1
   172  }
   173  
   174  func fieldByName(obj *codegen.Object, name string) *codegen.Field {
   175  	for _, field := range obj.Fields {
   176  		if field.Name == name {
   177  			return field
   178  		}
   179  	}
   180  	return nil
   181  }