github.com/igoogolx/clash@v1.19.8/transport/ssr/protocol/auth_chain_a.go (about)

     1  package protocol
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/cipher"
     6  	"crypto/rand"
     7  	"crypto/rc4"
     8  	"encoding/base64"
     9  	"encoding/binary"
    10  	"net"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"github.com/igoogolx/clash/common/pool"
    15  	"github.com/igoogolx/clash/log"
    16  	"github.com/igoogolx/clash/transport/shadowsocks/core"
    17  	"github.com/igoogolx/clash/transport/ssr/tools"
    18  )
    19  
    20  func init() {
    21  	register("auth_chain_a", newAuthChainA, 4)
    22  }
    23  
    24  type randDataLengthMethod func(int, []byte, *tools.XorShift128Plus) int
    25  
    26  type authChainA struct {
    27  	*Base
    28  	*authData
    29  	*userData
    30  	iv             []byte
    31  	salt           string
    32  	hasSentHeader  bool
    33  	rawTrans       bool
    34  	lastClientHash []byte
    35  	lastServerHash []byte
    36  	encrypter      cipher.Stream
    37  	decrypter      cipher.Stream
    38  	randomClient   tools.XorShift128Plus
    39  	randomServer   tools.XorShift128Plus
    40  	randDataLength randDataLengthMethod
    41  	packID         uint32
    42  	recvID         uint32
    43  }
    44  
    45  func newAuthChainA(b *Base) Protocol {
    46  	a := &authChainA{
    47  		Base:     b,
    48  		authData: &authData{},
    49  		userData: &userData{},
    50  		salt:     "auth_chain_a",
    51  	}
    52  	a.initUserData()
    53  	return a
    54  }
    55  
    56  func (a *authChainA) initUserData() {
    57  	params := strings.Split(a.Param, ":")
    58  	if len(params) > 1 {
    59  		if userID, err := strconv.ParseUint(params[0], 10, 32); err == nil {
    60  			binary.LittleEndian.PutUint32(a.userID[:], uint32(userID))
    61  			a.userKey = []byte(params[1])
    62  		} else {
    63  			log.Warnln("Wrong protocol-param for %s, only digits are expected before ':'", a.salt)
    64  		}
    65  	}
    66  	if len(a.userKey) == 0 {
    67  		a.userKey = a.Key
    68  		rand.Read(a.userID[:])
    69  	}
    70  }
    71  
    72  func (a *authChainA) StreamConn(c net.Conn, iv []byte) net.Conn {
    73  	p := &authChainA{
    74  		Base:     a.Base,
    75  		authData: a.next(),
    76  		userData: a.userData,
    77  		salt:     a.salt,
    78  		packID:   1,
    79  		recvID:   1,
    80  	}
    81  	p.iv = iv
    82  	p.randDataLength = p.getRandLength
    83  	return &Conn{Conn: c, Protocol: p}
    84  }
    85  
    86  func (a *authChainA) PacketConn(c net.PacketConn) net.PacketConn {
    87  	p := &authChainA{
    88  		Base:     a.Base,
    89  		salt:     a.salt,
    90  		userData: a.userData,
    91  	}
    92  	return &PacketConn{PacketConn: c, Protocol: p}
    93  }
    94  
    95  func (a *authChainA) Decode(dst, src *bytes.Buffer) error {
    96  	if a.rawTrans {
    97  		dst.ReadFrom(src)
    98  		return nil
    99  	}
   100  	for src.Len() > 4 {
   101  		macKey := pool.Get(len(a.userKey) + 4)
   102  		defer pool.Put(macKey)
   103  		copy(macKey, a.userKey)
   104  		binary.LittleEndian.PutUint32(macKey[len(a.userKey):], a.recvID)
   105  
   106  		dataLength := int(binary.LittleEndian.Uint16(src.Bytes()[:2]) ^ binary.LittleEndian.Uint16(a.lastServerHash[14:16]))
   107  		randDataLength := a.randDataLength(dataLength, a.lastServerHash, &a.randomServer)
   108  		length := dataLength + randDataLength
   109  
   110  		if length >= 4096 {
   111  			a.rawTrans = true
   112  			src.Reset()
   113  			return errAuthChainLengthError
   114  		}
   115  
   116  		if 4+length > src.Len() {
   117  			break
   118  		}
   119  
   120  		serverHash := tools.HmacMD5(macKey, src.Bytes()[:length+2])
   121  		if !bytes.Equal(serverHash[:2], src.Bytes()[length+2:length+4]) {
   122  			a.rawTrans = true
   123  			src.Reset()
   124  			return errAuthChainChksumError
   125  		}
   126  		a.lastServerHash = serverHash
   127  
   128  		pos := 2
   129  		if dataLength > 0 && randDataLength > 0 {
   130  			pos += getRandStartPos(randDataLength, &a.randomServer)
   131  		}
   132  		wantedData := src.Bytes()[pos : pos+dataLength]
   133  		a.decrypter.XORKeyStream(wantedData, wantedData)
   134  		if a.recvID == 1 {
   135  			dst.Write(wantedData[2:])
   136  		} else {
   137  			dst.Write(wantedData)
   138  		}
   139  		a.recvID++
   140  		src.Next(length + 4)
   141  	}
   142  	return nil
   143  }
   144  
   145  func (a *authChainA) Encode(buf *bytes.Buffer, b []byte) error {
   146  	if !a.hasSentHeader {
   147  		dataLength := getDataLength(b)
   148  		a.packAuthData(buf, b[:dataLength])
   149  		b = b[dataLength:]
   150  		a.hasSentHeader = true
   151  	}
   152  	for len(b) > 2800 {
   153  		a.packData(buf, b[:2800])
   154  		b = b[2800:]
   155  	}
   156  	if len(b) > 0 {
   157  		a.packData(buf, b)
   158  	}
   159  	return nil
   160  }
   161  
   162  func (a *authChainA) DecodePacket(b []byte) ([]byte, error) {
   163  	if len(b) < 9 {
   164  		return nil, errAuthChainLengthError
   165  	}
   166  	if !bytes.Equal(tools.HmacMD5(a.userKey, b[:len(b)-1])[:1], b[len(b)-1:]) {
   167  		return nil, errAuthChainChksumError
   168  	}
   169  	md5Data := tools.HmacMD5(a.Key, b[len(b)-8:len(b)-1])
   170  
   171  	randDataLength := udpGetRandLength(md5Data, &a.randomServer)
   172  
   173  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
   174  	rc4Cipher, err := rc4.NewCipher(key)
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  	wantedData := b[:len(b)-8-randDataLength]
   179  	rc4Cipher.XORKeyStream(wantedData, wantedData)
   180  	return wantedData, nil
   181  }
   182  
   183  func (a *authChainA) EncodePacket(buf *bytes.Buffer, b []byte) error {
   184  	authData := pool.Get(3)
   185  	defer pool.Put(authData)
   186  	rand.Read(authData)
   187  
   188  	md5Data := tools.HmacMD5(a.Key, authData)
   189  
   190  	randDataLength := udpGetRandLength(md5Data, &a.randomClient)
   191  
   192  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
   193  	rc4Cipher, err := rc4.NewCipher(key)
   194  	if err != nil {
   195  		return err
   196  	}
   197  	rc4Cipher.XORKeyStream(b, b)
   198  
   199  	buf.Write(b)
   200  	tools.AppendRandBytes(buf, randDataLength)
   201  	buf.Write(authData)
   202  	binary.Write(buf, binary.LittleEndian, binary.LittleEndian.Uint32(a.userID[:])^binary.LittleEndian.Uint32(md5Data[:4]))
   203  	buf.Write(tools.HmacMD5(a.userKey, buf.Bytes())[:1])
   204  	return nil
   205  }
   206  
   207  func (a *authChainA) packAuthData(poolBuf *bytes.Buffer, data []byte) {
   208  	/*
   209  		dataLength := len(data)
   210  		12:	checkHead(4) and hmac of checkHead(8)
   211  		4:	uint32 LittleEndian uid (uid = userID ^ last client hash)
   212  		16:	encrypted data of authdata(12), uint16 LittleEndian overhead(2) and uint16 LittleEndian number zero(2)
   213  		4:	last server hash(4)
   214  		packedAuthDataLength := 12 + 4 + 16 + 4 + dataLength
   215  	*/
   216  
   217  	macKey := pool.Get(len(a.iv) + len(a.Key))
   218  	defer pool.Put(macKey)
   219  	copy(macKey, a.iv)
   220  	copy(macKey[len(a.iv):], a.Key)
   221  
   222  	// check head
   223  	tools.AppendRandBytes(poolBuf, 4)
   224  	a.lastClientHash = tools.HmacMD5(macKey, poolBuf.Bytes())
   225  	a.initRC4Cipher()
   226  	poolBuf.Write(a.lastClientHash[:8])
   227  	// uid
   228  	binary.Write(poolBuf, binary.LittleEndian, binary.LittleEndian.Uint32(a.userID[:])^binary.LittleEndian.Uint32(a.lastClientHash[8:12]))
   229  	// encrypted data
   230  	err := a.putEncryptedData(poolBuf, a.userKey, [2]int{a.Overhead, 0}, a.salt)
   231  	if err != nil {
   232  		poolBuf.Reset()
   233  		return
   234  	}
   235  	// last server hash
   236  	a.lastServerHash = tools.HmacMD5(a.userKey, poolBuf.Bytes()[12:])
   237  	poolBuf.Write(a.lastServerHash[:4])
   238  	// packed data
   239  	a.packData(poolBuf, data)
   240  }
   241  
   242  func (a *authChainA) packData(poolBuf *bytes.Buffer, data []byte) {
   243  	a.encrypter.XORKeyStream(data, data)
   244  
   245  	macKey := pool.Get(len(a.userKey) + 4)
   246  	defer pool.Put(macKey)
   247  	copy(macKey, a.userKey)
   248  	binary.LittleEndian.PutUint32(macKey[len(a.userKey):], a.packID)
   249  	a.packID++
   250  
   251  	length := uint16(len(data)) ^ binary.LittleEndian.Uint16(a.lastClientHash[14:16])
   252  
   253  	originalLength := poolBuf.Len()
   254  	binary.Write(poolBuf, binary.LittleEndian, length)
   255  	a.putMixedRandDataAndData(poolBuf, data)
   256  	a.lastClientHash = tools.HmacMD5(macKey, poolBuf.Bytes()[originalLength:])
   257  	poolBuf.Write(a.lastClientHash[:2])
   258  }
   259  
   260  func (a *authChainA) putMixedRandDataAndData(poolBuf *bytes.Buffer, data []byte) {
   261  	randDataLength := a.randDataLength(len(data), a.lastClientHash, &a.randomClient)
   262  	if len(data) == 0 {
   263  		tools.AppendRandBytes(poolBuf, randDataLength)
   264  		return
   265  	}
   266  	if randDataLength > 0 {
   267  		startPos := getRandStartPos(randDataLength, &a.randomClient)
   268  		tools.AppendRandBytes(poolBuf, startPos)
   269  		poolBuf.Write(data)
   270  		tools.AppendRandBytes(poolBuf, randDataLength-startPos)
   271  		return
   272  	}
   273  	poolBuf.Write(data)
   274  }
   275  
   276  func getRandStartPos(length int, random *tools.XorShift128Plus) int {
   277  	if length == 0 {
   278  		return 0
   279  	}
   280  	return int(int64(random.Next()%8589934609) % int64(length))
   281  }
   282  
   283  func (a *authChainA) getRandLength(length int, lastHash []byte, random *tools.XorShift128Plus) int {
   284  	if length > 1440 {
   285  		return 0
   286  	}
   287  	random.InitFromBinAndLength(lastHash, length)
   288  	if length > 1300 {
   289  		return int(random.Next() % 31)
   290  	}
   291  	if length > 900 {
   292  		return int(random.Next() % 127)
   293  	}
   294  	if length > 400 {
   295  		return int(random.Next() % 521)
   296  	}
   297  	return int(random.Next() % 1021)
   298  }
   299  
   300  func (a *authChainA) initRC4Cipher() {
   301  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(a.lastClientHash), 16)
   302  	a.encrypter, _ = rc4.NewCipher(key)
   303  	a.decrypter, _ = rc4.NewCipher(key)
   304  }
   305  
   306  func udpGetRandLength(lastHash []byte, random *tools.XorShift128Plus) int {
   307  	random.InitFromBin(lastHash)
   308  	return int(random.Next() % 127)
   309  }