github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/bufio/counter_packet_conn.go (about)

     1  package bufio
     2  
     3  import (
     4  	"github.com/sagernet/sing/common"
     5  	"github.com/sagernet/sing/common/atomic"
     6  	"github.com/sagernet/sing/common/buf"
     7  	M "github.com/sagernet/sing/common/metadata"
     8  	N "github.com/sagernet/sing/common/network"
     9  )
    10  
    11  type CounterPacketConn struct {
    12  	N.PacketConn
    13  	readCounter  []N.CountFunc
    14  	writeCounter []N.CountFunc
    15  }
    16  
    17  func NewInt64CounterPacketConn(conn N.PacketConn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *CounterPacketConn {
    18  	return &CounterPacketConn{
    19  		conn,
    20  		common.Map(readCounter, func(it *atomic.Int64) N.CountFunc {
    21  			return func(n int64) {
    22  				it.Add(n)
    23  			}
    24  		}),
    25  		common.Map(writeCounter, func(it *atomic.Int64) N.CountFunc {
    26  			return func(n int64) {
    27  				it.Add(n)
    28  			}
    29  		}),
    30  	}
    31  }
    32  
    33  func NewCounterPacketConn(conn N.PacketConn, readCounter []N.CountFunc, writeCounter []N.CountFunc) *CounterPacketConn {
    34  	return &CounterPacketConn{conn, readCounter, writeCounter}
    35  }
    36  
    37  func (c *CounterPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
    38  	destination, err = c.PacketConn.ReadPacket(buffer)
    39  	if err == nil {
    40  		if buffer.Len() > 0 {
    41  			for _, counter := range c.readCounter {
    42  				counter(int64(buffer.Len()))
    43  			}
    44  		}
    45  	}
    46  	return
    47  }
    48  
    49  func (c *CounterPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
    50  	dataLen := int64(buffer.Len())
    51  	err := c.PacketConn.WritePacket(buffer, destination)
    52  	if err != nil {
    53  		return err
    54  	}
    55  	if dataLen > 0 {
    56  		for _, counter := range c.writeCounter {
    57  			counter(dataLen)
    58  		}
    59  	}
    60  	return nil
    61  }
    62  
    63  func (c *CounterPacketConn) UnwrapPacketReader() (N.PacketReader, []N.CountFunc) {
    64  	return c.PacketConn, c.readCounter
    65  }
    66  
    67  func (c *CounterPacketConn) UnwrapPacketWriter() (N.PacketWriter, []N.CountFunc) {
    68  	return c.PacketConn, c.writeCounter
    69  }
    70  
    71  func (c *CounterPacketConn) Upstream() any {
    72  	return c.PacketConn
    73  }