github.com/sagernet/sing-shadowsocks2@v0.2.0/shadowaead/method.go (about) 1 package shadowaead 2 3 import ( 4 "context" 5 "crypto/aes" 6 "crypto/cipher" 7 "net" 8 9 C "github.com/sagernet/sing-shadowsocks2/cipher" 10 "github.com/sagernet/sing-shadowsocks2/internal/legacykey" 11 "github.com/sagernet/sing-shadowsocks2/internal/shadowio" 12 "github.com/sagernet/sing/common" 13 "github.com/sagernet/sing/common/buf" 14 "github.com/sagernet/sing/common/bufio" 15 E "github.com/sagernet/sing/common/exceptions" 16 M "github.com/sagernet/sing/common/metadata" 17 N "github.com/sagernet/sing/common/network" 18 "github.com/sagernet/sing/common/rw" 19 20 "golang.org/x/crypto/chacha20poly1305" 21 ) 22 23 var MethodList = []string{ 24 "aes-128-gcm", 25 "aes-192-gcm", 26 "aes-256-gcm", 27 "chacha20-ietf-poly1305", 28 "xchacha20-ietf-poly1305", 29 } 30 31 func init() { 32 C.RegisterMethod(MethodList, func(ctx context.Context, methodName string, options C.MethodOptions) (C.Method, error) { 33 return NewMethod(ctx, methodName, options) 34 }) 35 } 36 37 type Method struct { 38 keySaltLength int 39 constructor func(key []byte) (cipher.AEAD, error) 40 key []byte 41 } 42 43 func NewMethod(ctx context.Context, methodName string, options C.MethodOptions) (*Method, error) { 44 m := &Method{} 45 switch methodName { 46 case "aes-128-gcm": 47 m.keySaltLength = 16 48 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 49 case "aes-192-gcm": 50 m.keySaltLength = 24 51 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 52 case "aes-256-gcm": 53 m.keySaltLength = 32 54 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 55 case "chacha20-ietf-poly1305": 56 m.keySaltLength = 32 57 m.constructor = chacha20poly1305.New 58 case "xchacha20-ietf-poly1305": 59 m.keySaltLength = 32 60 m.constructor = chacha20poly1305.NewX 61 } 62 if len(options.Key) == m.keySaltLength { 63 m.key = options.Key 64 } else if len(options.Key) > 0 { 65 return nil, E.New("bad key length, required ", m.keySaltLength, ", got ", len(options.Key)) 66 } else if options.Password == "" { 67 return nil, C.ErrMissingPassword 68 } else { 69 m.key = legacykey.Key([]byte(options.Password), m.keySaltLength) 70 } 71 return m, nil 72 } 73 74 func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) { 75 return func(key []byte) (cipher.AEAD, error) { 76 b, err := block(key) 77 if err != nil { 78 return nil, err 79 } 80 return aead(b) 81 } 82 } 83 84 func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { 85 ssConn := &clientConn{ 86 Conn: conn, 87 method: m, 88 destination: destination, 89 } 90 return ssConn, ssConn.writeRequest(nil) 91 } 92 93 func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { 94 return &clientConn{ 95 Conn: conn, 96 method: m, 97 destination: destination, 98 } 99 } 100 101 func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { 102 return &clientPacketConn{ 103 AbstractConn: conn, 104 reader: bufio.NewExtendedReader(conn), 105 writer: bufio.NewExtendedWriter(conn), 106 method: m, 107 } 108 } 109 110 var _ N.ExtendedConn = (*clientConn)(nil) 111 112 type clientConn struct { 113 net.Conn 114 method *Method 115 destination M.Socksaddr 116 reader *shadowio.Reader 117 readWaitOptions N.ReadWaitOptions 118 writer *shadowio.Writer 119 shadowio.WriterInterface 120 } 121 122 func (c *clientConn) writeRequest(payload []byte) error { 123 requestBuffer := buf.New() 124 requestBuffer.WriteRandom(c.method.keySaltLength) 125 key := make([]byte, c.method.keySaltLength) 126 legacykey.Kdf(c.method.key, requestBuffer.Bytes(), key) 127 writeCipher, err := c.method.constructor(key) 128 if err != nil { 129 return err 130 } 131 bufferedRequestWriter := bufio.NewBufferedWriter(c.Conn, requestBuffer) 132 requestContentWriter := shadowio.NewWriter(bufferedRequestWriter, writeCipher, nil, MaxPacketSize) 133 bufferedRequestContentWriter := bufio.NewBufferedWriter(requestContentWriter, buf.New()) 134 err = M.SocksaddrSerializer.WriteAddrPort(bufferedRequestContentWriter, c.destination) 135 if err != nil { 136 return err 137 } 138 _, err = bufferedRequestContentWriter.Write(payload) 139 if err != nil { 140 return err 141 } 142 err = bufferedRequestContentWriter.Fallthrough() 143 if err != nil { 144 return err 145 } 146 err = bufferedRequestWriter.Fallthrough() 147 if err != nil { 148 return err 149 } 150 c.writer = shadowio.NewWriter(c.Conn, writeCipher, requestContentWriter.TakeNonce(), MaxPacketSize) 151 return nil 152 } 153 154 func (c *clientConn) readResponse() error { 155 buffer := buf.NewSize(c.method.keySaltLength) 156 defer buffer.Release() 157 _, err := buffer.ReadFullFrom(c.Conn, c.method.keySaltLength) 158 if err != nil { 159 return err 160 } 161 legacykey.Kdf(c.method.key, buffer.Bytes(), buffer.Bytes()) 162 readCipher, err := c.method.constructor(buffer.Bytes()) 163 if err != nil { 164 return err 165 } 166 reader := shadowio.NewReader(c.Conn, readCipher) 167 reader.InitializeReadWaiter(c.readWaitOptions) 168 c.reader = reader 169 return nil 170 } 171 172 func (c *clientConn) Read(p []byte) (n int, err error) { 173 if c.reader == nil { 174 err = c.readResponse() 175 if err != nil { 176 return 177 } 178 } 179 return c.reader.Read(p) 180 } 181 182 func (c *clientConn) ReadBuffer(buffer *buf.Buffer) error { 183 if c.reader == nil { 184 err := c.readResponse() 185 if err != nil { 186 return err 187 } 188 } 189 return c.reader.ReadBuffer(buffer) 190 } 191 192 func (c *clientConn) Write(p []byte) (n int, err error) { 193 if c.writer == nil { 194 err = c.writeRequest(p) 195 if err == nil { 196 n = len(p) 197 } 198 return 199 } 200 return c.writer.Write(p) 201 } 202 203 func (c *clientConn) WriteBuffer(buffer *buf.Buffer) error { 204 if c.writer == nil { 205 defer buffer.Release() 206 return c.writeRequest(buffer.Bytes()) 207 } 208 return c.writer.WriteBuffer(buffer) 209 } 210 211 func (c *clientConn) NeedHandshake() bool { 212 return c.writer == nil 213 } 214 215 func (c *clientConn) Upstream() any { 216 return c.Conn 217 } 218 219 func (c *clientConn) WriterMTU() int { 220 return MaxPacketSize 221 } 222 223 type clientPacketConn struct { 224 N.AbstractConn 225 reader N.ExtendedReader 226 writer N.ExtendedWriter 227 method *Method 228 } 229 230 func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 231 err = c.reader.ReadBuffer(buffer) 232 if err != nil { 233 return 234 } 235 return c.readPacket(buffer) 236 } 237 238 func (c *clientPacketConn) readPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 239 if buffer.Len() < c.method.keySaltLength { 240 return M.Socksaddr{}, C.ErrPacketTooShort 241 } 242 key := buf.NewSize(c.method.keySaltLength) 243 legacykey.Kdf(c.method.key, buffer.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength)) 244 readCipher, err := c.method.constructor(key.Bytes()) 245 key.Release() 246 if err != nil { 247 return 248 } 249 packet, err := readCipher.Open(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil) 250 if err != nil { 251 return 252 } 253 buffer.Advance(c.method.keySaltLength) 254 buffer.Truncate(len(packet)) 255 if err != nil { 256 return 257 } 258 destination, err = M.SocksaddrSerializer.ReadAddrPort(buffer) 259 if err != nil { 260 return 261 } 262 return destination.Unwrap(), nil 263 } 264 265 func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 266 header := buf.With(buffer.ExtendHeader(c.method.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))) 267 header.WriteRandom(c.method.keySaltLength) 268 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 269 if err != nil { 270 return err 271 } 272 key := buf.NewSize(c.method.keySaltLength) 273 legacykey.Kdf(c.method.key, header.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength)) 274 writeCipher, err := c.method.constructor(key.Bytes()) 275 key.Release() 276 if err != nil { 277 return err 278 } 279 writeCipher.Seal(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil) 280 buffer.Extend(shadowio.Overhead) 281 return c.writer.WriteBuffer(buffer) 282 } 283 284 func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 285 n, err = c.reader.Read(p) 286 if err != nil { 287 return 288 } 289 if n < c.method.keySaltLength { 290 err = C.ErrPacketTooShort 291 return 292 } 293 key := buf.NewSize(c.method.keySaltLength) 294 legacykey.Kdf(c.method.key, p[:c.method.keySaltLength], key.Extend(c.method.keySaltLength)) 295 readCipher, err := c.method.constructor(key.Bytes()) 296 key.Release() 297 if err != nil { 298 return 299 } 300 packet, err := readCipher.Open(p[c.method.keySaltLength:c.method.keySaltLength], rw.ZeroBytes[:readCipher.NonceSize()], p[c.method.keySaltLength:n], nil) 301 if err != nil { 302 return 303 } 304 packetContent := buf.As(packet) 305 destination, err := M.SocksaddrSerializer.ReadAddrPort(packetContent) 306 if err != nil { 307 return 308 } 309 if !destination.IsFqdn() { 310 addr = destination.UDPAddr() 311 } else { 312 addr = destination 313 } 314 n = copy(p, packetContent.Bytes()) 315 return 316 } 317 318 func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 319 destination := M.SocksaddrFromNet(addr) 320 buffer := buf.NewSize(c.method.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p) + shadowio.Overhead) 321 defer buffer.Release() 322 buffer.WriteRandom(c.method.keySaltLength) 323 err = M.SocksaddrSerializer.WriteAddrPort(buffer, destination) 324 if err != nil { 325 return 326 } 327 common.Must1(buffer.Write(p)) 328 key := buf.NewSize(c.method.keySaltLength) 329 legacykey.Kdf(c.method.key, buffer.To(c.method.keySaltLength), key.Extend(c.method.keySaltLength)) 330 writeCipher, err := c.method.constructor(key.Bytes()) 331 key.Release() 332 if err != nil { 333 return 334 } 335 writeCipher.Seal(buffer.Index(c.method.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.method.keySaltLength), nil) 336 buffer.Extend(shadowio.Overhead) 337 _, err = c.writer.Write(buffer.Bytes()) 338 if err != nil { 339 return 340 } 341 return len(p), nil 342 } 343 344 func (c *clientPacketConn) FrontHeadroom() int { 345 return c.method.keySaltLength + M.MaxSocksaddrLength 346 } 347 348 func (c *clientPacketConn) RearHeadroom() int { 349 return shadowio.Overhead 350 } 351 352 func (c *clientPacketConn) ReaderMTU() int { 353 return MaxPacketSize 354 } 355 356 func (c *clientPacketConn) WriterMTU() int { 357 return MaxPacketSize 358 } 359 360 func (c *clientPacketConn) Upstream() any { 361 return c.AbstractConn 362 }