github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/shadowsocksr/protocol/auth_chain_a.go (about)

     1  // https://github.com/shadowsocksr-backup/shadowsocks-rss/blob/master/doc/auth_chain_a.md
     2  
     3  package protocol
     4  
     5  import (
     6  	"bytes"
     7  	"crypto"
     8  	"crypto/aes"
     9  	"crypto/cipher"
    10  	crand "crypto/rand"
    11  	"crypto/rc4"
    12  	"encoding/base64"
    13  	"encoding/binary"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/shadowsocks/core"
    19  	ssr "github.com/Asutorufa/yuhaiin/pkg/net/proxy/shadowsocksr/utils"
    20  )
    21  
    22  type authChainA struct {
    23  	Protocol
    24  	randomClient ssr.Shift128plusContext
    25  	randomServer ssr.Shift128plusContext
    26  	recvID       uint32
    27  
    28  	encrypter      cipher.Stream
    29  	decrypter      cipher.Stream
    30  	hasSentHeader  bool
    31  	lastClientHash []byte
    32  	lastServerHash []byte
    33  	userKey        []byte
    34  	userKeyLen     int
    35  	uid            [4]byte
    36  	salt           string
    37  	hmac           ssr.HMAC
    38  	rnd            func(dataLength int, random *ssr.Shift128plusContext, lastHash []byte, dataSizeList, dataSizeList2 []int, overhead int) int
    39  	dataSizeList   []int
    40  	dataSizeList2  []int
    41  	chunkID        uint32
    42  
    43  	overhead int
    44  }
    45  
    46  func NewAuthChainA(info Protocol) protocol { return newAuthChain(info, authChainAGetRandLen) }
    47  
    48  func newAuthChain(info Protocol, rnd func(dataLength int, random *ssr.Shift128plusContext, lastHash []byte, dataSizeList, dataSizeList2 []int, overhead int) int) *authChainA {
    49  	return &authChainA{
    50  		salt:     info.Name,
    51  		hmac:     ssr.HMAC(crypto.MD5),
    52  		rnd:      rnd,
    53  		recvID:   1,
    54  		Protocol: info,
    55  		overhead: 4 + info.ObfsOverhead,
    56  	}
    57  }
    58  
    59  func authChainAGetRandLen(dataLength int, random *ssr.Shift128plusContext, lastHash []byte, dataSizeList, dataSizeList2 []int, overhead int) int {
    60  	if dataLength > 1440 {
    61  		return 0
    62  	}
    63  	random.InitFromBinDatalen(lastHash[:16], dataLength)
    64  	if dataLength > 1300 {
    65  		return int(random.Next() % 31)
    66  	}
    67  	if dataLength > 900 {
    68  		return int(random.Next() % 127)
    69  	}
    70  	if dataLength > 400 {
    71  		return int(random.Next() % 521)
    72  	}
    73  	return int(random.Next() % 1021)
    74  }
    75  
    76  func getRandStartPos(random *ssr.Shift128plusContext, randLength int) int {
    77  	if randLength > 0 {
    78  		return int(random.Next() % 8589934609 % uint64(randLength))
    79  	}
    80  	return 0
    81  }
    82  
    83  func (a *authChainA) getClientRandLen(dataLength int, overhead int) int {
    84  	return a.rnd(dataLength, &a.randomClient, a.lastClientHash, a.dataSizeList, a.dataSizeList2, overhead)
    85  }
    86  
    87  func (a *authChainA) getServerRandLen(dataLength int, overhead int) int {
    88  	return a.rnd(dataLength, &a.randomServer, a.lastServerHash, a.dataSizeList, a.dataSizeList2, overhead)
    89  }
    90  
    91  func (a *authChainA) packedDataLen(data []byte) (chunkLength, randLength int) {
    92  	dataLength := len(data)
    93  	randLength = a.getClientRandLen(dataLength, a.overhead)
    94  	chunkLength = randLength + dataLength + 2 + 2
    95  	return
    96  }
    97  
    98  func (a *authChainA) packData(outData []byte, data []byte, randLength int) {
    99  	dataLength := len(data)
   100  	outLength := randLength + dataLength + 2
   101  	outData[0] = byte(dataLength) ^ a.lastClientHash[14]
   102  	outData[1] = byte(dataLength>>8) ^ a.lastClientHash[15]
   103  
   104  	{
   105  		if dataLength > 0 {
   106  			randPart1Length := getRandStartPos(&a.randomClient, randLength)
   107  			crand.Read(outData[2 : 2+randPart1Length])
   108  			a.encrypter.XORKeyStream(outData[2+randPart1Length:], data)
   109  			crand.Read(outData[2+randPart1Length+dataLength : outLength])
   110  		} else {
   111  			crand.Read(outData[2 : 2+randLength])
   112  		}
   113  	}
   114  
   115  	keyLen := a.userKeyLen + 4
   116  	key := make([]byte, keyLen)
   117  	copy(key, a.userKey)
   118  	a.chunkID++
   119  	binary.LittleEndian.PutUint32(key[a.userKeyLen:], a.chunkID)
   120  	a.lastClientHash = a.hmac.HMAC(key, outData[:outLength], nil)
   121  	copy(outData[outLength:], a.lastClientHash[:2])
   122  }
   123  
   124  const authheadLength = 4 + 8 + 4 + 16 + 4
   125  
   126  func (a *authChainA) packAuthData(data []byte) (outData []byte) {
   127  	outData = make([]byte, authheadLength, authheadLength+1500)
   128  
   129  	a.Protocol.Auth.nextAuth()
   130  
   131  	var key = make([]byte, a.IVSize()+len(a.Key()))
   132  	copy(key, a.IV)
   133  	copy(key[a.IVSize():], a.Key())
   134  
   135  	encrypt := make([]byte, 20)
   136  	t := time.Now().Unix()
   137  	binary.LittleEndian.PutUint32(encrypt[:4], uint32(t))
   138  	copy(encrypt[4:8], a.Protocol.Auth.clientID[:])
   139  	binary.LittleEndian.PutUint32(encrypt[8:], a.Protocol.Auth.connectionID.Load())
   140  	binary.LittleEndian.PutUint16(encrypt[12:], uint16(a.overhead))
   141  	binary.LittleEndian.PutUint16(encrypt[14:16], 0)
   142  
   143  	// first 12 bytes
   144  	{
   145  		crand.Read(outData[:4])
   146  		a.lastClientHash = a.hmac.HMAC(key, outData[:4], nil)
   147  		copy(outData[4:], a.lastClientHash[:8])
   148  	}
   149  	var base64UserKey string
   150  	// uid & 16 bytes auth data
   151  	{
   152  		uid := make([]byte, 4)
   153  		if a.userKey == nil {
   154  			params := strings.Split(a.Param, ":")
   155  			if len(params) >= 2 {
   156  				if userID, err := strconv.ParseUint(params[0], 10, 32); err == nil {
   157  					binary.LittleEndian.PutUint32(a.uid[:], uint32(userID))
   158  					a.userKeyLen = len(params[1])
   159  					a.userKey = []byte(params[1])
   160  				}
   161  			}
   162  			if a.userKey == nil {
   163  				crand.Read(a.uid[:])
   164  
   165  				a.userKeyLen = len(a.Key())
   166  				a.userKey = make([]byte, len(a.Key()))
   167  				copy(a.userKey, a.Key())
   168  			}
   169  		}
   170  		for i := 0; i < 4; i++ {
   171  			uid[i] = a.uid[i] ^ a.lastClientHash[8+i]
   172  		}
   173  		base64UserKey = base64.StdEncoding.EncodeToString(a.userKey)
   174  		aesCipherKey := core.KDF(base64UserKey+a.salt, 16)
   175  		block, err := aes.NewCipher(aesCipherKey)
   176  		if err != nil {
   177  			return
   178  		}
   179  		encryptData := make([]byte, 16)
   180  		iv := make([]byte, aes.BlockSize)
   181  		cbc := cipher.NewCBCEncrypter(block, iv)
   182  		cbc.CryptBlocks(encryptData, encrypt[:16])
   183  		copy(encrypt[:4], uid[:])
   184  		copy(encrypt[4:4+16], encryptData)
   185  	}
   186  	// final HMAC
   187  	{
   188  		a.lastServerHash = a.hmac.HMAC(a.userKey, encrypt[0:20], nil)
   189  
   190  		copy(outData[12:], encrypt)
   191  		copy(outData[12+20:], a.lastServerHash[:4])
   192  	}
   193  
   194  	// init cipher
   195  	password := make([]byte, len(base64UserKey)+base64.StdEncoding.EncodedLen(16))
   196  	copy(password, base64UserKey)
   197  	base64.StdEncoding.Encode(password[len(base64UserKey):], a.lastClientHash[:16])
   198  	a.initRC4Cipher(password)
   199  
   200  	// data
   201  	chunkLength, randLength := a.packedDataLen(data)
   202  	if authheadLength+chunkLength <= cap(outData) {
   203  		outData = outData[:authheadLength+chunkLength]
   204  	} else {
   205  		newOutData := make([]byte, authheadLength+chunkLength)
   206  		copy(newOutData, outData[:authheadLength])
   207  		outData = newOutData
   208  	}
   209  	a.packData(outData[authheadLength:], data, randLength)
   210  	return outData
   211  }
   212  
   213  func (a *authChainA) initRC4Cipher(key []byte) {
   214  	a.encrypter, _ = rc4.NewCipher(key)
   215  	a.decrypter, _ = rc4.NewCipher(key)
   216  }
   217  
   218  func (a *authChainA) EncryptStream(wbuf *bytes.Buffer, plainData []byte) (err error) {
   219  	dataLength := len(plainData)
   220  	offset := 0
   221  	if dataLength > 0 && !a.hasSentHeader {
   222  		headSize := 1200
   223  		if headSize > dataLength {
   224  			headSize = dataLength
   225  		}
   226  		wbuf.Write(a.packAuthData(plainData[:headSize]))
   227  		offset += headSize
   228  		dataLength -= headSize
   229  		a.hasSentHeader = true
   230  	}
   231  	var unitSize = a.TcpMss - a.overhead
   232  	for dataLength > unitSize {
   233  		dataLen, randLength := a.packedDataLen(plainData[offset : offset+unitSize])
   234  		b := make([]byte, dataLen)
   235  		a.packData(b, plainData[offset:offset+unitSize], randLength)
   236  		wbuf.Write(b)
   237  		dataLength -= unitSize
   238  		offset += unitSize
   239  	}
   240  	if dataLength > 0 {
   241  		dataLen, randLength := a.packedDataLen(plainData[offset:])
   242  		b := make([]byte, dataLen)
   243  		a.packData(b, plainData[offset:], randLength)
   244  		wbuf.Write(b)
   245  	}
   246  	return nil
   247  }
   248  
   249  func (a *authChainA) DecryptStream(rbuf *bytes.Buffer, plainData []byte) (n int, err error) {
   250  	key := make([]byte, len(a.userKey)+4)
   251  	readlenth := 0
   252  	copy(key, a.userKey)
   253  	for len(plainData) > 4 {
   254  		binary.LittleEndian.PutUint32(key[len(a.userKey):], a.recvID)
   255  		dataLen := (int)((uint(plainData[1]^a.lastServerHash[15]) << 8) + uint(plainData[0]^a.lastServerHash[14]))
   256  		randLen := a.getServerRandLen(dataLen, a.overhead)
   257  		length := randLen + dataLen
   258  		if length >= 4096 {
   259  			return 0, ssr.ErrAuthChainDataLengthError
   260  		}
   261  
   262  		length += 4
   263  		if length > len(plainData) {
   264  			break
   265  		}
   266  
   267  		hash := a.hmac.HMAC(key, plainData[:length-2], nil)
   268  		if !bytes.Equal(hash[:2], plainData[length-2:length]) {
   269  			return 0, ssr.ErrAuthChainIncorrectHMAC
   270  		}
   271  
   272  		dataPos := 2
   273  		if dataLen > 0 && randLen > 0 {
   274  			dataPos = 2 + getRandStartPos(&a.randomServer, randLen)
   275  		}
   276  
   277  		b := make([]byte, dataLen)
   278  		a.decrypter.XORKeyStream(b, plainData[dataPos:dataPos+dataLen])
   279  		rbuf.Write(b)
   280  		if a.recvID == 1 {
   281  			a.TcpMss = int(binary.LittleEndian.Uint16(rbuf.Next(2)))
   282  		}
   283  		a.lastServerHash = hash
   284  		a.recvID++
   285  		plainData = plainData[length:]
   286  		readlenth += length
   287  
   288  	}
   289  	return readlenth, nil
   290  }
   291  
   292  func (a *authChainA) GetOverhead() int {
   293  	return a.overhead
   294  }
   295  
   296  func (a *authChainA) EncryptPacket(b []byte) ([]byte, error) {
   297  	if a.userKey == nil {
   298  		params := strings.Split(a.Param, ":")
   299  		if len(params) >= 2 {
   300  			if userID, err := strconv.ParseUint(params[0], 10, 32); err == nil {
   301  				binary.LittleEndian.PutUint32(a.uid[:], uint32(userID))
   302  				a.userKeyLen = len(params[1])
   303  				a.userKey = []byte(params[1])
   304  			}
   305  		}
   306  		if a.userKey == nil {
   307  			crand.Read(a.uid[:])
   308  
   309  			a.userKeyLen = len(a.Key())
   310  			a.userKey = make([]byte, len(a.Key()))
   311  			copy(a.userKey, a.Key())
   312  		}
   313  	}
   314  	authData := make([]byte, 3)
   315  	crand.Read(authData)
   316  
   317  	md5Data := a.hmac.HMAC(a.userKey, authData, nil)
   318  	randDataLength := udpGetRandLength(md5Data, &a.randomClient)
   319  
   320  	key := core.KDF(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
   321  	rc4Cipher, err := rc4.NewCipher(key)
   322  	if err != nil {
   323  		return nil, err
   324  	}
   325  	wantedData := b[:len(b)-8-randDataLength]
   326  	rc4Cipher.XORKeyStream(wantedData, wantedData)
   327  	return wantedData, nil
   328  }
   329  
   330  func (a *authChainA) DecryptPacket(b []byte) ([]byte, error) {
   331  	if len(b) < 9 {
   332  		return nil, ssr.ErrAuthChainDataLengthError
   333  	}
   334  	if !bytes.Equal(a.hmac.HMAC(a.userKey, b[:len(b)-1], nil)[:1], b[len(b)-1:]) {
   335  		return nil, ssr.ErrAuthChainIncorrectHMAC
   336  	}
   337  	md5Data := a.hmac.HMAC(a.Key(), b[len(b)-8:len(b)-1], nil)
   338  
   339  	randDataLength := udpGetRandLength(md5Data, &a.randomServer)
   340  
   341  	key := core.KDF(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
   342  	rc4Cipher, err := rc4.NewCipher(key)
   343  	if err != nil {
   344  		return nil, err
   345  	}
   346  	wantedData := b[:len(b)-8-randDataLength]
   347  	rc4Cipher.XORKeyStream(wantedData, wantedData)
   348  	return wantedData, nil
   349  }
   350  
   351  func udpGetRandLength(lastHash []byte, random *ssr.Shift128plusContext) int {
   352  	random.InitFromBin(lastHash)
   353  	return int(random.Next() % 127)
   354  }