github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/ss2022/crypto.go (about)

     1  package ss2022
     2  
     3  import (
     4  	"crypto/aes"
     5  	"crypto/cipher"
     6  	"encoding/base64"
     7  	"fmt"
     8  
     9  	"lukechampine.com/blake3"
    10  )
    11  
    12  const (
    13  	subkeyCtxSession  = "shadowsocks 2022 session subkey"
    14  	subkeyCtxIdentity = "shadowsocks 2022 identity subkey"
    15  )
    16  
    17  func deriveSubkey(psk, salt []byte, ctx string) []byte {
    18  	if len(psk) == 0 || len(salt) == 0 {
    19  		panic("empty psk or salt")
    20  	}
    21  	keyMaterial := make([]byte, len(psk)+len(salt))
    22  	copy(keyMaterial, psk)
    23  	copy(keyMaterial[len(psk):], salt)
    24  	key := make([]byte, len(psk))
    25  	blake3.DeriveKey(key, ctx, keyMaterial)
    26  	return key
    27  }
    28  
    29  func newAES(psk, salt []byte, ctx string) (cipher.Block, error) {
    30  	key := deriveSubkey(psk, salt, ctx)
    31  	return aes.NewCipher(key)
    32  }
    33  
    34  func newAESGCM(psk, salt []byte) (cipher.AEAD, error) {
    35  	block, err := newAES(psk, salt, subkeyCtxSession)
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  	return cipher.NewGCM(block)
    40  }
    41  
    42  // UserCipherConfig stores cipher configuration for a non-EIH client/server or an EIH user.
    43  type UserCipherConfig struct {
    44  	PSK   []byte
    45  	block cipher.Block
    46  }
    47  
    48  // NewUserCipherConfig returns a new UserCipherConfig.
    49  func NewUserCipherConfig(psk []byte, enableUDP bool) (c UserCipherConfig, err error) {
    50  	c.PSK = psk
    51  	if enableUDP {
    52  		c.block, err = aes.NewCipher(psk)
    53  	}
    54  	return
    55  }
    56  
    57  // AEAD derives a subkey from the salt and returns a new AEAD cipher.
    58  func (c UserCipherConfig) AEAD(salt []byte) (cipher.AEAD, error) {
    59  	return newAESGCM(c.PSK, salt)
    60  }
    61  
    62  func (c UserCipherConfig) ShadowStreamCipher(salt []byte) (*ShadowStreamCipher, error) {
    63  	aead, err := c.AEAD(salt)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	return NewShadowStreamCipher(aead), nil
    68  }
    69  
    70  // Block returns the block cipher for UDP separate header.
    71  func (c UserCipherConfig) Block() cipher.Block {
    72  	return c.block
    73  }
    74  
    75  // ClientCipherConfig stores cipher configuration for a client.
    76  type ClientCipherConfig struct {
    77  	UserCipherConfig
    78  	iPSKs        [][]byte
    79  	eihCiphers   []cipher.Block
    80  	eihPSKHashes [][IdentityHeaderLength]byte
    81  }
    82  
    83  // TCPIdentityHeaderCiphers creates block ciphers for a client TCP session's identity headers.
    84  func (c *ClientCipherConfig) TCPIdentityHeaderCiphers(salt []byte) ([]cipher.Block, error) {
    85  	ciphers := make([]cipher.Block, len(c.iPSKs))
    86  
    87  	for i := range ciphers {
    88  		var err error
    89  		ciphers[i], err = newAES(c.iPSKs[i], salt, subkeyCtxIdentity)
    90  		if err != nil {
    91  			return nil, err
    92  		}
    93  	}
    94  
    95  	return ciphers, nil
    96  }
    97  
    98  // UDPIdentityHeaderCiphers returns the block ciphers for a client UDP service's identity headers.
    99  func (c *ClientCipherConfig) UDPIdentityHeaderCiphers() []cipher.Block {
   100  	return c.eihCiphers
   101  }
   102  
   103  // EIHPSKHashes returns the truncated BLAKE3 hashes of c.iPSKs[1:] and c.PSK.
   104  func (c *ClientCipherConfig) EIHPSKHashes() [][IdentityHeaderLength]byte {
   105  	return c.eihPSKHashes
   106  }
   107  
   108  // UDPSeparateHeaderPackerCipher returns the block cipher used by the client packer to encrypt the separate header.
   109  func (c *ClientCipherConfig) UDPSeparateHeaderPackerCipher() cipher.Block {
   110  	if len(c.eihCiphers) > 0 {
   111  		return c.eihCiphers[0]
   112  	}
   113  	return c.block
   114  }
   115  
   116  func udpIdentityHeaderClientCiphers(iPSKs [][]byte) ([]cipher.Block, error) {
   117  	ciphers := make([]cipher.Block, len(iPSKs))
   118  
   119  	for i := range ciphers {
   120  		var err error
   121  		ciphers[i], err = aes.NewCipher(iPSKs[i])
   122  		if err != nil {
   123  			return nil, err
   124  		}
   125  	}
   126  
   127  	return ciphers, nil
   128  }
   129  
   130  func clientPSKHashes(iPSKs [][]byte, psk []byte) [][IdentityHeaderLength]byte {
   131  	if len(iPSKs) == 0 {
   132  		return nil
   133  	}
   134  
   135  	hashes := make([][IdentityHeaderLength]byte, len(iPSKs))
   136  
   137  	for i := 1; i < len(iPSKs); i++ {
   138  		hash := blake3.Sum512(iPSKs[i])
   139  		hashes[i-1] = [IdentityHeaderLength]byte(hash[:])
   140  	}
   141  
   142  	hash := blake3.Sum512(psk)
   143  	hashes[len(hashes)-1] = [IdentityHeaderLength]byte(hash[:])
   144  
   145  	return hashes
   146  }
   147  
   148  // NewClientCipherConfig returns a new ClientCipherConfig.
   149  func NewClientCipherConfig(psk []byte, iPSKs [][]byte, enableUDP bool) (c *ClientCipherConfig, err error) {
   150  	c = &ClientCipherConfig{
   151  		UserCipherConfig: UserCipherConfig{
   152  			PSK: psk,
   153  		},
   154  		iPSKs:        iPSKs,
   155  		eihPSKHashes: clientPSKHashes(iPSKs, psk),
   156  	}
   157  	if enableUDP {
   158  		c.block, err = aes.NewCipher(psk)
   159  		if err != nil {
   160  			return
   161  		}
   162  		c.eihCiphers, err = udpIdentityHeaderClientCiphers(iPSKs)
   163  	}
   164  	return
   165  }
   166  
   167  // ServerIdentityCipherConfig stores cipher configuration for a server's identity header.
   168  type ServerIdentityCipherConfig struct {
   169  	IPSK  []byte
   170  	block cipher.Block
   171  }
   172  
   173  // NewServerIdentityCipherConfig returns a new ServerIdentityCipherConfig.
   174  func NewServerIdentityCipherConfig(iPSK []byte, enableUDP bool) (c ServerIdentityCipherConfig, err error) {
   175  	c.IPSK = iPSK
   176  	if enableUDP {
   177  		c.block, err = aes.NewCipher(iPSK)
   178  	}
   179  	return
   180  }
   181  
   182  // TCP creates a block cipher for a server TCP session's identity header.
   183  func (c ServerIdentityCipherConfig) TCP(salt []byte) (cipher.Block, error) {
   184  	return newAES(c.IPSK, salt, subkeyCtxIdentity)
   185  }
   186  
   187  // UDP returns the block cipher for a server UDP service's identity header.
   188  func (c ServerIdentityCipherConfig) UDP() cipher.Block {
   189  	return c.block
   190  }
   191  
   192  // ServerUserCipherConfig stores cipher configuration for a server's EIH user.
   193  type ServerUserCipherConfig struct {
   194  	UserCipherConfig
   195  	Name string
   196  }
   197  
   198  // NewServerUserCipherConfig returns a new ServerUserCipherConfig.
   199  func NewServerUserCipherConfig(name string, psk []byte, enableUDP bool) (c *ServerUserCipherConfig, err error) {
   200  	c = &ServerUserCipherConfig{Name: name}
   201  	c.UserCipherConfig, err = NewUserCipherConfig(psk, enableUDP)
   202  	return
   203  }
   204  
   205  // PSKHash returns the given PSK's BLAKE3 hash truncated to [IdentityHeaderLength] bytes.
   206  func PSKHash(psk []byte) [IdentityHeaderLength]byte {
   207  	hash := blake3.Sum512(psk)
   208  	return [IdentityHeaderLength]byte(hash[:])
   209  }
   210  
   211  // PSKLengthForMethod returns the required length of the PSK for the given method.
   212  func PSKLengthForMethod(method string) (int, error) {
   213  	switch method {
   214  	case "2022-blake3-aes-128-gcm":
   215  		return 16, nil
   216  	case "2022-blake3-aes-256-gcm":
   217  		return 32, nil
   218  	default:
   219  		return 0, fmt.Errorf("unknown method: %s", method)
   220  	}
   221  }
   222  
   223  type PSKLengthError struct {
   224  	PSK            []byte
   225  	ExpectedLength int
   226  }
   227  
   228  func (e PSKLengthError) Error() string {
   229  	return fmt.Sprintf("expected PSK length %d, got %d from %s", e.ExpectedLength, len(e.PSK), base64.StdEncoding.EncodeToString(e.PSK))
   230  }
   231  
   232  // CheckPSKLength checks that the PSK is the correct length for the given method.
   233  func CheckPSKLength(method string, psk []byte, psks [][]byte) error {
   234  	pskLength, err := PSKLengthForMethod(method)
   235  	if err != nil {
   236  		return err
   237  	}
   238  
   239  	if len(psk) != pskLength {
   240  		return &PSKLengthError{psk, pskLength}
   241  	}
   242  
   243  	for _, psk := range psks {
   244  		if len(psk) != pskLength {
   245  			return &PSKLengthError{psk, pskLength}
   246  		}
   247  	}
   248  
   249  	return nil
   250  }
   251  
   252  // UserLookupMap is a map of uPSK hashes to [*ServerUserCipherConfig].
   253  // Upon decryption of an identity header, the uPSK hash is looked up in this map.
   254  type UserLookupMap map[[IdentityHeaderLength]byte]*ServerUserCipherConfig