github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/uot/conn.go (about)

     1  package uot
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  	"net"
     7  
     8  	"github.com/sagernet/sing/common"
     9  	"github.com/sagernet/sing/common/buf"
    10  	"github.com/sagernet/sing/common/bufio"
    11  	E "github.com/sagernet/sing/common/exceptions"
    12  	M "github.com/sagernet/sing/common/metadata"
    13  	N "github.com/sagernet/sing/common/network"
    14  )
    15  
    16  var (
    17  	_ N.NetPacketConn    = (*Conn)(nil)
    18  	_ N.PacketReadWaiter = (*Conn)(nil)
    19  )
    20  
    21  type Conn struct {
    22  	net.Conn
    23  	isConnect       bool
    24  	destination     M.Socksaddr
    25  	writer          N.VectorisedWriter
    26  	readWaitOptions N.ReadWaitOptions
    27  }
    28  
    29  func NewConn(conn net.Conn, request Request) *Conn {
    30  	uConn := &Conn{
    31  		Conn:        conn,
    32  		isConnect:   request.IsConnect,
    33  		destination: request.Destination,
    34  	}
    35  	uConn.writer, _ = bufio.CreateVectorisedWriter(conn)
    36  	return uConn
    37  }
    38  
    39  func (c *Conn) Read(p []byte) (n int, err error) {
    40  	n, _, err = c.ReadFrom(p)
    41  	return
    42  }
    43  
    44  func (c *Conn) Write(p []byte) (n int, err error) {
    45  	return c.WriteTo(p, c.destination)
    46  }
    47  
    48  func (c *Conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
    49  	var destination M.Socksaddr
    50  	if c.isConnect {
    51  		destination = c.destination
    52  	} else {
    53  		destination, err = AddrParser.ReadAddrPort(c.Conn)
    54  		if err != nil {
    55  			return
    56  		}
    57  	}
    58  	var length uint16
    59  	err = binary.Read(c.Conn, binary.BigEndian, &length)
    60  	if err != nil {
    61  		return
    62  	}
    63  	if len(p) < int(length) {
    64  		err = E.Cause(io.ErrShortBuffer, "UoT read")
    65  		return
    66  	}
    67  	n, err = io.ReadFull(c.Conn, p[:length])
    68  	if err == nil {
    69  		addr = destination.UDPAddr()
    70  	}
    71  	return
    72  }
    73  
    74  func (c *Conn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
    75  	destination := M.SocksaddrFromNet(addr)
    76  	var bufferLen int
    77  	if !c.isConnect {
    78  		bufferLen += AddrParser.AddrPortLen(destination)
    79  	}
    80  	bufferLen += 2
    81  	if c.writer == nil {
    82  		bufferLen += len(p)
    83  	}
    84  	buffer := buf.NewSize(bufferLen)
    85  	defer buffer.Release()
    86  	if !c.isConnect {
    87  		err = AddrParser.WriteAddrPort(buffer, destination)
    88  		if err != nil {
    89  			return
    90  		}
    91  	}
    92  	common.Must(binary.Write(buffer, binary.BigEndian, uint16(len(p))))
    93  	if c.writer == nil {
    94  		common.Must1(buffer.Write(p))
    95  		return c.Conn.Write(buffer.Bytes())
    96  	}
    97  	err = c.writer.WriteVectorised([]*buf.Buffer{buffer, buf.As(p)})
    98  	if err == nil {
    99  		n = len(p)
   100  	}
   101  	return
   102  }
   103  
   104  func (c *Conn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
   105  	if c.isConnect {
   106  		destination = c.destination
   107  	} else {
   108  		destination, err = AddrParser.ReadAddrPort(c.Conn)
   109  		if err != nil {
   110  			return
   111  		}
   112  	}
   113  	var length uint16
   114  	err = binary.Read(c.Conn, binary.BigEndian, &length)
   115  	if err != nil {
   116  		return
   117  	}
   118  	_, err = buffer.ReadFullFrom(c.Conn, int(length))
   119  	if err != nil {
   120  		return M.Socksaddr{}, E.Cause(err, "UoT read")
   121  	}
   122  	return
   123  }
   124  
   125  func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   126  	var headerLen int
   127  	if !c.isConnect {
   128  		headerLen += AddrParser.AddrPortLen(destination)
   129  	}
   130  	headerLen += 2
   131  	if c.writer == nil {
   132  		headerLen += buffer.Len()
   133  	}
   134  	header := buf.NewSize(headerLen)
   135  	defer header.Release()
   136  	if !c.isConnect {
   137  		err := AddrParser.WriteAddrPort(header, destination)
   138  		if err != nil {
   139  			return err
   140  		}
   141  	}
   142  	common.Must(binary.Write(header, binary.BigEndian, uint16(buffer.Len())))
   143  	if c.writer == nil {
   144  		common.Must1(header.Write(buffer.Bytes()))
   145  		return common.Error(c.Conn.Write(header.Bytes()))
   146  	}
   147  	return c.writer.WriteVectorised([]*buf.Buffer{header, buffer})
   148  }
   149  
   150  func (c *Conn) NeedAdditionalReadDeadline() bool {
   151  	return true
   152  }
   153  
   154  func (c *Conn) Upstream() any {
   155  	return c.Conn
   156  }