github.com/nikandfor/tlog@v0.21.5-0.20231108111739-3ef89426a96d/convert/rewriter.go (about)

     1  package convert
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"sort"
     8  
     9  	"github.com/nikandfor/tlog"
    10  	"github.com/nikandfor/tlog/tlwire"
    11  )
    12  
    13  type (
    14  	Rewriter struct {
    15  		io.Writer
    16  		tlwire.Decoder
    17  		tlwire.Encoder
    18  
    19  		Rule RewriterRule
    20  
    21  		b []byte
    22  	}
    23  
    24  	RewriterRule interface {
    25  		Rewrite(b, p []byte, path []tlog.RawMessage, kst, st int) ([]byte, int, error)
    26  	}
    27  
    28  	RewriterFunc func(b, p []byte, path []tlog.RawMessage, kst, st int) ([]byte, int, error)
    29  
    30  	KeyRenamer struct {
    31  		rules []RenameRule
    32  
    33  		d tlwire.Decoder
    34  		e tlwire.Encoder
    35  
    36  		Rewriter RewriterRule
    37  		Fallback RewriterRule
    38  	}
    39  
    40  	RenameRule struct {
    41  		Path []tlog.RawMessage
    42  
    43  		Rename []byte
    44  		Prefix []byte
    45  		Remove bool
    46  	}
    47  )
    48  
    49  var ErrFallback = errors.New("fallback")
    50  
    51  func NewRewriter(w io.Writer) *Rewriter {
    52  	return &Rewriter{Writer: w}
    53  }
    54  
    55  func (w *Rewriter) Write(p []byte) (n int, err error) {
    56  	w.b, n, err = w.Rewrite(w.b[:0], p, nil, -1, 0)
    57  	if err != nil {
    58  		return 0, err
    59  	}
    60  
    61  	if w.Writer != nil {
    62  		_, err = w.Writer.Write(w.b)
    63  	}
    64  
    65  	return
    66  }
    67  
    68  func (w *Rewriter) Rewrite(b, p []byte, path []tlog.RawMessage, kst, st int) (r []byte, i int, err error) {
    69  	if w.Rule != nil {
    70  		r, i, err = w.Rule.Rewrite(b, p, path, kst, st)
    71  		if !errors.Is(err, ErrFallback) {
    72  			return
    73  		}
    74  	}
    75  
    76  	if st == len(p) {
    77  		return b, st, nil
    78  	}
    79  
    80  	tag, sub, i := w.Tag(p, st)
    81  
    82  	if kst != -1 && tag != tlwire.Semantic {
    83  		b = append(b, p[kst:st]...)
    84  	}
    85  
    86  	switch tag {
    87  	case tlwire.Int, tlwire.Neg:
    88  	case tlwire.String, tlwire.Bytes:
    89  		i = w.Skip(p, st)
    90  	case tlwire.Array, tlwire.Map:
    91  		b = append(b, p[st:i]...)
    92  		kst := -1
    93  		subp := path
    94  
    95  		if tag == tlwire.Array {
    96  			subp = append(subp, []byte{tlwire.Array})
    97  		}
    98  
    99  		for el := 0; sub == -1 || el < int(sub); el++ {
   100  			if sub == -1 && w.Break(p, &i) {
   101  				break
   102  			}
   103  
   104  			if tag == tlwire.Map {
   105  				kst = i
   106  				_, i = w.Bytes(p, i)
   107  
   108  				subp = append(subp[:len(path)], p[kst:i])
   109  			}
   110  
   111  			b, i, err = w.Rewrite(b, p, subp, kst, i)
   112  			if err != nil {
   113  				return
   114  			}
   115  		}
   116  
   117  		if sub == -1 {
   118  			b = w.AppendBreak(b)
   119  		}
   120  
   121  		return b, i, nil
   122  	case tlwire.Semantic:
   123  		path = append(path, p[st:i])
   124  
   125  		return w.Rewrite(b, p, path, kst, i)
   126  	case tlwire.Special:
   127  		switch sub {
   128  		case tlwire.False,
   129  			tlwire.True,
   130  			tlwire.Nil,
   131  			tlwire.Undefined,
   132  			tlwire.None,
   133  			tlwire.Hidden,
   134  			tlwire.SelfRef,
   135  			tlwire.Break:
   136  		case tlwire.Float8:
   137  			i += 1
   138  		case tlwire.Float16:
   139  			i += 2
   140  		case tlwire.Float32:
   141  			i += 4
   142  		case tlwire.Float64:
   143  			i += 8
   144  		default:
   145  			panic("unsupported special")
   146  		}
   147  	}
   148  
   149  	b = append(b, p[st:i]...)
   150  
   151  	return b, i, nil
   152  }
   153  
   154  func (f RewriterFunc) Rewrite(b, p []byte, path []tlog.RawMessage, kst, st int) ([]byte, int, error) {
   155  	return f(b, p, path, kst, st)
   156  }
   157  
   158  func NewKeyRenamer(rew RewriterRule, rules ...RenameRule) *KeyRenamer {
   159  	w := &KeyRenamer{
   160  		Rewriter: rew,
   161  	}
   162  
   163  	w.Append(rules...)
   164  
   165  	return w
   166  }
   167  
   168  func (w *KeyRenamer) Append(rules ...RenameRule) {
   169  	w.rules = append(w.rules, rules...)
   170  
   171  	sort.Slice(w.rules, func(i, j int) bool {
   172  		return w.cmp(w.rules[i].Path, w.rules[j].Path) < 0
   173  	})
   174  }
   175  
   176  func (w *KeyRenamer) Rewrite(b, p []byte, path []tlog.RawMessage, kst, st int) ([]byte, int, error) {
   177  	pos := sort.Search(len(w.rules), func(i int) bool {
   178  		rule := w.rules[i]
   179  
   180  		return w.cmp(path, rule.Path) <= 0
   181  	})
   182  
   183  	//	fmt.Printf("rewrite  %q -> %v %v\n", path, pos, pos < len(w.rules) && w.cmp(path, w.rules[pos].Path) == 0)
   184  
   185  	if pos == len(w.rules) || w.cmp(path, w.rules[pos].Path) != 0 {
   186  		return w.fallback(b, p, path, kst, st)
   187  	}
   188  
   189  	rule := w.rules[pos]
   190  
   191  	if rule.Remove {
   192  		end := w.d.Skip(p, st)
   193  		return b, end, nil
   194  	}
   195  
   196  	key, kend := w.d.Bytes(p, kst)
   197  
   198  	l := len(rule.Prefix)
   199  
   200  	if rule.Rename != nil {
   201  		l += len(rule.Rename)
   202  	} else {
   203  		l += len(key)
   204  	}
   205  
   206  	b = w.e.AppendTag(b, tlwire.String, l)
   207  	b = append(b, rule.Prefix...)
   208  
   209  	if rule.Rename != nil {
   210  		b = append(b, rule.Rename...)
   211  	} else {
   212  		b = append(b, key...)
   213  	}
   214  
   215  	b = append(b, p[kend:st]...)
   216  
   217  	if w.Rewriter != nil {
   218  		return w.Rewriter.Rewrite(b, p, path, -1, st)
   219  	}
   220  
   221  	end := w.d.Skip(p, st)
   222  	b = append(b, p[st:end]...)
   223  
   224  	return b, end, nil
   225  }
   226  
   227  func (w *KeyRenamer) fallback(b, p []byte, path []tlog.RawMessage, kst, st int) ([]byte, int, error) {
   228  	if w.Fallback == nil {
   229  		return b, st, ErrFallback
   230  	}
   231  
   232  	return w.Fallback.Rewrite(b, p, path, kst, st)
   233  }
   234  
   235  func (w *KeyRenamer) cmp(x, y []tlog.RawMessage) (r int) {
   236  	//	defer func() {
   237  	//		fmt.Printf("cmp %q %q -> %d  from %v\n", x, y, r, loc.Caller(1))
   238  	//	}()
   239  	for i := 0; i < min(len(x), len(y)); i++ {
   240  		r = bytes.Compare(x[i], y[i])
   241  		if r != 0 {
   242  			return r
   243  		}
   244  	}
   245  
   246  	if len(x) != len(y) {
   247  		if len(x) < len(y) {
   248  			return -1
   249  		}
   250  
   251  		return 1
   252  	}
   253  
   254  	return 0
   255  }
   256  
   257  func min(a, b int) int {
   258  	if a < b {
   259  		return a
   260  	}
   261  
   262  	return b
   263  }