github.com/sagernet/sing-mux@v0.2.1-0.20240124034317-9bfb33698bb6/protocol_conn.go (about)

     1  package mux
     2  
     3  import (
     4  	"net"
     5  
     6  	"github.com/sagernet/sing/common/buf"
     7  	"github.com/sagernet/sing/common/bufio"
     8  	N "github.com/sagernet/sing/common/network"
     9  )
    10  
    11  type protocolConn struct {
    12  	net.Conn
    13  	request        Request
    14  	requestWritten bool
    15  }
    16  
    17  func newProtocolConn(conn net.Conn, request Request) net.Conn {
    18  	writer, isVectorised := bufio.CreateVectorisedWriter(conn)
    19  	if isVectorised {
    20  		return &vectorisedProtocolConn{
    21  			protocolConn{
    22  				Conn:    conn,
    23  				request: request,
    24  			},
    25  			writer,
    26  		}
    27  	} else {
    28  		return &protocolConn{
    29  			Conn:    conn,
    30  			request: request,
    31  		}
    32  	}
    33  }
    34  
    35  func (c *protocolConn) NeedHandshake() bool {
    36  	return !c.requestWritten
    37  }
    38  
    39  func (c *protocolConn) Write(p []byte) (n int, err error) {
    40  	if c.requestWritten {
    41  		return c.Conn.Write(p)
    42  	}
    43  	buffer := EncodeRequest(c.request, p)
    44  	n, err = c.Conn.Write(buffer.Bytes())
    45  	buffer.Release()
    46  	if err == nil {
    47  		n--
    48  	}
    49  	c.requestWritten = true
    50  	return n, err
    51  }
    52  
    53  func (c *protocolConn) Upstream() any {
    54  	return c.Conn
    55  }
    56  
    57  type vectorisedProtocolConn struct {
    58  	protocolConn
    59  	writer N.VectorisedWriter
    60  }
    61  
    62  func (c *vectorisedProtocolConn) WriteVectorised(buffers []*buf.Buffer) error {
    63  	if c.requestWritten {
    64  		return c.writer.WriteVectorised(buffers)
    65  	}
    66  	c.requestWritten = true
    67  	buffer := EncodeRequest(c.request, nil)
    68  	return c.writer.WriteVectorised(append([]*buf.Buffer{buffer}, buffers...))
    69  }