github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/deadline/packet_reader.go (about)

     1  package deadline
     2  
     3  import (
     4  	"net"
     5  	"os"
     6  	"time"
     7  
     8  	"github.com/sagernet/sing/common/atomic"
     9  	"github.com/sagernet/sing/common/buf"
    10  	M "github.com/sagernet/sing/common/metadata"
    11  	N "github.com/sagernet/sing/common/network"
    12  )
    13  
    14  type TimeoutPacketReader interface {
    15  	N.NetPacketReader
    16  	SetReadDeadline(t time.Time) error
    17  }
    18  
    19  type PacketReader interface {
    20  	TimeoutPacketReader
    21  	N.WithUpstreamReader
    22  	N.ReaderWithUpstream
    23  }
    24  
    25  type packetReader struct {
    26  	TimeoutPacketReader
    27  	deadline     atomic.TypedValue[time.Time]
    28  	pipeDeadline pipeDeadline
    29  	result       chan *packetReadResult
    30  	done         chan struct{}
    31  }
    32  
    33  type packetReadResult struct {
    34  	buffer      *buf.Buffer
    35  	destination M.Socksaddr
    36  	err         error
    37  }
    38  
    39  func NewPacketReader(timeoutReader TimeoutPacketReader) PacketReader {
    40  	return &packetReader{
    41  		TimeoutPacketReader: timeoutReader,
    42  		pipeDeadline:        makePipeDeadline(),
    43  		result:              make(chan *packetReadResult, 1),
    44  		done:                makeFilledChan(),
    45  	}
    46  }
    47  
    48  func (r *packetReader) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
    49  	select {
    50  	case result := <-r.result:
    51  		return r.pipeReturnFrom(result, p)
    52  	default:
    53  	}
    54  	select {
    55  	case result := <-r.result:
    56  		return r.pipeReturnFrom(result, p)
    57  	case <-r.pipeDeadline.wait():
    58  		return 0, nil, os.ErrDeadlineExceeded
    59  	case <-r.done:
    60  		go r.pipeReadFrom(len(p))
    61  	}
    62  	select {
    63  	case result := <-r.result:
    64  		return r.pipeReturnFrom(result, p)
    65  	case <-r.pipeDeadline.wait():
    66  		return 0, nil, os.ErrDeadlineExceeded
    67  	}
    68  }
    69  
    70  func (r *packetReader) pipeReadFrom(pLen int) {
    71  	buffer := buf.NewSize(pLen)
    72  	n, addr, err := r.TimeoutPacketReader.ReadFrom(buffer.FreeBytes())
    73  	buffer.Truncate(n)
    74  	r.result <- &packetReadResult{
    75  		buffer:      buffer,
    76  		destination: M.SocksaddrFromNet(addr),
    77  		err:         err,
    78  	}
    79  	r.done <- struct{}{}
    80  }
    81  
    82  func (r *packetReader) pipeReturnFrom(result *packetReadResult, p []byte) (n int, addr net.Addr, err error) {
    83  	n = copy(p, result.buffer.Bytes())
    84  	if result.destination.IsValid() {
    85  		if result.destination.IsFqdn() {
    86  			addr = result.destination
    87  		} else {
    88  			addr = result.destination.UDPAddr()
    89  		}
    90  	}
    91  	result.buffer.Advance(n)
    92  	if result.buffer.IsEmpty() {
    93  		result.buffer.Release()
    94  		err = result.err
    95  	} else {
    96  		r.result <- result
    97  	}
    98  	return
    99  }
   100  
   101  func (r *packetReader) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   102  	select {
   103  	case result := <-r.result:
   104  		return r.pipeReturnFromBuffer(result, buffer)
   105  	default:
   106  	}
   107  	select {
   108  	case result := <-r.result:
   109  		return r.pipeReturnFromBuffer(result, buffer)
   110  	case <-r.pipeDeadline.wait():
   111  		return M.Socksaddr{}, os.ErrDeadlineExceeded
   112  	case <-r.done:
   113  		go r.pipeReadFrom(buffer.FreeLen())
   114  	}
   115  	select {
   116  	case result := <-r.result:
   117  		return r.pipeReturnFromBuffer(result, buffer)
   118  	case <-r.pipeDeadline.wait():
   119  		return M.Socksaddr{}, os.ErrDeadlineExceeded
   120  	}
   121  }
   122  
   123  func (r *packetReader) pipeReturnFromBuffer(result *packetReadResult, buffer *buf.Buffer) (M.Socksaddr, error) {
   124  	n, _ := buffer.Write(result.buffer.Bytes())
   125  	result.buffer.Advance(n)
   126  	if !result.buffer.IsEmpty() {
   127  		r.result <- result
   128  		return result.destination, nil
   129  	} else {
   130  		result.buffer.Release()
   131  		return result.destination, result.err
   132  	}
   133  }
   134  
   135  func (r *packetReader) SetReadDeadline(t time.Time) error {
   136  	r.deadline.Store(t)
   137  	r.pipeDeadline.set(t)
   138  	return nil
   139  }
   140  
   141  func (r *packetReader) ReaderReplaceable() bool {
   142  	select {
   143  	case <-r.done:
   144  		r.done <- struct{}{}
   145  	default:
   146  		return false
   147  	}
   148  	select {
   149  	case result := <-r.result:
   150  		r.result <- result
   151  		return false
   152  	default:
   153  	}
   154  	return r.deadline.Load().IsZero()
   155  }
   156  
   157  func (r *packetReader) UpstreamReader() any {
   158  	return r.TimeoutPacketReader
   159  }