github.com/sagernet/sing@v0.2.6/common/bufio/deadline/packet_reader.go (about)

     1  package deadline
     2  
     3  import (
     4  	"net"
     5  	"os"
     6  	"time"
     7  
     8  	"github.com/sagernet/sing/common/atomic"
     9  	"github.com/sagernet/sing/common/buf"
    10  	M "github.com/sagernet/sing/common/metadata"
    11  	N "github.com/sagernet/sing/common/network"
    12  )
    13  
    14  type TimeoutPacketReader interface {
    15  	N.NetPacketReader
    16  	SetReadDeadline(t time.Time) error
    17  }
    18  
    19  type PacketReader interface {
    20  	TimeoutPacketReader
    21  	N.WithUpstreamReader
    22  	N.ReaderWithUpstream
    23  }
    24  
    25  type packetReader struct {
    26  	TimeoutPacketReader
    27  	deadline     atomic.TypedValue[time.Time]
    28  	pipeDeadline pipeDeadline
    29  	result       chan *packetReadResult
    30  	done         chan struct{}
    31  }
    32  
    33  type packetReadResult struct {
    34  	buffer      *buf.Buffer
    35  	destination M.Socksaddr
    36  	err         error
    37  }
    38  
    39  func NewPacketReader(timeoutReader TimeoutPacketReader) PacketReader {
    40  	return &packetReader{
    41  		TimeoutPacketReader: timeoutReader,
    42  		pipeDeadline:        makePipeDeadline(),
    43  		result:              make(chan *packetReadResult, 1),
    44  		done:                makeFilledChan(),
    45  	}
    46  }
    47  
    48  func (r *packetReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
    49  	select {
    50  	case result := <-r.result:
    51  		return r.pipeReturnFrom(result, p)
    52  	default:
    53  	}
    54  	select {
    55  	case <-r.done:
    56  		go r.pipeReadFrom(len(p))
    57  	default:
    58  	}
    59  	return r.readFrom(p)
    60  }
    61  
    62  func (r *packetReader) readFrom(p []byte) (n int, addr net.Addr, err error) {
    63  	select {
    64  	case result := <-r.result:
    65  		return r.pipeReturnFrom(result, p)
    66  	case <-r.pipeDeadline.wait():
    67  		return 0, nil, os.ErrDeadlineExceeded
    68  	}
    69  }
    70  
    71  func (r *packetReader) pipeReadFrom(pLen int) {
    72  	buffer := buf.NewSize(pLen)
    73  	n, addr, err := r.TimeoutPacketReader.ReadFrom(buffer.FreeBytes())
    74  	buffer.Truncate(n)
    75  	r.result <- &packetReadResult{
    76  		buffer:      buffer,
    77  		destination: M.SocksaddrFromNet(addr),
    78  		err:         err,
    79  	}
    80  	r.done <- struct{}{}
    81  }
    82  
    83  func (r *packetReader) pipeReturnFrom(result *packetReadResult, p []byte) (n int, addr net.Addr, err error) {
    84  	n = copy(p, result.buffer.Bytes())
    85  	if result.destination.IsValid() {
    86  		if result.destination.IsFqdn() {
    87  			addr = result.destination
    88  		} else {
    89  			addr = result.destination.UDPAddr()
    90  		}
    91  	}
    92  	result.buffer.Advance(n)
    93  	if result.buffer.IsEmpty() {
    94  		result.buffer.Release()
    95  		err = result.err
    96  	} else {
    97  		r.result <- result
    98  	}
    99  	return
   100  }
   101  
   102  func (r *packetReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   103  	select {
   104  	case result := <-r.result:
   105  		return r.pipeReturnFromBuffer(result, buffer)
   106  	default:
   107  	}
   108  	select {
   109  	case <-r.done:
   110  		go r.pipeReadFromBuffer(buffer.FreeLen())
   111  	default:
   112  	}
   113  	return r.readPacket(buffer)
   114  }
   115  
   116  func (r *packetReader) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   117  	select {
   118  	case result := <-r.result:
   119  		return r.pipeReturnFromBuffer(result, buffer)
   120  	case <-r.pipeDeadline.wait():
   121  		return M.Socksaddr{}, os.ErrDeadlineExceeded
   122  	}
   123  }
   124  
   125  func (r *packetReader) pipeReturnFromBuffer(result *packetReadResult, buffer *buf.Buffer) (M.Socksaddr, error) {
   126  	n, _ := buffer.Write(result.buffer.Bytes())
   127  	result.buffer.Advance(n)
   128  	if !result.buffer.IsEmpty() {
   129  		r.result <- result
   130  		return result.destination, nil
   131  	} else {
   132  		result.buffer.Release()
   133  		return result.destination, result.err
   134  	}
   135  }
   136  
   137  func (r *packetReader) pipeReadFromBuffer(pLen int) {
   138  	buffer := buf.NewSize(pLen)
   139  	destination, err := r.TimeoutPacketReader.ReadPacket(buffer)
   140  	r.result <- &packetReadResult{
   141  		buffer:      buffer,
   142  		destination: destination,
   143  		err:         err,
   144  	}
   145  	r.done <- struct{}{}
   146  }
   147  
   148  func (r *packetReader) SetReadDeadline(t time.Time) error {
   149  	r.deadline.Store(t)
   150  	r.pipeDeadline.set(t)
   151  	return nil
   152  }
   153  
   154  func (r *packetReader) ReaderReplaceable() bool {
   155  	select {
   156  	case <-r.done:
   157  		r.done <- struct{}{}
   158  	default:
   159  		return false
   160  	}
   161  	select {
   162  	case result := <-r.result:
   163  		r.result <- result
   164  		return false
   165  	default:
   166  	}
   167  	return r.deadline.Load().IsZero()
   168  }
   169  
   170  func (r *packetReader) UpstreamReader() any {
   171  	return r.TimeoutPacketReader
   172  }