github.com/sagernet/sing@v0.2.6/common/bufio/counter_conn.go (about)

     1  package bufio
     2  
     3  import (
     4  	"io"
     5  	"net"
     6  
     7  	"github.com/sagernet/sing/common"
     8  	"github.com/sagernet/sing/common/atomic"
     9  	"github.com/sagernet/sing/common/buf"
    10  	N "github.com/sagernet/sing/common/network"
    11  )
    12  
    13  func NewInt64CounterConn(conn net.Conn, readCounter []*atomic.Int64, writeCounter []*atomic.Int64) *CounterConn {
    14  	return &CounterConn{
    15  		NewExtendedConn(conn),
    16  		common.Map(readCounter, func(it *atomic.Int64) N.CountFunc {
    17  			return func(n int64) {
    18  				it.Add(n)
    19  			}
    20  		}),
    21  		common.Map(writeCounter, func(it *atomic.Int64) N.CountFunc {
    22  			return func(n int64) {
    23  				it.Add(n)
    24  			}
    25  		}),
    26  	}
    27  }
    28  
    29  func NewCounterConn(conn net.Conn, readCounter []N.CountFunc, writeCounter []N.CountFunc) *CounterConn {
    30  	return &CounterConn{NewExtendedConn(conn), readCounter, writeCounter}
    31  }
    32  
    33  type CounterConn struct {
    34  	N.ExtendedConn
    35  	readCounter  []N.CountFunc
    36  	writeCounter []N.CountFunc
    37  }
    38  
    39  func (c *CounterConn) Read(p []byte) (n int, err error) {
    40  	n, err = c.ExtendedConn.Read(p)
    41  	if n > 0 {
    42  		for _, counter := range c.readCounter {
    43  			counter(int64(n))
    44  		}
    45  	}
    46  	return n, err
    47  }
    48  
    49  func (c *CounterConn) ReadBuffer(buffer *buf.Buffer) error {
    50  	err := c.ExtendedConn.ReadBuffer(buffer)
    51  	if err != nil {
    52  		return err
    53  	}
    54  	if buffer.Len() > 0 {
    55  		for _, counter := range c.readCounter {
    56  			counter(int64(buffer.Len()))
    57  		}
    58  	}
    59  	return nil
    60  }
    61  
    62  func (c *CounterConn) Write(p []byte) (n int, err error) {
    63  	n, err = c.ExtendedConn.Write(p)
    64  	if n > 0 {
    65  		for _, counter := range c.writeCounter {
    66  			counter(int64(n))
    67  		}
    68  	}
    69  	return n, err
    70  }
    71  
    72  func (c *CounterConn) WriteBuffer(buffer *buf.Buffer) error {
    73  	dataLen := int64(buffer.Len())
    74  	err := c.ExtendedConn.WriteBuffer(buffer)
    75  	if err != nil {
    76  		return err
    77  	}
    78  	if dataLen > 0 {
    79  		for _, counter := range c.writeCounter {
    80  			counter(dataLen)
    81  		}
    82  	}
    83  	return nil
    84  }
    85  
    86  func (c *CounterConn) UnwrapReader() (io.Reader, []N.CountFunc) {
    87  	return c.ExtendedConn, c.readCounter
    88  }
    89  
    90  func (c *CounterConn) UnwrapWriter() (io.Writer, []N.CountFunc) {
    91  	return c.ExtendedConn, c.writeCounter
    92  }
    93  
    94  func (c *CounterConn) Upstream() any {
    95  	return c.ExtendedConn
    96  }