github.com/aacfactory/avro@v1.2.12/internal/base/protocol.go (about) 1 package base 2 3 import ( 4 "crypto/md5" 5 "encoding/hex" 6 "errors" 7 "fmt" 8 "os" 9 10 jsoniter "github.com/json-iterator/go" 11 "github.com/mitchellh/mapstructure" 12 ) 13 14 var ( 15 protocolReserved = []string{"doc", "types", "messages", "protocol", "namespace"} 16 messageReserved = []string{"doc", "response", "request", "errors", "one-way"} 17 ) 18 19 type protocolConfig struct { 20 doc string 21 props map[string]any 22 } 23 24 // ProtocolOption is a function that sets a protocol option. 25 type ProtocolOption func(*protocolConfig) 26 27 // WithProtoDoc sets the doc on a protocol. 28 func WithProtoDoc(doc string) ProtocolOption { 29 return func(opts *protocolConfig) { 30 opts.doc = doc 31 } 32 } 33 34 // WithProtoProps sets the properties on a protocol. 35 func WithProtoProps(props map[string]any) ProtocolOption { 36 return func(opts *protocolConfig) { 37 opts.props = props 38 } 39 } 40 41 // Protocol is an Avro protocol. 42 type Protocol struct { 43 name 44 properties 45 46 types []NamedSchema 47 messages map[string]*Message 48 49 doc string 50 51 hash string 52 } 53 54 // NewProtocol creates a protocol instance. 55 func NewProtocol( 56 name, namepsace string, 57 types []NamedSchema, 58 messages map[string]*Message, 59 opts ...ProtocolOption, 60 ) (*Protocol, error) { 61 var cfg protocolConfig 62 for _, opt := range opts { 63 opt(&cfg) 64 } 65 66 n, err := newName(name, namepsace, nil) 67 if err != nil { 68 return nil, err 69 } 70 71 p := &Protocol{ 72 name: n, 73 properties: newProperties(cfg.props, protocolReserved), 74 types: types, 75 messages: messages, 76 doc: cfg.doc, 77 } 78 79 b := md5.Sum([]byte(p.String())) 80 p.hash = hex.EncodeToString(b[:]) 81 82 return p, nil 83 } 84 85 // Message returns a message with the given name or nil. 86 func (p *Protocol) Message(name string) *Message { 87 return p.messages[name] 88 } 89 90 // Doc returns the protocol doc. 91 func (p *Protocol) Doc() string { 92 return p.doc 93 } 94 95 // Hash returns the MD5 hash of the protocol. 96 func (p *Protocol) Hash() string { 97 return p.hash 98 } 99 100 // Types returns the types of the protocol. 101 func (p *Protocol) Types() []NamedSchema { 102 return p.types 103 } 104 105 // String returns the canonical form of the protocol. 106 func (p *Protocol) String() string { 107 types := "" 108 for _, f := range p.types { 109 types += f.String() + "," 110 } 111 if len(types) > 0 { 112 types = types[:len(types)-1] 113 } 114 115 messages := "" 116 for k, m := range p.messages { 117 messages += `"` + k + `":` + m.String() + "," 118 } 119 if len(messages) > 0 { 120 messages = messages[:len(messages)-1] 121 } 122 123 return `{"protocol":"` + p.Name() + 124 `","namespace":"` + p.Namespace() + 125 `","types":[` + types + `],"messages":{` + messages + `}}` 126 } 127 128 // Message is an Avro protocol message. 129 type Message struct { 130 properties 131 132 req *RecordSchema 133 resp Schema 134 errs *UnionSchema 135 oneWay bool 136 137 doc string 138 } 139 140 // NewMessage creates a protocol message instance. 141 func NewMessage(req *RecordSchema, resp Schema, errors *UnionSchema, oneWay bool, opts ...ProtocolOption) *Message { 142 var cfg protocolConfig 143 for _, opt := range opts { 144 opt(&cfg) 145 } 146 147 return &Message{ 148 properties: newProperties(cfg.props, messageReserved), 149 req: req, 150 resp: resp, 151 errs: errors, 152 oneWay: oneWay, 153 doc: cfg.doc, 154 } 155 } 156 157 // Request returns the message request schema. 158 func (m *Message) Request() *RecordSchema { 159 return m.req 160 } 161 162 // Response returns the message response schema. 163 func (m *Message) Response() Schema { 164 return m.resp 165 } 166 167 // Errors returns the message errors union schema. 168 func (m *Message) Errors() *UnionSchema { 169 return m.errs 170 } 171 172 // OneWay determines of the message is a one way message. 173 func (m *Message) OneWay() bool { 174 return m.oneWay 175 } 176 177 // Doc returns the message doc. 178 func (m *Message) Doc() string { 179 return m.doc 180 } 181 182 // String returns the canonical form of the message. 183 func (m *Message) String() string { 184 fields := "" 185 for _, f := range m.req.fields { 186 fields += f.String() + "," 187 } 188 if len(fields) > 0 { 189 fields = fields[:len(fields)-1] 190 } 191 192 str := `{"request":[` + fields + `]` 193 if m.resp != nil { 194 str += `,"response":` + m.resp.String() 195 } 196 if m.errs != nil && len(m.errs.Types()) > 1 { 197 errs, _ := NewUnionSchema(m.errs.Types()[1:]) 198 str += `,"errors":` + errs.String() 199 } 200 str += "}" 201 return str 202 } 203 204 // ParseProtocolFile parses an Avro protocol from a file. 205 func ParseProtocolFile(path string) (*Protocol, error) { 206 s, err := os.ReadFile(path) 207 if err != nil { 208 return nil, err 209 } 210 211 return ParseProtocol(string(s)) 212 } 213 214 // MustParseProtocol parses an Avro protocol, panicing if there is an error. 215 func MustParseProtocol(protocol string) *Protocol { 216 parsed, err := ParseProtocol(protocol) 217 if err != nil { 218 panic(err) 219 } 220 221 return parsed 222 } 223 224 // ParseProtocol parses an Avro protocol. 225 func ParseProtocol(protocol string) (*Protocol, error) { 226 cache := &SchemaCache{} 227 228 var m map[string]any 229 if err := jsoniter.Unmarshal([]byte(protocol), &m); err != nil { 230 return nil, err 231 } 232 233 seen := seenCache{} 234 return parseProtocol(m, seen, cache) 235 } 236 237 type protocol struct { 238 Protocol string `mapstructure:"protocol"` 239 Namespace string `mapstructure:"namespace"` 240 Doc string `mapstructure:"doc"` 241 Types []any `mapstructure:"types"` 242 Messages map[string]map[string]any `mapstructure:"messages"` 243 Props map[string]any `mapstructure:",remain"` 244 } 245 246 func parseProtocol(m map[string]any, seen seenCache, cache *SchemaCache) (*Protocol, error) { 247 var ( 248 p protocol 249 meta mapstructure.Metadata 250 ) 251 if err := decodeMap(m, &p, &meta); err != nil { 252 return nil, fmt.Errorf("avro: error decoding protocol: %w", err) 253 } 254 255 if err := checkParsedName(p.Protocol, p.Namespace, hasKey(meta.Keys, "namespace")); err != nil { 256 return nil, err 257 } 258 259 var ( 260 types []NamedSchema 261 err error 262 ) 263 if len(p.Types) > 0 { 264 types, err = parseProtocolTypes(p.Namespace, p.Types, seen, cache) 265 if err != nil { 266 return nil, err 267 } 268 } 269 270 messages := map[string]*Message{} 271 if len(p.Messages) > 0 { 272 for k, msg := range p.Messages { 273 message, err := parseMessage(p.Namespace, msg, seen, cache) 274 if err != nil { 275 return nil, err 276 } 277 278 messages[k] = message 279 } 280 } 281 282 return NewProtocol(p.Protocol, p.Namespace, types, messages, WithProtoDoc(p.Doc), WithProtoProps(p.Props)) 283 } 284 285 func parseProtocolTypes(namespace string, types []any, seen seenCache, cache *SchemaCache) ([]NamedSchema, error) { 286 ts := make([]NamedSchema, len(types)) 287 for i, typ := range types { 288 schema, err := parseType(namespace, typ, seen, cache) 289 if err != nil { 290 return nil, err 291 } 292 293 namedSchema, ok := schema.(NamedSchema) 294 if !ok { 295 return nil, errors.New("avro: protocol types must be named schemas") 296 } 297 298 ts[i] = namedSchema 299 } 300 301 return ts, nil 302 } 303 304 type message struct { 305 Doc string `mapstructure:"doc"` 306 Request []map[string]any `mapstructure:"request"` 307 Response any `mapstructure:"response"` 308 Errors []any `mapstructure:"errors"` 309 OneWay bool `mapstructure:"one-way"` 310 Props map[string]any `mapstructure:",remain"` 311 } 312 313 func parseMessage(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (*Message, error) { 314 var ( 315 msg message 316 meta mapstructure.Metadata 317 ) 318 if err := decodeMap(m, &msg, &meta); err != nil { 319 return nil, fmt.Errorf("avro: error decoding message: %w", err) 320 } 321 322 fields := make([]*Field, len(msg.Request)) 323 for i, f := range msg.Request { 324 field, err := parseField(namespace, f, seen, cache) 325 if err != nil { 326 return nil, err 327 } 328 fields[i] = field 329 } 330 request := &RecordSchema{ 331 name: name{}, 332 properties: properties{}, 333 fields: fields, 334 } 335 336 var response Schema 337 if msg.Response != nil { 338 schema, err := parseType(namespace, msg.Response, seen, cache) 339 if err != nil { 340 return nil, err 341 } 342 343 if schema.Type() != Null { 344 response = schema 345 } 346 } 347 348 types := []Schema{NewPrimitiveSchema(String, nil)} 349 if len(msg.Errors) > 0 { 350 for _, e := range msg.Errors { 351 schema, err := parseType(namespace, e, seen, cache) 352 if err != nil { 353 return nil, err 354 } 355 356 if rec, ok := schema.(*RecordSchema); ok && !rec.IsError() { 357 return nil, errors.New("avro: errors record schema must be of type error") 358 } 359 360 types = append(types, schema) 361 } 362 } 363 errs, err := NewUnionSchema(types) 364 if err != nil { 365 return nil, err 366 } 367 368 oneWay := msg.OneWay 369 if hasKey(meta.Keys, "one-way") && oneWay && (len(errs.Types()) > 1 || response != nil) { 370 return nil, errors.New("avro: one-way messages cannot not have a response or errors") 371 } 372 if !oneWay && len(errs.Types()) <= 1 && response == nil { 373 oneWay = true 374 } 375 376 return NewMessage(request, response, errs, oneWay, WithProtoDoc(msg.Doc), WithProtoProps(msg.Props)), nil 377 }