github.com/yaling888/clash@v1.53.0/transport/ssr/protocol/auth_chain_a.go (about)

     1  package protocol
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/cipher"
     6  	"crypto/rand"
     7  	"crypto/rc4"
     8  	"encoding/base64"
     9  	"encoding/binary"
    10  	"net"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"github.com/phuslu/log"
    15  
    16  	"github.com/yaling888/clash/common/pool"
    17  	"github.com/yaling888/clash/transport/shadowsocks/core"
    18  	"github.com/yaling888/clash/transport/ssr/tools"
    19  )
    20  
    21  func init() {
    22  	register("auth_chain_a", newAuthChainA, 4)
    23  }
    24  
    25  type randDataLengthMethod func(int, []byte, *tools.XorShift128Plus) int
    26  
    27  type authChainA struct {
    28  	*Base
    29  	*authData
    30  	*userData
    31  	iv             []byte
    32  	salt           string
    33  	hasSentHeader  bool
    34  	rawTrans       bool
    35  	lastClientHash []byte
    36  	lastServerHash []byte
    37  	encrypter      cipher.Stream
    38  	decrypter      cipher.Stream
    39  	randomClient   tools.XorShift128Plus
    40  	randomServer   tools.XorShift128Plus
    41  	randDataLength randDataLengthMethod
    42  	packID         uint32
    43  	recvID         uint32
    44  }
    45  
    46  func newAuthChainA(b *Base) Protocol {
    47  	a := &authChainA{
    48  		Base:     b,
    49  		authData: &authData{},
    50  		userData: &userData{},
    51  		salt:     "auth_chain_a",
    52  	}
    53  	a.initUserData()
    54  	return a
    55  }
    56  
    57  func (a *authChainA) initUserData() {
    58  	params := strings.Split(a.Param, ":")
    59  	if len(params) > 1 {
    60  		if userID, err := strconv.ParseUint(params[0], 10, 32); err == nil {
    61  			binary.LittleEndian.PutUint32(a.userID[:], uint32(userID))
    62  			a.userKey = []byte(params[1])
    63  		} else {
    64  			log.Warn().Msgf("[SSR] wrong protocol-param for %s, only digits are expected before ':'", a.salt)
    65  		}
    66  	}
    67  	if len(a.userKey) == 0 {
    68  		a.userKey = a.Key
    69  		_, _ = rand.Read(a.userID[:])
    70  	}
    71  }
    72  
    73  func (a *authChainA) StreamConn(c net.Conn, iv []byte) net.Conn {
    74  	p := &authChainA{
    75  		Base:     a.Base,
    76  		authData: a.next(),
    77  		userData: a.userData,
    78  		salt:     a.salt,
    79  		packID:   1,
    80  		recvID:   1,
    81  	}
    82  	p.iv = iv
    83  	p.randDataLength = p.getRandLength
    84  	return &Conn{Conn: c, Protocol: p}
    85  }
    86  
    87  func (a *authChainA) PacketConn(c net.PacketConn) net.PacketConn {
    88  	p := &authChainA{
    89  		Base:     a.Base,
    90  		salt:     a.salt,
    91  		userData: a.userData,
    92  	}
    93  	return &PacketConn{PacketConn: c, Protocol: p}
    94  }
    95  
    96  func (a *authChainA) Decode(dst, src *bytes.Buffer) error {
    97  	if a.rawTrans {
    98  		_, _ = dst.ReadFrom(src)
    99  		return nil
   100  	}
   101  	for src.Len() > 4 {
   102  		macKey := pool.GetBufferWriter()
   103  		macKey.Grow(len(a.userKey) + 4)
   104  		copy(*macKey, a.userKey)
   105  		binary.LittleEndian.PutUint32((*macKey)[len(a.userKey):], a.recvID)
   106  
   107  		dataLength := int(binary.LittleEndian.Uint16(src.Bytes()[:2]) ^ binary.LittleEndian.Uint16(a.lastServerHash[14:16]))
   108  		randDataLength := a.randDataLength(dataLength, a.lastServerHash, &a.randomServer)
   109  		length := dataLength + randDataLength
   110  
   111  		if length >= 4096 {
   112  			a.rawTrans = true
   113  			src.Reset()
   114  			pool.PutBufferWriter(macKey)
   115  			return errAuthChainLengthError
   116  		}
   117  
   118  		if 4+length > src.Len() {
   119  			pool.PutBufferWriter(macKey)
   120  			break
   121  		}
   122  
   123  		serverHash := tools.HmacMD5(macKey.Bytes(), src.Bytes()[:length+2])
   124  		pool.PutBufferWriter(macKey)
   125  		if !bytes.Equal(serverHash[:2], src.Bytes()[length+2:length+4]) {
   126  			a.rawTrans = true
   127  			src.Reset()
   128  			return errAuthChainChksumError
   129  		}
   130  		a.lastServerHash = serverHash
   131  
   132  		pos := 2
   133  		if dataLength > 0 && randDataLength > 0 {
   134  			pos += getRandStartPos(randDataLength, &a.randomServer)
   135  		}
   136  		wantedData := src.Bytes()[pos : pos+dataLength]
   137  		a.decrypter.XORKeyStream(wantedData, wantedData)
   138  		if a.recvID == 1 {
   139  			dst.Write(wantedData[2:])
   140  		} else {
   141  			dst.Write(wantedData)
   142  		}
   143  		a.recvID++
   144  		src.Next(length + 4)
   145  	}
   146  	return nil
   147  }
   148  
   149  func (a *authChainA) Encode(buf *bytes.Buffer, b []byte) error {
   150  	if !a.hasSentHeader {
   151  		dataLength := getDataLength(b)
   152  		a.packAuthData(buf, b[:dataLength])
   153  		b = b[dataLength:]
   154  		a.hasSentHeader = true
   155  	}
   156  	for len(b) > 2800 {
   157  		a.packData(buf, b[:2800])
   158  		b = b[2800:]
   159  	}
   160  	if len(b) > 0 {
   161  		a.packData(buf, b)
   162  	}
   163  	return nil
   164  }
   165  
   166  func (a *authChainA) DecodePacket(b []byte) ([]byte, error) {
   167  	if len(b) < 9 {
   168  		return nil, errAuthChainLengthError
   169  	}
   170  	if !bytes.Equal(tools.HmacMD5(a.userKey, b[:len(b)-1])[:1], b[len(b)-1:]) {
   171  		return nil, errAuthChainChksumError
   172  	}
   173  	md5Data := tools.HmacMD5(a.Key, b[len(b)-8:len(b)-1])
   174  
   175  	randDataLength := udpGetRandLength(md5Data, &a.randomServer)
   176  
   177  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
   178  	rc4Cipher, err := rc4.NewCipher(key)
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  	wantedData := b[:len(b)-8-randDataLength]
   183  	rc4Cipher.XORKeyStream(wantedData, wantedData)
   184  	return wantedData, nil
   185  }
   186  
   187  func (a *authChainA) EncodePacket(buf *bytes.Buffer, b []byte) error {
   188  	auth := pool.GetBufferWriter()
   189  	auth.Grow(3)
   190  	defer pool.PutBufferWriter(auth)
   191  	_, _ = rand.Read((*auth)[:3])
   192  
   193  	md5Data := tools.HmacMD5(a.Key, auth.Bytes())
   194  
   195  	randDataLength := udpGetRandLength(md5Data, &a.randomClient)
   196  
   197  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
   198  	rc4Cipher, err := rc4.NewCipher(key)
   199  	if err != nil {
   200  		return err
   201  	}
   202  	rc4Cipher.XORKeyStream(b, b)
   203  
   204  	buf.Write(b)
   205  	tools.AppendRandBytes(buf, randDataLength)
   206  	buf.Write(auth.Bytes())
   207  	_ = binary.Write(buf, binary.LittleEndian, binary.LittleEndian.Uint32(a.userID[:])^binary.LittleEndian.Uint32(md5Data[:4]))
   208  	buf.Write(tools.HmacMD5(a.userKey, buf.Bytes())[:1])
   209  	return nil
   210  }
   211  
   212  func (a *authChainA) packAuthData(poolBuf *bytes.Buffer, data []byte) {
   213  	/*
   214  		dataLength := len(data)
   215  		12:	checkHead(4) and hmac of checkHead(8)
   216  		4:	uint32 LittleEndian uid (uid = userID ^ last client hash)
   217  		16:	encrypted data of authdata(12), uint16 LittleEndian overhead(2) and uint16 LittleEndian number zero(2)
   218  		4:	last server hash(4)
   219  		packedAuthDataLength := 12 + 4 + 16 + 4 + dataLength
   220  	*/
   221  
   222  	macKey := pool.GetBufferWriter()
   223  	defer pool.PutBufferWriter(macKey)
   224  	macKey.PutSlice(a.iv)
   225  	macKey.PutSlice(a.Key)
   226  
   227  	// check head
   228  	tools.AppendRandBytes(poolBuf, 4)
   229  	a.lastClientHash = tools.HmacMD5(macKey.Bytes(), poolBuf.Bytes())
   230  	a.initRC4Cipher()
   231  	poolBuf.Write(a.lastClientHash[:8])
   232  	// uid
   233  	_ = binary.Write(poolBuf, binary.LittleEndian, binary.LittleEndian.Uint32(a.userID[:])^binary.LittleEndian.Uint32(a.lastClientHash[8:12]))
   234  	// encrypted data
   235  	err := a.putEncryptedData(poolBuf, a.userKey, [2]int{a.Overhead, 0}, a.salt)
   236  	if err != nil {
   237  		poolBuf.Reset()
   238  		return
   239  	}
   240  	// last server hash
   241  	a.lastServerHash = tools.HmacMD5(a.userKey, poolBuf.Bytes()[12:])
   242  	poolBuf.Write(a.lastServerHash[:4])
   243  	// packed data
   244  	a.packData(poolBuf, data)
   245  }
   246  
   247  func (a *authChainA) packData(poolBuf *bytes.Buffer, data []byte) {
   248  	a.encrypter.XORKeyStream(data, data)
   249  
   250  	macKey := pool.GetBufferWriter()
   251  	macKey.Grow(len(a.userKey) + 4)
   252  	defer pool.PutBufferWriter(macKey)
   253  	copy(*macKey, a.userKey)
   254  	binary.LittleEndian.PutUint32((*macKey)[len(a.userKey):], a.packID)
   255  	a.packID++
   256  
   257  	length := uint16(len(data)) ^ binary.LittleEndian.Uint16(a.lastClientHash[14:16])
   258  
   259  	originalLength := poolBuf.Len()
   260  	_ = binary.Write(poolBuf, binary.LittleEndian, length)
   261  	a.putMixedRandDataAndData(poolBuf, data)
   262  	a.lastClientHash = tools.HmacMD5(macKey.Bytes(), poolBuf.Bytes()[originalLength:])
   263  	poolBuf.Write(a.lastClientHash[:2])
   264  }
   265  
   266  func (a *authChainA) putMixedRandDataAndData(poolBuf *bytes.Buffer, data []byte) {
   267  	randDataLength := a.randDataLength(len(data), a.lastClientHash, &a.randomClient)
   268  	if len(data) == 0 {
   269  		tools.AppendRandBytes(poolBuf, randDataLength)
   270  		return
   271  	}
   272  	if randDataLength > 0 {
   273  		startPos := getRandStartPos(randDataLength, &a.randomClient)
   274  		tools.AppendRandBytes(poolBuf, startPos)
   275  		poolBuf.Write(data)
   276  		tools.AppendRandBytes(poolBuf, randDataLength-startPos)
   277  		return
   278  	}
   279  	poolBuf.Write(data)
   280  }
   281  
   282  func getRandStartPos(length int, random *tools.XorShift128Plus) int {
   283  	if length == 0 {
   284  		return 0
   285  	}
   286  	return int(int64(random.Next()%8589934609) % int64(length))
   287  }
   288  
   289  func (a *authChainA) getRandLength(length int, lastHash []byte, random *tools.XorShift128Plus) int {
   290  	if length > 1440 {
   291  		return 0
   292  	}
   293  	random.InitFromBinAndLength(lastHash, length)
   294  	if length > 1300 {
   295  		return int(random.Next() % 31)
   296  	}
   297  	if length > 900 {
   298  		return int(random.Next() % 127)
   299  	}
   300  	if length > 400 {
   301  		return int(random.Next() % 521)
   302  	}
   303  	return int(random.Next() % 1021)
   304  }
   305  
   306  func (a *authChainA) initRC4Cipher() {
   307  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(a.lastClientHash), 16)
   308  	a.encrypter, _ = rc4.NewCipher(key)
   309  	a.decrypter, _ = rc4.NewCipher(key)
   310  }
   311  
   312  func udpGetRandLength(lastHash []byte, random *tools.XorShift128Plus) int {
   313  	random.InitFromBin(lastHash)
   314  	return int(random.Next() % 127)
   315  }