github.com/kerryoscer/gqlgen@v0.17.29/plugin/federation/fieldset/fieldset.go (about) 1 package fieldset 2 3 import ( 4 "fmt" 5 "strings" 6 7 "github.com/kerryoscer/gqlgen/codegen" 8 "github.com/kerryoscer/gqlgen/codegen/templates" 9 "github.com/vektah/gqlparser/v2/ast" 10 ) 11 12 // Set represents a FieldSet that is used in federation directives @key and @requires. 13 // Would be happier to reuse FieldSet parsing from gqlparser, but this suits for now. 14 type Set []Field 15 16 // Field represents a single field in a FieldSet 17 type Field []string 18 19 // New parses a FieldSet string into a TinyFieldSet. 20 func New(raw string, prefix []string) Set { 21 if !strings.Contains(raw, "{") { 22 return parseUnnestedKeyFieldSet(raw, prefix) 23 } 24 25 var ( 26 ret = Set{} 27 subPrefix = prefix 28 ) 29 before, during, after := extractSubs(raw) 30 31 if before != "" { 32 befores := New(before, prefix) 33 if len(befores) > 0 { 34 subPrefix = befores[len(befores)-1] 35 ret = append(ret, befores[:len(befores)-1]...) 36 } 37 } 38 if during != "" { 39 ret = append(ret, New(during, subPrefix)...) 40 } 41 if after != "" { 42 ret = append(ret, New(after, prefix)...) 43 } 44 return ret 45 } 46 47 // FieldDefinition looks up a field in the type. 48 func (f Field) FieldDefinition(schemaType *ast.Definition, schema *ast.Schema) *ast.FieldDefinition { 49 objType := schemaType 50 def := objType.Fields.ForName(f[0]) 51 52 for _, part := range f[1:] { 53 if objType.Kind != ast.Object { 54 panic(fmt.Sprintf(`invalid sub-field reference "%s" in %v: `, objType.Name, f)) 55 } 56 x := def.Type.Name() 57 objType = schema.Types[x] 58 if objType == nil { 59 panic("invalid schema type: " + x) 60 } 61 def = objType.Fields.ForName(part) 62 } 63 if def == nil { 64 return nil 65 } 66 ret := *def // shallow copy 67 ret.Name = f.ToGoPrivate() 68 69 return &ret 70 } 71 72 // TypeReference looks up the type of a field. 73 func (f Field) TypeReference(obj *codegen.Object, objects codegen.Objects) *codegen.Field { 74 var def *codegen.Field 75 76 for _, part := range f { 77 def = fieldByName(obj, part) 78 if def == nil { 79 panic("unable to find field " + f[0]) 80 } 81 obj = objects.ByName(def.TypeReference.Definition.Name) 82 } 83 return def 84 } 85 86 // ToGo converts a (possibly nested) field into a proper public Go name. 87 func (f Field) ToGo() string { 88 var ret string 89 90 for _, field := range f { 91 ret += templates.ToGo(field) 92 } 93 return ret 94 } 95 96 // ToGoPrivate converts a (possibly nested) field into a proper private Go name. 97 func (f Field) ToGoPrivate() string { 98 var ret string 99 100 for i, field := range f { 101 if i == 0 { 102 ret += templates.ToGoPrivate(field) 103 continue 104 } 105 ret += templates.ToGo(field) 106 } 107 return ret 108 } 109 110 // Join concatenates the field parts with a string separator between. Useful in templates. 111 func (f Field) Join(str string) string { 112 return strings.Join(f, str) 113 } 114 115 // JoinGo concatenates the Go name of field parts with a string separator between. Useful in templates. 116 func (f Field) JoinGo(str string) string { 117 strs := []string{} 118 119 for _, s := range f { 120 strs = append(strs, templates.ToGo(s)) 121 } 122 return strings.Join(strs, str) 123 } 124 125 func (f Field) LastIndex() int { 126 return len(f) - 1 127 } 128 129 // local functions 130 131 // parseUnnestedKeyFieldSet // handles simple case where none of the fields are nested. 132 func parseUnnestedKeyFieldSet(raw string, prefix []string) Set { 133 ret := Set{} 134 135 for _, s := range strings.Fields(raw) { 136 next := append(prefix[:], s) //nolint:gocritic // slicing out on purpose 137 ret = append(ret, next) 138 } 139 return ret 140 } 141 142 // extractSubs splits out and trims sub-expressions from before, inside, and after "{}". 143 func extractSubs(str string) (string, string, string) { 144 start := strings.Index(str, "{") 145 end := matchingBracketIndex(str, start) 146 147 if start < 0 || end < 0 { 148 panic("invalid key fieldSet: " + str) 149 } 150 return strings.TrimSpace(str[:start]), strings.TrimSpace(str[start+1 : end]), strings.TrimSpace(str[end+1:]) 151 } 152 153 // matchingBracketIndex returns the index of the closing bracket, assuming an open bracket at start. 154 func matchingBracketIndex(str string, start int) int { 155 if start < 0 || len(str) <= start+1 { 156 return -1 157 } 158 var depth int 159 160 for i, c := range str[start+1:] { 161 switch c { 162 case '{': 163 depth++ 164 case '}': 165 if depth == 0 { 166 return start + 1 + i 167 } 168 depth-- 169 } 170 } 171 return -1 172 } 173 174 func fieldByName(obj *codegen.Object, name string) *codegen.Field { 175 for _, field := range obj.Fields { 176 if field.Name == name { 177 return field 178 } 179 } 180 return nil 181 }