github.com/kamalshkeir/kencoding@v0.0.2-0.20230409043843-44b609a0475a/proto/rewrite.go (about) 1 package proto 2 3 import ( 4 "fmt" 5 6 "github.com/kamalshkeir/kencoding/json" 7 ) 8 9 // Rewriter is an interface implemented by types that support rewriting protobuf 10 // messages. 11 type Rewriter interface { 12 // The function is expected to append the new content to the byte slice 13 // passed as argument. If it wasn't able to perform the rewrite, it must 14 // return a non-nil error. 15 Rewrite(out, in []byte) ([]byte, error) 16 } 17 18 type identity struct{} 19 20 func (identity) Rewrite(out, in []byte) ([]byte, error) { 21 return append(out, in...), nil 22 } 23 24 // MultiRewriter constructs a Rewriter which applies all rewriters passed as 25 // arguments. 26 func MultiRewriter(rewriters ...Rewriter) Rewriter { 27 if len(rewriters) == 1 { 28 return rewriters[0] 29 } 30 m := &multiRewriter{rewriters: make([]Rewriter, len(rewriters))} 31 copy(m.rewriters, rewriters) 32 return m 33 } 34 35 type multiRewriter struct { 36 rewriters []Rewriter 37 } 38 39 func (m *multiRewriter) Rewrite(out, in []byte) ([]byte, error) { 40 var err error 41 42 for _, rw := range m.rewriters { 43 if out, err = rw.Rewrite(out, in); err != nil { 44 return out, err 45 } 46 } 47 48 return out, nil 49 } 50 51 // RewriteFunc is a function type implementing the Rewriter interface. 52 type RewriteFunc func([]byte, []byte) ([]byte, error) 53 54 // Rewrite satisfies the Rewriter interface. 55 func (r RewriteFunc) Rewrite(out, in []byte) ([]byte, error) { 56 return r(out, in) 57 } 58 59 // MessageRewriter maps field numbers to rewrite rules, satisfying the Rewriter 60 // interace to support composing rewrite rules. 61 type MessageRewriter []Rewriter 62 63 // Rewrite applies the rewrite rule matching f in r, satisfies the Rewriter 64 // interface. 65 func (r MessageRewriter) Rewrite(out, in []byte) ([]byte, error) { 66 seen := make(fieldset, 4) 67 68 if n := seen.len(); len(r) >= n { 69 seen = makeFieldset(len(r) + 1) 70 } 71 72 for len(in) != 0 { 73 f, t, v, m, err := Parse(in) 74 if err != nil { 75 return out, err 76 } 77 78 if i := int(f); i >= 0 && i < len(r) && r[i] != nil { 79 if !seen.has(i) { 80 seen.set(i) 81 if out, err = r[i].Rewrite(out, v); err != nil { 82 return out, err 83 } 84 } 85 } else { 86 out = Append(out, f, t, v) 87 } 88 89 in = m 90 } 91 92 for i, f := range r { 93 if f != nil && !seen.has(i) { 94 b, err := r[i].Rewrite(out, nil) 95 if err != nil { 96 return b, err 97 } 98 out = b 99 } 100 } 101 102 return out, nil 103 } 104 105 type fieldset []uint64 106 107 func makeFieldset(n int) fieldset { 108 if (n % 64) != 0 { 109 n = (n + 1) / 64 110 } else { 111 n /= 64 112 } 113 return make(fieldset, n) 114 } 115 116 func (f fieldset) len() int { 117 return len(f) * 64 118 } 119 120 func (f fieldset) has(i int) bool { 121 x, y := f.index(i) 122 return ((f[x] >> y) & 1) != 0 123 } 124 125 func (f fieldset) set(i int) { 126 x, y := f.index(i) 127 f[x] |= 1 << y 128 } 129 130 func (f fieldset) unset(i int) { 131 x, y := f.index(i) 132 f[x] &= ^(1 << y) 133 } 134 135 func (f fieldset) index(i int) (int, int) { 136 return i / 64, i % 64 137 } 138 139 // ParseRewriteTemplate constructs a Rewriter for a protobuf type using the 140 // given json template to describe the rewrite rules. 141 // 142 // The json template contains a representation of the 143 func ParseRewriteTemplate(typ Type, jsonTemplate []byte) (Rewriter, error) { 144 switch typ.Kind() { 145 case Struct: 146 return parseRewriteTemplateStruct(typ, 0, jsonTemplate) 147 default: 148 return nil, fmt.Errorf("cannot construct a rewrite template from a non-struct type %s", typ.Name()) 149 } 150 } 151 152 func parseRewriteTemplate(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 153 switch t.Kind() { 154 case Bool: 155 return parseRewriteTemplateBool(t, f, j) 156 case Int32: 157 return parseRewriteTemplateInt32(t, f, j) 158 case Int64: 159 return parseRewriteTemplateInt64(t, f, j) 160 case Sint32: 161 return parseRewriteTemplateSint32(t, f, j) 162 case Sint64: 163 return parseRewriteTemplateSint64(t, f, j) 164 case Uint32: 165 return parseRewriteTemplateUint64(t, f, j) 166 case Uint64: 167 return parseRewriteTemplateUint64(t, f, j) 168 case Fix32: 169 return parseRewriteTemplateFix32(t, f, j) 170 case Fix64: 171 return parseRewriteTemplateFix64(t, f, j) 172 case Sfix32: 173 return parseRewriteTemplateSfix32(t, f, j) 174 case Sfix64: 175 return parseRewriteTemplateSfix64(t, f, j) 176 case Float: 177 return parseRewriteTemplateFloat(t, f, j) 178 case Double: 179 return parseRewriteTemplateDouble(t, f, j) 180 case String: 181 return parseRewriteTemplateString(t, f, j) 182 case Bytes: 183 return parseRewriteTemplateBytes(t, f, j) 184 case Map: 185 return parseRewriteTemplateMap(t, f, j) 186 case Struct: 187 return parseRewriteTemplateStruct(t, f, j) 188 default: 189 return nil, fmt.Errorf("cannot construct a rewriter from type %s", t.Name()) 190 } 191 } 192 193 func parseRewriteTemplateBool(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 194 var v bool 195 err := json.Unmarshal(j, &v) 196 if !v || err != nil { 197 return nil, err 198 } 199 return f.Bool(v), nil 200 } 201 202 func parseRewriteTemplateInt32(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 203 var v int32 204 err := json.Unmarshal(j, &v) 205 if v == 0 || err != nil { 206 return nil, err 207 } 208 return f.Int32(v), nil 209 } 210 211 func parseRewriteTemplateInt64(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 212 var v int64 213 err := json.Unmarshal(j, &v) 214 if v == 0 || err != nil { 215 return nil, err 216 } 217 return f.Int64(v), nil 218 } 219 220 func parseRewriteTemplateSint32(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 221 var v int32 222 err := json.Unmarshal(j, &v) 223 if v == 0 || err != nil { 224 return nil, err 225 } 226 return f.Uint32(encodeZigZag32(v)), nil 227 } 228 229 func parseRewriteTemplateSint64(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 230 var v int64 231 err := json.Unmarshal(j, &v) 232 if v == 0 || err != nil { 233 return nil, err 234 } 235 return f.Uint64(encodeZigZag64(v)), nil 236 } 237 238 func parseRewriteTemplateUint32(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 239 var v uint32 240 err := json.Unmarshal(j, &v) 241 if v == 0 || err != nil { 242 return nil, err 243 } 244 return f.Uint32(v), nil 245 } 246 247 func parseRewriteTemplateUint64(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 248 var v uint64 249 err := json.Unmarshal(j, &v) 250 if v == 0 || err != nil { 251 return nil, err 252 } 253 return f.Uint64(v), nil 254 } 255 256 func parseRewriteTemplateFix32(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 257 var v uint32 258 err := json.Unmarshal(j, &v) 259 if v == 0 || err != nil { 260 return nil, err 261 } 262 return f.Fixed32(v), nil 263 } 264 265 func parseRewriteTemplateFix64(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 266 var v uint64 267 err := json.Unmarshal(j, &v) 268 if v == 0 || err != nil { 269 return nil, err 270 } 271 return f.Fixed64(v), nil 272 } 273 274 func parseRewriteTemplateSfix32(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 275 var v int32 276 err := json.Unmarshal(j, &v) 277 if v == 0 || err != nil { 278 return nil, err 279 } 280 return f.Fixed32(encodeZigZag32(v)), nil 281 } 282 283 func parseRewriteTemplateSfix64(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 284 var v int64 285 err := json.Unmarshal(j, &v) 286 if v == 0 || err != nil { 287 return nil, err 288 } 289 return f.Fixed64(encodeZigZag64(v)), nil 290 } 291 292 func parseRewriteTemplateFloat(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 293 var v float32 294 err := json.Unmarshal(j, &v) 295 if v == 0 || err != nil { 296 return nil, err 297 } 298 return f.Float32(v), nil 299 } 300 301 func parseRewriteTemplateDouble(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 302 var v float64 303 err := json.Unmarshal(j, &v) 304 if v == 0 || err != nil { 305 return nil, err 306 } 307 return f.Float64(v), nil 308 } 309 310 func parseRewriteTemplateString(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 311 var v string 312 err := json.Unmarshal(j, &v) 313 if v == "" || err != nil { 314 return nil, err 315 } 316 return f.String(v), nil 317 } 318 319 func parseRewriteTemplateBytes(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 320 var v string 321 err := json.Unmarshal(j, &v) 322 if v == "" || err != nil { 323 return nil, err 324 } 325 return f.Bytes([]byte(v)), nil 326 } 327 328 func parseRewriteTemplateMap(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 329 st := &structType{ 330 name: t.Name(), 331 fields: []Field{ 332 {Index: 0, Number: 1, Name: "key", Type: t.Key()}, 333 {Index: 1, Number: 2, Name: "value", Type: t.Elem()}, 334 }, 335 fieldsByName: make(map[string]int), 336 fieldsByNumber: make(map[FieldNumber]int), 337 } 338 339 for _, f := range st.fields { 340 st.fieldsByName[f.Name] = f.Index 341 st.fieldsByNumber[f.Number] = f.Index 342 } 343 344 template := map[string]json.RawMessage{} 345 346 if err := json.Unmarshal(j, &template); err != nil { 347 return nil, err 348 } 349 350 maplist := make([]json.RawMessage, 0, len(template)) 351 352 for key, value := range template { 353 b, err := json.Marshal(struct { 354 Key string `json:"key"` 355 Value json.RawMessage `json:"value"` 356 }{ 357 Key: key, 358 Value: value, 359 }) 360 if err != nil { 361 return nil, err 362 } 363 maplist = append(maplist, b) 364 } 365 366 rewriters := make([]Rewriter, len(maplist)) 367 368 for i, b := range maplist { 369 r, err := parseRewriteTemplateStruct(st, f, b) 370 if err != nil { 371 return nil, err 372 } 373 rewriters[i] = r 374 } 375 376 return MultiRewriter(rewriters...), nil 377 } 378 379 func parseRewriteTemplateStruct(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) { 380 template := map[string]json.RawMessage{} 381 382 if err := json.Unmarshal(j, &template); err != nil { 383 return nil, err 384 } 385 386 fieldsByName := map[string]Field{} 387 388 for i, n := 0, t.NumField(); i < n; i++ { 389 f := t.Field(i) 390 fieldsByName[f.Name] = f 391 } 392 393 message := MessageRewriter{} 394 rewriters := []Rewriter{} 395 396 for k, v := range template { 397 f, ok := fieldsByName[k] 398 if !ok { 399 return nil, fmt.Errorf("rewrite template contained an invalid field named %q", k) 400 } 401 402 var fields []json.RawMessage 403 if f.Repeated { 404 if err := json.Unmarshal(v, &fields); err != nil { 405 return nil, err 406 } 407 } else { 408 fields = []json.RawMessage{v} 409 } 410 411 rewriters = rewriters[:0] 412 413 for _, v := range fields { 414 rw, err := parseRewriteTemplate(f.Type, f.Number, v) 415 if err != nil { 416 return nil, fmt.Errorf("%s: %w", k, err) 417 } 418 if rw != nil { 419 rewriters = append(rewriters, rw) 420 } 421 } 422 423 if cap(message) <= int(f.Number) { 424 m := make(MessageRewriter, f.Number+1) 425 copy(m, message) 426 message = m 427 } 428 429 message[f.Number] = MultiRewriter(rewriters...) 430 } 431 432 if f != 0 { 433 return &embddedRewriter{number: f, message: message}, nil 434 } 435 436 return message, nil 437 } 438 439 type embddedRewriter struct { 440 number FieldNumber 441 message MessageRewriter 442 } 443 444 func (f *embddedRewriter) Rewrite(out, in []byte) ([]byte, error) { 445 prefix := len(out) 446 447 out, err := f.message.Rewrite(out, in) 448 if err != nil { 449 return nil, err 450 } 451 if len(out) == prefix { 452 return out, nil 453 } 454 455 b := [24]byte{} 456 n1, _ := encodeVarint(b[:], EncodeTag(f.number, Varlen)) 457 n2, _ := encodeVarint(b[n1:], uint64(len(out)-prefix)) 458 tagAndLen := n1 + n2 459 460 out = append(out, b[:tagAndLen]...) 461 copy(out[prefix+tagAndLen:], out[prefix:]) 462 copy(out[prefix:], b[:tagAndLen]) 463 return out, nil 464 }