github.com/sagernet/sing-shadowsocks@v0.2.6/shadowaead/service_multi.go (about)

     1  package shadowaead
     2  
     3  import (
     4  	"context"
     5  	"crypto/cipher"
     6  	"io"
     7  	"net"
     8  	"net/netip"
     9  
    10  	"github.com/sagernet/sing-shadowsocks"
    11  	"github.com/sagernet/sing/common/auth"
    12  	"github.com/sagernet/sing/common/buf"
    13  	"github.com/sagernet/sing/common/bufio/deadline"
    14  	E "github.com/sagernet/sing/common/exceptions"
    15  	M "github.com/sagernet/sing/common/metadata"
    16  	N "github.com/sagernet/sing/common/network"
    17  	"github.com/sagernet/sing/common/rw"
    18  	"github.com/sagernet/sing/common/udpnat"
    19  )
    20  
    21  var _ shadowsocks.MultiService[int] = (*MultiService[int])(nil)
    22  
    23  type MultiService[U comparable] struct {
    24  	name      string
    25  	methodMap map[U]*Method
    26  	handler   shadowsocks.Handler
    27  	udpNat    *udpnat.Service[netip.AddrPort]
    28  }
    29  
    30  func NewMultiService[U comparable](method string, udpTimeout int64, handler shadowsocks.Handler) (*MultiService[U], error) {
    31  	s := &MultiService[U]{
    32  		name:    method,
    33  		handler: handler,
    34  		udpNat:  udpnat.New[netip.AddrPort](udpTimeout, handler),
    35  	}
    36  	return s, nil
    37  }
    38  
    39  func (s *MultiService[U]) Name() string {
    40  	return s.name
    41  }
    42  
    43  func (s *MultiService[U]) UpdateUsers(userList []U, keyList [][]byte) error {
    44  	s.methodMap = make(map[U]*Method)
    45  	for i, user := range userList {
    46  		key := keyList[i]
    47  		method, err := New(s.name, key, "")
    48  		if err != nil {
    49  			return err
    50  		}
    51  		s.methodMap[user] = method
    52  	}
    53  	return nil
    54  }
    55  
    56  func (s *MultiService[U]) UpdateUsersWithPasswords(userList []U, passwordList []string) error {
    57  	s.methodMap = make(map[U]*Method)
    58  	for i, user := range userList {
    59  		password := passwordList[i]
    60  		method, err := New(s.name, nil, password)
    61  		if err != nil {
    62  			return err
    63  		}
    64  		s.methodMap[user] = method
    65  	}
    66  	return nil
    67  }
    68  
    69  func (s *MultiService[U]) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
    70  	err := s.newConnection(ctx, conn, metadata)
    71  	if err != nil {
    72  		err = &shadowsocks.ServerConnError{Conn: conn, Source: metadata.Source, Cause: err}
    73  	}
    74  	return err
    75  }
    76  
    77  func (s *MultiService[U]) newConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
    78  	var user U
    79  	var method *Method
    80  	for u, m := range s.methodMap {
    81  		user, method = u, m
    82  		break
    83  	}
    84  	if method == nil {
    85  		return shadowsocks.ErrNoUsers
    86  	}
    87  	header := buf.NewSize(method.keySaltLength + PacketLengthBufferSize + Overhead)
    88  	defer header.Release()
    89  
    90  	_, err := header.ReadFullFrom(conn, header.FreeLen())
    91  	if err != nil {
    92  		return E.Cause(err, "read header")
    93  	} else if !header.IsFull() {
    94  		return ErrBadHeader
    95  	}
    96  
    97  	var reader *Reader
    98  	var readCipher cipher.AEAD
    99  	for u, m := range s.methodMap {
   100  		key := buf.NewSize(method.keySaltLength)
   101  		Kdf(m.key, header.To(m.keySaltLength), key)
   102  		readCipher, err = m.constructor(key.Bytes())
   103  		key.Release()
   104  		if err != nil {
   105  			return err
   106  		}
   107  		reader = NewReader(conn, readCipher, MaxPacketSize)
   108  
   109  		err = reader.ReadWithLengthChunk(header.From(method.keySaltLength))
   110  		if err != nil {
   111  			continue
   112  		}
   113  
   114  		user, method = u, m
   115  		break
   116  	}
   117  	if err != nil {
   118  		return err
   119  	}
   120  
   121  	destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
   122  	if err != nil {
   123  		return err
   124  	}
   125  
   126  	metadata.Protocol = "shadowsocks"
   127  	metadata.Destination = destination
   128  
   129  	return s.handler.NewConnection(auth.ContextWithUser(ctx, user), deadline.NewConn(&serverConn{
   130  		Method: method,
   131  		Conn:   conn,
   132  		reader: reader,
   133  	}), metadata)
   134  }
   135  
   136  func (s *MultiService[U]) WriteIsThreadUnsafe() {
   137  }
   138  
   139  func (s *MultiService[U]) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   140  	err := s.newPacket(ctx, conn, buffer, metadata)
   141  	if err != nil {
   142  		err = &shadowsocks.ServerPacketError{Source: metadata.Source, Cause: err}
   143  	}
   144  	return err
   145  }
   146  
   147  func (s *MultiService[U]) newPacket(ctx context.Context, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) error {
   148  	var user U
   149  	var method *Method
   150  	for u, m := range s.methodMap {
   151  		user, method = u, m
   152  		break
   153  	}
   154  	if method == nil {
   155  		return shadowsocks.ErrNoUsers
   156  	}
   157  	if buffer.Len() < method.keySaltLength {
   158  		return io.ErrShortBuffer
   159  	}
   160  	var readCipher cipher.AEAD
   161  	var err error
   162  	for u, m := range s.methodMap {
   163  		key := buf.NewSize(m.keySaltLength)
   164  		Kdf(m.key, buffer.To(m.keySaltLength), key)
   165  		readCipher, err = m.constructor(key.Bytes())
   166  		key.Release()
   167  		if err != nil {
   168  			return err
   169  		}
   170  		var packet []byte
   171  		packet, err = readCipher.Open(buffer.Index(m.keySaltLength), rw.ZeroBytes[:readCipher.NonceSize()], buffer.From(m.keySaltLength), nil)
   172  		if err != nil {
   173  			continue
   174  		}
   175  
   176  		buffer.Advance(m.keySaltLength)
   177  		buffer.Truncate(len(packet))
   178  
   179  		user, method = u, m
   180  		break
   181  	}
   182  	if err != nil {
   183  		return err
   184  	}
   185  
   186  	destination, err := M.SocksaddrSerializer.ReadAddrPort(buffer)
   187  	if err != nil {
   188  		return err
   189  	}
   190  
   191  	metadata.Protocol = "shadowsocks"
   192  	metadata.Destination = destination
   193  	s.udpNat.NewPacket(auth.ContextWithUser(ctx, user), metadata.Source.AddrPort(), buffer, metadata, func(natConn N.PacketConn) N.PacketWriter {
   194  		return &serverPacketWriter{method, conn, natConn}
   195  	})
   196  	return nil
   197  }
   198  
   199  func (s *MultiService[U]) NewError(ctx context.Context, err error) {
   200  	s.handler.NewError(ctx, err)
   201  }