github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/transport/internet/quic/conn.go (about) 1 package quic 2 3 import ( 4 "crypto/cipher" 5 "crypto/rand" 6 "errors" 7 "syscall" 8 "time" 9 10 "github.com/quic-go/quic-go" 11 "github.com/xtls/xray-core/common" 12 "github.com/xtls/xray-core/common/buf" 13 "github.com/xtls/xray-core/common/net" 14 "github.com/xtls/xray-core/transport/internet" 15 ) 16 17 type sysConn struct { 18 conn *net.UDPConn 19 header internet.PacketHeader 20 auth cipher.AEAD 21 } 22 23 func wrapSysConn(rawConn *net.UDPConn, config *Config) (*sysConn, error) { 24 header, err := getHeader(config) 25 if err != nil { 26 return nil, err 27 } 28 auth, err := getAuth(config) 29 if err != nil { 30 return nil, err 31 } 32 return &sysConn{ 33 conn: rawConn, 34 header: header, 35 auth: auth, 36 }, nil 37 } 38 39 var errInvalidPacket = errors.New("invalid packet") 40 41 func (c *sysConn) readFromInternal(p []byte) (int, net.Addr, error) { 42 buffer := getBuffer() 43 defer putBuffer(buffer) 44 45 nBytes, addr, err := c.conn.ReadFrom(buffer) 46 if err != nil { 47 return 0, nil, err 48 } 49 50 payload := buffer[:nBytes] 51 if c.header != nil { 52 if len(payload) <= int(c.header.Size()) { 53 return 0, nil, errInvalidPacket 54 } 55 payload = payload[c.header.Size():] 56 } 57 58 if c.auth == nil { 59 n := copy(p, payload) 60 return n, addr, nil 61 } 62 63 if len(payload) <= c.auth.NonceSize() { 64 return 0, nil, errInvalidPacket 65 } 66 67 nonce := payload[:c.auth.NonceSize()] 68 payload = payload[c.auth.NonceSize():] 69 70 p, err = c.auth.Open(p[:0], nonce, payload, nil) 71 if err != nil { 72 return 0, nil, errInvalidPacket 73 } 74 75 return len(p), addr, nil 76 } 77 78 func (c *sysConn) ReadFrom(p []byte) (int, net.Addr, error) { 79 if c.header == nil && c.auth == nil { 80 return c.conn.ReadFrom(p) 81 } 82 83 for { 84 n, addr, err := c.readFromInternal(p) 85 if err != nil && err != errInvalidPacket { 86 return 0, nil, err 87 } 88 if err == nil { 89 return n, addr, nil 90 } 91 } 92 } 93 94 func (c *sysConn) WriteTo(p []byte, addr net.Addr) (int, error) { 95 if c.header == nil && c.auth == nil { 96 return c.conn.WriteTo(p, addr) 97 } 98 99 buffer := getBuffer() 100 defer putBuffer(buffer) 101 102 payload := buffer 103 n := 0 104 if c.header != nil { 105 c.header.Serialize(payload) 106 n = int(c.header.Size()) 107 } 108 109 if c.auth == nil { 110 nBytes := copy(payload[n:], p) 111 n += nBytes 112 } else { 113 nounce := payload[n : n+c.auth.NonceSize()] 114 common.Must2(rand.Read(nounce)) 115 n += c.auth.NonceSize() 116 pp := c.auth.Seal(payload[:n], nounce, p, nil) 117 n = len(pp) 118 } 119 120 return c.conn.WriteTo(payload[:n], addr) 121 } 122 123 func (c *sysConn) Close() error { 124 return c.conn.Close() 125 } 126 127 func (c *sysConn) LocalAddr() net.Addr { 128 return c.conn.LocalAddr() 129 } 130 131 func (c *sysConn) SetReadBuffer(bytes int) error { 132 return c.conn.SetReadBuffer(bytes) 133 } 134 135 func (c *sysConn) SetWriteBuffer(bytes int) error { 136 return c.conn.SetWriteBuffer(bytes) 137 } 138 139 func (c *sysConn) SetDeadline(t time.Time) error { 140 return c.conn.SetDeadline(t) 141 } 142 143 func (c *sysConn) SetReadDeadline(t time.Time) error { 144 return c.conn.SetReadDeadline(t) 145 } 146 147 func (c *sysConn) SetWriteDeadline(t time.Time) error { 148 return c.conn.SetWriteDeadline(t) 149 } 150 151 func (c *sysConn) SyscallConn() (syscall.RawConn, error) { 152 return c.conn.SyscallConn() 153 } 154 155 type interConn struct { 156 stream quic.Stream 157 local net.Addr 158 remote net.Addr 159 } 160 161 func (c *interConn) Read(b []byte) (int, error) { 162 return c.stream.Read(b) 163 } 164 165 func (c *interConn) WriteMultiBuffer(mb buf.MultiBuffer) error { 166 mb = buf.Compact(mb) 167 mb, err := buf.WriteMultiBuffer(c, mb) 168 buf.ReleaseMulti(mb) 169 return err 170 } 171 172 func (c *interConn) Write(b []byte) (int, error) { 173 return c.stream.Write(b) 174 } 175 176 func (c *interConn) Close() error { 177 return c.stream.Close() 178 } 179 180 func (c *interConn) LocalAddr() net.Addr { 181 return c.local 182 } 183 184 func (c *interConn) RemoteAddr() net.Addr { 185 return c.remote 186 } 187 188 func (c *interConn) SetDeadline(t time.Time) error { 189 return c.stream.SetDeadline(t) 190 } 191 192 func (c *interConn) SetReadDeadline(t time.Time) error { 193 return c.stream.SetReadDeadline(t) 194 } 195 196 func (c *interConn) SetWriteDeadline(t time.Time) error { 197 return c.stream.SetWriteDeadline(t) 198 }