github.com/nikandfor/tlog@v0.21.5-0.20231108111739-3ef89426a96d/tlwire/stream_decoder.go (about)

     1  package tlwire
     2  
     3  import (
     4  	"io"
     5  
     6  	"github.com/nikandfor/errors"
     7  )
     8  
     9  type StreamDecoder struct {
    10  	io.Reader
    11  
    12  	b    []byte
    13  	i    int
    14  	boff int64
    15  }
    16  
    17  const (
    18  	eUnexpectedEOF = -1 - iota
    19  	eBadFormat
    20  	eBadSpecial
    21  )
    22  
    23  func NewStreamDecoder(r io.Reader) *StreamDecoder {
    24  	return &StreamDecoder{
    25  		Reader: r,
    26  	}
    27  }
    28  
    29  func (d *StreamDecoder) Decode() (data []byte, err error) {
    30  	end, err := d.skipRead()
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  
    35  	st := d.i
    36  	d.i = end
    37  
    38  	return d.b[st:end:end], nil
    39  }
    40  
    41  func (d *StreamDecoder) Read(p []byte) (n int, err error) {
    42  	end, err := d.skipRead()
    43  	if err != nil {
    44  		return 0, err
    45  	}
    46  
    47  	if len(p) < end-d.i {
    48  		return 0, io.ErrShortBuffer
    49  	}
    50  
    51  	copy(p, d.b[d.i:end])
    52  	d.i = end
    53  
    54  	return len(p), nil
    55  }
    56  
    57  func (d *StreamDecoder) WriteTo(w io.Writer) (n int64, err error) {
    58  	for {
    59  		data, err := d.Decode()
    60  		if errors.Is(err, io.EOF) {
    61  			return n, nil
    62  		}
    63  		if err != nil {
    64  			return n, errors.Wrap(err, "decode")
    65  		}
    66  
    67  		m, err := w.Write(data)
    68  		n += int64(m)
    69  		if err != nil {
    70  			return n, errors.Wrap(err, "write")
    71  		}
    72  	}
    73  }
    74  
    75  func (d *StreamDecoder) skipRead() (end int, err error) {
    76  	for {
    77  		end = d.skip(d.i)
    78  		//	println("skip", d.i, end)
    79  		if end > 0 {
    80  			return end, nil
    81  		}
    82  
    83  		if end < eUnexpectedEOF {
    84  			return 0, errors.New("bad format")
    85  		}
    86  
    87  		err = d.more()
    88  		if err != nil {
    89  			return 0, err
    90  		}
    91  	}
    92  }
    93  
    94  func (d *StreamDecoder) skip(st int) (i int) {
    95  	tag, sub, i := readTag(d.b, st)
    96  	//	println("tag", st, tag, sub, i)
    97  	if i < 0 {
    98  		return i
    99  	}
   100  
   101  	switch tag {
   102  	case Int, Neg:
   103  		// already read
   104  	case Bytes, String:
   105  		i += int(sub)
   106  	case Array, Map:
   107  		for el := 0; sub == -1 || el < int(sub); el++ {
   108  			if i == len(d.b) {
   109  				return eUnexpectedEOF
   110  			}
   111  			if sub == -1 && d.b[i] == Special|Break {
   112  				i++
   113  				break
   114  			}
   115  
   116  			if tag == Map {
   117  				i = d.skip(i)
   118  				if i < 0 {
   119  					return i
   120  				}
   121  			}
   122  
   123  			i = d.skip(i)
   124  			if i < 0 {
   125  				return i
   126  			}
   127  		}
   128  	case Semantic:
   129  		return d.skip(i)
   130  	case Special:
   131  		switch sub {
   132  		case False,
   133  			True,
   134  			Nil,
   135  			Undefined,
   136  			Break:
   137  		case Float8:
   138  			i += 1
   139  		case Float16:
   140  			i += 2
   141  		case Float32:
   142  			i += 4
   143  		case Float64:
   144  			i += 8
   145  		default:
   146  			return eBadSpecial
   147  		}
   148  	}
   149  
   150  	if i > len(d.b) {
   151  		return eUnexpectedEOF
   152  	}
   153  
   154  	return i
   155  }
   156  
   157  func (d *StreamDecoder) more() (err error) {
   158  	copy(d.b, d.b[d.i:])
   159  	d.b = d.b[:len(d.b)-d.i]
   160  	d.boff += int64(d.i)
   161  	d.i = 0
   162  
   163  	end := len(d.b)
   164  
   165  	if len(d.b) == 0 {
   166  		d.b = make([]byte, 1024)
   167  	} else {
   168  		d.b = append(d.b, 0, 0, 0, 0, 0, 0, 0, 0)
   169  	}
   170  
   171  	d.b = d.b[:cap(d.b)]
   172  
   173  	n, err := d.Reader.Read(d.b[end:])
   174  	//	println("more", d.i, end, end+n, n, len(d.b))
   175  	d.b = d.b[:end+n]
   176  
   177  	if n != 0 && errors.Is(err, io.EOF) {
   178  		err = nil
   179  	}
   180  
   181  	return err
   182  }
   183  
   184  func readTag(b []byte, st int) (tag byte, sub int64, i int) {
   185  	if st >= len(b) {
   186  		return tag, sub, eUnexpectedEOF
   187  	}
   188  
   189  	i = st
   190  
   191  	tag = b[i] & TagMask
   192  	sub = int64(b[i] & TagDetMask)
   193  	i++
   194  
   195  	if tag == Special {
   196  		return
   197  	}
   198  
   199  	if sub < Len1 {
   200  		return
   201  	}
   202  
   203  	switch sub {
   204  	case LenBreak:
   205  		sub = -1
   206  	case Len1:
   207  		if i+1 > len(b) {
   208  			return tag, sub, eUnexpectedEOF
   209  		}
   210  
   211  		sub = int64(b[i])
   212  		i++
   213  	case Len2:
   214  		if i+2 > len(b) {
   215  			return tag, sub, eUnexpectedEOF
   216  		}
   217  
   218  		sub = int64(b[i])<<8 | int64(b[i+1])
   219  		i += 2
   220  	case Len4:
   221  		if i+4 > len(b) {
   222  			return tag, sub, eUnexpectedEOF
   223  		}
   224  
   225  		sub = int64(b[i])<<24 | int64(b[i+1])<<16 | int64(b[i+2])<<8 | int64(b[i+3])
   226  		i += 4
   227  	case Len8:
   228  		if i+8 > len(b) {
   229  			return tag, sub, eUnexpectedEOF
   230  		}
   231  
   232  		sub = int64(b[i])<<56 | int64(b[i+1])<<48 | int64(b[i+2])<<40 | int64(b[i+3])<<32 |
   233  			int64(b[i+4])<<24 | int64(b[i+5])<<16 | int64(b[i+6])<<8 | int64(b[i+7])
   234  		i += 8
   235  	default:
   236  		return tag, sub, eBadFormat
   237  	}
   238  
   239  	return tag, sub, i
   240  }