github.com/metacubex/sing-shadowsocks@v0.2.6/shadowaead/service_multi.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"context"
     5  	"crypto/cipher"
     6  	"io"
     7  	"net"
     8  	"net/netip"
     9  
    10  	"github.com/metacubex/sing-shadowsocks"
    11  	"github.com/sagernet/sing/common/auth"
    12  	"github.com/sagernet/sing/common/buf"
    13  	E "github.com/sagernet/sing/common/exceptions"
    14  	M "github.com/sagernet/sing/common/metadata"
    15  	N "github.com/sagernet/sing/common/network"
    16  	"github.com/sagernet/sing/common/rw"
    17  	"github.com/sagernet/sing/common/udpnat"
    18  )
    19  
    20  var _ shadowsocks.MultiService[int] = (*MultiService[int])(nil)
    21  
    22  type MultiService[U comparable] struct {
    23  	name      string
    24  	methodMap map[U]*Method
    25  	handler   shadowsocks.Handler
    26  	udpNat    *udpnat.Service[netip.AddrPort]
    27  }
    28  
    29  func NewMultiService[U comparable](method string, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) {
    30  	s := &MultiService[U]{
    31  		name:    method,
    32  		handler: handler,
    33  		udpNat:  udpnat.New[netip.AddrPort](udpTimeout, handler),
    34  	}
    35  	return s, nil
    36  }
    37  
    38  func (s *MultiService[U]) Name() string {
    39  	return s.name
    40  }
    41  
    42  func (s *MultiService[U]) UpdateUsers(userList []U, keyList [][]byte) error {
    43  	s.methodMap = make(map[U]*Method)
    44  	for i, user := range userList {
    45  		key := keyList[i]
    46  		method, err := New(s.name, key, "")
    47  		if err != nil {
    48  			return err
    49  		}
    50  		s.methodMap[user] = method
    51  	}
    52  	return nil
    53  }
    54  
    55  func (s *MultiService[U]) UpdateUsersWithPasswords(userList []U, passwordList []string) error {
    56  	s.methodMap = make(map[U]*Method)
    57  	for i, user := range userList {
    58  		password := passwordList[i]
    59  		method, err := New(s.name, nil, password)
    60  		if err != nil {
    61  			return err
    62  		}
    63  		s.methodMap[user] = method
    64  	}
    65  	return nil
    66  }
    67  
    68  func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
    69  	err := s.newConnection(ctx, conn, metadata)
    70  	if err != nil {
    71  		err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
    72  	}
    73  	return err
    74  }
    75  
    76  func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
    77  	var user U
    78  	var method *Method
    79  	for u, m := range s.methodMap {
    80  		user, method = u, m
    81  		break
    82  	}
    83  	if method == nil {
    84  		return shadowsocks.ErrNoUsers
    85  	}
    86  	header := buf.NewSize(method.keySaltLength + PacketLengthBufferSize + Overhead)
    87  	defer header.Release()
    88  
    89  	_, err := header.ReadFullFrom(conn, header.FreeLen())
    90  	if err != nil {
    91  		return E.Cause(err, "read header")
    92  	} else if !header.IsFull() {
    93  		return ErrBadHeader
    94  	}
    95  
    96  	var reader *Reader
    97  	var readCipher cipher.AEAD
    98  	for u, m := range s.methodMap {
    99  		key := buf.NewSize(method.keySaltLength)
   100  		Kdf(m.key, header.To(m.keySaltLength), key)
   101  		readCipher, err = m.constructor(key.Bytes())
   102  		key.Release()
   103  		if err != nil {
   104  			return err
   105  		}
   106  		reader = NewReader(conn, readCipher, MaxPacketSize)
   107  
   108  		err = reader.ReadWithLengthChunk(header.From(method.keySaltLength))
   109  		if err != nil {
   110  			continue
   111  		}
   112  
   113  		user, method = u, m
   114  		break
   115  	}
   116  	if err != nil {
   117  		return err
   118  	}
   119  
   120  	destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
   121  	if err != nil {
   122  		return err
   123  	}
   124  
   125  	metadata.Protocol = "shadowsocks"
   126  	metadata.Destination = destination
   127  
   128  	return s.handler.NewConnection(auth.ContextWithUser(ctx, user), &serverConn{
   129  		Method: method,
   130  		Conn:   conn,
   131  		reader: reader,
   132  	}, metadata)
   133  }
   134  
   135  func (s *MultiService[U]) WriteIsThreadUnsafe() {
   136  }
   137  
   138  func (s *MultiService[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   139  	err := s.newPacket(ctx, conn, buffer, metadata)
   140  	if err != nil {
   141  		err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
   142  	}
   143  	return err
   144  }
   145  
   146  func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   147  	var user U
   148  	var method *Method
   149  	for u, m := range s.methodMap {
   150  		user, method = u, m
   151  		break
   152  	}
   153  	if method == nil {
   154  		return shadowsocks.ErrNoUsers
   155  	}
   156  	if buffer.Len() < method.keySaltLength {
   157  		return io.ErrShortBuffer
   158  	}
   159  	var readCipher cipher.AEAD
   160  	var err error
   161  	for u, m := range s.methodMap {
   162  		key := buf.NewSize(m.keySaltLength)
   163  		Kdf(m.key, buffer.To(m.keySaltLength), key)
   164  		readCipher, err = m.constructor(key.Bytes())
   165  		key.Release()
   166  		if err != nil {
   167  			return err
   168  		}
   169  		var packet []byte
   170  		packet, err = readCipher.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(m.keySaltLength), nil)
   171  		if err != nil {
   172  			continue
   173  		}
   174  
   175  		buffer.Advance(m.keySaltLength)
   176  		buffer.Truncate(len(packet))
   177  
   178  		user, method = u, m
   179  		break
   180  	}
   181  	if err != nil {
   182  		return err
   183  	}
   184  
   185  	destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
   186  	if err != nil {
   187  		return err
   188  	}
   189  
   190  	metadata.Protocol = "shadowsocks"
   191  	metadata.Destination = destination
   192  	s.udpNat.NewPacket(auth.ContextWithUser(ctx, user), metadata.Source.AddrPort(), buffer, metadata, func(natConn N.PacketConn) N.PacketWriter {
   193  		return &serverPacketWriter{method, conn, natConn}
   194  	})
   195  	return nil
   196  }
   197  
   198  func (s *MultiService[U]) NewError(ctx context.Context, err error) {
   199  	s.handler.NewError(ctx, err)
   200  }