github.com/sagernet/sing-shadowsocks@v0.2.6/shadowaead/service.go (about) 1 package shadowaead 2 3 import ( 4 "context" 5 "crypto/rand" 6 "io" 7 "net" 8 "net/netip" 9 "sync" 10 11 "github.com/sagernet/sing-shadowsocks" 12 "github.com/sagernet/sing/common" 13 "github.com/sagernet/sing/common/buf" 14 E "github.com/sagernet/sing/common/exceptions" 15 M "github.com/sagernet/sing/common/metadata" 16 N "github.com/sagernet/sing/common/network" 17 "github.com/sagernet/sing/common/rw" 18 "github.com/sagernet/sing/common/udpnat" 19 ) 20 21 var ErrBadHeader = E.New("bad header") 22 23 var _ shadowsocks.Service = (*Service)(nil) 24 25 type Service struct { 26 *Method 27 password string 28 handler shadowsocks.Handler 29 udpNat *udpnat.Service[netip.AddrPort] 30 } 31 32 func NewService(method string, key []byte, password string, udpTimeout int64, handler shadowsocks.Handler) (*Service, error) { 33 m, err := New(method, key, password) 34 if err != nil { 35 return nil, err 36 } 37 s := &Service{ 38 Method: m, 39 handler: handler, 40 udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler), 41 } 42 return s, nil 43 } 44 45 func (s *Service) Name() string { 46 return s.name 47 } 48 49 func (s *Service) Password() string { 50 return s.password 51 } 52 53 func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 54 err := s.newConnection(ctx, conn, metadata) 55 if err != nil { 56 err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err} 57 } 58 return err 59 } 60 61 func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { 62 header := buf.NewSize(s.keySaltLength + PacketLengthBufferSize + Overhead) 63 defer header.Release() 64 65 _, err := header.ReadFullFrom(conn, header.FreeLen()) 66 if err != nil { 67 return E.Cause(err, "read header") 68 } else if !header.IsFull() { 69 return ErrBadHeader 70 } 71 72 key := buf.NewSize(s.keySaltLength) 73 Kdf(s.key, header.To(s.keySaltLength), key) 74 readCipher, err := s.constructor(key.Bytes()) 75 key.Release() 76 if err != nil { 77 return err 78 } 79 reader := NewReader(conn, readCipher, MaxPacketSize) 80 81 err = reader.ReadWithLengthChunk(header.From(s.keySaltLength)) 82 if err != nil { 83 return err 84 } 85 86 destination, err := M.SocksaddrSerializer.ReadAddrPort(reader) 87 if err != nil { 88 return err 89 } 90 91 metadata.Protocol = "shadowsocks" 92 metadata.Destination = destination 93 94 return s.handler.NewConnection(ctx, &serverConn{ 95 Method: s.Method, 96 Conn: conn, 97 reader: reader, 98 }, metadata) 99 } 100 101 func (s *Service) NewError(ctx context.Context, err error) { 102 s.handler.NewError(ctx, err) 103 } 104 105 type serverConn struct { 106 *Method 107 net.Conn 108 access sync.Mutex 109 reader *Reader 110 writer *Writer 111 } 112 113 func (c *serverConn) writeResponse(payload []byte) (n int, err error) { 114 salt := buf.NewSize(c.keySaltLength) 115 salt.WriteRandom(c.keySaltLength) 116 117 key := buf.NewSize(c.keySaltLength) 118 119 Kdf(c.key, salt.Bytes(), key) 120 writeCipher, err := c.constructor(key.Bytes()) 121 key.Release() 122 if err != nil { 123 salt.Release() 124 return 125 } 126 writer := NewWriter(c.Conn, writeCipher, MaxPacketSize) 127 128 header := writer.Buffer() 129 common.Must1(header.Write(salt.Bytes())) 130 salt.Release() 131 132 bufferedWriter := writer.BufferedWriter(header.Len()) 133 if len(payload) > 0 { 134 n, err = bufferedWriter.Write(payload) 135 if err != nil { 136 return 137 } 138 } 139 140 err = bufferedWriter.Flush() 141 if err != nil { 142 return 143 } 144 145 c.writer = writer 146 return 147 } 148 149 func (c *serverConn) Read(b []byte) (n int, err error) { 150 return c.reader.Read(b) 151 } 152 153 func (c *serverConn) Write(p []byte) (n int, err error) { 154 if c.writer != nil { 155 return c.writer.Write(p) 156 } 157 c.access.Lock() 158 if c.writer != nil { 159 c.access.Unlock() 160 return c.writer.Write(p) 161 } 162 defer c.access.Unlock() 163 return c.writeResponse(p) 164 } 165 166 func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) { 167 return c.reader.WriteTo(w) 168 } 169 170 func (c *serverConn) NeedAdditionalReadDeadline() bool { 171 return true 172 } 173 174 func (c *serverConn) Upstream() any { 175 return c.Conn 176 } 177 178 func (c *serverConn) ReaderMTU() int { 179 return MaxPacketSize 180 } 181 182 func (c *Service) WriteIsThreadUnsafe() { 183 } 184 185 func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { 186 err := s.newPacket(ctx, conn, buffer, metadata) 187 if err != nil { 188 err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err} 189 } 190 return err 191 } 192 193 func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error { 194 if buffer.Len() < s.keySaltLength { 195 return io.ErrShortBuffer 196 } 197 key := buf.NewSize(s.keySaltLength) 198 Kdf(s.key, buffer.To(s.keySaltLength), key) 199 readCipher, err := s.constructor(key.Bytes()) 200 key.Release() 201 if err != nil { 202 return err 203 } 204 packet, err := readCipher.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(s.keySaltLength), nil) 205 if err != nil { 206 return err 207 } 208 buffer.Advance(s.keySaltLength) 209 buffer.Truncate(len(packet)) 210 211 destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer) 212 if err != nil { 213 return err 214 } 215 216 metadata.Protocol = "shadowsocks" 217 metadata.Destination = destination 218 s.udpNat.NewPacket(ctx, metadata.Source.AddrPort(), buffer, metadata, func(natConn N.PacketConn) N.PacketWriter { 219 return &serverPacketWriter{s.Method, conn, natConn} 220 }) 221 return nil 222 } 223 224 type serverPacketWriter struct { 225 *Method 226 source N.PacketConn 227 nat N.PacketConn 228 } 229 230 func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 231 header := buffer.ExtendHeader(w.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination)) 232 common.Must1(io.ReadFull(rand.Reader, header[:w.keySaltLength])) 233 err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination) 234 if err != nil { 235 buffer.Release() 236 return err 237 } 238 key := buf.NewSize(w.keySaltLength) 239 Kdf(w.key, buffer.To(w.keySaltLength), key) 240 writeCipher, err := w.constructor(key.Bytes()) 241 key.Release() 242 if err != nil { 243 return err 244 } 245 writeCipher.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(w.keySaltLength), nil) 246 buffer.Extend(Overhead) 247 return w.source.WritePacket(buffer, M.SocksaddrFromNet(w.nat.LocalAddr())) 248 } 249 250 func (w *serverPacketWriter) FrontHeadroom() int { 251 return w.keySaltLength + M.MaxSocksaddrLength 252 } 253 254 func (w *serverPacketWriter) RearHeadroom() int { 255 return Overhead 256 } 257 258 func (w *serverPacketWriter) WriterMTU() int { 259 return MaxPacketSize 260 } 261 262 func (w *serverPacketWriter) Upstream() any { 263 return w.source 264 } 265 266 func (w *serverPacketWriter) ReaderMTU() int { 267 return MaxPacketSize 268 } 269 270 func (w *serverPacketWriter) WriteIsThreadUnsafe() { 271 }