github.com/confluentinc/confluent-kafka-go@v1.9.2/schemaregistry/serde/protobuf/protobuf.go (about) 1 /** 2 * Copyright 2022 Confluent Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package protobuf 18 19 import ( 20 "encoding/binary" 21 "fmt" 22 "io" 23 "log" 24 "strings" 25 26 "github.com/confluentinc/confluent-kafka-go/schemaregistry" 27 "github.com/confluentinc/confluent-kafka-go/schemaregistry/confluent" 28 "github.com/confluentinc/confluent-kafka-go/schemaregistry/confluent/types" 29 "github.com/confluentinc/confluent-kafka-go/schemaregistry/serde" 30 protoV1 "github.com/golang/protobuf/proto" 31 "github.com/jhump/protoreflect/desc" 32 "github.com/jhump/protoreflect/desc/protoparse" 33 "github.com/jhump/protoreflect/desc/protoprint" 34 "google.golang.org/genproto/googleapis/type/calendarperiod" 35 "google.golang.org/genproto/googleapis/type/color" 36 "google.golang.org/genproto/googleapis/type/date" 37 "google.golang.org/genproto/googleapis/type/datetime" 38 "google.golang.org/genproto/googleapis/type/dayofweek" 39 "google.golang.org/genproto/googleapis/type/expr" 40 "google.golang.org/genproto/googleapis/type/fraction" 41 "google.golang.org/genproto/googleapis/type/latlng" 42 "google.golang.org/genproto/googleapis/type/money" 43 "google.golang.org/genproto/googleapis/type/postaladdress" 44 "google.golang.org/genproto/googleapis/type/quaternion" 45 "google.golang.org/genproto/googleapis/type/timeofday" 46 "google.golang.org/genproto/protobuf/field_mask" 47 "google.golang.org/genproto/protobuf/source_context" 48 "google.golang.org/protobuf/proto" 49 "google.golang.org/protobuf/reflect/protodesc" 50 "google.golang.org/protobuf/reflect/protoreflect" 51 "google.golang.org/protobuf/reflect/protoregistry" 52 "google.golang.org/protobuf/types/descriptorpb" 53 "google.golang.org/protobuf/types/known/anypb" 54 "google.golang.org/protobuf/types/known/durationpb" 55 "google.golang.org/protobuf/types/known/emptypb" 56 "google.golang.org/protobuf/types/known/structpb" 57 "google.golang.org/protobuf/types/known/timestamppb" 58 "google.golang.org/protobuf/types/known/typepb" 59 "google.golang.org/protobuf/types/known/wrapperspb" 60 ) 61 62 // Serializer represents a Protobuf serializer 63 type Serializer struct { 64 serde.BaseSerializer 65 } 66 67 // Deserializer represents a Protobuf deserializer 68 type Deserializer struct { 69 serde.BaseDeserializer 70 ProtoRegistry *protoregistry.Types 71 } 72 73 var _ serde.Serializer = new(Serializer) 74 var _ serde.Deserializer = new(Deserializer) 75 76 var builtInDeps = make(map[string]string) 77 78 func init() { 79 builtins := map[string]protoreflect.FileDescriptor{ 80 "confluent/meta.proto": confluent.File_schemaregistry_confluent_meta_proto, 81 "confluent/type/decimal.proto": types.File_schemaregistry_confluent_type_decimal_proto, 82 "google/type/calendar_period.proto": calendarperiod.File_google_type_calendar_period_proto, 83 "google/type/color.proto": color.File_google_type_color_proto, 84 "google/type/date.proto": date.File_google_type_date_proto, 85 "google/type/datetime.proto": datetime.File_google_type_datetime_proto, 86 "google/type/dayofweek.proto": dayofweek.File_google_type_dayofweek_proto, 87 "google/type/expr.proto": expr.File_google_type_expr_proto, 88 "google/type/fraction.proto": fraction.File_google_type_fraction_proto, 89 "google/type/latlng.proto": latlng.File_google_type_latlng_proto, 90 "google/type/money.proto": money.File_google_type_money_proto, 91 "google/type/month.proto": money.File_google_type_money_proto, 92 "google/type/postal_address.proto": postaladdress.File_google_type_postal_address_proto, 93 "google/type/quaternion.proto": quaternion.File_google_type_quaternion_proto, 94 "google/type/timeofday.proto": timeofday.File_google_type_timeofday_proto, 95 "google/protobuf/any.proto": anypb.File_google_protobuf_any_proto, 96 "google/protobuf/api.proto": anypb.File_google_protobuf_any_proto, 97 "google/protobuf/descriptor.proto": descriptorpb.File_google_protobuf_descriptor_proto, 98 "google/protobuf/duration.proto": durationpb.File_google_protobuf_duration_proto, 99 "google/protobuf/empty.proto": emptypb.File_google_protobuf_empty_proto, 100 "google/protobuf/field_mask.proto": field_mask.File_google_protobuf_field_mask_proto, 101 "google/protobuf/source_context.proto": source_context.File_google_protobuf_source_context_proto, 102 "google/protobuf/struct.proto": structpb.File_google_protobuf_struct_proto, 103 "google/protobuf/timestamp.proto": timestamppb.File_google_protobuf_timestamp_proto, 104 "google/protobuf/type.proto": typepb.File_google_protobuf_type_proto, 105 "google/protobuf/wrappers.proto": wrapperspb.File_google_protobuf_wrappers_proto, 106 } 107 var fds []*descriptorpb.FileDescriptorProto 108 for _, value := range builtins { 109 fd := protodesc.ToFileDescriptorProto(value) 110 fds = append(fds, fd) 111 } 112 fdMap, err := desc.CreateFileDescriptors(fds) 113 if err != nil { 114 log.Fatalf("Could not create fds") 115 } 116 printer := protoprint.Printer{OmitComments: protoprint.CommentsAll} 117 for key, value := range fdMap { 118 var writer strings.Builder 119 err = printer.PrintProtoFile(value, &writer) 120 if err != nil { 121 log.Fatalf("Could not print %s", key) 122 } 123 builtInDeps[key] = writer.String() 124 } 125 } 126 127 // NewSerializer creates a Protobuf serializer for Protobuf-generated objects 128 func NewSerializer(client schemaregistry.Client, serdeType serde.Type, conf *SerializerConfig) (*Serializer, error) { 129 s := &Serializer{} 130 err := s.ConfigureSerializer(client, serdeType, &conf.SerializerConfig) 131 if err != nil { 132 return nil, err 133 } 134 return s, nil 135 } 136 137 // ConfigureDeserializer configures the Protobuf deserializer 138 func (s *Deserializer) ConfigureDeserializer(client schemaregistry.Client, serdeType serde.Type, conf *serde.DeserializerConfig) error { 139 if client == nil { 140 return fmt.Errorf("schema registry client missing") 141 } 142 s.Client = client 143 s.Conf = conf 144 s.SerdeType = serdeType 145 s.SubjectNameStrategy = serde.TopicNameStrategy 146 s.MessageFactory = s.protoMessageFactory 147 s.ProtoRegistry = new(protoregistry.Types) 148 return nil 149 } 150 151 // Serialize implements serialization of Protobuf data 152 func (s *Serializer) Serialize(topic string, msg interface{}) ([]byte, error) { 153 if msg == nil { 154 return nil, nil 155 } 156 var protoMsg proto.Message 157 switch t := msg.(type) { 158 case proto.Message: 159 protoMsg = t 160 default: 161 return nil, fmt.Errorf("serialization target must be a protobuf message. Got '%v'", t) 162 } 163 autoRegister := s.Conf.AutoRegisterSchemas 164 normalize := s.Conf.NormalizeSchemas 165 fileDesc, deps, err := s.toProtobufSchema(protoMsg) 166 if err != nil { 167 return nil, err 168 } 169 metadata, err := s.resolveDependencies(fileDesc, deps, "", autoRegister, normalize) 170 if err != nil { 171 return nil, err 172 } 173 info := schemaregistry.SchemaInfo{ 174 Schema: metadata.Schema, 175 SchemaType: metadata.SchemaType, 176 References: metadata.References, 177 } 178 id, err := s.GetID(topic, protoMsg, info) 179 if err != nil { 180 return nil, err 181 } 182 msgIndexBytes := toMessageIndexBytes(protoMsg.ProtoReflect().Descriptor()) 183 msgBytes, err := proto.Marshal(protoMsg) 184 if err != nil { 185 return nil, err 186 } 187 payload, err := s.WriteBytes(id, append(msgIndexBytes, msgBytes...)) 188 if err != nil { 189 return nil, err 190 } 191 return payload, nil 192 } 193 194 func (s *Serializer) toProtobufSchema(msg proto.Message) (*desc.FileDescriptor, map[string]string, error) { 195 messageDesc, err := desc.LoadMessageDescriptorForMessage(protoV1.MessageV1(msg)) 196 fileDesc := messageDesc.GetFile() 197 if err != nil { 198 return nil, nil, err 199 } 200 deps := make(map[string]string) 201 err = s.toDependencies(fileDesc, deps) 202 if err != nil { 203 return nil, nil, err 204 } 205 return fileDesc, deps, nil 206 } 207 208 func (s *Serializer) toDependencies(fileDesc *desc.FileDescriptor, deps map[string]string) error { 209 printer := protoprint.Printer{OmitComments: protoprint.CommentsAll} 210 var writer strings.Builder 211 err := printer.PrintProtoFile(fileDesc, &writer) 212 if err != nil { 213 return err 214 } 215 deps[fileDesc.GetName()] = writer.String() 216 for _, d := range fileDesc.GetDependencies() { 217 if ignoreFile(d.GetName()) { 218 continue 219 } 220 err = s.toDependencies(d, deps) 221 if err != nil { 222 return err 223 } 224 } 225 for _, d := range fileDesc.GetPublicDependencies() { 226 if ignoreFile(d.GetName()) { 227 continue 228 } 229 err = s.toDependencies(d, deps) 230 if err != nil { 231 return err 232 } 233 } 234 return nil 235 } 236 237 func (s *Serializer) resolveDependencies(fileDesc *desc.FileDescriptor, deps map[string]string, subject string, autoRegister bool, normalize bool) (schemaregistry.SchemaMetadata, error) { 238 refs := make([]schemaregistry.Reference, 0, len(fileDesc.GetDependencies())+len(fileDesc.GetPublicDependencies())) 239 for _, d := range fileDesc.GetDependencies() { 240 if ignoreFile(d.GetName()) { 241 continue 242 } 243 ref, err := s.resolveDependencies(d, deps, d.GetName(), autoRegister, normalize) 244 if err != nil { 245 return schemaregistry.SchemaMetadata{}, err 246 } 247 refs = append(refs, schemaregistry.Reference{d.GetName(), ref.Subject, ref.Version}) 248 } 249 for _, d := range fileDesc.GetPublicDependencies() { 250 if ignoreFile(d.GetName()) { 251 continue 252 } 253 ref, err := s.resolveDependencies(d, deps, d.GetName(), autoRegister, normalize) 254 if err != nil { 255 return schemaregistry.SchemaMetadata{}, err 256 } 257 refs = append(refs, schemaregistry.Reference{d.GetName(), ref.Subject, ref.Version}) 258 } 259 info := schemaregistry.SchemaInfo{ 260 Schema: deps[fileDesc.GetName()], 261 SchemaType: "PROTOBUF", 262 References: refs, 263 } 264 var id = -1 265 var err error 266 var version = 0 267 if subject != "" { 268 if autoRegister { 269 id, err = s.Client.Register(subject, info, normalize) 270 if err != nil { 271 return schemaregistry.SchemaMetadata{}, err 272 } 273 } else { 274 id, err = s.Client.GetID(subject, info, normalize) 275 if err != nil { 276 return schemaregistry.SchemaMetadata{}, err 277 } 278 } 279 version, err = s.Client.GetVersion(subject, info, normalize) 280 if err != nil { 281 return schemaregistry.SchemaMetadata{}, err 282 } 283 } 284 metadata := schemaregistry.SchemaMetadata{ 285 SchemaInfo: info, 286 ID: id, 287 Subject: subject, 288 Version: version, 289 } 290 return metadata, nil 291 } 292 293 func toMessageIndexBytes(descriptor protoreflect.Descriptor) []byte { 294 if descriptor.Index() == 0 { 295 switch descriptor.Parent().(type) { 296 case protoreflect.FileDescriptor: 297 // This is an optimization for the first message in the schema 298 return []byte{0} 299 } 300 } 301 msgIndexes := toMessageIndexes(descriptor, 0) 302 buf := make([]byte, (1+len(msgIndexes))*binary.MaxVarintLen64) 303 length := binary.PutVarint(buf, int64(len(msgIndexes))) 304 305 for _, element := range msgIndexes { 306 length += binary.PutVarint(buf[length:], int64(element)) 307 } 308 return buf[0:length] 309 } 310 311 // Adapted from ideasculptor, see https://github.com/riferrei/srclient/issues/17 312 func toMessageIndexes(descriptor protoreflect.Descriptor, count int) []int { 313 index := descriptor.Index() 314 switch v := descriptor.Parent().(type) { 315 case protoreflect.FileDescriptor: 316 // parent is FileDescriptor, we reached the top of the stack, so we are 317 // done. Allocate an array large enough to hold count+1 entries and 318 // populate first value with index 319 msgIndexes := make([]int, count+1) 320 msgIndexes[0] = index 321 return msgIndexes[0:1] 322 default: 323 // parent is another MessageDescriptor. We were nested so get that 324 // descriptor's indexes and append the index of this one 325 msgIndexes := toMessageIndexes(v, count+1) 326 return append(msgIndexes, index) 327 } 328 } 329 330 func ignoreFile(name string) bool { 331 return strings.HasPrefix(name, "confluent/") || 332 strings.HasPrefix(name, "google/protobuf/") || 333 strings.HasPrefix(name, "google/type/") 334 } 335 336 // NewDeserializer creates a Protobuf deserializer for Protobuf-generated objects 337 func NewDeserializer(client schemaregistry.Client, serdeType serde.Type, conf *DeserializerConfig) (*Deserializer, error) { 338 s := &Deserializer{} 339 err := s.ConfigureDeserializer(client, serdeType, &conf.DeserializerConfig) 340 if err != nil { 341 return nil, err 342 } 343 return s, nil 344 } 345 346 // Deserialize implements deserialization of Protobuf data 347 func (s *Deserializer) Deserialize(topic string, payload []byte) (interface{}, error) { 348 if payload == nil { 349 return nil, nil 350 } 351 info, err := s.GetSchema(topic, payload) 352 if err != nil { 353 return nil, err 354 } 355 fd, err := s.toFileDesc(info) 356 if err != nil { 357 return nil, err 358 } 359 bytesRead, msgIndexes, err := readMessageIndexes(payload[5:]) 360 if err != nil { 361 return nil, err 362 } 363 messageDesc, err := toMessageDesc(fd, msgIndexes) 364 if err != nil { 365 return nil, err 366 } 367 subject, err := s.SubjectNameStrategy(topic, s.SerdeType, info) 368 if err != nil { 369 return nil, err 370 } 371 msg, err := s.MessageFactory(subject, messageDesc.GetFullyQualifiedName()) 372 if err != nil { 373 return nil, err 374 } 375 var protoMsg proto.Message 376 switch t := msg.(type) { 377 case proto.Message: 378 protoMsg = t 379 default: 380 return nil, fmt.Errorf("deserialization target must be a protobuf message. Got '%v'", t) 381 } 382 err = proto.Unmarshal(payload[5+bytesRead:], protoMsg) 383 return protoMsg, err 384 } 385 386 // DeserializeInto implements deserialization of Protobuf data to the given object 387 func (s *Deserializer) DeserializeInto(topic string, payload []byte, msg interface{}) error { 388 if payload == nil { 389 return nil 390 } 391 var protoMsg proto.Message 392 switch t := msg.(type) { 393 case proto.Message: 394 protoMsg = t 395 default: 396 return fmt.Errorf("deserialization target must be a protobuf message. Got '%v'", t) 397 } 398 bytesRead, _, err := readMessageIndexes(payload[5:]) 399 if err != nil { 400 return err 401 } 402 return proto.Unmarshal(payload[5+bytesRead:], protoMsg) 403 } 404 405 func (s *Deserializer) toFileDesc(info schemaregistry.SchemaInfo) (*desc.FileDescriptor, error) { 406 deps := make(map[string]string) 407 err := serde.ResolveReferences(s.Client, info, deps) 408 if err != nil { 409 return nil, err 410 } 411 parser := protoparse.Parser{ 412 Accessor: func(filename string) (io.ReadCloser, error) { 413 var schema string 414 if filename == "." { 415 schema = info.Schema 416 } else { 417 schema = deps[filename] 418 } 419 if schema == "" { 420 schema = builtInDeps[filename] 421 } 422 return io.NopCloser(strings.NewReader(schema)), nil 423 }, 424 } 425 426 fileDescriptors, err := parser.ParseFiles(".") 427 if err != nil { 428 return nil, err 429 } 430 431 if len(fileDescriptors) != 1 { 432 return nil, fmt.Errorf("could not resolve schema") 433 } 434 return fileDescriptors[0], nil 435 } 436 437 func readMessageIndexes(payload []byte) (int, []int, error) { 438 arrayLen, bytesRead := binary.Varint(payload) 439 if bytesRead <= 0 { 440 return bytesRead, nil, fmt.Errorf("unable to read message indexes") 441 } 442 if arrayLen == 0 { 443 // Handle the optimization for the first message in the schema 444 return bytesRead, []int{0}, nil 445 } 446 msgIndexes := make([]int, arrayLen) 447 for i := 0; i < int(arrayLen); i++ { 448 idx, read := binary.Varint(payload[bytesRead:]) 449 if read <= 0 { 450 return bytesRead, nil, fmt.Errorf("unable to read message indexes") 451 } 452 bytesRead += read 453 msgIndexes[i] = int(idx) 454 } 455 return bytesRead, msgIndexes, nil 456 } 457 458 func toMessageDesc(descriptor desc.Descriptor, msgIndexes []int) (*desc.MessageDescriptor, error) { 459 index := msgIndexes[0] 460 461 switch v := descriptor.(type) { 462 case *desc.FileDescriptor: 463 if len(msgIndexes) == 1 { 464 return v.GetMessageTypes()[index], nil 465 } 466 return toMessageDesc(v.GetMessageTypes()[index], msgIndexes[1:]) 467 case *desc.MessageDescriptor: 468 if len(msgIndexes) == 1 { 469 return v.GetNestedMessageTypes()[index], nil 470 } 471 return toMessageDesc(v.GetNestedMessageTypes()[index], msgIndexes[1:]) 472 default: 473 return nil, fmt.Errorf("unexpected type") 474 } 475 } 476 477 func (s *Deserializer) protoMessageFactory(subject string, name string) (interface{}, error) { 478 mt, err := s.ProtoRegistry.FindMessageByName(protoreflect.FullName(name)) 479 if mt == nil { 480 err = fmt.Errorf("unable to find MessageType %s", name) 481 } 482 if err != nil { 483 return nil, err 484 } 485 msg := mt.New() 486 return msg.Interface(), nil 487 }