github.com/hack0072008/kafka-go@v1.0.1/compress/snappy/xerial.go (about) 1 package snappy 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "io" 7 8 "github.com/golang/snappy" 9 ) 10 11 const defaultBufferSize = 32 * 1024 12 13 // An implementation of io.Reader which consumes a stream of xerial-framed 14 // snappy-encoeded data. The framing is optional, if no framing is detected 15 // the reader will simply forward the bytes from its underlying stream. 16 type xerialReader struct { 17 reader io.Reader 18 header [16]byte 19 input []byte 20 output []byte 21 offset int64 22 nbytes int64 23 decode func([]byte, []byte) ([]byte, error) 24 } 25 26 func (x *xerialReader) Reset(r io.Reader) { 27 x.reader = r 28 x.input = x.input[:0] 29 x.output = x.output[:0] 30 x.header = [16]byte{} 31 x.offset = 0 32 x.nbytes = 0 33 } 34 35 func (x *xerialReader) Read(b []byte) (int, error) { 36 for { 37 if x.offset < int64(len(x.output)) { 38 n := copy(b, x.output[x.offset:]) 39 x.offset += int64(n) 40 return n, nil 41 } 42 43 n, err := x.readChunk(b) 44 if err != nil { 45 return 0, err 46 } 47 if n > 0 { 48 return n, nil 49 } 50 } 51 } 52 53 func (x *xerialReader) WriteTo(w io.Writer) (int64, error) { 54 wn := int64(0) 55 56 for { 57 for x.offset < int64(len(x.output)) { 58 n, err := w.Write(x.output[x.offset:]) 59 wn += int64(n) 60 x.offset += int64(n) 61 if err != nil { 62 return wn, err 63 } 64 } 65 66 if _, err := x.readChunk(nil); err != nil { 67 if err == io.EOF { 68 err = nil 69 } 70 return wn, err 71 } 72 } 73 } 74 75 func (x *xerialReader) readChunk(dst []byte) (int, error) { 76 x.output = x.output[:0] 77 x.offset = 0 78 prefix := 0 79 80 if x.nbytes == 0 { 81 n, err := x.readFull(x.header[:]) 82 if err != nil && n == 0 { 83 return 0, err 84 } 85 prefix = n 86 } 87 88 if isXerialHeader(x.header[:]) { 89 if cap(x.input) < 4 { 90 x.input = make([]byte, 4, defaultBufferSize) 91 } else { 92 x.input = x.input[:4] 93 } 94 95 _, err := x.readFull(x.input) 96 if err != nil { 97 return 0, err 98 } 99 100 frame := int(binary.BigEndian.Uint32(x.input)) 101 if cap(x.input) < frame { 102 x.input = make([]byte, frame, align(frame, defaultBufferSize)) 103 } else { 104 x.input = x.input[:frame] 105 } 106 107 if _, err := x.readFull(x.input); err != nil { 108 return 0, err 109 } 110 } else { 111 if cap(x.input) == 0 { 112 x.input = make([]byte, 0, defaultBufferSize) 113 } else { 114 x.input = x.input[:0] 115 } 116 117 if prefix > 0 { 118 x.input = append(x.input, x.header[:prefix]...) 119 } 120 121 for { 122 if len(x.input) == cap(x.input) { 123 b := make([]byte, len(x.input), 2*cap(x.input)) 124 copy(b, x.input) 125 x.input = b 126 } 127 128 n, err := x.read(x.input[len(x.input):cap(x.input)]) 129 x.input = x.input[:len(x.input)+n] 130 if err != nil { 131 if err == io.EOF && len(x.input) > 0 { 132 break 133 } 134 return 0, err 135 } 136 } 137 } 138 139 var n int 140 var err error 141 142 if x.decode == nil { 143 x.output, x.input, err = x.input, x.output, nil 144 } else if n, err = snappy.DecodedLen(x.input); n <= len(dst) && err == nil { 145 // If the output buffer is large enough to hold the decode value, 146 // write it there directly instead of using the intermediary output 147 // buffer. 148 _, err = x.decode(dst, x.input) 149 } else { 150 var b []byte 151 n = 0 152 b, err = x.decode(x.output[:cap(x.output)], x.input) 153 if err == nil { 154 x.output = b 155 } 156 } 157 158 return n, err 159 } 160 161 func (x *xerialReader) read(b []byte) (int, error) { 162 n, err := x.reader.Read(b) 163 x.nbytes += int64(n) 164 return n, err 165 } 166 167 func (x *xerialReader) readFull(b []byte) (int, error) { 168 n, err := io.ReadFull(x.reader, b) 169 x.nbytes += int64(n) 170 return n, err 171 } 172 173 // An implementation of a xerial-framed snappy-encoded output stream. 174 // Each Write made to the writer is framed with a xerial header. 175 type xerialWriter struct { 176 writer io.Writer 177 header [16]byte 178 input []byte 179 output []byte 180 nbytes int64 181 framed bool 182 encode func([]byte, []byte) []byte 183 } 184 185 func (x *xerialWriter) Reset(w io.Writer) { 186 x.writer = w 187 x.input = x.input[:0] 188 x.output = x.output[:0] 189 x.nbytes = 0 190 } 191 192 func (x *xerialWriter) ReadFrom(r io.Reader) (int64, error) { 193 wn := int64(0) 194 195 if cap(x.input) == 0 { 196 x.input = make([]byte, 0, defaultBufferSize) 197 } 198 199 for { 200 if x.full() { 201 x.grow() 202 } 203 204 n, err := r.Read(x.input[len(x.input):cap(x.input)]) 205 wn += int64(n) 206 x.input = x.input[:len(x.input)+n] 207 208 if x.fullEnough() { 209 if err := x.Flush(); err != nil { 210 return wn, err 211 } 212 } 213 214 if err != nil { 215 if err == io.EOF { 216 err = nil 217 } 218 return wn, err 219 } 220 } 221 } 222 223 func (x *xerialWriter) Write(b []byte) (int, error) { 224 wn := 0 225 226 if cap(x.input) == 0 { 227 x.input = make([]byte, 0, defaultBufferSize) 228 } 229 230 for len(b) > 0 { 231 if x.full() { 232 x.grow() 233 } 234 235 n := copy(x.input[len(x.input):cap(x.input)], b) 236 b = b[n:] 237 wn += n 238 x.input = x.input[:len(x.input)+n] 239 240 if x.fullEnough() { 241 if err := x.Flush(); err != nil { 242 return wn, err 243 } 244 } 245 } 246 247 return wn, nil 248 } 249 250 func (x *xerialWriter) Flush() error { 251 if len(x.input) == 0 { 252 return nil 253 } 254 255 var b []byte 256 if x.encode == nil { 257 b = x.input 258 } else { 259 x.output = x.encode(x.output[:cap(x.output)], x.input) 260 b = x.output 261 } 262 263 x.input = x.input[:0] 264 x.output = x.output[:0] 265 266 if x.framed && x.nbytes == 0 { 267 writeXerialHeader(x.header[:]) 268 _, err := x.write(x.header[:]) 269 if err != nil { 270 return err 271 } 272 } 273 274 if x.framed { 275 writeXerialFrame(x.header[:4], len(b)) 276 _, err := x.write(x.header[:4]) 277 if err != nil { 278 return err 279 } 280 } 281 282 _, err := x.write(b) 283 return err 284 } 285 286 func (x *xerialWriter) write(b []byte) (int, error) { 287 n, err := x.writer.Write(b) 288 x.nbytes += int64(n) 289 return n, err 290 } 291 292 func (x *xerialWriter) full() bool { 293 return len(x.input) == cap(x.input) 294 } 295 296 func (x *xerialWriter) fullEnough() bool { 297 return x.framed && (cap(x.input)-len(x.input)) < 1024 298 } 299 300 func (x *xerialWriter) grow() { 301 tmp := make([]byte, len(x.input), 2*cap(x.input)) 302 copy(tmp, x.input) 303 x.input = tmp 304 } 305 306 func align(n, a int) int { 307 if (n % a) == 0 { 308 return n 309 } 310 return ((n / a) + 1) * a 311 } 312 313 var ( 314 xerialHeader = [...]byte{130, 83, 78, 65, 80, 80, 89, 0} 315 xerialVersionInfo = [...]byte{0, 0, 0, 1, 0, 0, 0, 1} 316 ) 317 318 func isXerialHeader(src []byte) bool { 319 return len(src) >= 16 && bytes.Equal(src[:8], xerialHeader[:]) 320 } 321 322 func writeXerialHeader(b []byte) { 323 copy(b[:8], xerialHeader[:]) 324 copy(b[8:], xerialVersionInfo[:]) 325 } 326 327 func writeXerialFrame(b []byte, n int) { 328 binary.BigEndian.PutUint32(b, uint32(n)) 329 }