github.com/metacubex/mihomo@v1.18.5/transport/ssr/protocol/auth_chain_a.go (about)

     1  package protocol
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/cipher"
     6  	"crypto/rand"
     7  	"crypto/rc4"
     8  	"encoding/base64"
     9  	"encoding/binary"
    10  	"errors"
    11  	"net"
    12  	"strconv"
    13  	"strings"
    14  
    15  	N "github.com/metacubex/mihomo/common/net"
    16  	"github.com/metacubex/mihomo/common/pool"
    17  	"github.com/metacubex/mihomo/log"
    18  	"github.com/metacubex/mihomo/transport/shadowsocks/core"
    19  	"github.com/metacubex/mihomo/transport/ssr/tools"
    20  )
    21  
    22  func init() {
    23  	register("auth_chain_a", newAuthChainA, 4)
    24  }
    25  
    26  type randDataLengthMethod func(int, []byte, *tools.XorShift128Plus) int
    27  
    28  type authChainA struct {
    29  	*Base
    30  	*authData
    31  	*userData
    32  	iv             []byte
    33  	salt           string
    34  	hasSentHeader  bool
    35  	rawTrans       bool
    36  	lastClientHash []byte
    37  	lastServerHash []byte
    38  	encrypter      cipher.Stream
    39  	decrypter      cipher.Stream
    40  	randomClient   tools.XorShift128Plus
    41  	randomServer   tools.XorShift128Plus
    42  	randDataLength randDataLengthMethod
    43  	packID         uint32
    44  	recvID         uint32
    45  }
    46  
    47  func newAuthChainA(b *Base) Protocol {
    48  	a := &authChainA{
    49  		Base:     b,
    50  		authData: &authData{},
    51  		userData: &userData{},
    52  		salt:     "auth_chain_a",
    53  	}
    54  	a.initUserData()
    55  	return a
    56  }
    57  
    58  func (a *authChainA) initUserData() {
    59  	params := strings.Split(a.Param, ":")
    60  	if len(params) > 1 {
    61  		if userID, err := strconv.ParseUint(params[0], 10, 32); err == nil {
    62  			binary.LittleEndian.PutUint32(a.userID[:], uint32(userID))
    63  			a.userKey = []byte(params[1])
    64  		} else {
    65  			log.Warnln("Wrong protocol-param for %s, only digits are expected before ':'", a.salt)
    66  		}
    67  	}
    68  	if len(a.userKey) == 0 {
    69  		a.userKey = a.Key
    70  		rand.Read(a.userID[:])
    71  	}
    72  }
    73  
    74  func (a *authChainA) StreamConn(c net.Conn, iv []byte) net.Conn {
    75  	p := &authChainA{
    76  		Base:     a.Base,
    77  		authData: a.next(),
    78  		userData: a.userData,
    79  		salt:     a.salt,
    80  		packID:   1,
    81  		recvID:   1,
    82  	}
    83  	p.iv = iv
    84  	p.randDataLength = p.getRandLength
    85  	return &Conn{Conn: c, Protocol: p}
    86  }
    87  
    88  func (a *authChainA) PacketConn(c N.EnhancePacketConn) N.EnhancePacketConn {
    89  	p := &authChainA{
    90  		Base:     a.Base,
    91  		salt:     a.salt,
    92  		userData: a.userData,
    93  	}
    94  	return &PacketConn{EnhancePacketConn: c, Protocol: p}
    95  }
    96  
    97  func (a *authChainA) Decode(dst, src *bytes.Buffer) error {
    98  	if a.rawTrans {
    99  		dst.ReadFrom(src)
   100  		return nil
   101  	}
   102  	for src.Len() > 4 {
   103  		macKey := pool.Get(len(a.userKey) + 4)
   104  		defer pool.Put(macKey)
   105  		copy(macKey, a.userKey)
   106  		binary.LittleEndian.PutUint32(macKey[len(a.userKey):], a.recvID)
   107  
   108  		dataLength := int(binary.LittleEndian.Uint16(src.Bytes()[:2]) ^ binary.LittleEndian.Uint16(a.lastServerHash[14:16]))
   109  		randDataLength := a.randDataLength(dataLength, a.lastServerHash, &a.randomServer)
   110  		length := dataLength + randDataLength
   111  		// Temporary workaround for https://github.com/metacubex/mihomo/issues/1352
   112  		if dataLength < 0 || randDataLength < 0 || length < 0 {
   113  			return errors.New("ssr crashing blocked")
   114  		}
   115  
   116  		if length >= 4096 {
   117  			a.rawTrans = true
   118  			src.Reset()
   119  			return errAuthChainLengthError
   120  		}
   121  
   122  		if 4+length > src.Len() {
   123  			break
   124  		}
   125  
   126  		serverHash := tools.HmacMD5(macKey, src.Bytes()[:length+2])
   127  		if !bytes.Equal(serverHash[:2], src.Bytes()[length+2:length+4]) {
   128  			a.rawTrans = true
   129  			src.Reset()
   130  			return errAuthChainChksumError
   131  		}
   132  		a.lastServerHash = serverHash
   133  
   134  		pos := 2
   135  		if dataLength > 0 && randDataLength > 0 {
   136  			pos += getRandStartPos(randDataLength, &a.randomServer)
   137  		}
   138  		// Temporary workaround for https://github.com/metacubex/mihomo/issues/1352
   139  		if pos < 0 || pos+dataLength < 0 || dataLength < 0 {
   140  			return errors.New("ssr crashing blocked")
   141  		}
   142  
   143  		wantedData := src.Bytes()[pos : pos+dataLength]
   144  		a.decrypter.XORKeyStream(wantedData, wantedData)
   145  		if a.recvID == 1 {
   146  			dst.Write(wantedData[2:])
   147  		} else {
   148  			dst.Write(wantedData)
   149  		}
   150  		a.recvID++
   151  		src.Next(length + 4)
   152  	}
   153  	return nil
   154  }
   155  
   156  func (a *authChainA) Encode(buf *bytes.Buffer, b []byte) error {
   157  	if !a.hasSentHeader {
   158  		dataLength := getDataLength(b)
   159  		a.packAuthData(buf, b[:dataLength])
   160  		b = b[dataLength:]
   161  		a.hasSentHeader = true
   162  	}
   163  	for len(b) > 2800 {
   164  		a.packData(buf, b[:2800])
   165  		b = b[2800:]
   166  	}
   167  	if len(b) > 0 {
   168  		a.packData(buf, b)
   169  	}
   170  	return nil
   171  }
   172  
   173  func (a *authChainA) DecodePacket(b []byte) ([]byte, error) {
   174  	if len(b) < 9 {
   175  		return nil, errAuthChainLengthError
   176  	}
   177  	if !bytes.Equal(tools.HmacMD5(a.userKey, b[:len(b)-1])[:1], b[len(b)-1:]) {
   178  		return nil, errAuthChainChksumError
   179  	}
   180  	md5Data := tools.HmacMD5(a.Key, b[len(b)-8:len(b)-1])
   181  
   182  	randDataLength := udpGetRandLength(md5Data, &a.randomServer)
   183  
   184  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
   185  	rc4Cipher, err := rc4.NewCipher(key)
   186  	if err != nil {
   187  		return nil, err
   188  	}
   189  	wantedData := b[:len(b)-8-randDataLength]
   190  	rc4Cipher.XORKeyStream(wantedData, wantedData)
   191  	return wantedData, nil
   192  }
   193  
   194  func (a *authChainA) EncodePacket(buf *bytes.Buffer, b []byte) error {
   195  	authData := pool.Get(3)
   196  	defer pool.Put(authData)
   197  	rand.Read(authData)
   198  
   199  	md5Data := tools.HmacMD5(a.Key, authData)
   200  
   201  	randDataLength := udpGetRandLength(md5Data, &a.randomClient)
   202  
   203  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
   204  	rc4Cipher, err := rc4.NewCipher(key)
   205  	if err != nil {
   206  		return err
   207  	}
   208  	rc4Cipher.XORKeyStream(b, b)
   209  
   210  	buf.Write(b)
   211  	tools.AppendRandBytes(buf, randDataLength)
   212  	buf.Write(authData)
   213  	binary.Write(buf, binary.LittleEndian, binary.LittleEndian.Uint32(a.userID[:])^binary.LittleEndian.Uint32(md5Data[:4]))
   214  	buf.Write(tools.HmacMD5(a.userKey, buf.Bytes())[:1])
   215  	return nil
   216  }
   217  
   218  func (a *authChainA) packAuthData(poolBuf *bytes.Buffer, data []byte) {
   219  	/*
   220  		dataLength := len(data)
   221  		12:	checkHead(4) and hmac of checkHead(8)
   222  		4:	uint32 LittleEndian uid (uid = userID ^ last client hash)
   223  		16:	encrypted data of authdata(12), uint16 LittleEndian overhead(2) and uint16 LittleEndian number zero(2)
   224  		4:	last server hash(4)
   225  		packedAuthDataLength := 12 + 4 + 16 + 4 + dataLength
   226  	*/
   227  
   228  	macKey := pool.Get(len(a.iv) + len(a.Key))
   229  	defer pool.Put(macKey)
   230  	copy(macKey, a.iv)
   231  	copy(macKey[len(a.iv):], a.Key)
   232  
   233  	// check head
   234  	tools.AppendRandBytes(poolBuf, 4)
   235  	a.lastClientHash = tools.HmacMD5(macKey, poolBuf.Bytes())
   236  	a.initRC4Cipher()
   237  	poolBuf.Write(a.lastClientHash[:8])
   238  	// uid
   239  	binary.Write(poolBuf, binary.LittleEndian, binary.LittleEndian.Uint32(a.userID[:])^binary.LittleEndian.Uint32(a.lastClientHash[8:12]))
   240  	// encrypted data
   241  	err := a.putEncryptedData(poolBuf, a.userKey, [2]int{a.Overhead, 0}, a.salt)
   242  	if err != nil {
   243  		poolBuf.Reset()
   244  		return
   245  	}
   246  	// last server hash
   247  	a.lastServerHash = tools.HmacMD5(a.userKey, poolBuf.Bytes()[12:])
   248  	poolBuf.Write(a.lastServerHash[:4])
   249  	// packed data
   250  	a.packData(poolBuf, data)
   251  }
   252  
   253  func (a *authChainA) packData(poolBuf *bytes.Buffer, data []byte) {
   254  	a.encrypter.XORKeyStream(data, data)
   255  
   256  	macKey := pool.Get(len(a.userKey) + 4)
   257  	defer pool.Put(macKey)
   258  	copy(macKey, a.userKey)
   259  	binary.LittleEndian.PutUint32(macKey[len(a.userKey):], a.packID)
   260  	a.packID++
   261  
   262  	length := uint16(len(data)) ^ binary.LittleEndian.Uint16(a.lastClientHash[14:16])
   263  
   264  	originalLength := poolBuf.Len()
   265  	binary.Write(poolBuf, binary.LittleEndian, length)
   266  	a.putMixedRandDataAndData(poolBuf, data)
   267  	a.lastClientHash = tools.HmacMD5(macKey, poolBuf.Bytes()[originalLength:])
   268  	poolBuf.Write(a.lastClientHash[:2])
   269  }
   270  
   271  func (a *authChainA) putMixedRandDataAndData(poolBuf *bytes.Buffer, data []byte) {
   272  	randDataLength := a.randDataLength(len(data), a.lastClientHash, &a.randomClient)
   273  	if len(data) == 0 {
   274  		tools.AppendRandBytes(poolBuf, randDataLength)
   275  		return
   276  	}
   277  	if randDataLength > 0 {
   278  		startPos := getRandStartPos(randDataLength, &a.randomClient)
   279  		tools.AppendRandBytes(poolBuf, startPos)
   280  		poolBuf.Write(data)
   281  		tools.AppendRandBytes(poolBuf, randDataLength-startPos)
   282  		return
   283  	}
   284  	poolBuf.Write(data)
   285  }
   286  
   287  func getRandStartPos(length int, random *tools.XorShift128Plus) int {
   288  	if length == 0 {
   289  		return 0
   290  	}
   291  	return int(int64(random.Next()%8589934609) % int64(length))
   292  }
   293  
   294  func (a *authChainA) getRandLength(length int, lastHash []byte, random *tools.XorShift128Plus) int {
   295  	if length > 1440 {
   296  		return 0
   297  	}
   298  	random.InitFromBinAndLength(lastHash, length)
   299  	if length > 1300 {
   300  		return int(random.Next() % 31)
   301  	}
   302  	if length > 900 {
   303  		return int(random.Next() % 127)
   304  	}
   305  	if length > 400 {
   306  		return int(random.Next() % 521)
   307  	}
   308  	return int(random.Next() % 1021)
   309  }
   310  
   311  func (a *authChainA) initRC4Cipher() {
   312  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(a.lastClientHash), 16)
   313  	a.encrypter, _ = rc4.NewCipher(key)
   314  	a.decrypter, _ = rc4.NewCipher(key)
   315  }
   316  
   317  func udpGetRandLength(lastHash []byte, random *tools.XorShift128Plus) int {
   318  	random.InitFromBin(lastHash)
   319  	return int(random.Next() % 127)
   320  }