github.com/sagernet/sing-box@v1.2.7/common/mux/session.go (about)

     1  package mux
     2  
     3  import (
     4  	"io"
     5  	"net"
     6  
     7  	"github.com/sagernet/sing/common"
     8  	"github.com/sagernet/sing/common/buf"
     9  	"github.com/sagernet/sing/common/bufio"
    10  	N "github.com/sagernet/sing/common/network"
    11  	"github.com/sagernet/smux"
    12  )
    13  
    14  type abstractSession interface {
    15  	Open() (net.Conn, error)
    16  	Accept() (net.Conn, error)
    17  	NumStreams() int
    18  	Close() error
    19  	IsClosed() bool
    20  }
    21  
    22  var _ abstractSession = (*smuxSession)(nil)
    23  
    24  type smuxSession struct {
    25  	*smux.Session
    26  }
    27  
    28  func (s *smuxSession) Open() (net.Conn, error) {
    29  	return s.OpenStream()
    30  }
    31  
    32  func (s *smuxSession) Accept() (net.Conn, error) {
    33  	return s.AcceptStream()
    34  }
    35  
    36  type protocolConn struct {
    37  	net.Conn
    38  	protocol        Protocol
    39  	protocolWritten bool
    40  }
    41  
    42  func (c *protocolConn) Write(p []byte) (n int, err error) {
    43  	if c.protocolWritten {
    44  		return c.Conn.Write(p)
    45  	}
    46  	_buffer := buf.StackNewSize(2 + len(p))
    47  	defer common.KeepAlive(_buffer)
    48  	buffer := common.Dup(_buffer)
    49  	defer buffer.Release()
    50  	EncodeRequest(buffer, Request{
    51  		Protocol: c.protocol,
    52  	})
    53  	common.Must(common.Error(buffer.Write(p)))
    54  	n, err = c.Conn.Write(buffer.Bytes())
    55  	if err == nil {
    56  		n--
    57  	}
    58  	c.protocolWritten = true
    59  	return n, err
    60  }
    61  
    62  func (c *protocolConn) ReadFrom(r io.Reader) (n int64, err error) {
    63  	if !c.protocolWritten {
    64  		return bufio.ReadFrom0(c, r)
    65  	}
    66  	return bufio.Copy(c.Conn, r)
    67  }
    68  
    69  func (c *protocolConn) Upstream() any {
    70  	return c.Conn
    71  }
    72  
    73  type vectorisedProtocolConn struct {
    74  	protocolConn
    75  	N.VectorisedWriter
    76  }
    77  
    78  func (c *vectorisedProtocolConn) WriteVectorised(buffers []*buf.Buffer) error {
    79  	if c.protocolWritten {
    80  		return c.VectorisedWriter.WriteVectorised(buffers)
    81  	}
    82  	c.protocolWritten = true
    83  	_buffer := buf.StackNewSize(2)
    84  	defer common.KeepAlive(_buffer)
    85  	buffer := common.Dup(_buffer)
    86  	defer buffer.Release()
    87  	EncodeRequest(buffer, Request{
    88  		Protocol: c.protocol,
    89  	})
    90  	return c.VectorisedWriter.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...))
    91  }