github.com/sagernet/sing-mux@v0.2.1-0.20240124034317-9bfb33698bb6/client_conn.go (about)

     1  package mux
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  	"net"
     7  	"sync"
     8  
     9  	"github.com/sagernet/sing/common"
    10  	"github.com/sagernet/sing/common/buf"
    11  	E "github.com/sagernet/sing/common/exceptions"
    12  	M "github.com/sagernet/sing/common/metadata"
    13  	N "github.com/sagernet/sing/common/network"
    14  )
    15  
    16  type clientConn struct {
    17  	net.Conn
    18  	destination    M.Socksaddr
    19  	requestWritten bool
    20  	responseRead   bool
    21  }
    22  
    23  func (c *clientConn) NeedHandshake() bool {
    24  	return !c.requestWritten
    25  }
    26  
    27  func (c *clientConn) readResponse() error {
    28  	response, err := ReadStreamResponse(c.Conn)
    29  	if err != nil {
    30  		return err
    31  	}
    32  	if response.Status == statusError {
    33  		return E.New("remote error: ", response.Message)
    34  	}
    35  	return nil
    36  }
    37  
    38  func (c *clientConn) Read(b []byte) (n int, err error) {
    39  	if !c.responseRead {
    40  		err = c.readResponse()
    41  		if err != nil {
    42  			return
    43  		}
    44  		c.responseRead = true
    45  	}
    46  	return c.Conn.Read(b)
    47  }
    48  
    49  func (c *clientConn) Write(b []byte) (n int, err error) {
    50  	if c.requestWritten {
    51  		return c.Conn.Write(b)
    52  	}
    53  	request := StreamRequest{
    54  		Network:     N.NetworkTCP,
    55  		Destination: c.destination,
    56  	}
    57  	buffer := buf.NewSize(streamRequestLen(request) + len(b))
    58  	defer buffer.Release()
    59  	err = EncodeStreamRequest(request, buffer)
    60  	if err != nil {
    61  		return
    62  	}
    63  	buffer.Write(b)
    64  	_, err = c.Conn.Write(buffer.Bytes())
    65  	if err != nil {
    66  		return
    67  	}
    68  	c.requestWritten = true
    69  	return len(b), nil
    70  }
    71  
    72  func (c *clientConn) LocalAddr() net.Addr {
    73  	return c.Conn.LocalAddr()
    74  }
    75  
    76  func (c *clientConn) RemoteAddr() net.Addr {
    77  	return c.destination.TCPAddr()
    78  }
    79  
    80  func (c *clientConn) ReaderReplaceable() bool {
    81  	return c.responseRead
    82  }
    83  
    84  func (c *clientConn) WriterReplaceable() bool {
    85  	return c.requestWritten
    86  }
    87  
    88  func (c *clientConn) NeedAdditionalReadDeadline() bool {
    89  	return true
    90  }
    91  
    92  func (c *clientConn) Upstream() any {
    93  	return c.Conn
    94  }
    95  
    96  var _ N.NetPacketConn = (*clientPacketConn)(nil)
    97  
    98  type clientPacketConn struct {
    99  	N.AbstractConn
   100  	conn            N.ExtendedConn
   101  	access          sync.Mutex
   102  	destination     M.Socksaddr
   103  	requestWritten  bool
   104  	responseRead    bool
   105  	readWaitOptions N.ReadWaitOptions
   106  }
   107  
   108  func (c *clientPacketConn) NeedHandshake() bool {
   109  	return !c.requestWritten
   110  }
   111  
   112  func (c *clientPacketConn) readResponse() error {
   113  	response, err := ReadStreamResponse(c.conn)
   114  	if err != nil {
   115  		return err
   116  	}
   117  	if response.Status == statusError {
   118  		return E.New("remote error: ", response.Message)
   119  	}
   120  	return nil
   121  }
   122  
   123  func (c *clientPacketConn) Read(b []byte) (n int, err error) {
   124  	if !c.responseRead {
   125  		err = c.readResponse()
   126  		if err != nil {
   127  			return
   128  		}
   129  		c.responseRead = true
   130  	}
   131  	var length uint16
   132  	err = binary.Read(c.conn, binary.BigEndian, &length)
   133  	if err != nil {
   134  		return
   135  	}
   136  	if cap(b) < int(length) {
   137  		return 0, io.ErrShortBuffer
   138  	}
   139  	return io.ReadFull(c.conn, b[:length])
   140  }
   141  
   142  func (c *clientPacketConn) writeRequest(payload []byte) (n int, err error) {
   143  	request := StreamRequest{
   144  		Network:     N.NetworkUDP,
   145  		Destination: c.destination,
   146  	}
   147  	rLen := streamRequestLen(request)
   148  	if len(payload) > 0 {
   149  		rLen += 2 + len(payload)
   150  	}
   151  	buffer := buf.NewSize(rLen)
   152  	defer buffer.Release()
   153  	err = EncodeStreamRequest(request, buffer)
   154  	if err != nil {
   155  		return
   156  	}
   157  	if len(payload) > 0 {
   158  		common.Must(
   159  			binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
   160  			common.Error(buffer.Write(payload)),
   161  		)
   162  	}
   163  	_, err = c.conn.Write(buffer.Bytes())
   164  	if err != nil {
   165  		return
   166  	}
   167  	c.requestWritten = true
   168  	return len(payload), nil
   169  }
   170  
   171  func (c *clientPacketConn) Write(b []byte) (n int, err error) {
   172  	if !c.requestWritten {
   173  		c.access.Lock()
   174  		if c.requestWritten {
   175  			c.access.Unlock()
   176  		} else {
   177  			defer c.access.Unlock()
   178  			return c.writeRequest(b)
   179  		}
   180  	}
   181  	err = binary.Write(c.conn, binary.BigEndian, uint16(len(b)))
   182  	if err != nil {
   183  		return
   184  	}
   185  	return c.conn.Write(b)
   186  }
   187  
   188  func (c *clientPacketConn) ReadBuffer(buffer *buf.Buffer) (err error) {
   189  	if !c.responseRead {
   190  		err = c.readResponse()
   191  		if err != nil {
   192  			return
   193  		}
   194  		c.responseRead = true
   195  	}
   196  	var length uint16
   197  	err = binary.Read(c.conn, binary.BigEndian, &length)
   198  	if err != nil {
   199  		return
   200  	}
   201  	_, err = buffer.ReadFullFrom(c.conn, int(length))
   202  	return
   203  }
   204  
   205  func (c *clientPacketConn) WriteBuffer(buffer *buf.Buffer) error {
   206  	if !c.requestWritten {
   207  		c.access.Lock()
   208  		if c.requestWritten {
   209  			c.access.Unlock()
   210  		} else {
   211  			defer c.access.Unlock()
   212  			defer buffer.Release()
   213  			return common.Error(c.writeRequest(buffer.Bytes()))
   214  		}
   215  	}
   216  	bLen := buffer.Len()
   217  	binary.BigEndian.PutUint16(buffer.ExtendHeader(2), uint16(bLen))
   218  	return c.conn.WriteBuffer(buffer)
   219  }
   220  
   221  func (c *clientPacketConn) FrontHeadroom() int {
   222  	return 2
   223  }
   224  
   225  func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   226  	if !c.responseRead {
   227  		err = c.readResponse()
   228  		if err != nil {
   229  			return
   230  		}
   231  		c.responseRead = true
   232  	}
   233  	var length uint16
   234  	err = binary.Read(c.conn, binary.BigEndian, &length)
   235  	if err != nil {
   236  		return
   237  	}
   238  	if cap(p) < int(length) {
   239  		return 0, nil, io.ErrShortBuffer
   240  	}
   241  	n, err = io.ReadFull(c.conn, p[:length])
   242  	return
   243  }
   244  
   245  func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   246  	if !c.requestWritten {
   247  		c.access.Lock()
   248  		if c.requestWritten {
   249  			c.access.Unlock()
   250  		} else {
   251  			defer c.access.Unlock()
   252  			return c.writeRequest(p)
   253  		}
   254  	}
   255  	err = binary.Write(c.conn, binary.BigEndian, uint16(len(p)))
   256  	if err != nil {
   257  		return
   258  	}
   259  	return c.conn.Write(p)
   260  }
   261  
   262  func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   263  	err = c.ReadBuffer(buffer)
   264  	return
   265  }
   266  
   267  func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   268  	return c.WriteBuffer(buffer)
   269  }
   270  
   271  func (c *clientPacketConn) LocalAddr() net.Addr {
   272  	return c.conn.LocalAddr()
   273  }
   274  
   275  func (c *clientPacketConn) RemoteAddr() net.Addr {
   276  	return c.destination.UDPAddr()
   277  }
   278  
   279  func (c *clientPacketConn) NeedAdditionalReadDeadline() bool {
   280  	return true
   281  }
   282  
   283  func (c *clientPacketConn) Upstream() any {
   284  	return c.conn
   285  }
   286  
   287  var _ N.NetPacketConn = (*clientPacketAddrConn)(nil)
   288  
   289  type clientPacketAddrConn struct {
   290  	N.AbstractConn
   291  	conn            N.ExtendedConn
   292  	access          sync.Mutex
   293  	destination     M.Socksaddr
   294  	requestWritten  bool
   295  	responseRead    bool
   296  	readWaitOptions N.ReadWaitOptions
   297  }
   298  
   299  func (c *clientPacketAddrConn) NeedHandshake() bool {
   300  	return !c.requestWritten
   301  }
   302  
   303  func (c *clientPacketAddrConn) readResponse() error {
   304  	response, err := ReadStreamResponse(c.conn)
   305  	if err != nil {
   306  		return err
   307  	}
   308  	if response.Status == statusError {
   309  		return E.New("remote error: ", response.Message)
   310  	}
   311  	return nil
   312  }
   313  
   314  func (c *clientPacketAddrConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   315  	if !c.responseRead {
   316  		err = c.readResponse()
   317  		if err != nil {
   318  			return
   319  		}
   320  		c.responseRead = true
   321  	}
   322  	destination, err := M.SocksaddrSerializer.ReadAddrPort(c.conn)
   323  	if err != nil {
   324  		return
   325  	}
   326  	if destination.IsFqdn() {
   327  		addr = destination
   328  	} else {
   329  		addr = destination.UDPAddr()
   330  	}
   331  	var length uint16
   332  	err = binary.Read(c.conn, binary.BigEndian, &length)
   333  	if err != nil {
   334  		return
   335  	}
   336  	if cap(p) < int(length) {
   337  		return 0, nil, io.ErrShortBuffer
   338  	}
   339  	n, err = io.ReadFull(c.conn, p[:length])
   340  	return
   341  }
   342  
   343  func (c *clientPacketAddrConn) writeRequest(payload []byte, destination M.Socksaddr) (n int, err error) {
   344  	request := StreamRequest{
   345  		Network:     N.NetworkUDP,
   346  		Destination: c.destination,
   347  		PacketAddr:  true,
   348  	}
   349  	rLen := streamRequestLen(request)
   350  	if len(payload) > 0 {
   351  		rLen += M.SocksaddrSerializer.AddrPortLen(destination) + 2 + len(payload)
   352  	}
   353  	buffer := buf.NewSize(rLen)
   354  	defer buffer.Release()
   355  	err = EncodeStreamRequest(request, buffer)
   356  	if err != nil {
   357  		return
   358  	}
   359  	if len(payload) > 0 {
   360  		err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination)
   361  		if err != nil {
   362  			return
   363  		}
   364  		common.Must(
   365  			binary.Write(buffer, binary.BigEndian, uint16(len(payload))),
   366  			common.Error(buffer.Write(payload)),
   367  		)
   368  	}
   369  	_, err = c.conn.Write(buffer.Bytes())
   370  	if err != nil {
   371  		return
   372  	}
   373  	c.requestWritten = true
   374  	return len(payload), nil
   375  }
   376  
   377  func (c *clientPacketAddrConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   378  	if !c.requestWritten {
   379  		c.access.Lock()
   380  		if c.requestWritten {
   381  			c.access.Unlock()
   382  		} else {
   383  			defer c.access.Unlock()
   384  			return c.writeRequest(p, M.SocksaddrFromNet(addr))
   385  		}
   386  	}
   387  	err = M.SocksaddrSerializer.WriteAddrPort(c.conn, M.SocksaddrFromNet(addr))
   388  	if err != nil {
   389  		return
   390  	}
   391  	err = binary.Write(c.conn, binary.BigEndian, uint16(len(p)))
   392  	if err != nil {
   393  		return
   394  	}
   395  	return c.conn.Write(p)
   396  }
   397  
   398  func (c *clientPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   399  	if !c.responseRead {
   400  		err = c.readResponse()
   401  		if err != nil {
   402  			return
   403  		}
   404  		c.responseRead = true
   405  	}
   406  	destination, err = M.SocksaddrSerializer.ReadAddrPort(c.conn)
   407  	if err != nil {
   408  		return
   409  	}
   410  	var length uint16
   411  	err = binary.Read(c.conn, binary.BigEndian, &length)
   412  	if err != nil {
   413  		return
   414  	}
   415  	_, err = buffer.ReadFullFrom(c.conn, int(length))
   416  	return
   417  }
   418  
   419  func (c *clientPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   420  	if !c.requestWritten {
   421  		c.access.Lock()
   422  		if c.requestWritten {
   423  			c.access.Unlock()
   424  		} else {
   425  			defer c.access.Unlock()
   426  			defer buffer.Release()
   427  			return common.Error(c.writeRequest(buffer.Bytes(), destination))
   428  		}
   429  	}
   430  	bLen := buffer.Len()
   431  	header := buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination) + 2))
   432  	err := M.SocksaddrSerializer.WriteAddrPort(header, destination)
   433  	if err != nil {
   434  		return err
   435  	}
   436  	common.Must(binary.Write(header, binary.BigEndian, uint16(bLen)))
   437  	return c.conn.WriteBuffer(buffer)
   438  }
   439  
   440  func (c *clientPacketAddrConn) LocalAddr() net.Addr {
   441  	return c.conn.LocalAddr()
   442  }
   443  
   444  func (c *clientPacketAddrConn) FrontHeadroom() int {
   445  	return 2 + M.MaxSocksaddrLength
   446  }
   447  
   448  func (c *clientPacketAddrConn) NeedAdditionalReadDeadline() bool {
   449  	return true
   450  }
   451  
   452  func (c *clientPacketAddrConn) Upstream() any {
   453  	return c.conn
   454  }