github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/canceler/packet_timeout.go (about)

     1  package canceler
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"time"
     7  
     8  	"github.com/sagernet/sing/common"
     9  	"github.com/sagernet/sing/common/buf"
    10  	E "github.com/sagernet/sing/common/exceptions"
    11  	M "github.com/sagernet/sing/common/metadata"
    12  	N "github.com/sagernet/sing/common/network"
    13  )
    14  
    15  type TimeoutPacketConn struct {
    16  	N.PacketConn
    17  	timeout time.Duration
    18  	cancel  common.ContextCancelCauseFunc
    19  	active  time.Time
    20  }
    21  
    22  func NewTimeoutPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) {
    23  	ctx, cancel := common.ContextWithCancelCause(ctx)
    24  	return ctx, &TimeoutPacketConn{
    25  		PacketConn: conn,
    26  		timeout:    timeout,
    27  		cancel:     cancel,
    28  	}
    29  }
    30  
    31  func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
    32  	for {
    33  		err = c.PacketConn.SetReadDeadline(time.Now().Add(c.timeout))
    34  		if err != nil {
    35  			return
    36  		}
    37  		destination, err = c.PacketConn.ReadPacket(buffer)
    38  		if err == nil {
    39  			c.active = time.Now()
    40  			return
    41  		} else if E.IsTimeout(err) {
    42  			if time.Since(c.active) > c.timeout {
    43  				c.cancel(err)
    44  				return
    45  			}
    46  		} else {
    47  			return
    48  		}
    49  	}
    50  }
    51  
    52  func (c *TimeoutPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
    53  	err := c.PacketConn.WritePacket(buffer, destination)
    54  	if err == nil {
    55  		c.active = time.Now()
    56  	}
    57  	return err
    58  }
    59  
    60  func (c *TimeoutPacketConn) Timeout() time.Duration {
    61  	return c.timeout
    62  }
    63  
    64  func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
    65  	c.timeout = timeout
    66  	c.PacketConn.SetReadDeadline(time.Now())
    67  }
    68  
    69  func (c *TimeoutPacketConn) Close() error {
    70  	c.cancel(net.ErrClosed)
    71  	return c.PacketConn.Close()
    72  }
    73  
    74  func (c *TimeoutPacketConn) Upstream() any {
    75  	return c.PacketConn
    76  }