github.com/confluentinc/confluent-kafka-go@v1.9.2/schemaregistry/serde/protobuf/protobuf.go (about)

     1  /**
     2   * Copyright 2022 Confluent Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   * http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package protobuf
    18  
    19  import (
    20  	"encoding/binary"
    21  	"fmt"
    22  	"io"
    23  	"log"
    24  	"strings"
    25  
    26  	"github.com/confluentinc/confluent-kafka-go/schemaregistry"
    27  	"github.com/confluentinc/confluent-kafka-go/schemaregistry/confluent"
    28  	"github.com/confluentinc/confluent-kafka-go/schemaregistry/confluent/types"
    29  	"github.com/confluentinc/confluent-kafka-go/schemaregistry/serde"
    30  	protoV1 "github.com/golang/protobuf/proto"
    31  	"github.com/jhump/protoreflect/desc"
    32  	"github.com/jhump/protoreflect/desc/protoparse"
    33  	"github.com/jhump/protoreflect/desc/protoprint"
    34  	"google.golang.org/genproto/googleapis/type/calendarperiod"
    35  	"google.golang.org/genproto/googleapis/type/color"
    36  	"google.golang.org/genproto/googleapis/type/date"
    37  	"google.golang.org/genproto/googleapis/type/datetime"
    38  	"google.golang.org/genproto/googleapis/type/dayofweek"
    39  	"google.golang.org/genproto/googleapis/type/expr"
    40  	"google.golang.org/genproto/googleapis/type/fraction"
    41  	"google.golang.org/genproto/googleapis/type/latlng"
    42  	"google.golang.org/genproto/googleapis/type/money"
    43  	"google.golang.org/genproto/googleapis/type/postaladdress"
    44  	"google.golang.org/genproto/googleapis/type/quaternion"
    45  	"google.golang.org/genproto/googleapis/type/timeofday"
    46  	"google.golang.org/genproto/protobuf/field_mask"
    47  	"google.golang.org/genproto/protobuf/source_context"
    48  	"google.golang.org/protobuf/proto"
    49  	"google.golang.org/protobuf/reflect/protodesc"
    50  	"google.golang.org/protobuf/reflect/protoreflect"
    51  	"google.golang.org/protobuf/reflect/protoregistry"
    52  	"google.golang.org/protobuf/types/descriptorpb"
    53  	"google.golang.org/protobuf/types/known/anypb"
    54  	"google.golang.org/protobuf/types/known/durationpb"
    55  	"google.golang.org/protobuf/types/known/emptypb"
    56  	"google.golang.org/protobuf/types/known/structpb"
    57  	"google.golang.org/protobuf/types/known/timestamppb"
    58  	"google.golang.org/protobuf/types/known/typepb"
    59  	"google.golang.org/protobuf/types/known/wrapperspb"
    60  )
    61  
    62  // Serializer represents a Protobuf serializer
    63  type Serializer struct {
    64  	serde.BaseSerializer
    65  }
    66  
    67  // Deserializer represents a Protobuf deserializer
    68  type Deserializer struct {
    69  	serde.BaseDeserializer
    70  	ProtoRegistry *protoregistry.Types
    71  }
    72  
    73  var _ serde.Serializer = new(Serializer)
    74  var _ serde.Deserializer = new(Deserializer)
    75  
    76  var builtInDeps = make(map[string]string)
    77  
    78  func init() {
    79  	builtins := map[string]protoreflect.FileDescriptor{
    80  		"confluent/meta.proto":                 confluent.File_schemaregistry_confluent_meta_proto,
    81  		"confluent/type/decimal.proto":         types.File_schemaregistry_confluent_type_decimal_proto,
    82  		"google/type/calendar_period.proto":    calendarperiod.File_google_type_calendar_period_proto,
    83  		"google/type/color.proto":              color.File_google_type_color_proto,
    84  		"google/type/date.proto":               date.File_google_type_date_proto,
    85  		"google/type/datetime.proto":           datetime.File_google_type_datetime_proto,
    86  		"google/type/dayofweek.proto":          dayofweek.File_google_type_dayofweek_proto,
    87  		"google/type/expr.proto":               expr.File_google_type_expr_proto,
    88  		"google/type/fraction.proto":           fraction.File_google_type_fraction_proto,
    89  		"google/type/latlng.proto":             latlng.File_google_type_latlng_proto,
    90  		"google/type/money.proto":              money.File_google_type_money_proto,
    91  		"google/type/month.proto":              money.File_google_type_money_proto,
    92  		"google/type/postal_address.proto":     postaladdress.File_google_type_postal_address_proto,
    93  		"google/type/quaternion.proto":         quaternion.File_google_type_quaternion_proto,
    94  		"google/type/timeofday.proto":          timeofday.File_google_type_timeofday_proto,
    95  		"google/protobuf/any.proto":            anypb.File_google_protobuf_any_proto,
    96  		"google/protobuf/api.proto":            anypb.File_google_protobuf_any_proto,
    97  		"google/protobuf/descriptor.proto":     descriptorpb.File_google_protobuf_descriptor_proto,
    98  		"google/protobuf/duration.proto":       durationpb.File_google_protobuf_duration_proto,
    99  		"google/protobuf/empty.proto":          emptypb.File_google_protobuf_empty_proto,
   100  		"google/protobuf/field_mask.proto":     field_mask.File_google_protobuf_field_mask_proto,
   101  		"google/protobuf/source_context.proto": source_context.File_google_protobuf_source_context_proto,
   102  		"google/protobuf/struct.proto":         structpb.File_google_protobuf_struct_proto,
   103  		"google/protobuf/timestamp.proto":      timestamppb.File_google_protobuf_timestamp_proto,
   104  		"google/protobuf/type.proto":           typepb.File_google_protobuf_type_proto,
   105  		"google/protobuf/wrappers.proto":       wrapperspb.File_google_protobuf_wrappers_proto,
   106  	}
   107  	var fds []*descriptorpb.FileDescriptorProto
   108  	for _, value := range builtins {
   109  		fd := protodesc.ToFileDescriptorProto(value)
   110  		fds = append(fds, fd)
   111  	}
   112  	fdMap, err := desc.CreateFileDescriptors(fds)
   113  	if err != nil {
   114  		log.Fatalf("Could not create fds")
   115  	}
   116  	printer := protoprint.Printer{OmitComments: protoprint.CommentsAll}
   117  	for key, value := range fdMap {
   118  		var writer strings.Builder
   119  		err = printer.PrintProtoFile(value, &writer)
   120  		if err != nil {
   121  			log.Fatalf("Could not print %s", key)
   122  		}
   123  		builtInDeps[key] = writer.String()
   124  	}
   125  }
   126  
   127  // NewSerializer creates a Protobuf serializer for Protobuf-generated objects
   128  func NewSerializer(client schemaregistry.Client, serdeType serde.Type, conf *SerializerConfig) (*Serializer, error) {
   129  	s := &Serializer{}
   130  	err := s.ConfigureSerializer(client, serdeType, &conf.SerializerConfig)
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  	return s, nil
   135  }
   136  
   137  // ConfigureDeserializer configures the Protobuf deserializer
   138  func (s *Deserializer) ConfigureDeserializer(client schemaregistry.Client, serdeType serde.Type, conf *serde.DeserializerConfig) error {
   139  	if client == nil {
   140  		return fmt.Errorf("schema registry client missing")
   141  	}
   142  	s.Client = client
   143  	s.Conf = conf
   144  	s.SerdeType = serdeType
   145  	s.SubjectNameStrategy = serde.TopicNameStrategy
   146  	s.MessageFactory = s.protoMessageFactory
   147  	s.ProtoRegistry = new(protoregistry.Types)
   148  	return nil
   149  }
   150  
   151  // Serialize implements serialization of Protobuf data
   152  func (s *Serializer) Serialize(topic string, msg interface{}) ([]byte, error) {
   153  	if msg == nil {
   154  		return nil, nil
   155  	}
   156  	var protoMsg proto.Message
   157  	switch t := msg.(type) {
   158  	case proto.Message:
   159  		protoMsg = t
   160  	default:
   161  		return nil, fmt.Errorf("serialization target must be a protobuf message. Got '%v'", t)
   162  	}
   163  	autoRegister := s.Conf.AutoRegisterSchemas
   164  	normalize := s.Conf.NormalizeSchemas
   165  	fileDesc, deps, err := s.toProtobufSchema(protoMsg)
   166  	if err != nil {
   167  		return nil, err
   168  	}
   169  	metadata, err := s.resolveDependencies(fileDesc, deps, "", autoRegister, normalize)
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  	info := schemaregistry.SchemaInfo{
   174  		Schema:     metadata.Schema,
   175  		SchemaType: metadata.SchemaType,
   176  		References: metadata.References,
   177  	}
   178  	id, err := s.GetID(topic, protoMsg, info)
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  	msgIndexBytes := toMessageIndexBytes(protoMsg.ProtoReflect().Descriptor())
   183  	msgBytes, err := proto.Marshal(protoMsg)
   184  	if err != nil {
   185  		return nil, err
   186  	}
   187  	payload, err := s.WriteBytes(id, append(msgIndexBytes, msgBytes...))
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  	return payload, nil
   192  }
   193  
   194  func (s *Serializer) toProtobufSchema(msg proto.Message) (*desc.FileDescriptor, map[string]string, error) {
   195  	messageDesc, err := desc.LoadMessageDescriptorForMessage(protoV1.MessageV1(msg))
   196  	fileDesc := messageDesc.GetFile()
   197  	if err != nil {
   198  		return nil, nil, err
   199  	}
   200  	deps := make(map[string]string)
   201  	err = s.toDependencies(fileDesc, deps)
   202  	if err != nil {
   203  		return nil, nil, err
   204  	}
   205  	return fileDesc, deps, nil
   206  }
   207  
   208  func (s *Serializer) toDependencies(fileDesc *desc.FileDescriptor, deps map[string]string) error {
   209  	printer := protoprint.Printer{OmitComments: protoprint.CommentsAll}
   210  	var writer strings.Builder
   211  	err := printer.PrintProtoFile(fileDesc, &writer)
   212  	if err != nil {
   213  		return err
   214  	}
   215  	deps[fileDesc.GetName()] = writer.String()
   216  	for _, d := range fileDesc.GetDependencies() {
   217  		if ignoreFile(d.GetName()) {
   218  			continue
   219  		}
   220  		err = s.toDependencies(d, deps)
   221  		if err != nil {
   222  			return err
   223  		}
   224  	}
   225  	for _, d := range fileDesc.GetPublicDependencies() {
   226  		if ignoreFile(d.GetName()) {
   227  			continue
   228  		}
   229  		err = s.toDependencies(d, deps)
   230  		if err != nil {
   231  			return err
   232  		}
   233  	}
   234  	return nil
   235  }
   236  
   237  func (s *Serializer) resolveDependencies(fileDesc *desc.FileDescriptor, deps map[string]string, subject string, autoRegister bool, normalize bool) (schemaregistry.SchemaMetadata, error) {
   238  	refs := make([]schemaregistry.Reference, 0, len(fileDesc.GetDependencies())+len(fileDesc.GetPublicDependencies()))
   239  	for _, d := range fileDesc.GetDependencies() {
   240  		if ignoreFile(d.GetName()) {
   241  			continue
   242  		}
   243  		ref, err := s.resolveDependencies(d, deps, d.GetName(), autoRegister, normalize)
   244  		if err != nil {
   245  			return schemaregistry.SchemaMetadata{}, err
   246  		}
   247  		refs = append(refs, schemaregistry.Reference{d.GetName(), ref.Subject, ref.Version})
   248  	}
   249  	for _, d := range fileDesc.GetPublicDependencies() {
   250  		if ignoreFile(d.GetName()) {
   251  			continue
   252  		}
   253  		ref, err := s.resolveDependencies(d, deps, d.GetName(), autoRegister, normalize)
   254  		if err != nil {
   255  			return schemaregistry.SchemaMetadata{}, err
   256  		}
   257  		refs = append(refs, schemaregistry.Reference{d.GetName(), ref.Subject, ref.Version})
   258  	}
   259  	info := schemaregistry.SchemaInfo{
   260  		Schema:     deps[fileDesc.GetName()],
   261  		SchemaType: "PROTOBUF",
   262  		References: refs,
   263  	}
   264  	var id = -1
   265  	var err error
   266  	var version = 0
   267  	if subject != "" {
   268  		if autoRegister {
   269  			id, err = s.Client.Register(subject, info, normalize)
   270  			if err != nil {
   271  				return schemaregistry.SchemaMetadata{}, err
   272  			}
   273  		} else {
   274  			id, err = s.Client.GetID(subject, info, normalize)
   275  			if err != nil {
   276  				return schemaregistry.SchemaMetadata{}, err
   277  			}
   278  		}
   279  		version, err = s.Client.GetVersion(subject, info, normalize)
   280  		if err != nil {
   281  			return schemaregistry.SchemaMetadata{}, err
   282  		}
   283  	}
   284  	metadata := schemaregistry.SchemaMetadata{
   285  		SchemaInfo: info,
   286  		ID:         id,
   287  		Subject:    subject,
   288  		Version:    version,
   289  	}
   290  	return metadata, nil
   291  }
   292  
   293  func toMessageIndexBytes(descriptor protoreflect.Descriptor) []byte {
   294  	if descriptor.Index() == 0 {
   295  		switch descriptor.Parent().(type) {
   296  		case protoreflect.FileDescriptor:
   297  			// This is an optimization for the first message in the schema
   298  			return []byte{0}
   299  		}
   300  	}
   301  	msgIndexes := toMessageIndexes(descriptor, 0)
   302  	buf := make([]byte, (1+len(msgIndexes))*binary.MaxVarintLen64)
   303  	length := binary.PutVarint(buf, int64(len(msgIndexes)))
   304  
   305  	for _, element := range msgIndexes {
   306  		length += binary.PutVarint(buf[length:], int64(element))
   307  	}
   308  	return buf[0:length]
   309  }
   310  
   311  // Adapted from ideasculptor, see https://github.com/riferrei/srclient/issues/17
   312  func toMessageIndexes(descriptor protoreflect.Descriptor, count int) []int {
   313  	index := descriptor.Index()
   314  	switch v := descriptor.Parent().(type) {
   315  	case protoreflect.FileDescriptor:
   316  		// parent is FileDescriptor, we reached the top of the stack, so we are
   317  		// done. Allocate an array large enough to hold count+1 entries and
   318  		// populate first value with index
   319  		msgIndexes := make([]int, count+1)
   320  		msgIndexes[0] = index
   321  		return msgIndexes[0:1]
   322  	default:
   323  		// parent is another MessageDescriptor.  We were nested so get that
   324  		// descriptor's indexes and append the index of this one
   325  		msgIndexes := toMessageIndexes(v, count+1)
   326  		return append(msgIndexes, index)
   327  	}
   328  }
   329  
   330  func ignoreFile(name string) bool {
   331  	return strings.HasPrefix(name, "confluent/") ||
   332  		strings.HasPrefix(name, "google/protobuf/") ||
   333  		strings.HasPrefix(name, "google/type/")
   334  }
   335  
   336  // NewDeserializer creates a Protobuf deserializer for Protobuf-generated objects
   337  func NewDeserializer(client schemaregistry.Client, serdeType serde.Type, conf *DeserializerConfig) (*Deserializer, error) {
   338  	s := &Deserializer{}
   339  	err := s.ConfigureDeserializer(client, serdeType, &conf.DeserializerConfig)
   340  	if err != nil {
   341  		return nil, err
   342  	}
   343  	return s, nil
   344  }
   345  
   346  // Deserialize implements deserialization of Protobuf data
   347  func (s *Deserializer) Deserialize(topic string, payload []byte) (interface{}, error) {
   348  	if payload == nil {
   349  		return nil, nil
   350  	}
   351  	info, err := s.GetSchema(topic, payload)
   352  	if err != nil {
   353  		return nil, err
   354  	}
   355  	fd, err := s.toFileDesc(info)
   356  	if err != nil {
   357  		return nil, err
   358  	}
   359  	bytesRead, msgIndexes, err := readMessageIndexes(payload[5:])
   360  	if err != nil {
   361  		return nil, err
   362  	}
   363  	messageDesc, err := toMessageDesc(fd, msgIndexes)
   364  	if err != nil {
   365  		return nil, err
   366  	}
   367  	subject, err := s.SubjectNameStrategy(topic, s.SerdeType, info)
   368  	if err != nil {
   369  		return nil, err
   370  	}
   371  	msg, err := s.MessageFactory(subject, messageDesc.GetFullyQualifiedName())
   372  	if err != nil {
   373  		return nil, err
   374  	}
   375  	var protoMsg proto.Message
   376  	switch t := msg.(type) {
   377  	case proto.Message:
   378  		protoMsg = t
   379  	default:
   380  		return nil, fmt.Errorf("deserialization target must be a protobuf message. Got '%v'", t)
   381  	}
   382  	err = proto.Unmarshal(payload[5+bytesRead:], protoMsg)
   383  	return protoMsg, err
   384  }
   385  
   386  // DeserializeInto implements deserialization of Protobuf data to the given object
   387  func (s *Deserializer) DeserializeInto(topic string, payload []byte, msg interface{}) error {
   388  	if payload == nil {
   389  		return nil
   390  	}
   391  	var protoMsg proto.Message
   392  	switch t := msg.(type) {
   393  	case proto.Message:
   394  		protoMsg = t
   395  	default:
   396  		return fmt.Errorf("deserialization target must be a protobuf message. Got '%v'", t)
   397  	}
   398  	bytesRead, _, err := readMessageIndexes(payload[5:])
   399  	if err != nil {
   400  		return err
   401  	}
   402  	return proto.Unmarshal(payload[5+bytesRead:], protoMsg)
   403  }
   404  
   405  func (s *Deserializer) toFileDesc(info schemaregistry.SchemaInfo) (*desc.FileDescriptor, error) {
   406  	deps := make(map[string]string)
   407  	err := serde.ResolveReferences(s.Client, info, deps)
   408  	if err != nil {
   409  		return nil, err
   410  	}
   411  	parser := protoparse.Parser{
   412  		Accessor: func(filename string) (io.ReadCloser, error) {
   413  			var schema string
   414  			if filename == "." {
   415  				schema = info.Schema
   416  			} else {
   417  				schema = deps[filename]
   418  			}
   419  			if schema == "" {
   420  				schema = builtInDeps[filename]
   421  			}
   422  			return io.NopCloser(strings.NewReader(schema)), nil
   423  		},
   424  	}
   425  
   426  	fileDescriptors, err := parser.ParseFiles(".")
   427  	if err != nil {
   428  		return nil, err
   429  	}
   430  
   431  	if len(fileDescriptors) != 1 {
   432  		return nil, fmt.Errorf("could not resolve schema")
   433  	}
   434  	return fileDescriptors[0], nil
   435  }
   436  
   437  func readMessageIndexes(payload []byte) (int, []int, error) {
   438  	arrayLen, bytesRead := binary.Varint(payload)
   439  	if bytesRead <= 0 {
   440  		return bytesRead, nil, fmt.Errorf("unable to read message indexes")
   441  	}
   442  	if arrayLen == 0 {
   443  		// Handle the optimization for the first message in the schema
   444  		return bytesRead, []int{0}, nil
   445  	}
   446  	msgIndexes := make([]int, arrayLen)
   447  	for i := 0; i < int(arrayLen); i++ {
   448  		idx, read := binary.Varint(payload[bytesRead:])
   449  		if read <= 0 {
   450  			return bytesRead, nil, fmt.Errorf("unable to read message indexes")
   451  		}
   452  		bytesRead += read
   453  		msgIndexes[i] = int(idx)
   454  	}
   455  	return bytesRead, msgIndexes, nil
   456  }
   457  
   458  func toMessageDesc(descriptor desc.Descriptor, msgIndexes []int) (*desc.MessageDescriptor, error) {
   459  	index := msgIndexes[0]
   460  
   461  	switch v := descriptor.(type) {
   462  	case *desc.FileDescriptor:
   463  		if len(msgIndexes) == 1 {
   464  			return v.GetMessageTypes()[index], nil
   465  		}
   466  		return toMessageDesc(v.GetMessageTypes()[index], msgIndexes[1:])
   467  	case *desc.MessageDescriptor:
   468  		if len(msgIndexes) == 1 {
   469  			return v.GetNestedMessageTypes()[index], nil
   470  		}
   471  		return toMessageDesc(v.GetNestedMessageTypes()[index], msgIndexes[1:])
   472  	default:
   473  		return nil, fmt.Errorf("unexpected type")
   474  	}
   475  }
   476  
   477  func (s *Deserializer) protoMessageFactory(subject string, name string) (interface{}, error) {
   478  	mt, err := s.ProtoRegistry.FindMessageByName(protoreflect.FullName(name))
   479  	if mt == nil {
   480  		err = fmt.Errorf("unable to find MessageType %s", name)
   481  	}
   482  	if err != nil {
   483  		return nil, err
   484  	}
   485  	msg := mt.New()
   486  	return msg.Interface(), nil
   487  }