github.com/database64128/shadowsocks-go@v1.7.0/zerocopy/stream.go (about) 1 package zerocopy 2 3 import ( 4 "bytes" 5 "io" 6 ) 7 8 // defaultBufferSize is the default buffer size to use 9 // when neither the reader nor the writer has buffer size requirements. 10 // It's the same default as io.Copy. 11 const defaultBufferSize = 32768 12 13 // ReaderInfo contains information about a reader. 14 type ReaderInfo struct { 15 Headroom Headroom 16 17 // MinPayloadBufferSizePerRead is the minimum size of payload buffer 18 // the ReadZeroCopy method requires for an unbuffered read. 19 // 20 // This is usually required by chunk-based protocols to be able to read 21 // whole chunks without needing internal caching. 22 MinPayloadBufferSizePerRead int 23 } 24 25 // Reader provides a stream interface for reading. 26 type Reader interface { 27 // ReaderInfo returns information about the reader. 28 ReaderInfo() ReaderInfo 29 30 // ReadZeroCopy uses b as buffer space to initiate a read operation. 31 // 32 // b must have at least [ReaderInfo.Headroom.Front] bytes before payloadBufStart 33 // and [ReaderInfo.Headroom.Rear] bytes after payloadBufStart + payloadBufLen. 34 // 35 // payloadBufLen must be at least [ReaderInfo.MinPayloadBufferSizePerRead]. 36 // 37 // The read operation may use the whole space of b. 38 // The actual payload will be confined in [payloadBufStart, payloadBufLen). 39 // 40 // If no error occurs, the returned payload is b[payloadBufStart : payloadBufStart+payloadLen]. 41 ReadZeroCopy(b []byte, payloadBufStart, payloadBufLen int) (payloadLen int, err error) 42 } 43 44 // WriterInfo contains information about a writer. 45 type WriterInfo struct { 46 Headroom Headroom 47 48 // MaxPayloadSizePerWrite is the maximum size of payload 49 // the WriteZeroCopy method can write at a time. 50 // 51 // This is usually required by chunk-based protocols to be able to write 52 // one chunk at a time without needing to break up the payload. 53 // 54 // 0 means no size limit. 55 MaxPayloadSizePerWrite int 56 } 57 58 // Writer provides a stream interface for writing. 59 type Writer interface { 60 // WriterInfo returns information about the writer. 61 WriterInfo() WriterInfo 62 63 // WriteZeroCopy uses b as buffer space to initiate a write operation. 64 // 65 // b must have at least [WriterInfo.Headroom.Front] bytes before payloadBufStart 66 // and [WriterInfo.Headroom.Rear] bytes after payloadBufStart + payloadBufLen. 67 // 68 // payloadLen must not exceed [WriterInfo.MaxPayloadSizePerWrite]. 69 // 70 // The write operation may use the whole space of b. 71 WriteZeroCopy(b []byte, payloadStart, payloadLen int) (payloadWritten int, err error) 72 } 73 74 // DirectReader provides access to the underlying [io.Reader]. 75 type DirectReader interface { 76 // DirectReader returns the underlying reader for direct reads. 77 DirectReader() io.Reader 78 } 79 80 // DirectWriter provides access to the underlying [io.Writer]. 81 type DirectWriter interface { 82 // DirectWriter returns the underlying writer for direct writes. 83 DirectWriter() io.Writer 84 } 85 86 // Relay reads from r and writes to w using zero-copy methods. 87 // It returns the number of bytes transferred, and any error occurred during transfer. 88 func Relay(w Writer, r Reader) (n int64, err error) { 89 // Use direct read/write when possible. 90 if dr, ok := r.(DirectReader); ok { 91 if dw, ok := w.(DirectWriter); ok { 92 r := dr.DirectReader() 93 w := dw.DirectWriter() 94 return io.Copy(w, r) 95 } 96 } 97 98 // Process reader and writer info. 99 ri := r.ReaderInfo() 100 wi := w.WriterInfo() 101 headroom := MaxHeadroom(ri.Headroom, wi.Headroom) 102 103 // Check payload buffer size requirement compatibility. 104 if wi.MaxPayloadSizePerWrite > 0 && ri.MinPayloadBufferSizePerRead > wi.MaxPayloadSizePerWrite { 105 return relayFallback(w, r, headroom.Front, headroom.Rear, ri.MinPayloadBufferSizePerRead, wi.MaxPayloadSizePerWrite) 106 } 107 108 payloadBufSize := ri.MinPayloadBufferSizePerRead 109 if payloadBufSize == 0 { 110 payloadBufSize = wi.MaxPayloadSizePerWrite 111 if payloadBufSize == 0 { 112 payloadBufSize = defaultBufferSize 113 } 114 } 115 116 // Make buffer. 117 b := make([]byte, headroom.Front+payloadBufSize+headroom.Rear) 118 119 // Main relay loop. 120 for { 121 var payloadLen int 122 payloadLen, err = r.ReadZeroCopy(b, headroom.Front, payloadBufSize) 123 if payloadLen == 0 { 124 if err == io.EOF { 125 err = nil 126 } 127 return 128 } 129 130 payloadWritten, werr := w.WriteZeroCopy(b, headroom.Front, payloadLen) 131 n += int64(payloadWritten) 132 if werr != nil { 133 err = werr 134 return 135 } 136 137 if err != nil { 138 if err == io.EOF { 139 err = nil 140 } 141 return 142 } 143 } 144 } 145 146 // relayFallback uses copying to handle situations where the reader requires more payload buffer space than the writer can handle in one write call. 147 func relayFallback(w Writer, r Reader, frontHeadroom, rearHeadroom, readMaxPayloadSize, writeMaxPayloadSize int) (n int64, err error) { 148 br := make([]byte, frontHeadroom+readMaxPayloadSize+rearHeadroom) 149 bw := make([]byte, frontHeadroom+writeMaxPayloadSize+rearHeadroom) 150 151 for { 152 var payloadLen int 153 payloadLen, err = r.ReadZeroCopy(br, frontHeadroom, readMaxPayloadSize) 154 if payloadLen == 0 { 155 if err == io.EOF { 156 err = nil 157 } 158 return 159 } 160 161 // Short-circuit to avoid copying if payload can fit in one write. 162 if payloadLen <= writeMaxPayloadSize { 163 payloadWritten, werr := w.WriteZeroCopy(br, frontHeadroom, payloadLen) 164 n += int64(payloadWritten) 165 if werr != nil { 166 err = werr 167 } 168 if err != nil { 169 return 170 } 171 continue 172 } 173 174 // Loop until all of br[frontHeadroom : frontHeadroom+payloadLen] is written. 175 for i, j := 0, 0; i < payloadLen; i += j { 176 j = copy(bw[frontHeadroom:frontHeadroom+writeMaxPayloadSize], br[frontHeadroom+i:frontHeadroom+payloadLen]) 177 payloadWritten, werr := w.WriteZeroCopy(bw, frontHeadroom, j) 178 n += int64(payloadWritten) 179 if werr != nil { 180 err = werr 181 return 182 } 183 } 184 185 if err != nil { 186 if err == io.EOF { 187 err = nil 188 } 189 return 190 } 191 } 192 } 193 194 // CloseRead provides the CloseRead method. 195 type CloseRead interface { 196 // CloseRead indicates to the underlying reader that no further reads will happen. 197 CloseRead() error 198 } 199 200 // CloseWrite provides the CloseWrite method. 201 type CloseWrite interface { 202 // CloseWrite indicates to the underlying writer that no further writes will happen. 203 CloseWrite() error 204 } 205 206 // ReadWriter provides a stream interface for reading and writing. 207 type ReadWriter interface { 208 Reader 209 Writer 210 CloseRead 211 CloseWrite 212 io.Closer 213 } 214 215 // TwoWayRelay relays data between left and right using zero-copy methods. 216 // It returns the number of bytes sent from left to right, from right to left, 217 // and any error occurred during transfer. 218 func TwoWayRelay(left, right ReadWriter) (nl2r, nr2l int64, err error) { 219 var l2rErr error 220 ctrlCh := make(chan struct{}) 221 222 go func() { 223 nl2r, l2rErr = Relay(right, left) 224 right.CloseWrite() 225 ctrlCh <- struct{}{} 226 }() 227 228 nr2l, err = Relay(left, right) 229 left.CloseWrite() 230 <-ctrlCh 231 232 if l2rErr != nil { 233 err = l2rErr 234 } 235 return 236 } 237 238 // DirectReadWriteCloser extends io.ReadWriteCloser with CloseRead and CloseWrite. 239 type DirectReadWriteCloser interface { 240 io.ReadWriteCloser 241 CloseRead 242 CloseWrite 243 } 244 245 // DirectTwoWayRelay relays data between left and right using [io.Copy]. 246 // It returns the number of bytes sent from left to right, from right to left, 247 // and any error occurred during transfer. 248 func DirectTwoWayRelay(left, right DirectReadWriteCloser) (nl2r, nr2l int64, err error) { 249 var l2rErr error 250 ctrlCh := make(chan struct{}) 251 252 go func() { 253 nl2r, l2rErr = io.Copy(right, left) 254 right.CloseWrite() 255 ctrlCh <- struct{}{} 256 }() 257 258 nr2l, err = io.Copy(left, right) 259 left.CloseWrite() 260 <-ctrlCh 261 262 if l2rErr != nil { 263 err = l2rErr 264 } 265 return 266 } 267 268 // DirectReadWriteCloserOpener provides the Open method to open a [DirectReadWriteCloser]. 269 type DirectReadWriteCloserOpener interface { 270 // Open opens a [DirectReadWriteCloser] with the specified initial payload. 271 Open(b []byte) (DirectReadWriteCloser, error) 272 } 273 274 // SimpleDirectReadWriteCloserOpener wraps a [DirectReadWriteCloser] for the Open method to return. 275 type SimpleDirectReadWriteCloserOpener struct { 276 DirectReadWriteCloser 277 } 278 279 // Open implements the DirectReadWriteCloserOpener Open method. 280 func (o *SimpleDirectReadWriteCloserOpener) Open(b []byte) (DirectReadWriteCloser, error) { 281 _, err := o.DirectReadWriteCloser.Write(b) 282 return o.DirectReadWriteCloser, err 283 } 284 285 // ReadWriterTestFunc tests the left and right ReadWriters by performing 2 writes 286 // on each ReadWriter and validating the read results. 287 // 288 // The left and right ReadWriters must be connected with a duplex pipe. 289 func ReadWriterTestFunc(t tester, l, r ReadWriter) { 290 defer r.Close() 291 defer l.Close() 292 293 var ( 294 hello = []byte{'h', 'e', 'l', 'l', 'o'} 295 world = []byte{'w', 'o', 'r', 'l', 'd'} 296 ) 297 298 lri := l.ReaderInfo() 299 lwi := l.WriterInfo() 300 lwmax := lwi.MaxPayloadSizePerWrite 301 if lwmax == 0 { 302 lwmax = 5 303 } 304 lrmin := lri.MinPayloadBufferSizePerRead 305 if lrmin == 0 { 306 lrmin = 5 307 } 308 lwbuf := make([]byte, lwi.Headroom.Front+lwmax+lwi.Headroom.Rear) 309 lrbuf := make([]byte, lri.Headroom.Front+lrmin+lri.Headroom.Rear) 310 311 rri := r.ReaderInfo() 312 rwi := r.WriterInfo() 313 rwmax := rwi.MaxPayloadSizePerWrite 314 if rwmax == 0 { 315 rwmax = 5 316 } 317 rrmin := rri.MinPayloadBufferSizePerRead 318 if rrmin == 0 { 319 rrmin = 5 320 } 321 rwbuf := make([]byte, rwi.Headroom.Front+rwmax+rwi.Headroom.Rear) 322 rrbuf := make([]byte, rri.Headroom.Front+rrmin+rri.Headroom.Rear) 323 324 ctrlCh := make(chan struct{}) 325 326 // Start read goroutines. 327 go func() { 328 pl, err := l.ReadZeroCopy(lrbuf, lri.Headroom.Front, lrmin) 329 if err != nil { 330 t.Error(err) 331 } 332 if pl != 5 { 333 t.Errorf("Expected payloadLen 5, got %d", pl) 334 } 335 p := lrbuf[lri.Headroom.Front : lri.Headroom.Front+pl] 336 if !bytes.Equal(p, world) { 337 t.Errorf("Expected payload %v, got %v", world, p) 338 } 339 340 pl, err = l.ReadZeroCopy(lrbuf, lri.Headroom.Front, lrmin) 341 if err != nil { 342 t.Error(err) 343 } 344 if pl != 5 { 345 t.Errorf("Expected payloadLen 5, got %d", pl) 346 } 347 p = lrbuf[lri.Headroom.Front : lri.Headroom.Front+pl] 348 if !bytes.Equal(p, hello) { 349 t.Errorf("Expected payload %v, got %v", hello, p) 350 } 351 352 pl, err = l.ReadZeroCopy(lrbuf, lri.Headroom.Front, lrmin) 353 if err != io.EOF { 354 t.Errorf("Expected io.EOF, got %v", err) 355 } 356 if pl != 0 { 357 t.Errorf("Expected payloadLen 0, got %v", pl) 358 } 359 360 ctrlCh <- struct{}{} 361 }() 362 363 go func() { 364 pl, err := r.ReadZeroCopy(rrbuf, rri.Headroom.Front, rrmin) 365 if err != nil { 366 t.Error(err) 367 } 368 if pl != 5 { 369 t.Errorf("Expected payloadLen 5, got %d", pl) 370 } 371 p := rrbuf[rri.Headroom.Front : rri.Headroom.Front+pl] 372 if !bytes.Equal(p, hello) { 373 t.Errorf("Expected payload %v, got %v", hello, p) 374 } 375 376 pl, err = r.ReadZeroCopy(rrbuf, rri.Headroom.Front, rrmin) 377 if err != nil { 378 t.Error(err) 379 } 380 if pl != 5 { 381 t.Errorf("Expected payloadLen 5, got %d", pl) 382 } 383 p = rrbuf[rri.Headroom.Front : rri.Headroom.Front+pl] 384 if !bytes.Equal(p, world) { 385 t.Errorf("Expected payload %v, got %v", world, p) 386 } 387 388 pl, err = r.ReadZeroCopy(rrbuf, rri.Headroom.Front, rrmin) 389 if err != io.EOF { 390 t.Errorf("Expected io.EOF, got %v", err) 391 } 392 if pl != 0 { 393 t.Errorf("Expected payloadLen 0, got %v", pl) 394 } 395 396 ctrlCh <- struct{}{} 397 }() 398 399 // Write from left to right. 400 n := copy(lwbuf[lwi.Headroom.Front:], hello) 401 written, err := l.WriteZeroCopy(lwbuf, lwi.Headroom.Front, n) 402 if err != nil { 403 t.Error(err) 404 } 405 if written != n { 406 t.Errorf("Expected bytes written: %d, got %d", n, written) 407 } 408 409 n = copy(lwbuf[lwi.Headroom.Front:], world) 410 written, err = l.WriteZeroCopy(lwbuf, lwi.Headroom.Front, n) 411 if err != nil { 412 t.Error(err) 413 } 414 if written != n { 415 t.Errorf("Expected bytes written: %d, got %d", n, written) 416 } 417 418 err = l.CloseWrite() 419 if err != nil { 420 t.Error(err) 421 } 422 423 // Write from right to left. 424 n = copy(rwbuf[rwi.Headroom.Front:], world) 425 written, err = r.WriteZeroCopy(rwbuf, rwi.Headroom.Front, n) 426 if err != nil { 427 t.Error(err) 428 } 429 if written != n { 430 t.Errorf("Expected bytes written: %d, got %d", n, written) 431 } 432 433 n = copy(rwbuf[rwi.Headroom.Front:], hello) 434 written, err = r.WriteZeroCopy(rwbuf, rwi.Headroom.Front, n) 435 if err != nil { 436 t.Error(err) 437 } 438 if written != n { 439 t.Errorf("Expected bytes written: %d, got %d", n, written) 440 } 441 442 err = r.CloseWrite() 443 if err != nil { 444 t.Error(err) 445 } 446 447 <-ctrlCh 448 <-ctrlCh 449 } 450 451 // CopyReadWriter wraps a ReadWriter and provides the io.ReadWriter Read and Write methods 452 // by copying from and to internal buffers and using the zerocopy methods on them. 453 // 454 // The io.ReaderFrom ReadFrom method is implemented using the internal write buffer without copying. 455 type CopyReadWriter struct { 456 ReadWriter 457 458 readHeadroom Headroom 459 writeHeadroom Headroom 460 461 readBuf []byte 462 readBufStart int 463 readBufLength int 464 465 writeBuf []byte 466 } 467 468 func NewCopyReadWriter(rw ReadWriter) *CopyReadWriter { 469 ri := rw.ReaderInfo() 470 wi := rw.WriterInfo() 471 472 readBufSize := ri.MinPayloadBufferSizePerRead 473 if readBufSize == 0 { 474 readBufSize = defaultBufferSize 475 } 476 477 writeBufSize := wi.MaxPayloadSizePerWrite 478 if writeBufSize == 0 { 479 writeBufSize = defaultBufferSize 480 } 481 482 return &CopyReadWriter{ 483 ReadWriter: rw, 484 readHeadroom: ri.Headroom, 485 writeHeadroom: wi.Headroom, 486 readBuf: make([]byte, ri.Headroom.Front+readBufSize+ri.Headroom.Front), 487 writeBuf: make([]byte, wi.Headroom.Front+writeBufSize+wi.Headroom.Rear), 488 } 489 } 490 491 // Read implements the io.Reader Read method. 492 func (rw *CopyReadWriter) Read(b []byte) (n int, err error) { 493 if rw.readBufLength == 0 { 494 rw.readBufStart = rw.readHeadroom.Front 495 rw.readBufLength = len(rw.readBuf) - rw.readHeadroom.Front - rw.readHeadroom.Rear 496 rw.readBufLength, err = rw.ReadWriter.ReadZeroCopy(rw.readBuf, rw.readBufStart, rw.readBufLength) 497 if err != nil { 498 return 499 } 500 } 501 502 n = copy(b, rw.readBuf[rw.readBufStart:rw.readBufStart+rw.readBufLength]) 503 rw.readBufStart += n 504 rw.readBufLength -= n 505 return n, nil 506 } 507 508 // Write implements the io.Writer Write method. 509 func (rw *CopyReadWriter) Write(b []byte) (n int, err error) { 510 payloadBuf := rw.writeBuf[rw.writeHeadroom.Front : len(rw.writeBuf)-rw.writeHeadroom.Rear] 511 512 for n < len(b) { 513 payloadLength := copy(payloadBuf, b[n:]) 514 var payloadWritten int 515 payloadWritten, err = rw.ReadWriter.WriteZeroCopy(rw.writeBuf, rw.writeHeadroom.Front, payloadLength) 516 n += payloadWritten 517 if err != nil { 518 return 519 } 520 } 521 522 return 523 } 524 525 // ReadFrom implements the io.ReaderFrom ReadFrom method. 526 func (rw *CopyReadWriter) ReadFrom(r io.Reader) (n int64, err error) { 527 for { 528 nr, err := r.Read(rw.writeBuf[rw.writeHeadroom.Front : len(rw.writeBuf)-rw.writeHeadroom.Rear]) 529 n += int64(nr) 530 switch err { 531 case nil: 532 case io.EOF: 533 return n, nil 534 default: 535 return n, err 536 } 537 538 _, err = rw.ReadWriter.WriteZeroCopy(rw.writeBuf, rw.writeHeadroom.Front, nr) 539 if err != nil { 540 return n, err 541 } 542 } 543 } 544 545 func CopyWriteOnce(w Writer, b []byte) (n int, err error) { 546 wi := w.WriterInfo() 547 writeBufSize := wi.MaxPayloadSizePerWrite 548 if writeBufSize == 0 { 549 writeBufSize = defaultBufferSize 550 } 551 if writeBufSize > len(b) { 552 writeBufSize = len(b) 553 } 554 555 writeBuf := make([]byte, wi.Headroom.Front+writeBufSize+wi.Headroom.Rear) 556 payloadBuf := writeBuf[wi.Headroom.Front : wi.Headroom.Front+writeBufSize] 557 558 for n < len(b) { 559 payloadLength := copy(payloadBuf, b[n:]) 560 var payloadWritten int 561 payloadWritten, err = w.WriteZeroCopy(writeBuf, wi.Headroom.Front, payloadLength) 562 n += payloadWritten 563 if err != nil { 564 return 565 } 566 } 567 568 return 569 }