github.com/lastbackend/toolkit@v0.0.0-20241020043710-cafa37b95aad/pkg/util/converter/query.go (about) 1 /* 2 Copyright [2014] - [2023] The Last.Backend authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package converter 18 19 import ( 20 "bytes" 21 "encoding/base64" 22 "fmt" 23 "io" 24 "net" 25 "net/http" 26 "net/textproto" 27 "net/url" 28 "regexp" 29 "strconv" 30 "strings" 31 "time" 32 33 "github.com/gorilla/mux" 34 "github.com/pkg/errors" 35 "google.golang.org/genproto/protobuf/field_mask" 36 "google.golang.org/grpc/codes" 37 "google.golang.org/grpc/status" 38 "google.golang.org/protobuf/proto" 39 "google.golang.org/protobuf/reflect/protoreflect" 40 "google.golang.org/protobuf/reflect/protoregistry" 41 "google.golang.org/protobuf/types/known/durationpb" 42 "google.golang.org/protobuf/types/known/timestamppb" 43 "google.golang.org/protobuf/types/known/wrapperspb" 44 ) 45 46 var ( 47 valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`) 48 ) 49 50 const ( 51 MetadataPrefix = "router-" 52 MetadataHeaderPrefix = "Router-Metadata-" 53 metadataHeaderBinarySuffix = "-Bin" 54 xHeaderPrefix = "X-" 55 xForwardedFor = "X-Forwarded-For" 56 xForwardedHostHeader = "X-Forwarded-Host" 57 ) 58 59 func NewReader(r io.Reader) (io.Reader, error) { 60 b, err := io.ReadAll(r) 61 if err != nil { 62 return nil, err 63 } 64 return bytes.NewReader(b), nil 65 } 66 67 func PrepareHeaderFromRequest(req *http.Request) (map[string]string, error) { 68 var headers = make(map[string]string, 0) 69 70 for k, v := range req.Header { 71 k = textproto.CanonicalMIMEHeaderKey(k) 72 for _, val := range v { 73 if h, ok := headerMatcher(k); ok { 74 if strings.HasSuffix(k, metadataHeaderBinarySuffix) { 75 b, err := decodeBinHeader(val) 76 if err != nil { 77 return nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", k, err) 78 } 79 val = string(b) 80 } 81 headers[h] = val 82 } 83 } 84 } 85 86 if host := req.Header.Get(xForwardedHostHeader); host != "" { 87 headers[strings.ToLower(xForwardedHostHeader)] = host 88 } else if req.Host != "" { 89 headers[strings.ToLower(xForwardedHostHeader)] = req.Host 90 } 91 92 if addr := req.RemoteAddr; addr != "" { 93 if remoteIP, _, err := net.SplitHostPort(addr); err == nil { 94 if forward := req.Header.Get(xForwardedFor); forward == "" { 95 headers[strings.ToLower(xForwardedFor)] = remoteIP 96 } else { 97 headers[strings.ToLower(xForwardedFor)] = fmt.Sprintf("%s, %s", forward, remoteIP) 98 } 99 } 100 } 101 102 return headers, nil 103 } 104 105 func SetRawBodyToProto(r *http.Request, message proto.Message, param string) error { 106 raw, err := io.ReadAll(r.Body) 107 if err != nil { 108 return err 109 } 110 paramPath := strings.Split(param, ".") 111 return matchFieldValue(message.ProtoReflect(), paramPath, []string{base64.StdEncoding.EncodeToString(raw)}) 112 } 113 114 func ParseRequestUrlParametersToProto(r *http.Request, message proto.Message, param string) error { 115 paramPath := strings.Split(param, ".") 116 return matchFieldValue(message.ProtoReflect(), paramPath, []string{mux.Vars(r)[param]}) 117 } 118 119 func ParseRequestQueryParametersToProto(message proto.Message, values url.Values) error { 120 for k, v := range values { 121 match := valuesKeyRegexp.FindStringSubmatch(k) 122 if len(match) == 3 { 123 k = match[1] 124 v = append([]string{match[2]}, v...) 125 } 126 paramPath := strings.Split(k, ".") 127 if err := matchFieldValue(message.ProtoReflect(), paramPath, v); err != nil { 128 return err 129 } 130 } 131 return nil 132 } 133 134 func headerMatcher(key string) (string, bool) { 135 key = textproto.CanonicalMIMEHeaderKey(key) 136 if strings.HasPrefix(key, xHeaderPrefix) { 137 return key, true 138 } else if strings.HasPrefix(key, MetadataHeaderPrefix) { 139 return key[len(MetadataHeaderPrefix):], true 140 } else { 141 return MetadataPrefix + key, true 142 } 143 } 144 145 func decodeBinHeader(v string) ([]byte, error) { 146 if len(v)%4 == 0 { 147 return base64.StdEncoding.DecodeString(v) 148 } 149 return base64.RawStdEncoding.DecodeString(v) 150 } 151 152 func matchFieldValue(msgValue protoreflect.Message, paramPath []string, values []string) error { 153 if len(paramPath) < 1 { 154 return errors.New("no param path") 155 } 156 if len(values) < 1 { 157 return errors.New("no value provided") 158 } 159 160 var fieldDescriptor protoreflect.FieldDescriptor 161 for index, paramName := range paramPath { 162 fields := msgValue.Descriptor().Fields() 163 fieldDescriptor = fields.ByName(protoreflect.Name(paramName)) 164 if fieldDescriptor == nil { 165 fieldDescriptor = fields.ByJSONName(paramName) 166 if fieldDescriptor == nil { 167 fmt.Println(fmt.Sprintf("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(paramPath, "."))) 168 return nil 169 } 170 } 171 if index == len(paramPath)-1 { 172 break 173 } 174 if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated { 175 return fmt.Errorf("invalid path: %q is not a message", paramName) 176 } 177 msgValue = msgValue.Mutable(fieldDescriptor).Message() 178 } 179 if of := fieldDescriptor.ContainingOneof(); of != nil { 180 if f := msgValue.WhichOneof(of); f != nil { 181 return fmt.Errorf("field already set for oneof %q", of.FullName().Name()) 182 } 183 } 184 switch { 185 case fieldDescriptor.IsList(): 186 return matchListField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values) 187 case fieldDescriptor.IsMap(): 188 return matchMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values) 189 } 190 if len(values) > 1 { 191 return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", ")) 192 } 193 return matchField(fieldDescriptor, msgValue, values[0]) 194 } 195 196 func matchField(fieldDescriptor protoreflect.FieldDescriptor, message protoreflect.Message, value string) error { 197 v, err := parseField(fieldDescriptor, value) 198 if err != nil { 199 return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err) 200 } 201 message.Set(fieldDescriptor, v) 202 return nil 203 } 204 205 func matchListField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error { 206 for _, v := range values { 207 v, err := parseField(fieldDescriptor, v) 208 if err != nil { 209 return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err) 210 } 211 list.Append(v) 212 } 213 return nil 214 } 215 216 func matchMapField(fieldDescriptor protoreflect.FieldDescriptor, items protoreflect.Map, values []string) error { 217 if len(values) != 2 { 218 return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName()) 219 } 220 key, err := parseField(fieldDescriptor.MapKey(), values[0]) 221 if err != nil { 222 return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err) 223 } 224 value, err := parseField(fieldDescriptor.MapValue(), values[1]) 225 if err != nil { 226 return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err) 227 } 228 items.Set(key.MapKey(), value) 229 return nil 230 } 231 232 func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) { 233 switch fieldDescriptor.Kind() { 234 case protoreflect.StringKind: 235 return protoreflect.ValueOfString(value), nil 236 case protoreflect.BytesKind: 237 v, err := base64.StdEncoding.DecodeString(value) 238 if err != nil { 239 return protoreflect.Value{}, err 240 } 241 return protoreflect.ValueOfBytes(v), nil 242 case protoreflect.BoolKind: 243 v, err := strconv.ParseBool(value) 244 if err != nil { 245 return protoreflect.Value{}, err 246 } 247 return protoreflect.ValueOfBool(v), nil 248 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 249 v, err := strconv.ParseInt(value, 10, 32) 250 if err != nil { 251 return protoreflect.Value{}, err 252 } 253 return protoreflect.ValueOfInt32(int32(v)), nil 254 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 255 v, err := strconv.ParseInt(value, 10, 64) 256 if err != nil { 257 return protoreflect.Value{}, err 258 } 259 return protoreflect.ValueOfInt64(v), nil 260 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 261 v, err := strconv.ParseUint(value, 10, 32) 262 if err != nil { 263 return protoreflect.Value{}, err 264 } 265 return protoreflect.ValueOfUint32(uint32(v)), nil 266 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 267 v, err := strconv.ParseUint(value, 10, 64) 268 if err != nil { 269 return protoreflect.Value{}, err 270 } 271 return protoreflect.ValueOfUint64(v), nil 272 case protoreflect.FloatKind: 273 v, err := strconv.ParseFloat(value, 32) 274 if err != nil { 275 return protoreflect.Value{}, err 276 } 277 return protoreflect.ValueOfFloat32(float32(v)), nil 278 case protoreflect.DoubleKind: 279 v, err := strconv.ParseFloat(value, 64) 280 if err != nil { 281 return protoreflect.Value{}, err 282 } 283 return protoreflect.ValueOfFloat64(v), nil 284 case protoreflect.MessageKind, protoreflect.GroupKind: 285 return parseMessage(fieldDescriptor.Message(), value) 286 case protoreflect.EnumKind: 287 enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName()) 288 switch { 289 case err == protoregistry.NotFound: 290 return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName()) 291 case err != nil: 292 return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err) 293 } 294 v := enum.Descriptor().Values().ByName(protoreflect.Name(value)) 295 if v == nil { 296 i, err := strconv.Atoi(value) 297 if err != nil { 298 return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value) 299 } 300 v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i)) 301 if v == nil { 302 return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value) 303 } 304 } 305 return protoreflect.ValueOfEnum(v.Number()), nil 306 default: 307 panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind())) 308 } 309 } 310 311 func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) { 312 var protoMessage proto.Message 313 switch msgDescriptor.FullName() { 314 case "google.protobuf.BoolValue": 315 v, err := strconv.ParseBool(value) 316 if err != nil { 317 return protoreflect.Value{}, err 318 } 319 protoMessage = &wrapperspb.BoolValue{Value: v} 320 case "google.protobuf.StringValue": 321 protoMessage = &wrapperspb.StringValue{Value: value} 322 case "google.protobuf.FloatValue": 323 v, err := strconv.ParseFloat(value, 32) 324 if err != nil { 325 return protoreflect.Value{}, err 326 } 327 protoMessage = &wrapperspb.FloatValue{Value: float32(v)} 328 case "google.protobuf.Int64Value": 329 v, err := strconv.ParseInt(value, 10, 64) 330 if err != nil { 331 return protoreflect.Value{}, err 332 } 333 protoMessage = &wrapperspb.Int64Value{Value: v} 334 case "google.protobuf.Int32Value": 335 v, err := strconv.ParseInt(value, 10, 32) 336 if err != nil { 337 return protoreflect.Value{}, err 338 } 339 protoMessage = &wrapperspb.Int32Value{Value: int32(v)} 340 case "google.protobuf.UInt64Value": 341 v, err := strconv.ParseUint(value, 10, 64) 342 if err != nil { 343 return protoreflect.Value{}, err 344 } 345 protoMessage = &wrapperspb.UInt64Value{Value: v} 346 case "google.protobuf.UInt32Value": 347 v, err := strconv.ParseUint(value, 10, 32) 348 if err != nil { 349 return protoreflect.Value{}, err 350 } 351 protoMessage = &wrapperspb.UInt32Value{Value: uint32(v)} 352 case "google.protobuf.BytesValue": 353 v, err := base64.StdEncoding.DecodeString(value) 354 if err != nil { 355 return protoreflect.Value{}, err 356 } 357 protoMessage = &wrapperspb.BytesValue{Value: v} 358 case "google.protobuf.DoubleValue": 359 v, err := strconv.ParseFloat(value, 64) 360 if err != nil { 361 return protoreflect.Value{}, err 362 } 363 protoMessage = &wrapperspb.DoubleValue{Value: v} 364 case "google.protobuf.FieldMask": 365 fm := &field_mask.FieldMask{} 366 fm.Paths = append(fm.Paths, strings.Split(value, ",")...) 367 protoMessage = fm 368 case "google.protobuf.Duration": 369 if value == "null" { 370 break 371 } 372 d, err := time.ParseDuration(value) 373 if err != nil { 374 return protoreflect.Value{}, err 375 } 376 protoMessage = durationpb.New(d) 377 case "google.protobuf.Timestamp": 378 if value == "null" { 379 break 380 } 381 t, err := time.Parse(time.RFC3339Nano, value) 382 if err != nil { 383 return protoreflect.Value{}, err 384 } 385 protoMessage = timestamppb.New(t) 386 default: 387 return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName())) 388 } 389 390 return protoreflect.ValueOfMessage(protoMessage.ProtoReflect()), nil 391 }