github.com/sagernet/sing@v0.2.6/common/canceler/packet_timeout.go (about)

     1  package canceler
     2  
     3  import (
     4  	"context"
     5  	"time"
     6  
     7  	"github.com/sagernet/sing/common"
     8  	"github.com/sagernet/sing/common/buf"
     9  	E "github.com/sagernet/sing/common/exceptions"
    10  	M "github.com/sagernet/sing/common/metadata"
    11  	N "github.com/sagernet/sing/common/network"
    12  )
    13  
    14  type TimeoutPacketConn struct {
    15  	N.PacketConn
    16  	timeout time.Duration
    17  	cancel  common.ContextCancelCauseFunc
    18  	active  time.Time
    19  }
    20  
    21  func NewTimeoutPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, PacketConn) {
    22  	ctx, cancel := common.ContextWithCancelCause(ctx)
    23  	return ctx, &TimeoutPacketConn{
    24  		PacketConn: conn,
    25  		timeout:    timeout,
    26  		cancel:     cancel,
    27  	}
    28  }
    29  
    30  func (c *TimeoutPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
    31  	for {
    32  		err = c.PacketConn.SetReadDeadline(time.Now().Add(c.timeout))
    33  		if err != nil {
    34  			return M.Socksaddr{}, err
    35  		}
    36  		destination, err = c.PacketConn.ReadPacket(buffer)
    37  		if err == nil {
    38  			c.active = time.Now()
    39  			return
    40  		} else if E.IsTimeout(err) {
    41  			if time.Since(c.active) > c.timeout {
    42  				c.cancel(err)
    43  				return
    44  			}
    45  		} else {
    46  			return M.Socksaddr{}, err
    47  		}
    48  	}
    49  }
    50  
    51  func (c *TimeoutPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
    52  	err := c.PacketConn.WritePacket(buffer, destination)
    53  	if err == nil {
    54  		c.active = time.Now()
    55  	}
    56  	return err
    57  }
    58  
    59  func (c *TimeoutPacketConn) Timeout() time.Duration {
    60  	return c.timeout
    61  }
    62  
    63  func (c *TimeoutPacketConn) SetTimeout(timeout time.Duration) {
    64  	c.timeout = timeout
    65  	c.PacketConn.SetReadDeadline(time.Now())
    66  }
    67  
    68  func (c *TimeoutPacketConn) Close() error {
    69  	return c.PacketConn.Close()
    70  }
    71  
    72  func (c *TimeoutPacketConn) Upstream() any {
    73  	return c.PacketConn
    74  }