github.com/kamalshkeir/kencoding@v0.0.2-0.20230409043843-44b609a0475a/proto/rewrite.go (about)

     1  package proto
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/kamalshkeir/kencoding/json"
     7  )
     8  
     9  // Rewriter is an interface implemented by types that support rewriting protobuf
    10  // messages.
    11  type Rewriter interface {
    12  	// The function is expected to append the new content to the byte slice
    13  	// passed as argument. If it wasn't able to perform the rewrite, it must
    14  	// return a non-nil error.
    15  	Rewrite(out, in []byte) ([]byte, error)
    16  }
    17  
    18  type identity struct{}
    19  
    20  func (identity) Rewrite(out, in []byte) ([]byte, error) {
    21  	return append(out, in...), nil
    22  }
    23  
    24  // MultiRewriter constructs a Rewriter which applies all rewriters passed as
    25  // arguments.
    26  func MultiRewriter(rewriters ...Rewriter) Rewriter {
    27  	if len(rewriters) == 1 {
    28  		return rewriters[0]
    29  	}
    30  	m := &multiRewriter{rewriters: make([]Rewriter, len(rewriters))}
    31  	copy(m.rewriters, rewriters)
    32  	return m
    33  }
    34  
    35  type multiRewriter struct {
    36  	rewriters []Rewriter
    37  }
    38  
    39  func (m *multiRewriter) Rewrite(out, in []byte) ([]byte, error) {
    40  	var err error
    41  
    42  	for _, rw := range m.rewriters {
    43  		if out, err = rw.Rewrite(out, in); err != nil {
    44  			return out, err
    45  		}
    46  	}
    47  
    48  	return out, nil
    49  }
    50  
    51  // RewriteFunc is a function type implementing the Rewriter interface.
    52  type RewriteFunc func([]byte, []byte) ([]byte, error)
    53  
    54  // Rewrite satisfies the Rewriter interface.
    55  func (r RewriteFunc) Rewrite(out, in []byte) ([]byte, error) {
    56  	return r(out, in)
    57  }
    58  
    59  // MessageRewriter maps field numbers to rewrite rules, satisfying the Rewriter
    60  // interace to support composing rewrite rules.
    61  type MessageRewriter []Rewriter
    62  
    63  // Rewrite applies the rewrite rule matching f in r, satisfies the Rewriter
    64  // interface.
    65  func (r MessageRewriter) Rewrite(out, in []byte) ([]byte, error) {
    66  	seen := make(fieldset, 4)
    67  
    68  	if n := seen.len(); len(r) >= n {
    69  		seen = makeFieldset(len(r) + 1)
    70  	}
    71  
    72  	for len(in) != 0 {
    73  		f, t, v, m, err := Parse(in)
    74  		if err != nil {
    75  			return out, err
    76  		}
    77  
    78  		if i := int(f); i >= 0 && i < len(r) && r[i] != nil {
    79  			if !seen.has(i) {
    80  				seen.set(i)
    81  				if out, err = r[i].Rewrite(out, v); err != nil {
    82  					return out, err
    83  				}
    84  			}
    85  		} else {
    86  			out = Append(out, f, t, v)
    87  		}
    88  
    89  		in = m
    90  	}
    91  
    92  	for i, f := range r {
    93  		if f != nil && !seen.has(i) {
    94  			b, err := r[i].Rewrite(out, nil)
    95  			if err != nil {
    96  				return b, err
    97  			}
    98  			out = b
    99  		}
   100  	}
   101  
   102  	return out, nil
   103  }
   104  
   105  type fieldset []uint64
   106  
   107  func makeFieldset(n int) fieldset {
   108  	if (n % 64) != 0 {
   109  		n = (n + 1) / 64
   110  	} else {
   111  		n /= 64
   112  	}
   113  	return make(fieldset, n)
   114  }
   115  
   116  func (f fieldset) len() int {
   117  	return len(f) * 64
   118  }
   119  
   120  func (f fieldset) has(i int) bool {
   121  	x, y := f.index(i)
   122  	return ((f[x] >> y) & 1) != 0
   123  }
   124  
   125  func (f fieldset) set(i int) {
   126  	x, y := f.index(i)
   127  	f[x] |= 1 << y
   128  }
   129  
   130  func (f fieldset) unset(i int) {
   131  	x, y := f.index(i)
   132  	f[x] &= ^(1 << y)
   133  }
   134  
   135  func (f fieldset) index(i int) (int, int) {
   136  	return i / 64, i % 64
   137  }
   138  
   139  // ParseRewriteTemplate constructs a Rewriter for a protobuf type using the
   140  // given json template to describe the rewrite rules.
   141  //
   142  // The json template contains a representation of the
   143  func ParseRewriteTemplate(typ Type, jsonTemplate []byte) (Rewriter, error) {
   144  	switch typ.Kind() {
   145  	case Struct:
   146  		return parseRewriteTemplateStruct(typ, 0, jsonTemplate)
   147  	default:
   148  		return nil, fmt.Errorf("cannot construct a rewrite template from a non-struct type %s", typ.Name())
   149  	}
   150  }
   151  
   152  func parseRewriteTemplate(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   153  	switch t.Kind() {
   154  	case Bool:
   155  		return parseRewriteTemplateBool(t, f, j)
   156  	case Int32:
   157  		return parseRewriteTemplateInt32(t, f, j)
   158  	case Int64:
   159  		return parseRewriteTemplateInt64(t, f, j)
   160  	case Sint32:
   161  		return parseRewriteTemplateSint32(t, f, j)
   162  	case Sint64:
   163  		return parseRewriteTemplateSint64(t, f, j)
   164  	case Uint32:
   165  		return parseRewriteTemplateUint64(t, f, j)
   166  	case Uint64:
   167  		return parseRewriteTemplateUint64(t, f, j)
   168  	case Fix32:
   169  		return parseRewriteTemplateFix32(t, f, j)
   170  	case Fix64:
   171  		return parseRewriteTemplateFix64(t, f, j)
   172  	case Sfix32:
   173  		return parseRewriteTemplateSfix32(t, f, j)
   174  	case Sfix64:
   175  		return parseRewriteTemplateSfix64(t, f, j)
   176  	case Float:
   177  		return parseRewriteTemplateFloat(t, f, j)
   178  	case Double:
   179  		return parseRewriteTemplateDouble(t, f, j)
   180  	case String:
   181  		return parseRewriteTemplateString(t, f, j)
   182  	case Bytes:
   183  		return parseRewriteTemplateBytes(t, f, j)
   184  	case Map:
   185  		return parseRewriteTemplateMap(t, f, j)
   186  	case Struct:
   187  		return parseRewriteTemplateStruct(t, f, j)
   188  	default:
   189  		return nil, fmt.Errorf("cannot construct a rewriter from type %s", t.Name())
   190  	}
   191  }
   192  
   193  func parseRewriteTemplateBool(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   194  	var v bool
   195  	err := json.Unmarshal(j, &v)
   196  	if !v || err != nil {
   197  		return nil, err
   198  	}
   199  	return f.Bool(v), nil
   200  }
   201  
   202  func parseRewriteTemplateInt32(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   203  	var v int32
   204  	err := json.Unmarshal(j, &v)
   205  	if v == 0 || err != nil {
   206  		return nil, err
   207  	}
   208  	return f.Int32(v), nil
   209  }
   210  
   211  func parseRewriteTemplateInt64(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   212  	var v int64
   213  	err := json.Unmarshal(j, &v)
   214  	if v == 0 || err != nil {
   215  		return nil, err
   216  	}
   217  	return f.Int64(v), nil
   218  }
   219  
   220  func parseRewriteTemplateSint32(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   221  	var v int32
   222  	err := json.Unmarshal(j, &v)
   223  	if v == 0 || err != nil {
   224  		return nil, err
   225  	}
   226  	return f.Uint32(encodeZigZag32(v)), nil
   227  }
   228  
   229  func parseRewriteTemplateSint64(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   230  	var v int64
   231  	err := json.Unmarshal(j, &v)
   232  	if v == 0 || err != nil {
   233  		return nil, err
   234  	}
   235  	return f.Uint64(encodeZigZag64(v)), nil
   236  }
   237  
   238  func parseRewriteTemplateUint32(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   239  	var v uint32
   240  	err := json.Unmarshal(j, &v)
   241  	if v == 0 || err != nil {
   242  		return nil, err
   243  	}
   244  	return f.Uint32(v), nil
   245  }
   246  
   247  func parseRewriteTemplateUint64(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   248  	var v uint64
   249  	err := json.Unmarshal(j, &v)
   250  	if v == 0 || err != nil {
   251  		return nil, err
   252  	}
   253  	return f.Uint64(v), nil
   254  }
   255  
   256  func parseRewriteTemplateFix32(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   257  	var v uint32
   258  	err := json.Unmarshal(j, &v)
   259  	if v == 0 || err != nil {
   260  		return nil, err
   261  	}
   262  	return f.Fixed32(v), nil
   263  }
   264  
   265  func parseRewriteTemplateFix64(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   266  	var v uint64
   267  	err := json.Unmarshal(j, &v)
   268  	if v == 0 || err != nil {
   269  		return nil, err
   270  	}
   271  	return f.Fixed64(v), nil
   272  }
   273  
   274  func parseRewriteTemplateSfix32(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   275  	var v int32
   276  	err := json.Unmarshal(j, &v)
   277  	if v == 0 || err != nil {
   278  		return nil, err
   279  	}
   280  	return f.Fixed32(encodeZigZag32(v)), nil
   281  }
   282  
   283  func parseRewriteTemplateSfix64(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   284  	var v int64
   285  	err := json.Unmarshal(j, &v)
   286  	if v == 0 || err != nil {
   287  		return nil, err
   288  	}
   289  	return f.Fixed64(encodeZigZag64(v)), nil
   290  }
   291  
   292  func parseRewriteTemplateFloat(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   293  	var v float32
   294  	err := json.Unmarshal(j, &v)
   295  	if v == 0 || err != nil {
   296  		return nil, err
   297  	}
   298  	return f.Float32(v), nil
   299  }
   300  
   301  func parseRewriteTemplateDouble(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   302  	var v float64
   303  	err := json.Unmarshal(j, &v)
   304  	if v == 0 || err != nil {
   305  		return nil, err
   306  	}
   307  	return f.Float64(v), nil
   308  }
   309  
   310  func parseRewriteTemplateString(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   311  	var v string
   312  	err := json.Unmarshal(j, &v)
   313  	if v == "" || err != nil {
   314  		return nil, err
   315  	}
   316  	return f.String(v), nil
   317  }
   318  
   319  func parseRewriteTemplateBytes(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   320  	var v string
   321  	err := json.Unmarshal(j, &v)
   322  	if v == "" || err != nil {
   323  		return nil, err
   324  	}
   325  	return f.Bytes([]byte(v)), nil
   326  }
   327  
   328  func parseRewriteTemplateMap(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   329  	st := &structType{
   330  		name: t.Name(),
   331  		fields: []Field{
   332  			{Index: 0, Number: 1, Name: "key", Type: t.Key()},
   333  			{Index: 1, Number: 2, Name: "value", Type: t.Elem()},
   334  		},
   335  		fieldsByName:   make(map[string]int),
   336  		fieldsByNumber: make(map[FieldNumber]int),
   337  	}
   338  
   339  	for _, f := range st.fields {
   340  		st.fieldsByName[f.Name] = f.Index
   341  		st.fieldsByNumber[f.Number] = f.Index
   342  	}
   343  
   344  	template := map[string]json.RawMessage{}
   345  
   346  	if err := json.Unmarshal(j, &template); err != nil {
   347  		return nil, err
   348  	}
   349  
   350  	maplist := make([]json.RawMessage, 0, len(template))
   351  
   352  	for key, value := range template {
   353  		b, err := json.Marshal(struct {
   354  			Key   string          `json:"key"`
   355  			Value json.RawMessage `json:"value"`
   356  		}{
   357  			Key:   key,
   358  			Value: value,
   359  		})
   360  		if err != nil {
   361  			return nil, err
   362  		}
   363  		maplist = append(maplist, b)
   364  	}
   365  
   366  	rewriters := make([]Rewriter, len(maplist))
   367  
   368  	for i, b := range maplist {
   369  		r, err := parseRewriteTemplateStruct(st, f, b)
   370  		if err != nil {
   371  			return nil, err
   372  		}
   373  		rewriters[i] = r
   374  	}
   375  
   376  	return MultiRewriter(rewriters...), nil
   377  }
   378  
   379  func parseRewriteTemplateStruct(t Type, f FieldNumber, j json.RawMessage) (Rewriter, error) {
   380  	template := map[string]json.RawMessage{}
   381  
   382  	if err := json.Unmarshal(j, &template); err != nil {
   383  		return nil, err
   384  	}
   385  
   386  	fieldsByName := map[string]Field{}
   387  
   388  	for i, n := 0, t.NumField(); i < n; i++ {
   389  		f := t.Field(i)
   390  		fieldsByName[f.Name] = f
   391  	}
   392  
   393  	message := MessageRewriter{}
   394  	rewriters := []Rewriter{}
   395  
   396  	for k, v := range template {
   397  		f, ok := fieldsByName[k]
   398  		if !ok {
   399  			return nil, fmt.Errorf("rewrite template contained an invalid field named %q", k)
   400  		}
   401  
   402  		var fields []json.RawMessage
   403  		if f.Repeated {
   404  			if err := json.Unmarshal(v, &fields); err != nil {
   405  				return nil, err
   406  			}
   407  		} else {
   408  			fields = []json.RawMessage{v}
   409  		}
   410  
   411  		rewriters = rewriters[:0]
   412  
   413  		for _, v := range fields {
   414  			rw, err := parseRewriteTemplate(f.Type, f.Number, v)
   415  			if err != nil {
   416  				return nil, fmt.Errorf("%s: %w", k, err)
   417  			}
   418  			if rw != nil {
   419  				rewriters = append(rewriters, rw)
   420  			}
   421  		}
   422  
   423  		if cap(message) <= int(f.Number) {
   424  			m := make(MessageRewriter, f.Number+1)
   425  			copy(m, message)
   426  			message = m
   427  		}
   428  
   429  		message[f.Number] = MultiRewriter(rewriters...)
   430  	}
   431  
   432  	if f != 0 {
   433  		return &embddedRewriter{number: f, message: message}, nil
   434  	}
   435  
   436  	return message, nil
   437  }
   438  
   439  type embddedRewriter struct {
   440  	number  FieldNumber
   441  	message MessageRewriter
   442  }
   443  
   444  func (f *embddedRewriter) Rewrite(out, in []byte) ([]byte, error) {
   445  	prefix := len(out)
   446  
   447  	out, err := f.message.Rewrite(out, in)
   448  	if err != nil {
   449  		return nil, err
   450  	}
   451  	if len(out) == prefix {
   452  		return out, nil
   453  	}
   454  
   455  	b := [24]byte{}
   456  	n1, _ := encodeVarint(b[:], EncodeTag(f.number, Varlen))
   457  	n2, _ := encodeVarint(b[n1:], uint64(len(out)-prefix))
   458  	tagAndLen := n1 + n2
   459  
   460  	out = append(out, b[:tagAndLen]...)
   461  	copy(out[prefix+tagAndLen:], out[prefix:])
   462  	copy(out[prefix:], b[:tagAndLen])
   463  	return out, nil
   464  }