github.com/metacubex/sing-shadowsocks@v0.2.6/shadowaead/service.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"io"
     7  	"net"
     8  	"net/netip"
     9  	"sync"
    10  
    11  	"github.com/metacubex/sing-shadowsocks"
    12  	"github.com/sagernet/sing/common"
    13  	"github.com/sagernet/sing/common/buf"
    14  	E "github.com/sagernet/sing/common/exceptions"
    15  	M "github.com/sagernet/sing/common/metadata"
    16  	N "github.com/sagernet/sing/common/network"
    17  	"github.com/sagernet/sing/common/rw"
    18  	"github.com/sagernet/sing/common/udpnat"
    19  )
    20  
    21  var ErrBadHeader = E.New("bad header")
    22  
    23  var _ shadowsocks.Service = (*Service)(nil)
    24  
    25  type Service struct {
    26  	*Method
    27  	password string
    28  	handler  shadowsocks.Handler
    29  	udpNat   *udpnat.Service[netip.AddrPort]
    30  }
    31  
    32  func NewService(method string, key []byte, password string, udpTimeout int64, handler shadowsocks.Handler) (*Service, error) {
    33  	m, err := New(method, key, password)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  	s := &Service{
    38  		Method:  m,
    39  		handler: handler,
    40  		udpNat:  udpnat.New[netip.AddrPort](udpTimeout, handler),
    41  	}
    42  	return s, nil
    43  }
    44  
    45  func (s *Service) Name() string {
    46  	return s.name
    47  }
    48  
    49  func (s *Service) Password() string {
    50  	return s.password
    51  }
    52  
    53  func (s *Service) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
    54  	err := s.newConnection(ctx, conn, metadata)
    55  	if err != nil {
    56  		err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
    57  	}
    58  	return err
    59  }
    60  
    61  func (s *Service) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
    62  	header := buf.NewSize(s.keySaltLength + PacketLengthBufferSize + Overhead)
    63  	defer header.Release()
    64  
    65  	_, err := header.ReadFullFrom(conn, header.FreeLen())
    66  	if err != nil {
    67  		return E.Cause(err, "read header")
    68  	} else if !header.IsFull() {
    69  		return ErrBadHeader
    70  	}
    71  
    72  	key := buf.NewSize(s.keySaltLength)
    73  	Kdf(s.key, header.To(s.keySaltLength), key)
    74  	readCipher, err := s.constructor(key.Bytes())
    75  	key.Release()
    76  	if err != nil {
    77  		return err
    78  	}
    79  	reader := NewReader(conn, readCipher, MaxPacketSize)
    80  
    81  	err = reader.ReadWithLengthChunk(header.From(s.keySaltLength))
    82  	if err != nil {
    83  		return err
    84  	}
    85  
    86  	destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
    87  	if err != nil {
    88  		return err
    89  	}
    90  
    91  	metadata.Protocol = "shadowsocks"
    92  	metadata.Destination = destination
    93  
    94  	return s.handler.NewConnection(ctx, &serverConn{
    95  		Method: s.Method,
    96  		Conn:   conn,
    97  		reader: reader,
    98  	}, metadata)
    99  }
   100  
   101  func (s *Service) NewError(ctx context.Context, err error) {
   102  	s.handler.NewError(ctx, err)
   103  }
   104  
   105  type serverConn struct {
   106  	*Method
   107  	net.Conn
   108  	access sync.Mutex
   109  	reader *Reader
   110  	writer *Writer
   111  }
   112  
   113  func (c *serverConn) writeResponse(payload []byte) (n int, err error) {
   114  	salt := buf.NewSize(c.keySaltLength)
   115  	salt.WriteRandom(c.keySaltLength)
   116  
   117  	key := buf.NewSize(c.keySaltLength)
   118  
   119  	Kdf(c.key, salt.Bytes(), key)
   120  	writeCipher, err := c.constructor(key.Bytes())
   121  	key.Release()
   122  	if err != nil {
   123  		salt.Release()
   124  		return
   125  	}
   126  	writer := NewWriter(c.Conn, writeCipher, MaxPacketSize)
   127  
   128  	header := writer.Buffer()
   129  	common.Must1(header.Write(salt.Bytes()))
   130  	salt.Release()
   131  
   132  	bufferedWriter := writer.BufferedWriter(header.Len())
   133  	if len(payload) > 0 {
   134  		n, err = bufferedWriter.Write(payload)
   135  		if err != nil {
   136  			return
   137  		}
   138  	}
   139  
   140  	err = bufferedWriter.Flush()
   141  	if err != nil {
   142  		return
   143  	}
   144  
   145  	c.writer = writer
   146  	return
   147  }
   148  
   149  func (c *serverConn) Read(b []byte) (n int, err error) {
   150  	return c.reader.Read(b)
   151  }
   152  
   153  func (c *serverConn) Write(p []byte) (n int, err error) {
   154  	if c.writer != nil {
   155  		return c.writer.Write(p)
   156  	}
   157  	c.access.Lock()
   158  	if c.writer != nil {
   159  		c.access.Unlock()
   160  		return c.writer.Write(p)
   161  	}
   162  	defer c.access.Unlock()
   163  	return c.writeResponse(p)
   164  }
   165  
   166  func (c *serverConn) WriteTo(w io.Writer) (n int64, err error) {
   167  	return c.reader.WriteTo(w)
   168  }
   169  
   170  func (c *serverConn) NeedAdditionalReadDeadline() bool {
   171  	return true
   172  }
   173  
   174  func (c *serverConn) Upstream() any {
   175  	return c.Conn
   176  }
   177  
   178  func (c *serverConn) ReaderMTU() int {
   179  	return MaxPacketSize
   180  }
   181  
   182  func (c *Service) WriteIsThreadUnsafe() {
   183  }
   184  
   185  func (s *Service) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   186  	err := s.newPacket(ctx, conn, buffer, metadata)
   187  	if err != nil {
   188  		err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
   189  	}
   190  	return err
   191  }
   192  
   193  func (s *Service) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   194  	if buffer.Len() < s.keySaltLength {
   195  		return io.ErrShortBuffer
   196  	}
   197  	key := buf.NewSize(s.keySaltLength)
   198  	Kdf(s.key, buffer.To(s.keySaltLength), key)
   199  	readCipher, err := s.constructor(key.Bytes())
   200  	key.Release()
   201  	if err != nil {
   202  		return err
   203  	}
   204  	packet, err := readCipher.Open(buffer.Index(s.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(s.keySaltLength), nil)
   205  	if err != nil {
   206  		return err
   207  	}
   208  	buffer.Advance(s.keySaltLength)
   209  	buffer.Truncate(len(packet))
   210  
   211  	destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
   212  	if err != nil {
   213  		return err
   214  	}
   215  
   216  	metadata.Protocol = "shadowsocks"
   217  	metadata.Destination = destination
   218  	s.udpNat.NewPacket(ctx, metadata.Source.AddrPort(), buffer, metadata, func(natConn N.PacketConn) N.PacketWriter {
   219  		return &serverPacketWriter{s.Method, conn, natConn}
   220  	})
   221  	return nil
   222  }
   223  
   224  type serverPacketWriter struct {
   225  	*Method
   226  	source N.PacketConn
   227  	nat    N.PacketConn
   228  }
   229  
   230  func (w *serverPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
   231  	header := buffer.ExtendHeader(w.keySaltLength + M.SocksaddrSerializer.AddrPortLen(destination))
   232  	common.Must1(io.ReadFull(rand.Reader, header[:w.keySaltLength]))
   233  	err := M.SocksaddrSerializer.WriteAddrPort(buf.With(header[w.keySaltLength:]), destination)
   234  	if err != nil {
   235  		buffer.Release()
   236  		return err
   237  	}
   238  	key := buf.NewSize(w.keySaltLength)
   239  	Kdf(w.key, buffer.To(w.keySaltLength), key)
   240  	writeCipher, err := w.constructor(key.Bytes())
   241  	key.Release()
   242  	if err != nil {
   243  		return err
   244  	}
   245  	writeCipher.Seal(buffer.From(w.keySaltLength)[:0], rw.ZeroBytes[:writeCipher.NonceSize()], buffer.From(w.keySaltLength), nil)
   246  	buffer.Extend(Overhead)
   247  	return w.source.WritePacket(buffer, M.SocksaddrFromNet(w.nat.LocalAddr()))
   248  }
   249  
   250  func (w *serverPacketWriter) FrontHeadroom() int {
   251  	return w.keySaltLength + M.MaxSocksaddrLength
   252  }
   253  
   254  func (w *serverPacketWriter) RearHeadroom() int {
   255  	return Overhead
   256  }
   257  
   258  func (w *serverPacketWriter) WriterMTU() int {
   259  	return MaxPacketSize
   260  }
   261  
   262  func (w *serverPacketWriter) Upstream() any {
   263  	return w.source
   264  }
   265  
   266  func (w *serverPacketWriter) ReaderMTU() int {
   267  	return MaxPacketSize
   268  }
   269  
   270  func (w *serverPacketWriter) WriteIsThreadUnsafe() {
   271  }