github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/canceler/packet.go (about)

     1  package canceler
     2  
     3  import (
     4  	"context"
     5  	"time"
     6  
     7  	"github.com/sagernet/sing/common"
     8  	"github.com/sagernet/sing/common/buf"
     9  	M "github.com/sagernet/sing/common/metadata"
    10  	N "github.com/sagernet/sing/common/network"
    11  )
    12  
    13  type PacketConn interface {
    14  	N.PacketConn
    15  	Timeout() time.Duration
    16  	SetTimeout(timeout time.Duration)
    17  }
    18  
    19  type TimerPacketConn struct {
    20  	N.PacketConn
    21  	instance *Instance
    22  }
    23  
    24  func NewPacketConn(ctx context.Context, conn N.PacketConn, timeout time.Duration) (context.Context, N.PacketConn) {
    25  	if timeoutConn, isTimeoutConn := common.Cast[PacketConn](conn); isTimeoutConn {
    26  		oldTimeout := timeoutConn.Timeout()
    27  		if timeout < oldTimeout {
    28  			timeoutConn.SetTimeout(timeout)
    29  		}
    30  		return ctx, conn
    31  	}
    32  	err := conn.SetReadDeadline(time.Time{})
    33  	if err == nil {
    34  		return NewTimeoutPacketConn(ctx, conn, timeout)
    35  	}
    36  	ctx, cancel := common.ContextWithCancelCause(ctx)
    37  	instance := New(ctx, cancel, timeout)
    38  	return ctx, &TimerPacketConn{conn, instance}
    39  }
    40  
    41  func (c *TimerPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
    42  	destination, err = c.PacketConn.ReadPacket(buffer)
    43  	if err == nil {
    44  		c.instance.Update()
    45  	}
    46  	return
    47  }
    48  
    49  func (c *TimerPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
    50  	err := c.PacketConn.WritePacket(buffer, destination)
    51  	if err == nil {
    52  		c.instance.Update()
    53  	}
    54  	return err
    55  }
    56  
    57  func (c *TimerPacketConn) Timeout() time.Duration {
    58  	return c.instance.Timeout()
    59  }
    60  
    61  func (c *TimerPacketConn) SetTimeout(timeout time.Duration) {
    62  	c.instance.SetTimeout(timeout)
    63  }
    64  
    65  func (c *TimerPacketConn) Close() error {
    66  	return common.Close(
    67  		c.PacketConn,
    68  		c.instance,
    69  	)
    70  }
    71  
    72  func (c *TimerPacketConn) Upstream() any {
    73  	return c.PacketConn
    74  }