github.com/grpc-ecosystem/grpc-gateway/v2@v2.19.1/runtime/query.go (about) 1 package runtime 2 3 import ( 4 "errors" 5 "fmt" 6 "net/url" 7 "regexp" 8 "strconv" 9 "strings" 10 "time" 11 12 "github.com/grpc-ecosystem/grpc-gateway/v2/utilities" 13 "google.golang.org/grpc/grpclog" 14 "google.golang.org/protobuf/encoding/protojson" 15 "google.golang.org/protobuf/proto" 16 "google.golang.org/protobuf/reflect/protoreflect" 17 "google.golang.org/protobuf/reflect/protoregistry" 18 "google.golang.org/protobuf/types/known/durationpb" 19 field_mask "google.golang.org/protobuf/types/known/fieldmaskpb" 20 "google.golang.org/protobuf/types/known/structpb" 21 "google.golang.org/protobuf/types/known/timestamppb" 22 "google.golang.org/protobuf/types/known/wrapperspb" 23 ) 24 25 var valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`) 26 27 var currentQueryParser QueryParameterParser = &DefaultQueryParser{} 28 29 // QueryParameterParser defines interface for all query parameter parsers 30 type QueryParameterParser interface { 31 Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error 32 } 33 34 // PopulateQueryParameters parses query parameters 35 // into "msg" using current query parser 36 func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error { 37 return currentQueryParser.Parse(msg, values, filter) 38 } 39 40 // DefaultQueryParser is a QueryParameterParser which implements the default 41 // query parameters parsing behavior. 42 // 43 // See https://github.com/grpc-ecosystem/grpc-gateway/issues/2632 for more context. 44 type DefaultQueryParser struct{} 45 46 // Parse populates "values" into "msg". 47 // A value is ignored if its key starts with one of the elements in "filter". 48 func (*DefaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error { 49 for key, values := range values { 50 if match := valuesKeyRegexp.FindStringSubmatch(key); len(match) == 3 { 51 key = match[1] 52 values = append([]string{match[2]}, values...) 53 } 54 55 msgValue := msg.ProtoReflect() 56 fieldPath := normalizeFieldPath(msgValue, strings.Split(key, ".")) 57 if filter.HasCommonPrefix(fieldPath) { 58 continue 59 } 60 if err := populateFieldValueFromPath(msgValue, fieldPath, values); err != nil { 61 return err 62 } 63 } 64 return nil 65 } 66 67 // PopulateFieldFromPath sets a value in a nested Protobuf structure. 68 func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error { 69 fieldPath := strings.Split(fieldPathString, ".") 70 return populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, []string{value}) 71 } 72 73 func normalizeFieldPath(msgValue protoreflect.Message, fieldPath []string) []string { 74 newFieldPath := make([]string, 0, len(fieldPath)) 75 for i, fieldName := range fieldPath { 76 fields := msgValue.Descriptor().Fields() 77 fieldDesc := fields.ByTextName(fieldName) 78 if fieldDesc == nil { 79 fieldDesc = fields.ByJSONName(fieldName) 80 } 81 if fieldDesc == nil { 82 // return initial field path values if no matching message field was found 83 return fieldPath 84 } 85 86 newFieldPath = append(newFieldPath, string(fieldDesc.Name())) 87 88 // If this is the last element, we're done 89 if i == len(fieldPath)-1 { 90 break 91 } 92 93 // Only singular message fields are allowed 94 if fieldDesc.Message() == nil || fieldDesc.Cardinality() == protoreflect.Repeated { 95 return fieldPath 96 } 97 98 // Get the nested message 99 msgValue = msgValue.Get(fieldDesc).Message() 100 } 101 102 return newFieldPath 103 } 104 105 func populateFieldValueFromPath(msgValue protoreflect.Message, fieldPath []string, values []string) error { 106 if len(fieldPath) < 1 { 107 return errors.New("no field path") 108 } 109 if len(values) < 1 { 110 return errors.New("no value provided") 111 } 112 113 var fieldDescriptor protoreflect.FieldDescriptor 114 for i, fieldName := range fieldPath { 115 fields := msgValue.Descriptor().Fields() 116 117 // Get field by name 118 fieldDescriptor = fields.ByName(protoreflect.Name(fieldName)) 119 if fieldDescriptor == nil { 120 fieldDescriptor = fields.ByJSONName(fieldName) 121 if fieldDescriptor == nil { 122 // We're not returning an error here because this could just be 123 // an extra query parameter that isn't part of the request. 124 grpclog.Infof("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(fieldPath, ".")) 125 return nil 126 } 127 } 128 129 // If this is the last element, we're done 130 if i == len(fieldPath)-1 { 131 break 132 } 133 134 // Only singular message fields are allowed 135 if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated { 136 return fmt.Errorf("invalid path: %q is not a message", fieldName) 137 } 138 139 // Get the nested message 140 msgValue = msgValue.Mutable(fieldDescriptor).Message() 141 } 142 143 // Check if oneof already set 144 if of := fieldDescriptor.ContainingOneof(); of != nil { 145 if f := msgValue.WhichOneof(of); f != nil { 146 return fmt.Errorf("field already set for oneof %q", of.FullName().Name()) 147 } 148 } 149 150 switch { 151 case fieldDescriptor.IsList(): 152 return populateRepeatedField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values) 153 case fieldDescriptor.IsMap(): 154 return populateMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values) 155 } 156 157 if len(values) > 1 { 158 return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", ")) 159 } 160 161 return populateField(fieldDescriptor, msgValue, values[0]) 162 } 163 164 func populateField(fieldDescriptor protoreflect.FieldDescriptor, msgValue protoreflect.Message, value string) error { 165 v, err := parseField(fieldDescriptor, value) 166 if err != nil { 167 return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err) 168 } 169 170 msgValue.Set(fieldDescriptor, v) 171 return nil 172 } 173 174 func populateRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error { 175 for _, value := range values { 176 v, err := parseField(fieldDescriptor, value) 177 if err != nil { 178 return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err) 179 } 180 list.Append(v) 181 } 182 183 return nil 184 } 185 186 func populateMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error { 187 if len(values) != 2 { 188 return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName()) 189 } 190 191 key, err := parseField(fieldDescriptor.MapKey(), values[0]) 192 if err != nil { 193 return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err) 194 } 195 196 value, err := parseField(fieldDescriptor.MapValue(), values[1]) 197 if err != nil { 198 return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err) 199 } 200 201 mp.Set(key.MapKey(), value) 202 203 return nil 204 } 205 206 func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) { 207 switch fieldDescriptor.Kind() { 208 case protoreflect.BoolKind: 209 v, err := strconv.ParseBool(value) 210 if err != nil { 211 return protoreflect.Value{}, err 212 } 213 return protoreflect.ValueOfBool(v), nil 214 case protoreflect.EnumKind: 215 enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName()) 216 if err != nil { 217 if errors.Is(err, protoregistry.NotFound) { 218 return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName()) 219 } 220 return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err) 221 } 222 // Look for enum by name 223 v := enum.Descriptor().Values().ByName(protoreflect.Name(value)) 224 if v == nil { 225 i, err := strconv.Atoi(value) 226 if err != nil { 227 return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value) 228 } 229 // Look for enum by number 230 if v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i)); v == nil { 231 return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value) 232 } 233 } 234 return protoreflect.ValueOfEnum(v.Number()), nil 235 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 236 v, err := strconv.ParseInt(value, 10, 32) 237 if err != nil { 238 return protoreflect.Value{}, err 239 } 240 return protoreflect.ValueOfInt32(int32(v)), nil 241 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 242 v, err := strconv.ParseInt(value, 10, 64) 243 if err != nil { 244 return protoreflect.Value{}, err 245 } 246 return protoreflect.ValueOfInt64(v), nil 247 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 248 v, err := strconv.ParseUint(value, 10, 32) 249 if err != nil { 250 return protoreflect.Value{}, err 251 } 252 return protoreflect.ValueOfUint32(uint32(v)), nil 253 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 254 v, err := strconv.ParseUint(value, 10, 64) 255 if err != nil { 256 return protoreflect.Value{}, err 257 } 258 return protoreflect.ValueOfUint64(v), nil 259 case protoreflect.FloatKind: 260 v, err := strconv.ParseFloat(value, 32) 261 if err != nil { 262 return protoreflect.Value{}, err 263 } 264 return protoreflect.ValueOfFloat32(float32(v)), nil 265 case protoreflect.DoubleKind: 266 v, err := strconv.ParseFloat(value, 64) 267 if err != nil { 268 return protoreflect.Value{}, err 269 } 270 return protoreflect.ValueOfFloat64(v), nil 271 case protoreflect.StringKind: 272 return protoreflect.ValueOfString(value), nil 273 case protoreflect.BytesKind: 274 v, err := Bytes(value) 275 if err != nil { 276 return protoreflect.Value{}, err 277 } 278 return protoreflect.ValueOfBytes(v), nil 279 case protoreflect.MessageKind, protoreflect.GroupKind: 280 return parseMessage(fieldDescriptor.Message(), value) 281 default: 282 panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind())) 283 } 284 } 285 286 func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) { 287 var msg proto.Message 288 switch msgDescriptor.FullName() { 289 case "google.protobuf.Timestamp": 290 t, err := time.Parse(time.RFC3339Nano, value) 291 if err != nil { 292 return protoreflect.Value{}, err 293 } 294 msg = timestamppb.New(t) 295 case "google.protobuf.Duration": 296 d, err := time.ParseDuration(value) 297 if err != nil { 298 return protoreflect.Value{}, err 299 } 300 msg = durationpb.New(d) 301 case "google.protobuf.DoubleValue": 302 v, err := strconv.ParseFloat(value, 64) 303 if err != nil { 304 return protoreflect.Value{}, err 305 } 306 msg = wrapperspb.Double(v) 307 case "google.protobuf.FloatValue": 308 v, err := strconv.ParseFloat(value, 32) 309 if err != nil { 310 return protoreflect.Value{}, err 311 } 312 msg = wrapperspb.Float(float32(v)) 313 case "google.protobuf.Int64Value": 314 v, err := strconv.ParseInt(value, 10, 64) 315 if err != nil { 316 return protoreflect.Value{}, err 317 } 318 msg = wrapperspb.Int64(v) 319 case "google.protobuf.Int32Value": 320 v, err := strconv.ParseInt(value, 10, 32) 321 if err != nil { 322 return protoreflect.Value{}, err 323 } 324 msg = wrapperspb.Int32(int32(v)) 325 case "google.protobuf.UInt64Value": 326 v, err := strconv.ParseUint(value, 10, 64) 327 if err != nil { 328 return protoreflect.Value{}, err 329 } 330 msg = wrapperspb.UInt64(v) 331 case "google.protobuf.UInt32Value": 332 v, err := strconv.ParseUint(value, 10, 32) 333 if err != nil { 334 return protoreflect.Value{}, err 335 } 336 msg = wrapperspb.UInt32(uint32(v)) 337 case "google.protobuf.BoolValue": 338 v, err := strconv.ParseBool(value) 339 if err != nil { 340 return protoreflect.Value{}, err 341 } 342 msg = wrapperspb.Bool(v) 343 case "google.protobuf.StringValue": 344 msg = wrapperspb.String(value) 345 case "google.protobuf.BytesValue": 346 v, err := Bytes(value) 347 if err != nil { 348 return protoreflect.Value{}, err 349 } 350 msg = wrapperspb.Bytes(v) 351 case "google.protobuf.FieldMask": 352 fm := &field_mask.FieldMask{} 353 fm.Paths = append(fm.Paths, strings.Split(value, ",")...) 354 msg = fm 355 case "google.protobuf.Value": 356 var v structpb.Value 357 if err := protojson.Unmarshal([]byte(value), &v); err != nil { 358 return protoreflect.Value{}, err 359 } 360 msg = &v 361 case "google.protobuf.Struct": 362 var v structpb.Struct 363 if err := protojson.Unmarshal([]byte(value), &v); err != nil { 364 return protoreflect.Value{}, err 365 } 366 msg = &v 367 default: 368 return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName())) 369 } 370 371 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil 372 }