github.com/lastbackend/toolkit@v0.0.0-20241020043710-cafa37b95aad/pkg/util/converter/query.go (about)

     1  /*
     2  Copyright [2014] - [2023] The Last.Backend authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package converter
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/base64"
    22  	"fmt"
    23  	"io"
    24  	"net"
    25  	"net/http"
    26  	"net/textproto"
    27  	"net/url"
    28  	"regexp"
    29  	"strconv"
    30  	"strings"
    31  	"time"
    32  
    33  	"github.com/gorilla/mux"
    34  	"github.com/pkg/errors"
    35  	"google.golang.org/genproto/protobuf/field_mask"
    36  	"google.golang.org/grpc/codes"
    37  	"google.golang.org/grpc/status"
    38  	"google.golang.org/protobuf/proto"
    39  	"google.golang.org/protobuf/reflect/protoreflect"
    40  	"google.golang.org/protobuf/reflect/protoregistry"
    41  	"google.golang.org/protobuf/types/known/durationpb"
    42  	"google.golang.org/protobuf/types/known/timestamppb"
    43  	"google.golang.org/protobuf/types/known/wrapperspb"
    44  )
    45  
    46  var (
    47  	valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`)
    48  )
    49  
    50  const (
    51  	MetadataPrefix             = "router-"
    52  	MetadataHeaderPrefix       = "Router-Metadata-"
    53  	metadataHeaderBinarySuffix = "-Bin"
    54  	xHeaderPrefix              = "X-"
    55  	xForwardedFor              = "X-Forwarded-For"
    56  	xForwardedHostHeader       = "X-Forwarded-Host"
    57  )
    58  
    59  func NewReader(r io.Reader) (io.Reader, error) {
    60  	b, err := io.ReadAll(r)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	return bytes.NewReader(b), nil
    65  }
    66  
    67  func PrepareHeaderFromRequest(req *http.Request) (map[string]string, error) {
    68  	var headers = make(map[string]string, 0)
    69  
    70  	for k, v := range req.Header {
    71  		k = textproto.CanonicalMIMEHeaderKey(k)
    72  		for _, val := range v {
    73  			if h, ok := headerMatcher(k); ok {
    74  				if strings.HasSuffix(k, metadataHeaderBinarySuffix) {
    75  					b, err := decodeBinHeader(val)
    76  					if err != nil {
    77  						return nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", k, err)
    78  					}
    79  					val = string(b)
    80  				}
    81  				headers[h] = val
    82  			}
    83  		}
    84  	}
    85  
    86  	if host := req.Header.Get(xForwardedHostHeader); host != "" {
    87  		headers[strings.ToLower(xForwardedHostHeader)] = host
    88  	} else if req.Host != "" {
    89  		headers[strings.ToLower(xForwardedHostHeader)] = req.Host
    90  	}
    91  
    92  	if addr := req.RemoteAddr; addr != "" {
    93  		if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
    94  			if forward := req.Header.Get(xForwardedFor); forward == "" {
    95  				headers[strings.ToLower(xForwardedFor)] = remoteIP
    96  			} else {
    97  				headers[strings.ToLower(xForwardedFor)] = fmt.Sprintf("%s, %s", forward, remoteIP)
    98  			}
    99  		}
   100  	}
   101  
   102  	return headers, nil
   103  }
   104  
   105  func SetRawBodyToProto(r *http.Request, message proto.Message, param string) error {
   106  	raw, err := io.ReadAll(r.Body)
   107  	if err != nil {
   108  		return err
   109  	}
   110  	paramPath := strings.Split(param, ".")
   111  	return matchFieldValue(message.ProtoReflect(), paramPath, []string{base64.StdEncoding.EncodeToString(raw)})
   112  }
   113  
   114  func ParseRequestUrlParametersToProto(r *http.Request, message proto.Message, param string) error {
   115  	paramPath := strings.Split(param, ".")
   116  	return matchFieldValue(message.ProtoReflect(), paramPath, []string{mux.Vars(r)[param]})
   117  }
   118  
   119  func ParseRequestQueryParametersToProto(message proto.Message, values url.Values) error {
   120  	for k, v := range values {
   121  		match := valuesKeyRegexp.FindStringSubmatch(k)
   122  		if len(match) == 3 {
   123  			k = match[1]
   124  			v = append([]string{match[2]}, v...)
   125  		}
   126  		paramPath := strings.Split(k, ".")
   127  		if err := matchFieldValue(message.ProtoReflect(), paramPath, v); err != nil {
   128  			return err
   129  		}
   130  	}
   131  	return nil
   132  }
   133  
   134  func headerMatcher(key string) (string, bool) {
   135  	key = textproto.CanonicalMIMEHeaderKey(key)
   136  	if strings.HasPrefix(key, xHeaderPrefix) {
   137  		return key, true
   138  	} else if strings.HasPrefix(key, MetadataHeaderPrefix) {
   139  		return key[len(MetadataHeaderPrefix):], true
   140  	} else {
   141  		return MetadataPrefix + key, true
   142  	}
   143  }
   144  
   145  func decodeBinHeader(v string) ([]byte, error) {
   146  	if len(v)%4 == 0 {
   147  		return base64.StdEncoding.DecodeString(v)
   148  	}
   149  	return base64.RawStdEncoding.DecodeString(v)
   150  }
   151  
   152  func matchFieldValue(msgValue protoreflect.Message, paramPath []string, values []string) error {
   153  	if len(paramPath) < 1 {
   154  		return errors.New("no param path")
   155  	}
   156  	if len(values) < 1 {
   157  		return errors.New("no value provided")
   158  	}
   159  
   160  	var fieldDescriptor protoreflect.FieldDescriptor
   161  	for index, paramName := range paramPath {
   162  		fields := msgValue.Descriptor().Fields()
   163  		fieldDescriptor = fields.ByName(protoreflect.Name(paramName))
   164  		if fieldDescriptor == nil {
   165  			fieldDescriptor = fields.ByJSONName(paramName)
   166  			if fieldDescriptor == nil {
   167  				fmt.Println(fmt.Sprintf("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(paramPath, ".")))
   168  				return nil
   169  			}
   170  		}
   171  		if index == len(paramPath)-1 {
   172  			break
   173  		}
   174  		if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated {
   175  			return fmt.Errorf("invalid path: %q is not a message", paramName)
   176  		}
   177  		msgValue = msgValue.Mutable(fieldDescriptor).Message()
   178  	}
   179  	if of := fieldDescriptor.ContainingOneof(); of != nil {
   180  		if f := msgValue.WhichOneof(of); f != nil {
   181  			return fmt.Errorf("field already set for oneof %q", of.FullName().Name())
   182  		}
   183  	}
   184  	switch {
   185  	case fieldDescriptor.IsList():
   186  		return matchListField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values)
   187  	case fieldDescriptor.IsMap():
   188  		return matchMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values)
   189  	}
   190  	if len(values) > 1 {
   191  		return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", "))
   192  	}
   193  	return matchField(fieldDescriptor, msgValue, values[0])
   194  }
   195  
   196  func matchField(fieldDescriptor protoreflect.FieldDescriptor, message protoreflect.Message, value string) error {
   197  	v, err := parseField(fieldDescriptor, value)
   198  	if err != nil {
   199  		return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err)
   200  	}
   201  	message.Set(fieldDescriptor, v)
   202  	return nil
   203  }
   204  
   205  func matchListField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {
   206  	for _, v := range values {
   207  		v, err := parseField(fieldDescriptor, v)
   208  		if err != nil {
   209  			return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err)
   210  		}
   211  		list.Append(v)
   212  	}
   213  	return nil
   214  }
   215  
   216  func matchMapField(fieldDescriptor protoreflect.FieldDescriptor, items protoreflect.Map, values []string) error {
   217  	if len(values) != 2 {
   218  		return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName())
   219  	}
   220  	key, err := parseField(fieldDescriptor.MapKey(), values[0])
   221  	if err != nil {
   222  		return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err)
   223  	}
   224  	value, err := parseField(fieldDescriptor.MapValue(), values[1])
   225  	if err != nil {
   226  		return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err)
   227  	}
   228  	items.Set(key.MapKey(), value)
   229  	return nil
   230  }
   231  
   232  func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
   233  	switch fieldDescriptor.Kind() {
   234  	case protoreflect.StringKind:
   235  		return protoreflect.ValueOfString(value), nil
   236  	case protoreflect.BytesKind:
   237  		v, err := base64.StdEncoding.DecodeString(value)
   238  		if err != nil {
   239  			return protoreflect.Value{}, err
   240  		}
   241  		return protoreflect.ValueOfBytes(v), nil
   242  	case protoreflect.BoolKind:
   243  		v, err := strconv.ParseBool(value)
   244  		if err != nil {
   245  			return protoreflect.Value{}, err
   246  		}
   247  		return protoreflect.ValueOfBool(v), nil
   248  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   249  		v, err := strconv.ParseInt(value, 10, 32)
   250  		if err != nil {
   251  			return protoreflect.Value{}, err
   252  		}
   253  		return protoreflect.ValueOfInt32(int32(v)), nil
   254  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   255  		v, err := strconv.ParseInt(value, 10, 64)
   256  		if err != nil {
   257  			return protoreflect.Value{}, err
   258  		}
   259  		return protoreflect.ValueOfInt64(v), nil
   260  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   261  		v, err := strconv.ParseUint(value, 10, 32)
   262  		if err != nil {
   263  			return protoreflect.Value{}, err
   264  		}
   265  		return protoreflect.ValueOfUint32(uint32(v)), nil
   266  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   267  		v, err := strconv.ParseUint(value, 10, 64)
   268  		if err != nil {
   269  			return protoreflect.Value{}, err
   270  		}
   271  		return protoreflect.ValueOfUint64(v), nil
   272  	case protoreflect.FloatKind:
   273  		v, err := strconv.ParseFloat(value, 32)
   274  		if err != nil {
   275  			return protoreflect.Value{}, err
   276  		}
   277  		return protoreflect.ValueOfFloat32(float32(v)), nil
   278  	case protoreflect.DoubleKind:
   279  		v, err := strconv.ParseFloat(value, 64)
   280  		if err != nil {
   281  			return protoreflect.Value{}, err
   282  		}
   283  		return protoreflect.ValueOfFloat64(v), nil
   284  	case protoreflect.MessageKind, protoreflect.GroupKind:
   285  		return parseMessage(fieldDescriptor.Message(), value)
   286  	case protoreflect.EnumKind:
   287  		enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName())
   288  		switch {
   289  		case err == protoregistry.NotFound:
   290  			return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName())
   291  		case err != nil:
   292  			return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)
   293  		}
   294  		v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
   295  		if v == nil {
   296  			i, err := strconv.Atoi(value)
   297  			if err != nil {
   298  				return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
   299  			}
   300  			v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i))
   301  			if v == nil {
   302  				return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
   303  			}
   304  		}
   305  		return protoreflect.ValueOfEnum(v.Number()), nil
   306  	default:
   307  		panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind()))
   308  	}
   309  }
   310  
   311  func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {
   312  	var protoMessage proto.Message
   313  	switch msgDescriptor.FullName() {
   314  	case "google.protobuf.BoolValue":
   315  		v, err := strconv.ParseBool(value)
   316  		if err != nil {
   317  			return protoreflect.Value{}, err
   318  		}
   319  		protoMessage = &wrapperspb.BoolValue{Value: v}
   320  	case "google.protobuf.StringValue":
   321  		protoMessage = &wrapperspb.StringValue{Value: value}
   322  	case "google.protobuf.FloatValue":
   323  		v, err := strconv.ParseFloat(value, 32)
   324  		if err != nil {
   325  			return protoreflect.Value{}, err
   326  		}
   327  		protoMessage = &wrapperspb.FloatValue{Value: float32(v)}
   328  	case "google.protobuf.Int64Value":
   329  		v, err := strconv.ParseInt(value, 10, 64)
   330  		if err != nil {
   331  			return protoreflect.Value{}, err
   332  		}
   333  		protoMessage = &wrapperspb.Int64Value{Value: v}
   334  	case "google.protobuf.Int32Value":
   335  		v, err := strconv.ParseInt(value, 10, 32)
   336  		if err != nil {
   337  			return protoreflect.Value{}, err
   338  		}
   339  		protoMessage = &wrapperspb.Int32Value{Value: int32(v)}
   340  	case "google.protobuf.UInt64Value":
   341  		v, err := strconv.ParseUint(value, 10, 64)
   342  		if err != nil {
   343  			return protoreflect.Value{}, err
   344  		}
   345  		protoMessage = &wrapperspb.UInt64Value{Value: v}
   346  	case "google.protobuf.UInt32Value":
   347  		v, err := strconv.ParseUint(value, 10, 32)
   348  		if err != nil {
   349  			return protoreflect.Value{}, err
   350  		}
   351  		protoMessage = &wrapperspb.UInt32Value{Value: uint32(v)}
   352  	case "google.protobuf.BytesValue":
   353  		v, err := base64.StdEncoding.DecodeString(value)
   354  		if err != nil {
   355  			return protoreflect.Value{}, err
   356  		}
   357  		protoMessage = &wrapperspb.BytesValue{Value: v}
   358  	case "google.protobuf.DoubleValue":
   359  		v, err := strconv.ParseFloat(value, 64)
   360  		if err != nil {
   361  			return protoreflect.Value{}, err
   362  		}
   363  		protoMessage = &wrapperspb.DoubleValue{Value: v}
   364  	case "google.protobuf.FieldMask":
   365  		fm := &field_mask.FieldMask{}
   366  		fm.Paths = append(fm.Paths, strings.Split(value, ",")...)
   367  		protoMessage = fm
   368  	case "google.protobuf.Duration":
   369  		if value == "null" {
   370  			break
   371  		}
   372  		d, err := time.ParseDuration(value)
   373  		if err != nil {
   374  			return protoreflect.Value{}, err
   375  		}
   376  		protoMessage = durationpb.New(d)
   377  	case "google.protobuf.Timestamp":
   378  		if value == "null" {
   379  			break
   380  		}
   381  		t, err := time.Parse(time.RFC3339Nano, value)
   382  		if err != nil {
   383  			return protoreflect.Value{}, err
   384  		}
   385  		protoMessage = timestamppb.New(t)
   386  	default:
   387  		return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
   388  	}
   389  
   390  	return protoreflect.ValueOfMessage(protoMessage.ProtoReflect()), nil
   391  }