github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/clashssr/protocol/auth_chain_a.go (about)

     1  package protocol
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/cipher"
     6  	"crypto/rand"
     7  	"crypto/rc4"
     8  	"encoding/base64"
     9  	"encoding/binary"
    10  	"net"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"github.com/Dreamacro/clash/common/pool"
    15  	"github.com/Dreamacro/clash/transport/shadowsocks/core"
    16  	"github.com/Dreamacro/clash/transport/ssr/tools"
    17  )
    18  
    19  func init() {
    20  	register("auth_chain_a", newAuthChainA, 4)
    21  }
    22  
    23  type randDataLengthMethod func(int, []byte, *tools.XorShift128Plus) int
    24  
    25  type authChainA struct {
    26  	*Base
    27  	*authData
    28  	*userData
    29  	iv             []byte
    30  	salt           string
    31  	hasSentHeader  bool
    32  	rawTrans       bool
    33  	lastClientHash []byte
    34  	lastServerHash []byte
    35  	encrypter      cipher.Stream
    36  	decrypter      cipher.Stream
    37  	randomClient   tools.XorShift128Plus
    38  	randomServer   tools.XorShift128Plus
    39  	randDataLength randDataLengthMethod
    40  	packID         uint32
    41  	recvID         uint32
    42  }
    43  
    44  func newAuthChainA(b *Base) Protocol {
    45  	a := &authChainA{
    46  		Base:     b,
    47  		authData: &authData{},
    48  		userData: &userData{},
    49  		salt:     "auth_chain_a",
    50  	}
    51  	a.initUserData()
    52  	return a
    53  }
    54  
    55  func (a *authChainA) initUserData() {
    56  	params := strings.Split(a.Param, ":")
    57  	if len(params) > 1 {
    58  		if userID, err := strconv.ParseUint(params[0], 10, 32); err == nil {
    59  			binary.LittleEndian.PutUint32(a.userID[:], uint32(userID))
    60  			a.userKey = []byte(params[1])
    61  		}
    62  	}
    63  	if len(a.userKey) == 0 {
    64  		a.userKey = a.Key
    65  		rand.Read(a.userID[:])
    66  	}
    67  }
    68  
    69  func (a *authChainA) StreamConn(c net.Conn, iv []byte) net.Conn {
    70  	p := &authChainA{
    71  		Base:     a.Base,
    72  		authData: a.next(),
    73  		userData: a.userData,
    74  		salt:     a.salt,
    75  		packID:   1,
    76  		recvID:   1,
    77  	}
    78  	p.iv = iv
    79  	p.randDataLength = p.getRandLength
    80  	return &Conn{Conn: c, Protocol: p}
    81  }
    82  
    83  func (a *authChainA) PacketConn(c net.PacketConn) net.PacketConn {
    84  	p := &authChainA{
    85  		Base:     a.Base,
    86  		salt:     a.salt,
    87  		userData: a.userData,
    88  	}
    89  	return &PacketConn{PacketConn: c, Protocol: p}
    90  }
    91  
    92  func (a *authChainA) Decode(dst, src *bytes.Buffer) error {
    93  	if a.rawTrans {
    94  		dst.ReadFrom(src)
    95  		return nil
    96  	}
    97  	for src.Len() > 4 {
    98  		macKey := pool.Get(len(a.userKey) + 4)
    99  		defer pool.Put(macKey)
   100  		copy(macKey, a.userKey)
   101  		binary.LittleEndian.PutUint32(macKey[len(a.userKey):], a.recvID)
   102  
   103  		dataLength := int(binary.LittleEndian.Uint16(src.Bytes()[:2]) ^ binary.LittleEndian.Uint16(a.lastServerHash[14:16]))
   104  		randDataLength := a.randDataLength(dataLength, a.lastServerHash, &a.randomServer)
   105  		length := dataLength + randDataLength
   106  
   107  		if length >= 4096 {
   108  			a.rawTrans = true
   109  			src.Reset()
   110  			return errAuthChainLengthError
   111  		}
   112  
   113  		if 4+length > src.Len() {
   114  			break
   115  		}
   116  
   117  		serverHash := tools.HmacMD5(macKey, src.Bytes()[:length+2])
   118  		if !bytes.Equal(serverHash[:2], src.Bytes()[length+2:length+4]) {
   119  			a.rawTrans = true
   120  			src.Reset()
   121  			return errAuthChainChksumError
   122  		}
   123  		a.lastServerHash = serverHash
   124  
   125  		pos := 2
   126  		if dataLength > 0 && randDataLength > 0 {
   127  			pos += getRandStartPos(randDataLength, &a.randomServer)
   128  		}
   129  		wantedData := src.Bytes()[pos : pos+dataLength]
   130  		a.decrypter.XORKeyStream(wantedData, wantedData)
   131  		if a.recvID == 1 {
   132  			dst.Write(wantedData[2:])
   133  		} else {
   134  			dst.Write(wantedData)
   135  		}
   136  		a.recvID++
   137  		src.Next(length + 4)
   138  	}
   139  	return nil
   140  }
   141  
   142  func (a *authChainA) Encode(buf *bytes.Buffer, b []byte) error {
   143  	if !a.hasSentHeader {
   144  		dataLength := getDataLength(b)
   145  		a.packAuthData(buf, b[:dataLength])
   146  		b = b[dataLength:]
   147  		a.hasSentHeader = true
   148  	}
   149  	for len(b) > 2800 {
   150  		a.packData(buf, b[:2800])
   151  		b = b[2800:]
   152  	}
   153  	if len(b) > 0 {
   154  		a.packData(buf, b)
   155  	}
   156  	return nil
   157  }
   158  
   159  func (a *authChainA) DecodePacket(b []byte) ([]byte, error) {
   160  	if len(b) < 9 {
   161  		return nil, errAuthChainLengthError
   162  	}
   163  	if !bytes.Equal(tools.HmacMD5(a.userKey, b[:len(b)-1])[:1], b[len(b)-1:]) {
   164  		return nil, errAuthChainChksumError
   165  	}
   166  	md5Data := tools.HmacMD5(a.Key, b[len(b)-8:len(b)-1])
   167  
   168  	randDataLength := udpGetRandLength(md5Data, &a.randomServer)
   169  
   170  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
   171  	rc4Cipher, err := rc4.NewCipher(key)
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  	wantedData := b[:len(b)-8-randDataLength]
   176  	rc4Cipher.XORKeyStream(wantedData, wantedData)
   177  	return wantedData, nil
   178  }
   179  
   180  func (a *authChainA) EncodePacket(buf *bytes.Buffer, b []byte) error {
   181  	authData := pool.Get(3)
   182  	defer pool.Put(authData)
   183  	rand.Read(authData)
   184  
   185  	md5Data := tools.HmacMD5(a.Key, authData)
   186  
   187  	randDataLength := udpGetRandLength(md5Data, &a.randomClient)
   188  
   189  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
   190  	rc4Cipher, err := rc4.NewCipher(key)
   191  	if err != nil {
   192  		return err
   193  	}
   194  	rc4Cipher.XORKeyStream(b, b)
   195  
   196  	buf.Write(b)
   197  	tools.AppendRandBytes(buf, randDataLength)
   198  	buf.Write(authData)
   199  	binary.Write(buf, binary.LittleEndian, binary.LittleEndian.Uint32(a.userID[:])^binary.LittleEndian.Uint32(md5Data[:4]))
   200  	buf.Write(tools.HmacMD5(a.userKey, buf.Bytes())[:1])
   201  	return nil
   202  }
   203  
   204  func (a *authChainA) packAuthData(poolBuf *bytes.Buffer, data []byte) {
   205  	/*
   206  		dataLength := len(data)
   207  		12:	checkHead(4) and hmac of checkHead(8)
   208  		4:	uint32 LittleEndian uid (uid = userID ^ last client hash)
   209  		16:	encrypted data of authdata(12), uint16 LittleEndian overhead(2) and uint16 LittleEndian number zero(2)
   210  		4:	last server hash(4)
   211  		packedAuthDataLength := 12 + 4 + 16 + 4 + dataLength
   212  	*/
   213  
   214  	macKey := pool.Get(len(a.iv) + len(a.Key))
   215  	defer pool.Put(macKey)
   216  	copy(macKey, a.iv)
   217  	copy(macKey[len(a.iv):], a.Key)
   218  
   219  	// check head
   220  	tools.AppendRandBytes(poolBuf, 4)
   221  	a.lastClientHash = tools.HmacMD5(macKey, poolBuf.Bytes())
   222  	a.initRC4Cipher()
   223  	poolBuf.Write(a.lastClientHash[:8])
   224  	// uid
   225  	binary.Write(poolBuf, binary.LittleEndian, binary.LittleEndian.Uint32(a.userID[:])^binary.LittleEndian.Uint32(a.lastClientHash[8:12]))
   226  	// encrypted data
   227  	err := a.putEncryptedData(poolBuf, a.userKey, [2]int{a.Overhead, 0}, a.salt)
   228  	if err != nil {
   229  		poolBuf.Reset()
   230  		return
   231  	}
   232  	// last server hash
   233  	a.lastServerHash = tools.HmacMD5(a.userKey, poolBuf.Bytes()[12:])
   234  	poolBuf.Write(a.lastServerHash[:4])
   235  	// packed data
   236  	a.packData(poolBuf, data)
   237  }
   238  
   239  func (a *authChainA) packData(poolBuf *bytes.Buffer, data []byte) {
   240  	a.encrypter.XORKeyStream(data, data)
   241  
   242  	macKey := pool.Get(len(a.userKey) + 4)
   243  	defer pool.Put(macKey)
   244  	copy(macKey, a.userKey)
   245  	binary.LittleEndian.PutUint32(macKey[len(a.userKey):], a.packID)
   246  	a.packID++
   247  
   248  	length := uint16(len(data)) ^ binary.LittleEndian.Uint16(a.lastClientHash[14:16])
   249  
   250  	originalLength := poolBuf.Len()
   251  	binary.Write(poolBuf, binary.LittleEndian, length)
   252  	a.putMixedRandDataAndData(poolBuf, data)
   253  	a.lastClientHash = tools.HmacMD5(macKey, poolBuf.Bytes()[originalLength:])
   254  	poolBuf.Write(a.lastClientHash[:2])
   255  }
   256  
   257  func (a *authChainA) putMixedRandDataAndData(poolBuf *bytes.Buffer, data []byte) {
   258  	randDataLength := a.randDataLength(len(data), a.lastClientHash, &a.randomClient)
   259  	if len(data) == 0 {
   260  		tools.AppendRandBytes(poolBuf, randDataLength)
   261  		return
   262  	}
   263  	if randDataLength > 0 {
   264  		startPos := getRandStartPos(randDataLength, &a.randomClient)
   265  		tools.AppendRandBytes(poolBuf, startPos)
   266  		poolBuf.Write(data)
   267  		tools.AppendRandBytes(poolBuf, randDataLength-startPos)
   268  		return
   269  	}
   270  	poolBuf.Write(data)
   271  }
   272  
   273  func getRandStartPos(length int, random *tools.XorShift128Plus) int {
   274  	if length == 0 {
   275  		return 0
   276  	}
   277  	return int(int64(random.Next()%8589934609) % int64(length))
   278  }
   279  
   280  func (a *authChainA) getRandLength(length int, lastHash []byte, random *tools.XorShift128Plus) int {
   281  	if length > 1440 {
   282  		return 0
   283  	}
   284  	random.InitFromBinAndLength(lastHash, length)
   285  	if length > 1300 {
   286  		return int(random.Next() % 31)
   287  	}
   288  	if length > 900 {
   289  		return int(random.Next() % 127)
   290  	}
   291  	if length > 400 {
   292  		return int(random.Next() % 521)
   293  	}
   294  	return int(random.Next() % 1021)
   295  }
   296  
   297  func (a *authChainA) initRC4Cipher() {
   298  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(a.lastClientHash), 16)
   299  	a.encrypter, _ = rc4.NewCipher(key)
   300  	a.decrypter, _ = rc4.NewCipher(key)
   301  }
   302  
   303  func udpGetRandLength(lastHash []byte, random *tools.XorShift128Plus) int {
   304  	random.InitFromBin(lastHash)
   305  	return int(random.Next() % 127)
   306  }