github.com/nikandfor/tlog@v0.21.5-0.20231108111739-3ef89426a96d/tlwire/stream_decoder.go (about) 1 package tlwire 2 3 import ( 4 "io" 5 6 "github.com/nikandfor/errors" 7 ) 8 9 type StreamDecoder struct { 10 io.Reader 11 12 b []byte 13 i int 14 boff int64 15 } 16 17 const ( 18 eUnexpectedEOF = -1 - iota 19 eBadFormat 20 eBadSpecial 21 ) 22 23 func NewStreamDecoder(r io.Reader) *StreamDecoder { 24 return &StreamDecoder{ 25 Reader: r, 26 } 27 } 28 29 func (d *StreamDecoder) Decode() (data []byte, err error) { 30 end, err := d.skipRead() 31 if err != nil { 32 return nil, err 33 } 34 35 st := d.i 36 d.i = end 37 38 return d.b[st:end:end], nil 39 } 40 41 func (d *StreamDecoder) Read(p []byte) (n int, err error) { 42 end, err := d.skipRead() 43 if err != nil { 44 return 0, err 45 } 46 47 if len(p) < end-d.i { 48 return 0, io.ErrShortBuffer 49 } 50 51 copy(p, d.b[d.i:end]) 52 d.i = end 53 54 return len(p), nil 55 } 56 57 func (d *StreamDecoder) WriteTo(w io.Writer) (n int64, err error) { 58 for { 59 data, err := d.Decode() 60 if errors.Is(err, io.EOF) { 61 return n, nil 62 } 63 if err != nil { 64 return n, errors.Wrap(err, "decode") 65 } 66 67 m, err := w.Write(data) 68 n += int64(m) 69 if err != nil { 70 return n, errors.Wrap(err, "write") 71 } 72 } 73 } 74 75 func (d *StreamDecoder) skipRead() (end int, err error) { 76 for { 77 end = d.skip(d.i) 78 // println("skip", d.i, end) 79 if end > 0 { 80 return end, nil 81 } 82 83 if end < eUnexpectedEOF { 84 return 0, errors.New("bad format") 85 } 86 87 err = d.more() 88 if err != nil { 89 return 0, err 90 } 91 } 92 } 93 94 func (d *StreamDecoder) skip(st int) (i int) { 95 tag, sub, i := readTag(d.b, st) 96 // println("tag", st, tag, sub, i) 97 if i < 0 { 98 return i 99 } 100 101 switch tag { 102 case Int, Neg: 103 // already read 104 case Bytes, String: 105 i += int(sub) 106 case Array, Map: 107 for el := 0; sub == -1 || el < int(sub); el++ { 108 if i == len(d.b) { 109 return eUnexpectedEOF 110 } 111 if sub == -1 && d.b[i] == Special|Break { 112 i++ 113 break 114 } 115 116 if tag == Map { 117 i = d.skip(i) 118 if i < 0 { 119 return i 120 } 121 } 122 123 i = d.skip(i) 124 if i < 0 { 125 return i 126 } 127 } 128 case Semantic: 129 return d.skip(i) 130 case Special: 131 switch sub { 132 case False, 133 True, 134 Nil, 135 Undefined, 136 Break: 137 case Float8: 138 i += 1 139 case Float16: 140 i += 2 141 case Float32: 142 i += 4 143 case Float64: 144 i += 8 145 default: 146 return eBadSpecial 147 } 148 } 149 150 if i > len(d.b) { 151 return eUnexpectedEOF 152 } 153 154 return i 155 } 156 157 func (d *StreamDecoder) more() (err error) { 158 copy(d.b, d.b[d.i:]) 159 d.b = d.b[:len(d.b)-d.i] 160 d.boff += int64(d.i) 161 d.i = 0 162 163 end := len(d.b) 164 165 if len(d.b) == 0 { 166 d.b = make([]byte, 1024) 167 } else { 168 d.b = append(d.b, 0, 0, 0, 0, 0, 0, 0, 0) 169 } 170 171 d.b = d.b[:cap(d.b)] 172 173 n, err := d.Reader.Read(d.b[end:]) 174 // println("more", d.i, end, end+n, n, len(d.b)) 175 d.b = d.b[:end+n] 176 177 if n != 0 && errors.Is(err, io.EOF) { 178 err = nil 179 } 180 181 return err 182 } 183 184 func readTag(b []byte, st int) (tag byte, sub int64, i int) { 185 if st >= len(b) { 186 return tag, sub, eUnexpectedEOF 187 } 188 189 i = st 190 191 tag = b[i] & TagMask 192 sub = int64(b[i] & TagDetMask) 193 i++ 194 195 if tag == Special { 196 return 197 } 198 199 if sub < Len1 { 200 return 201 } 202 203 switch sub { 204 case LenBreak: 205 sub = -1 206 case Len1: 207 if i+1 > len(b) { 208 return tag, sub, eUnexpectedEOF 209 } 210 211 sub = int64(b[i]) 212 i++ 213 case Len2: 214 if i+2 > len(b) { 215 return tag, sub, eUnexpectedEOF 216 } 217 218 sub = int64(b[i])<<8 | int64(b[i+1]) 219 i += 2 220 case Len4: 221 if i+4 > len(b) { 222 return tag, sub, eUnexpectedEOF 223 } 224 225 sub = int64(b[i])<<24 | int64(b[i+1])<<16 | int64(b[i+2])<<8 | int64(b[i+3]) 226 i += 4 227 case Len8: 228 if i+8 > len(b) { 229 return tag, sub, eUnexpectedEOF 230 } 231 232 sub = int64(b[i])<<56 | int64(b[i+1])<<48 | int64(b[i+2])<<40 | int64(b[i+3])<<32 | 233 int64(b[i+4])<<24 | int64(b[i+5])<<16 | int64(b[i+6])<<8 | int64(b[i+7]) 234 i += 8 235 default: 236 return tag, sub, eBadFormat 237 } 238 239 return tag, sub, i 240 }