github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/deadline/packet_reader_fallback.go (about)

     1  package deadline
     2  
     3  import (
     4  	"net"
     5  	"os"
     6  	"time"
     7  
     8  	"github.com/sagernet/sing/common/atomic"
     9  	"github.com/sagernet/sing/common/buf"
    10  	M "github.com/sagernet/sing/common/metadata"
    11  )
    12  
    13  type fallbackPacketReader struct {
    14  	*packetReader
    15  	disablePipe atomic.Bool
    16  	inRead      atomic.Bool
    17  }
    18  
    19  func NewFallbackPacketReader(timeoutReader TimeoutPacketReader) PacketReader {
    20  	return &fallbackPacketReader{packetReader: NewPacketReader(timeoutReader).(*packetReader)}
    21  }
    22  
    23  func (r *fallbackPacketReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
    24  	select {
    25  	case result := <-r.result:
    26  		return r.pipeReturnFrom(result, p)
    27  	default:
    28  	}
    29  	select {
    30  	case result := <-r.result:
    31  		return r.pipeReturnFrom(result, p)
    32  	case <-r.pipeDeadline.wait():
    33  		return 0, nil, os.ErrDeadlineExceeded
    34  	case <-r.done:
    35  		if r.disablePipe.Load() {
    36  			return r.TimeoutPacketReader.ReadFrom(p)
    37  		} else if r.deadline.Load().IsZero() {
    38  			r.done <- struct{}{}
    39  			r.inRead.Store(true)
    40  			defer r.inRead.Store(false)
    41  			n, addr, err = r.TimeoutPacketReader.ReadFrom(p)
    42  			return
    43  		}
    44  		go r.pipeReadFrom(len(p))
    45  	}
    46  	select {
    47  	case result := <-r.result:
    48  		return r.pipeReturnFrom(result, p)
    49  	case <-r.pipeDeadline.wait():
    50  		return 0, nil, os.ErrDeadlineExceeded
    51  	}
    52  }
    53  
    54  func (r *fallbackPacketReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
    55  	select {
    56  	case result := <-r.result:
    57  		return r.pipeReturnFromBuffer(result, buffer)
    58  	default:
    59  	}
    60  	select {
    61  	case result := <-r.result:
    62  		return r.pipeReturnFromBuffer(result, buffer)
    63  	case <-r.pipeDeadline.wait():
    64  		return M.Socksaddr{}, os.ErrDeadlineExceeded
    65  	case <-r.done:
    66  		if r.disablePipe.Load() {
    67  			return r.TimeoutPacketReader.ReadPacket(buffer)
    68  		} else if r.deadline.Load().IsZero() {
    69  			r.done <- struct{}{}
    70  			r.inRead.Store(true)
    71  			defer r.inRead.Store(false)
    72  			destination, err = r.TimeoutPacketReader.ReadPacket(buffer)
    73  			return
    74  		}
    75  		go r.pipeReadFrom(buffer.FreeLen())
    76  	}
    77  	select {
    78  	case result := <-r.result:
    79  		return r.pipeReturnFromBuffer(result, buffer)
    80  	case <-r.pipeDeadline.wait():
    81  		return M.Socksaddr{}, os.ErrDeadlineExceeded
    82  	}
    83  }
    84  
    85  func (r *fallbackPacketReader) SetReadDeadline(t time.Time) error {
    86  	if r.disablePipe.Load() {
    87  		return r.TimeoutPacketReader.SetReadDeadline(t)
    88  	} else if r.inRead.Load() {
    89  		r.disablePipe.Store(true)
    90  		return r.TimeoutPacketReader.SetReadDeadline(t)
    91  	}
    92  	return r.packetReader.SetReadDeadline(t)
    93  }
    94  
    95  func (r *fallbackPacketReader) ReaderReplaceable() bool {
    96  	return r.disablePipe.Load() || r.packetReader.ReaderReplaceable()
    97  }
    98  
    99  func (r *fallbackPacketReader) UpstreamReader() any {
   100  	return r.packetReader.UpstreamReader()
   101  }