github.com/sagernet/sing-shadowsocks@v0.2.6/shadowaead/protocol.go (about) 1 package shadowaead 2 3 import ( 4 "crypto/aes" 5 "crypto/cipher" 6 "crypto/sha1" 7 "io" 8 "net" 9 10 "github.com/sagernet/sing-shadowsocks" 11 "github.com/sagernet/sing/common" 12 "github.com/sagernet/sing/common/buf" 13 M "github.com/sagernet/sing/common/metadata" 14 N "github.com/sagernet/sing/common/network" 15 "github.com/sagernet/sing/common/rw" 16 17 "golang.org/x/crypto/chacha20poly1305" 18 "golang.org/x/crypto/hkdf" 19 ) 20 21 var List = []string{ 22 "aes-128-gcm", 23 "aes-192-gcm", 24 "aes-256-gcm", 25 "chacha20-ietf-poly1305", 26 "xchacha20-ietf-poly1305", 27 } 28 29 var _ shadowsocks.Method = (*Method)(nil) 30 31 func New(method string, key []byte, password string) (*Method, error) { 32 m := &Method{ 33 name: method, 34 } 35 switch method { 36 case "aes-128-gcm": 37 m.keySaltLength = 16 38 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 39 case "aes-192-gcm": 40 m.keySaltLength = 24 41 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 42 case "aes-256-gcm": 43 m.keySaltLength = 32 44 m.constructor = aeadCipher(aes.NewCipher, cipher.NewGCM) 45 case "chacha20-ietf-poly1305": 46 m.keySaltLength = 32 47 m.constructor = chacha20poly1305.New 48 case "xchacha20-ietf-poly1305": 49 m.keySaltLength = 32 50 m.constructor = chacha20poly1305.NewX 51 } 52 if len(key) == m.keySaltLength { 53 m.key = key 54 } else if len(key) > 0 { 55 return nil, shadowsocks.ErrBadKey 56 } else if password == "" { 57 return nil, shadowsocks.ErrMissingPassword 58 } else { 59 m.key = shadowsocks.Key([]byte(password), m.keySaltLength) 60 } 61 return m, nil 62 } 63 64 func Kdf(key, iv []byte, buffer *buf.Buffer) { 65 kdf := hkdf.New(sha1.New, key, iv, []byte("ss-subkey")) 66 common.Must1(buffer.ReadFullFrom(kdf, buffer.FreeLen())) 67 } 68 69 func aeadCipher(block func(key []byte) (cipher.Block, error), aead func(block cipher.Block) (cipher.AEAD, error)) func(key []byte) (cipher.AEAD, error) { 70 return func(key []byte) (cipher.AEAD, error) { 71 b, err := block(key) 72 if err != nil { 73 return nil, err 74 } 75 return aead(b) 76 } 77 } 78 79 type Method struct { 80 name string 81 keySaltLength int 82 constructor func(key []byte) (cipher.AEAD, error) 83 key []byte 84 } 85 86 func (m *Method) Name() string { 87 return m.name 88 } 89 90 func (m *Method) DialConn(conn net.Conn, destination M.Socksaddr) (net.Conn, error) { 91 shadowsocksConn := &clientConn{ 92 Conn: conn, 93 Method: m, 94 destination: destination, 95 } 96 return shadowsocksConn, shadowsocksConn.writeRequest(nil) 97 } 98 99 func (m *Method) DialEarlyConn(conn net.Conn, destination M.Socksaddr) net.Conn { 100 return &clientConn{ 101 Conn: conn, 102 Method: m, 103 destination: destination, 104 } 105 } 106 107 func (m *Method) DialPacketConn(conn net.Conn) N.NetPacketConn { 108 return &clientPacketConn{m, conn} 109 } 110 111 type clientConn struct { 112 net.Conn 113 *Method 114 destination M.Socksaddr 115 reader *Reader 116 writer *Writer 117 } 118 119 func (c *clientConn) writeRequest(payload []byte) error { 120 salt := buf.NewSize(c.keySaltLength) 121 defer salt.Release() 122 salt.WriteRandom(c.keySaltLength) 123 124 key := buf.NewSize(c.keySaltLength) 125 126 Kdf(c.key, salt.Bytes(), key) 127 writeCipher, err := c.constructor(key.Bytes()) 128 key.Release() 129 if err != nil { 130 return err 131 } 132 writer := NewWriter(c.Conn, writeCipher, MaxPacketSize) 133 header := writer.Buffer() 134 common.Must1(header.Write(salt.Bytes())) 135 bufferedWriter := writer.BufferedWriter(header.Len()) 136 137 if len(payload) > 0 { 138 err = M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination) 139 if err != nil { 140 return err 141 } 142 143 _, err = bufferedWriter.Write(payload) 144 if err != nil { 145 return err 146 } 147 } else { 148 err = M.SocksaddrSerializer.WriteAddrPort(bufferedWriter, c.destination) 149 if err != nil { 150 return err 151 } 152 } 153 154 err = bufferedWriter.Flush() 155 if err != nil { 156 return err 157 } 158 159 c.writer = writer 160 return nil 161 } 162 163 func (c *clientConn) readResponse() error { 164 salt := buf.NewSize(c.keySaltLength) 165 defer salt.Release() 166 _, err := salt.ReadFullFrom(c.Conn, c.keySaltLength) 167 if err != nil { 168 return err 169 } 170 key := buf.NewSize(c.keySaltLength) 171 defer key.Release() 172 Kdf(c.key, salt.Bytes(), key) 173 readCipher, err := c.constructor(key.Bytes()) 174 if err != nil { 175 return err 176 } 177 c.reader = NewReader( 178 c.Conn, 179 readCipher, 180 MaxPacketSize, 181 ) 182 return nil 183 } 184 185 func (c *clientConn) Read(p []byte) (n int, err error) { 186 if c.reader == nil { 187 if err = c.readResponse(); err != nil { 188 return 189 } 190 } 191 return c.reader.Read(p) 192 } 193 194 func (c *clientConn) WriteTo(w io.Writer) (n int64, err error) { 195 if c.reader == nil { 196 if err = c.readResponse(); err != nil { 197 return 198 } 199 } 200 return c.reader.WriteTo(w) 201 } 202 203 func (c *clientConn) Write(p []byte) (n int, err error) { 204 if c.writer == nil { 205 err = c.writeRequest(p) 206 if err != nil { 207 return 208 } 209 return len(p), nil 210 } 211 return c.writer.Write(p) 212 } 213 214 func (c *clientConn) NeedHandshake() bool { 215 return c.writer == nil 216 } 217 218 func (c *clientConn) NeedAdditionalReadDeadline() bool { 219 return true 220 } 221 222 func (c *clientConn) Upstream() any { 223 return c.Conn 224 } 225 226 type clientPacketConn struct { 227 *Method 228 net.Conn 229 } 230 231 func (c *clientPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 232 defer buffer.Release() 233 header := buf.With(buffer.ExtendHeader(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))) 234 header.WriteRandom(c.keySaltLength) 235 err := M.SocksaddrSerializer.WriteAddrPort(header, destination) 236 if err != nil { 237 return err 238 } 239 key := buf.NewSize(c.keySaltLength) 240 Kdf(c.key, buffer.To(c.keySaltLength), key) 241 writeCipher, err := c.constructor(key.Bytes()) 242 key.Release() 243 if err != nil { 244 return err 245 } 246 writeCipher.Seal(buffer.Index(c.keySaltLength), rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(c.keySaltLength), nil) 247 buffer.Extend(Overhead) 248 return common.Error(c.Write(buffer.Bytes())) 249 } 250 251 func (c *clientPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) { 252 n, err := c.Read(buffer.FreeBytes()) 253 if err != nil { 254 return M.Socksaddr{}, err 255 } 256 buffer.Truncate(n) 257 if buffer.Len() < c.keySaltLength { 258 return M.Socksaddr{}, io.ErrShortBuffer 259 } 260 key := buf.NewSize(c.keySaltLength) 261 Kdf(c.key, buffer.To(c.keySaltLength), key) 262 readCipher, err := c.constructor(key.Bytes()) 263 key.Release() 264 if err != nil { 265 return M.Socksaddr{}, err 266 } 267 packet, err := readCipher.Open(buffer.Index(c.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(c.keySaltLength), nil) 268 if err != nil { 269 return M.Socksaddr{}, err 270 } 271 buffer.Advance(c.keySaltLength) 272 buffer.Truncate(len(packet)) 273 if err != nil { 274 return M.Socksaddr{}, err 275 } 276 return M.SocksaddrSerializer.ReadAddrPort(buffer) 277 } 278 279 func (c *clientPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 280 buffer := buf.With(p) 281 destination, err := c.ReadPacket(buffer) 282 if err != nil { 283 return 284 } 285 if destination.IsFqdn() { 286 addr = destination 287 } else { 288 addr = destination.UDPAddr() 289 } 290 n = copy(p, buffer.Bytes()) 291 return 292 } 293 294 func (c *clientPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 295 destination := M.SocksaddrFromNet(addr) 296 buffer := buf.NewSize(c.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination) + len(p) + Overhead) 297 defer buffer.Release() 298 buffer.Resize(c.keySaltLength+M.SocksaddrSerializer.AddrPortLen(destination), 0) 299 common.Must1(buffer.Write(p)) 300 err = c.WritePacket(buffer, destination) 301 if err != nil { 302 return 303 } 304 return len(p), nil 305 } 306 307 func (c *clientPacketConn) FrontHeadroom() int { 308 return c.keySaltLength + M.MaxSocksaddrLength 309 } 310 311 func (c *clientPacketConn) RearHeadroom() int { 312 return Overhead 313 } 314 315 func (c *clientPacketConn) ReaderMTU() int { 316 return MaxPacketSize 317 } 318 319 func (c *clientPacketConn) WriterMTU() int { 320 return MaxPacketSize 321 } 322 323 func (c *clientPacketConn) Upstream() any { 324 return c.Conn 325 }