go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/proto/fieldmasks.go (about)

     1  // Copyright 2019 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proto
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/json"
    20  	"fmt"
    21  	"reflect"
    22  	"strconv"
    23  	"strings"
    24  	"sync"
    25  	"unicode"
    26  
    27  	"github.com/golang/protobuf/proto"
    28  	"google.golang.org/protobuf/types/known/fieldmaskpb"
    29  	"google.golang.org/protobuf/types/known/structpb"
    30  )
    31  
    32  var (
    33  	fieldMaskType    = reflect.TypeOf((*fieldmaskpb.FieldMask)(nil))
    34  	structType       = reflect.TypeOf((*structpb.Struct)(nil))
    35  	protoMessageType = reflect.TypeOf((*proto.Message)(nil)).Elem()
    36  )
    37  
    38  // FixFieldMasksBeforeUnmarshal reads FieldMask fields from a JSON-encoded message,
    39  // parses them as a string according to
    40  // https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/field_mask.proto#L180
    41  // and converts them to a JSON serialization format that Golang Protobuf library
    42  // can unmarshal from.
    43  // It is a workaround for https://github.com/golang/protobuf/issues/745.
    44  //
    45  // This function is a reverse of FixFieldMasksAfterMarshal.
    46  //
    47  // messageType must be a struct, not a struct pointer.
    48  //
    49  // WARNING: AVOID. LIKELY BUGGY, see https://crbug.com/1028915.
    50  func FixFieldMasksBeforeUnmarshal(jsonMessage []byte, messageType reflect.Type) ([]byte, error) {
    51  	var msg map[string]any
    52  	if err := json.Unmarshal(jsonMessage, &msg); err != nil {
    53  		return nil, err
    54  	}
    55  
    56  	if err := fixFieldMasksBeforeUnmarshal(make([]string, 0, 10), msg, messageType); err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	return json.Marshal(msg)
    61  }
    62  
    63  func fixFieldMasksBeforeUnmarshal(fieldPath []string, msg map[string]any, messageType reflect.Type) error {
    64  	fieldTypes := getFieldTypes(messageType)
    65  	for name, val := range msg {
    66  		localPath := append(fieldPath, name)
    67  		typ := fieldTypes[name]
    68  		if typ == nil {
    69  			return fmt.Errorf("unexpected field path %q", strings.Join(localPath, "."))
    70  		}
    71  
    72  		switch val := val.(type) {
    73  		case string:
    74  			if typ == fieldMaskType {
    75  				msg[name] = convertFieldMask(val)
    76  			}
    77  
    78  		case map[string]any:
    79  			if typ != structType && typ.Implements(protoMessageType) {
    80  				if err := fixFieldMasksBeforeUnmarshal(localPath, val, typ.Elem()); err != nil {
    81  					return err
    82  				}
    83  			}
    84  
    85  		case []any:
    86  			if typ.Kind() == reflect.Slice && typ.Elem().Implements(protoMessageType) {
    87  				subMsgType := typ.Elem().Elem()
    88  				for i, el := range val {
    89  					if subMsg, ok := el.(map[string]any); ok {
    90  						elPath := append(localPath, strconv.Itoa(i))
    91  						if err := fixFieldMasksBeforeUnmarshal(elPath, subMsg, subMsgType); err != nil {
    92  							return err
    93  						}
    94  					}
    95  				}
    96  			}
    97  		}
    98  	}
    99  	return nil
   100  }
   101  
   102  // convertFieldMask converts a FieldMask from a string according to
   103  // https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/field_mask.proto#L180
   104  // and converts them to a JSON object that Golang Protobuf library understands.
   105  func convertFieldMask(s string) map[string]any {
   106  	paths := parseFieldMaskString(s)
   107  	for i := range paths {
   108  		paths[i] = toSnakeCase(paths[i])
   109  	}
   110  	return map[string]any{
   111  		"paths": paths,
   112  	}
   113  }
   114  
   115  func toSnakeCase(s string) string {
   116  	buf := &bytes.Buffer{}
   117  	buf.Grow(len(s) + 5) // accounts for 5 underscores
   118  	for _, c := range s {
   119  		if unicode.IsUpper(c) {
   120  			buf.WriteString("_")
   121  			buf.WriteRune(unicode.ToLower(c))
   122  		} else {
   123  			buf.WriteRune(c)
   124  		}
   125  	}
   126  	return buf.String()
   127  }
   128  
   129  // parseFieldMaskString parses a google.protobuf.FieldMask string according to
   130  // https://github.com/protocolbuffers/protobuf/blob/ec1a70913e5793a7d0a7b5fbf7e0e4f75409dd41/src/google/protobuf/field_mask.proto#L180
   131  // Does not convert JSON names (e.g. fooBar) to original names (e.g. foo_bar).
   132  func parseFieldMaskString(s string) (paths []string) {
   133  	inQuote := false
   134  	var seps []int
   135  	for i, c := range s {
   136  		switch {
   137  		case c == '`':
   138  			inQuote = !inQuote
   139  
   140  		case inQuote:
   141  			continue
   142  
   143  		case c == ',':
   144  			seps = append(seps, i)
   145  		}
   146  	}
   147  
   148  	if len(seps) == 0 {
   149  		return []string{s}
   150  	}
   151  
   152  	paths = make([]string, 0, len(seps)+1)
   153  	for i := range seps {
   154  		start := 0
   155  		if i > 0 {
   156  			start = seps[i-1] + 1
   157  		}
   158  		paths = append(paths, s[start:seps[i]])
   159  	}
   160  	paths = append(paths, s[seps[len(seps)-1]+1:])
   161  	return paths
   162  }
   163  
   164  var fieldTypeCache struct {
   165  	sync.RWMutex
   166  	types map[reflect.Type]map[string]reflect.Type
   167  }
   168  
   169  func init() {
   170  	fieldTypeCache.types = map[reflect.Type]map[string]reflect.Type{}
   171  }
   172  
   173  // getFieldTypes returns a map from JSON field name to a Go type.
   174  func getFieldTypes(t reflect.Type) map[string]reflect.Type {
   175  	fieldTypeCache.RLock()
   176  	ret, ok := fieldTypeCache.types[t]
   177  	fieldTypeCache.RUnlock()
   178  
   179  	if ok {
   180  		return ret
   181  	}
   182  
   183  	ret = map[string]reflect.Type{}
   184  
   185  	addFieldType := func(p *proto.Properties, fieldType reflect.Type) {
   186  		jsonName := p.JSONName
   187  		if jsonName == "" {
   188  			// it set only for fields where the JSON name is different.
   189  			jsonName = p.OrigName
   190  		}
   191  		ret[jsonName] = fieldType
   192  	}
   193  
   194  	n := t.NumField()
   195  	fields := make(map[string]reflect.StructField, n)
   196  	for i := 0; i < n; i++ {
   197  		f := t.Field(i)
   198  		fields[f.Name] = f
   199  	}
   200  	props := proto.GetProperties(t)
   201  	for _, p := range props.Prop {
   202  		addFieldType(p, fields[p.Name].Type)
   203  	}
   204  
   205  	for _, of := range props.OneofTypes {
   206  		addFieldType(of.Prop, of.Type.Elem().Field(0).Type)
   207  	}
   208  
   209  	fieldTypeCache.Lock()
   210  	fieldTypeCache.types[t] = ret
   211  	fieldTypeCache.Unlock()
   212  	return ret
   213  }