github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/uot/conn.go (about) 1 package uot 2 3 import ( 4 "encoding/binary" 5 "io" 6 "net" 7 8 "github.com/sagernet/sing/common" 9 "github.com/sagernet/sing/common/buf" 10 "github.com/sagernet/sing/common/bufio" 11 E "github.com/sagernet/sing/common/exceptions" 12 M "github.com/sagernet/sing/common/metadata" 13 N "github.com/sagernet/sing/common/network" 14 ) 15 16 var ( 17 _ N.NetPacketConn = (*Conn)(nil) 18 _ N.PacketReadWaiter = (*Conn)(nil) 19 ) 20 21 type Conn struct { 22 net.Conn 23 isConnect bool 24 destination M.Socksaddr 25 writer N.VectorisedWriter 26 readWaitOptions N.ReadWaitOptions 27 } 28 29 func NewConn(conn net.Conn, request Request) *Conn { 30 uConn := &Conn{ 31 Conn: conn, 32 isConnect: request.IsConnect, 33 destination: request.Destination, 34 } 35 uConn.writer, _ = bufio.CreateVectorisedWriter(conn) 36 return uConn 37 } 38 39 func (c *Conn) Read(p []byte) (n int, err error) { 40 n, _, err = c.ReadFrom(p) 41 return 42 } 43 44 func (c *Conn) Write(p []byte) (n int, err error) { 45 return c.WriteTo(p, c.destination) 46 } 47 48 func (c *Conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 49 var destination M.Socksaddr 50 if c.isConnect { 51 destination = c.destination 52 } else { 53 destination, err = AddrParser.ReadAddrPort(c.Conn) 54 if err != nil { 55 return 56 } 57 } 58 var length uint16 59 err = binary.Read(c.Conn, binary.BigEndian, &length) 60 if err != nil { 61 return 62 } 63 if len(p) < int(length) { 64 err = E.Cause(io.ErrShortBuffer, "UoT read") 65 return 66 } 67 n, err = io.ReadFull(c.Conn, p[:length]) 68 if err == nil { 69 addr = destination.UDPAddr() 70 } 71 return 72 } 73 74 func (c *Conn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 75 destination := M.SocksaddrFromNet(addr) 76 var bufferLen int 77 if !c.isConnect { 78 bufferLen += AddrParser.AddrPortLen(destination) 79 } 80 bufferLen += 2 81 if c.writer == nil { 82 bufferLen += len(p) 83 } 84 buffer := buf.NewSize(bufferLen) 85 defer buffer.Release() 86 if !c.isConnect { 87 err = AddrParser.WriteAddrPort(buffer, destination) 88 if err != nil { 89 return 90 } 91 } 92 common.Must(binary.Write(buffer, binary.BigEndian, uint16(len(p)))) 93 if c.writer == nil { 94 common.Must1(buffer.Write(p)) 95 return c.Conn.Write(buffer.Bytes()) 96 } 97 err = c.writer.WriteVectorised([]*buf.Buffer{buffer, buf.As(p)}) 98 if err == nil { 99 n = len(p) 100 } 101 return 102 } 103 104 func (c *Conn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 105 if c.isConnect { 106 destination = c.destination 107 } else { 108 destination, err = AddrParser.ReadAddrPort(c.Conn) 109 if err != nil { 110 return 111 } 112 } 113 var length uint16 114 err = binary.Read(c.Conn, binary.BigEndian, &length) 115 if err != nil { 116 return 117 } 118 _, err = buffer.ReadFullFrom(c.Conn, int(length)) 119 if err != nil { 120 return M.Socksaddr{}, E.Cause(err, "UoT read") 121 } 122 return 123 } 124 125 func (c *Conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 126 var headerLen int 127 if !c.isConnect { 128 headerLen += AddrParser.AddrPortLen(destination) 129 } 130 headerLen += 2 131 if c.writer == nil { 132 headerLen += buffer.Len() 133 } 134 header := buf.NewSize(headerLen) 135 defer header.Release() 136 if !c.isConnect { 137 err := AddrParser.WriteAddrPort(header, destination) 138 if err != nil { 139 return err 140 } 141 } 142 common.Must(binary.Write(header, binary.BigEndian, uint16(buffer.Len()))) 143 if c.writer == nil { 144 common.Must1(header.Write(buffer.Bytes())) 145 return common.Error(c.Conn.Write(header.Bytes())) 146 } 147 return c.writer.WriteVectorised([]*buf.Buffer{header, buffer}) 148 } 149 150 func (c *Conn) NeedAdditionalReadDeadline() bool { 151 return true 152 } 153 154 func (c *Conn) Upstream() any { 155 return c.Conn 156 }