istio.io/istio@v0.0.0-20240520182934-d79c90f27776/pkg/proto/merge/merge.go (about)

     1  // Copyright Istio Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package merge
    16  
    17  /*
    18   CODE Copied and modified from https://github.com/kumahq/kuma/blob/master/pkg/util/proto/google_proto.go
    19   because of: https://github.com/golang/protobuf/issues/1359
    20  
    21    Copyright 2019 The Go Authors. All rights reserved.
    22    Use of this source code is governed by a BSD-style
    23    license that can be found in the LICENSE file.
    24  */
    25  
    26  import (
    27  	"fmt"
    28  
    29  	"google.golang.org/protobuf/proto"
    30  	"google.golang.org/protobuf/reflect/protoreflect"
    31  	"google.golang.org/protobuf/types/known/durationpb"
    32  )
    33  
    34  type (
    35  	MergeFunction func(dst, src protoreflect.Message)
    36  	mergeOptions  struct {
    37  		customMergeFn map[protoreflect.FullName]MergeFunction
    38  	}
    39  )
    40  type OptionFn func(options mergeOptions) mergeOptions
    41  
    42  func MergeFunctionOptionFn(name protoreflect.FullName, function MergeFunction) OptionFn {
    43  	return func(options mergeOptions) mergeOptions {
    44  		options.customMergeFn[name] = function
    45  		return options
    46  	}
    47  }
    48  
    49  // ReplaceMergeFn instead of merging all subfields one by one, takes src and set it to dest
    50  var ReplaceMergeFn MergeFunction = func(dst, src protoreflect.Message) {
    51  	dst.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
    52  		dst.Clear(fd)
    53  		return true
    54  	})
    55  	src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
    56  		dst.Set(fd, v)
    57  		return true
    58  	})
    59  }
    60  
    61  var options = []OptionFn{
    62  	// Workaround https://github.com/golang/protobuf/issues/1359, merge duration properly
    63  	MergeFunctionOptionFn((&durationpb.Duration{}).ProtoReflect().Descriptor().FullName(), ReplaceMergeFn),
    64  }
    65  
    66  func Merge(dst, src proto.Message) {
    67  	merge(dst, src, options...)
    68  }
    69  
    70  // Merge Code of proto.Merge with modifications to support custom types
    71  func merge(dst, src proto.Message, opts ...OptionFn) {
    72  	mo := mergeOptions{customMergeFn: map[protoreflect.FullName]MergeFunction{}}
    73  	for _, opt := range opts {
    74  		mo = opt(mo)
    75  	}
    76  	dstMsg, srcMsg := dst.ProtoReflect(), src.ProtoReflect()
    77  	if dstMsg.Descriptor() != srcMsg.Descriptor() {
    78  		if got, want := dstMsg.Descriptor().FullName(), srcMsg.Descriptor().FullName(); got != want {
    79  			panic(fmt.Sprintf("descriptor mismatch: %v != %v", got, want))
    80  		}
    81  		panic("descriptor mismatch")
    82  	}
    83  	mo.mergeMessage(dstMsg, srcMsg)
    84  }
    85  
    86  func (o mergeOptions) mergeMessage(dst, src protoreflect.Message) {
    87  	// The regular proto.mergeMessage would have a fast path method option here.
    88  	// As we want to have exceptions we always use the slow path.
    89  	if !dst.IsValid() {
    90  		panic(fmt.Sprintf("cannot merge into invalid %v message", dst.Descriptor().FullName()))
    91  	}
    92  
    93  	src.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
    94  		switch {
    95  		case fd.IsList():
    96  			o.mergeList(dst.Mutable(fd).List(), v.List(), fd)
    97  		case fd.IsMap():
    98  			o.mergeMap(dst.Mutable(fd).Map(), v.Map(), fd.MapValue())
    99  		case fd.Message() != nil:
   100  			mergeFn, exists := o.customMergeFn[fd.Message().FullName()]
   101  			if exists {
   102  				mergeFn(dst.Mutable(fd).Message(), v.Message())
   103  			} else {
   104  				o.mergeMessage(dst.Mutable(fd).Message(), v.Message())
   105  			}
   106  		case fd.Kind() == protoreflect.BytesKind:
   107  			dst.Set(fd, o.cloneBytes(v))
   108  		default:
   109  			dst.Set(fd, v)
   110  		}
   111  		return true
   112  	})
   113  
   114  	if len(src.GetUnknown()) > 0 {
   115  		dst.SetUnknown(append(dst.GetUnknown(), src.GetUnknown()...))
   116  	}
   117  }
   118  
   119  func (o mergeOptions) mergeList(dst, src protoreflect.List, fd protoreflect.FieldDescriptor) {
   120  	// Merge semantics appends to the end of the existing list.
   121  	for i, n := 0, src.Len(); i < n; i++ {
   122  		switch v := src.Get(i); {
   123  		case fd.Message() != nil:
   124  			dstv := dst.NewElement()
   125  			o.mergeMessage(dstv.Message(), v.Message())
   126  			dst.Append(dstv)
   127  		case fd.Kind() == protoreflect.BytesKind:
   128  			dst.Append(o.cloneBytes(v))
   129  		default:
   130  			dst.Append(v)
   131  		}
   132  	}
   133  }
   134  
   135  func (o mergeOptions) mergeMap(dst, src protoreflect.Map, fd protoreflect.FieldDescriptor) {
   136  	// Merge semantics replaces, rather than merges into existing entries.
   137  	src.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
   138  		switch {
   139  		case fd.Message() != nil:
   140  			dstv := dst.NewValue()
   141  			o.mergeMessage(dstv.Message(), v.Message())
   142  			dst.Set(k, dstv)
   143  		case fd.Kind() == protoreflect.BytesKind:
   144  			dst.Set(k, o.cloneBytes(v))
   145  		default:
   146  			dst.Set(k, v)
   147  		}
   148  		return true
   149  	})
   150  }
   151  
   152  func (o mergeOptions) cloneBytes(v protoreflect.Value) protoreflect.Value {
   153  	return protoreflect.ValueOfBytes(append([]byte{}, v.Bytes()...))
   154  }