github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/zerocopy/stream.go (about) 1 package zerocopy 2 3 import ( 4 "bytes" 5 "context" 6 "io" 7 "sync" 8 ) 9 10 // defaultBufferSize is the default buffer size to use 11 // when neither the reader nor the writer has buffer size requirements. 12 // It's the same default as io.Copy. 13 const defaultBufferSize = 32768 14 15 // ReaderInfo contains information about a reader. 16 type ReaderInfo struct { 17 Headroom Headroom 18 19 // MinPayloadBufferSizePerRead is the minimum size of payload buffer 20 // the ReadZeroCopy method requires for an unbuffered read. 21 // 22 // This is usually required by chunk-based protocols to be able to read 23 // whole chunks without needing internal caching. 24 MinPayloadBufferSizePerRead int 25 } 26 27 // Reader provides a stream interface for reading. 28 type Reader interface { 29 // ReaderInfo returns information about the reader. 30 ReaderInfo() ReaderInfo 31 32 // ReadZeroCopy uses b as buffer space to initiate a read operation. 33 // 34 // b must have at least [ReaderInfo.Headroom.Front] bytes before payloadBufStart 35 // and [ReaderInfo.Headroom.Rear] bytes after payloadBufStart + payloadBufLen. 36 // 37 // payloadBufLen must be at least [ReaderInfo.MinPayloadBufferSizePerRead]. 38 // 39 // The read operation may use the whole space of b. 40 // The actual payload will be confined in [payloadBufStart, payloadBufLen). 41 // 42 // If no error occurs, the returned payload is b[payloadBufStart : payloadBufStart+payloadLen]. 43 ReadZeroCopy(b []byte, payloadBufStart, payloadBufLen int) (payloadLen int, err error) 44 } 45 46 // WriterInfo contains information about a writer. 47 type WriterInfo struct { 48 Headroom Headroom 49 50 // MaxPayloadSizePerWrite is the maximum size of payload 51 // the WriteZeroCopy method can write at a time. 52 // 53 // This is usually required by chunk-based protocols to be able to write 54 // one chunk at a time without needing to break up the payload. 55 // 56 // 0 means no size limit. 57 MaxPayloadSizePerWrite int 58 } 59 60 // Writer provides a stream interface for writing. 61 type Writer interface { 62 // WriterInfo returns information about the writer. 63 WriterInfo() WriterInfo 64 65 // WriteZeroCopy uses b as buffer space to initiate a write operation. 66 // 67 // b must have at least [WriterInfo.Headroom.Front] bytes before payloadBufStart 68 // and [WriterInfo.Headroom.Rear] bytes after payloadBufStart + payloadBufLen. 69 // 70 // payloadLen must not exceed [WriterInfo.MaxPayloadSizePerWrite]. 71 // 72 // The write operation may use the whole space of b. 73 WriteZeroCopy(b []byte, payloadStart, payloadLen int) (payloadWritten int, err error) 74 } 75 76 // DirectReader provides access to the underlying [io.Reader]. 77 type DirectReader interface { 78 // DirectReader returns the underlying reader for direct reads. 79 DirectReader() io.Reader 80 } 81 82 // DirectWriter provides access to the underlying [io.Writer]. 83 type DirectWriter interface { 84 // DirectWriter returns the underlying writer for direct writes. 85 DirectWriter() io.Writer 86 } 87 88 // Relay reads from r and writes to w using zero-copy methods. 89 // It returns the number of bytes transferred, and any error occurred during transfer. 90 func Relay(w Writer, r Reader) (n int64, err error) { 91 // Use direct read/write when possible. 92 if dr, ok := r.(DirectReader); ok { 93 if dw, ok := w.(DirectWriter); ok { 94 r := dr.DirectReader() 95 w := dw.DirectWriter() 96 return io.Copy(w, r) 97 } 98 } 99 100 // Process reader and writer info. 101 ri := r.ReaderInfo() 102 wi := w.WriterInfo() 103 headroom := MaxHeadroom(ri.Headroom, wi.Headroom) 104 105 // Check payload buffer size requirement compatibility. 106 if wi.MaxPayloadSizePerWrite > 0 && ri.MinPayloadBufferSizePerRead > wi.MaxPayloadSizePerWrite { 107 return relayFallback(w, r, headroom.Front, headroom.Rear, ri.MinPayloadBufferSizePerRead, wi.MaxPayloadSizePerWrite) 108 } 109 110 payloadBufSize := ri.MinPayloadBufferSizePerRead 111 if payloadBufSize == 0 { 112 payloadBufSize = wi.MaxPayloadSizePerWrite 113 if payloadBufSize == 0 { 114 payloadBufSize = defaultBufferSize 115 } 116 } 117 118 // Make buffer. 119 b := make([]byte, headroom.Front+payloadBufSize+headroom.Rear) 120 121 // Main relay loop. 122 for { 123 var payloadLen int 124 payloadLen, err = r.ReadZeroCopy(b, headroom.Front, payloadBufSize) 125 if payloadLen == 0 { 126 if err == io.EOF { 127 err = nil 128 } 129 return 130 } 131 132 payloadWritten, werr := w.WriteZeroCopy(b, headroom.Front, payloadLen) 133 n += int64(payloadWritten) 134 if werr != nil { 135 err = werr 136 return 137 } 138 139 if err != nil { 140 if err == io.EOF { 141 err = nil 142 } 143 return 144 } 145 } 146 } 147 148 // relayFallback uses copying to handle situations where the reader requires more payload buffer space than the writer can handle in one write call. 149 func relayFallback(w Writer, r Reader, frontHeadroom, rearHeadroom, readMaxPayloadSize, writeMaxPayloadSize int) (n int64, err error) { 150 br := make([]byte, frontHeadroom+readMaxPayloadSize+rearHeadroom) 151 bw := make([]byte, frontHeadroom+writeMaxPayloadSize+rearHeadroom) 152 153 for { 154 var payloadLen int 155 payloadLen, err = r.ReadZeroCopy(br, frontHeadroom, readMaxPayloadSize) 156 if payloadLen == 0 { 157 if err == io.EOF { 158 err = nil 159 } 160 return 161 } 162 163 // Short-circuit to avoid copying if payload can fit in one write. 164 if payloadLen <= writeMaxPayloadSize { 165 payloadWritten, werr := w.WriteZeroCopy(br, frontHeadroom, payloadLen) 166 n += int64(payloadWritten) 167 if werr != nil { 168 err = werr 169 } 170 if err != nil { 171 return 172 } 173 continue 174 } 175 176 // Loop until all of br[frontHeadroom : frontHeadroom+payloadLen] is written. 177 for i, j := 0, 0; i < payloadLen; i += j { 178 j = copy(bw[frontHeadroom:frontHeadroom+writeMaxPayloadSize], br[frontHeadroom+i:frontHeadroom+payloadLen]) 179 payloadWritten, werr := w.WriteZeroCopy(bw, frontHeadroom, j) 180 n += int64(payloadWritten) 181 if werr != nil { 182 err = werr 183 return 184 } 185 } 186 187 if err != nil { 188 if err == io.EOF { 189 err = nil 190 } 191 return 192 } 193 } 194 } 195 196 // CloseRead provides the CloseRead method. 197 type CloseRead interface { 198 // CloseRead indicates to the underlying reader that no further reads will happen. 199 CloseRead() error 200 } 201 202 // CloseWrite provides the CloseWrite method. 203 type CloseWrite interface { 204 // CloseWrite indicates to the underlying writer that no further writes will happen. 205 CloseWrite() error 206 } 207 208 // ReadWriter provides a stream interface for reading and writing. 209 type ReadWriter interface { 210 Reader 211 Writer 212 CloseRead 213 CloseWrite 214 io.Closer 215 } 216 217 // TwoWayRelay relays data between left and right using zero-copy methods. 218 // It returns the number of bytes sent from left to right, from right to left, 219 // and any error occurred during transfer. 220 func TwoWayRelay(left, right ReadWriter) (nl2r, nr2l int64, err error) { 221 var l2rErr error 222 l2rDone := make(chan struct{}) 223 224 go func() { 225 nl2r, l2rErr = Relay(right, left) 226 right.CloseWrite() 227 close(l2rDone) 228 }() 229 230 nr2l, err = Relay(left, right) 231 left.CloseWrite() 232 <-l2rDone 233 234 if l2rErr != nil { 235 err = l2rErr 236 } 237 return 238 } 239 240 // DirectReadWriteCloser extends io.ReadWriteCloser with CloseRead and CloseWrite. 241 type DirectReadWriteCloser interface { 242 io.ReadWriteCloser 243 CloseRead 244 CloseWrite 245 } 246 247 // DirectTwoWayRelay relays data between left and right using [io.Copy]. 248 // It returns the number of bytes sent from left to right, from right to left, 249 // and any error occurred during transfer. 250 func DirectTwoWayRelay(left, right DirectReadWriteCloser) (nl2r, nr2l int64, err error) { 251 var l2rErr error 252 l2rDone := make(chan struct{}) 253 254 go func() { 255 nl2r, l2rErr = io.Copy(right, left) 256 right.CloseWrite() 257 close(l2rDone) 258 }() 259 260 nr2l, err = io.Copy(left, right) 261 left.CloseWrite() 262 <-l2rDone 263 264 if l2rErr != nil { 265 err = l2rErr 266 } 267 return 268 } 269 270 // DirectReadWriteCloserOpener provides the Open method to open a [DirectReadWriteCloser]. 271 type DirectReadWriteCloserOpener interface { 272 // Open opens a [DirectReadWriteCloser] with the specified initial payload. 273 Open(ctx context.Context, b []byte) (DirectReadWriteCloser, error) 274 } 275 276 // SimpleDirectReadWriteCloserOpener wraps a [DirectReadWriteCloser] for the Open method to return. 277 type SimpleDirectReadWriteCloserOpener struct { 278 DirectReadWriteCloser 279 } 280 281 // Open implements the DirectReadWriteCloserOpener Open method. 282 func (o *SimpleDirectReadWriteCloserOpener) Open(ctx context.Context, b []byte) (DirectReadWriteCloser, error) { 283 _, err := o.DirectReadWriteCloser.Write(b) 284 return o.DirectReadWriteCloser, err 285 } 286 287 // ReadWriterTestFunc tests the left and right ReadWriters by performing 2 writes 288 // on each ReadWriter and validating the read results. 289 // 290 // The left and right ReadWriters must be connected with a duplex pipe. 291 func ReadWriterTestFunc(t tester, l, r ReadWriter) { 292 defer r.Close() 293 defer l.Close() 294 295 var ( 296 hello = []byte{'h', 'e', 'l', 'l', 'o'} 297 world = []byte{'w', 'o', 'r', 'l', 'd'} 298 ) 299 300 lri := l.ReaderInfo() 301 lwi := l.WriterInfo() 302 lwmax := lwi.MaxPayloadSizePerWrite 303 if lwmax == 0 { 304 lwmax = 5 305 } 306 lrmin := lri.MinPayloadBufferSizePerRead 307 if lrmin == 0 { 308 lrmin = 5 309 } 310 lwbuf := make([]byte, lwi.Headroom.Front+lwmax+lwi.Headroom.Rear) 311 lrbuf := make([]byte, lri.Headroom.Front+lrmin+lri.Headroom.Rear) 312 313 rri := r.ReaderInfo() 314 rwi := r.WriterInfo() 315 rwmax := rwi.MaxPayloadSizePerWrite 316 if rwmax == 0 { 317 rwmax = 5 318 } 319 rrmin := rri.MinPayloadBufferSizePerRead 320 if rrmin == 0 { 321 rrmin = 5 322 } 323 rwbuf := make([]byte, rwi.Headroom.Front+rwmax+rwi.Headroom.Rear) 324 rrbuf := make([]byte, rri.Headroom.Front+rrmin+rri.Headroom.Rear) 325 326 var wg sync.WaitGroup 327 wg.Add(2) 328 329 // Start read goroutines. 330 go func() { 331 defer wg.Done() 332 333 pl, err := l.ReadZeroCopy(lrbuf, lri.Headroom.Front, lrmin) 334 if err != nil { 335 t.Error(err) 336 } 337 if pl != 5 { 338 t.Errorf("Expected payloadLen 5, got %d", pl) 339 } 340 p := lrbuf[lri.Headroom.Front : lri.Headroom.Front+pl] 341 if !bytes.Equal(p, world) { 342 t.Errorf("Expected payload %v, got %v", world, p) 343 } 344 345 pl, err = l.ReadZeroCopy(lrbuf, lri.Headroom.Front, lrmin) 346 if err != nil { 347 t.Error(err) 348 } 349 if pl != 5 { 350 t.Errorf("Expected payloadLen 5, got %d", pl) 351 } 352 p = lrbuf[lri.Headroom.Front : lri.Headroom.Front+pl] 353 if !bytes.Equal(p, hello) { 354 t.Errorf("Expected payload %v, got %v", hello, p) 355 } 356 357 pl, err = l.ReadZeroCopy(lrbuf, lri.Headroom.Front, lrmin) 358 if err != io.EOF { 359 t.Errorf("Expected io.EOF, got %v", err) 360 } 361 if pl != 0 { 362 t.Errorf("Expected payloadLen 0, got %v", pl) 363 } 364 }() 365 366 go func() { 367 defer wg.Done() 368 369 pl, err := r.ReadZeroCopy(rrbuf, rri.Headroom.Front, rrmin) 370 if err != nil { 371 t.Error(err) 372 } 373 if pl != 5 { 374 t.Errorf("Expected payloadLen 5, got %d", pl) 375 } 376 p := rrbuf[rri.Headroom.Front : rri.Headroom.Front+pl] 377 if !bytes.Equal(p, hello) { 378 t.Errorf("Expected payload %v, got %v", hello, p) 379 } 380 381 pl, err = r.ReadZeroCopy(rrbuf, rri.Headroom.Front, rrmin) 382 if err != nil { 383 t.Error(err) 384 } 385 if pl != 5 { 386 t.Errorf("Expected payloadLen 5, got %d", pl) 387 } 388 p = rrbuf[rri.Headroom.Front : rri.Headroom.Front+pl] 389 if !bytes.Equal(p, world) { 390 t.Errorf("Expected payload %v, got %v", world, p) 391 } 392 393 pl, err = r.ReadZeroCopy(rrbuf, rri.Headroom.Front, rrmin) 394 if err != io.EOF { 395 t.Errorf("Expected io.EOF, got %v", err) 396 } 397 if pl != 0 { 398 t.Errorf("Expected payloadLen 0, got %v", pl) 399 } 400 }() 401 402 // Write from left to right. 403 n := copy(lwbuf[lwi.Headroom.Front:], hello) 404 written, err := l.WriteZeroCopy(lwbuf, lwi.Headroom.Front, n) 405 if err != nil { 406 t.Error(err) 407 } 408 if written != n { 409 t.Errorf("Expected bytes written: %d, got %d", n, written) 410 } 411 412 n = copy(lwbuf[lwi.Headroom.Front:], world) 413 written, err = l.WriteZeroCopy(lwbuf, lwi.Headroom.Front, n) 414 if err != nil { 415 t.Error(err) 416 } 417 if written != n { 418 t.Errorf("Expected bytes written: %d, got %d", n, written) 419 } 420 421 err = l.CloseWrite() 422 if err != nil { 423 t.Error(err) 424 } 425 426 // Write from right to left. 427 n = copy(rwbuf[rwi.Headroom.Front:], world) 428 written, err = r.WriteZeroCopy(rwbuf, rwi.Headroom.Front, n) 429 if err != nil { 430 t.Error(err) 431 } 432 if written != n { 433 t.Errorf("Expected bytes written: %d, got %d", n, written) 434 } 435 436 n = copy(rwbuf[rwi.Headroom.Front:], hello) 437 written, err = r.WriteZeroCopy(rwbuf, rwi.Headroom.Front, n) 438 if err != nil { 439 t.Error(err) 440 } 441 if written != n { 442 t.Errorf("Expected bytes written: %d, got %d", n, written) 443 } 444 445 err = r.CloseWrite() 446 if err != nil { 447 t.Error(err) 448 } 449 450 wg.Wait() 451 } 452 453 // CopyReadWriter wraps a ReadWriter and provides the io.ReadWriter Read and Write methods 454 // by copying from and to internal buffers and using the zerocopy methods on them. 455 // 456 // The io.ReaderFrom ReadFrom method is implemented using the internal write buffer without copying. 457 type CopyReadWriter struct { 458 ReadWriter 459 460 readHeadroom Headroom 461 writeHeadroom Headroom 462 463 readBuf []byte 464 readBufStart int 465 readBufLength int 466 467 writeBuf []byte 468 } 469 470 func NewCopyReadWriter(rw ReadWriter) *CopyReadWriter { 471 ri := rw.ReaderInfo() 472 wi := rw.WriterInfo() 473 474 readBufSize := ri.MinPayloadBufferSizePerRead 475 if readBufSize == 0 { 476 readBufSize = defaultBufferSize 477 } 478 479 writeBufSize := wi.MaxPayloadSizePerWrite 480 if writeBufSize == 0 { 481 writeBufSize = defaultBufferSize 482 } 483 484 return &CopyReadWriter{ 485 ReadWriter: rw, 486 readHeadroom: ri.Headroom, 487 writeHeadroom: wi.Headroom, 488 readBuf: make([]byte, ri.Headroom.Front+readBufSize+ri.Headroom.Front), 489 writeBuf: make([]byte, wi.Headroom.Front+writeBufSize+wi.Headroom.Rear), 490 } 491 } 492 493 // Read implements the io.Reader Read method. 494 func (rw *CopyReadWriter) Read(b []byte) (n int, err error) { 495 if rw.readBufLength == 0 { 496 rw.readBufStart = rw.readHeadroom.Front 497 rw.readBufLength = len(rw.readBuf) - rw.readHeadroom.Front - rw.readHeadroom.Rear 498 rw.readBufLength, err = rw.ReadWriter.ReadZeroCopy(rw.readBuf, rw.readBufStart, rw.readBufLength) 499 if err != nil { 500 return 501 } 502 } 503 504 n = copy(b, rw.readBuf[rw.readBufStart:rw.readBufStart+rw.readBufLength]) 505 rw.readBufStart += n 506 rw.readBufLength -= n 507 return n, nil 508 } 509 510 // Write implements the io.Writer Write method. 511 func (rw *CopyReadWriter) Write(b []byte) (n int, err error) { 512 payloadBuf := rw.writeBuf[rw.writeHeadroom.Front : len(rw.writeBuf)-rw.writeHeadroom.Rear] 513 514 for n < len(b) { 515 payloadLength := copy(payloadBuf, b[n:]) 516 var payloadWritten int 517 payloadWritten, err = rw.ReadWriter.WriteZeroCopy(rw.writeBuf, rw.writeHeadroom.Front, payloadLength) 518 n += payloadWritten 519 if err != nil { 520 return 521 } 522 } 523 524 return 525 } 526 527 // ReadFrom implements the io.ReaderFrom ReadFrom method. 528 func (rw *CopyReadWriter) ReadFrom(r io.Reader) (n int64, err error) { 529 for { 530 nr, err := r.Read(rw.writeBuf[rw.writeHeadroom.Front : len(rw.writeBuf)-rw.writeHeadroom.Rear]) 531 n += int64(nr) 532 switch err { 533 case nil: 534 case io.EOF: 535 return n, nil 536 default: 537 return n, err 538 } 539 540 _, err = rw.ReadWriter.WriteZeroCopy(rw.writeBuf, rw.writeHeadroom.Front, nr) 541 if err != nil { 542 return n, err 543 } 544 } 545 } 546 547 func CopyWriteOnce(w Writer, b []byte) (n int, err error) { 548 wi := w.WriterInfo() 549 writeBufSize := wi.MaxPayloadSizePerWrite 550 if writeBufSize == 0 { 551 writeBufSize = defaultBufferSize 552 } 553 if writeBufSize > len(b) { 554 writeBufSize = len(b) 555 } 556 557 writeBuf := make([]byte, wi.Headroom.Front+writeBufSize+wi.Headroom.Rear) 558 payloadBuf := writeBuf[wi.Headroom.Front : wi.Headroom.Front+writeBufSize] 559 560 for n < len(b) { 561 payloadLength := copy(payloadBuf, b[n:]) 562 var payloadWritten int 563 payloadWritten, err = w.WriteZeroCopy(writeBuf, wi.Headroom.Front, payloadLength) 564 n += payloadWritten 565 if err != nil { 566 return 567 } 568 } 569 570 return 571 }