github.com/laof/lite-speed-test@v0.0.0-20230930011949-1f39b7037845/transport/ssr/protocol/auth_chain_a.go (about)

     1  package protocol
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/cipher"
     6  	"crypto/rand"
     7  	"crypto/rc4"
     8  	"encoding/base64"
     9  	"encoding/binary"
    10  	"net"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"github.com/laof/lite-speed-test/common/pool"
    15  	"github.com/laof/lite-speed-test/log"
    16  	"github.com/laof/lite-speed-test/transport/ssr/tools"
    17  
    18  	"github.com/laof/go2/core"
    19  )
    20  
    21  func init() {
    22  	register("auth_chain_a", newAuthChainA, 4)
    23  }
    24  
    25  type randDataLengthMethod func(int, []byte, *tools.XorShift128Plus) int
    26  
    27  type authChainA struct {
    28  	*Base
    29  	*authData
    30  	*userData
    31  	iv             []byte
    32  	salt           string
    33  	hasSentHeader  bool
    34  	rawTrans       bool
    35  	lastClientHash []byte
    36  	lastServerHash []byte
    37  	encrypter      cipher.Stream
    38  	decrypter      cipher.Stream
    39  	randomClient   tools.XorShift128Plus
    40  	randomServer   tools.XorShift128Plus
    41  	randDataLength randDataLengthMethod
    42  	packID         uint32
    43  	recvID         uint32
    44  }
    45  
    46  func newAuthChainA(b *Base) Protocol {
    47  	a := &authChainA{
    48  		Base:     b,
    49  		authData: &authData{},
    50  		userData: &userData{},
    51  		salt:     "auth_chain_a",
    52  	}
    53  	a.initUserData()
    54  	return a
    55  }
    56  
    57  func (a *authChainA) initUserData() {
    58  	params := strings.Split(a.Param, ":")
    59  	if len(params) > 1 {
    60  		if userID, err := strconv.ParseUint(params[0], 10, 32); err == nil {
    61  			binary.LittleEndian.PutUint32(a.userID[:], uint32(userID))
    62  			a.userKey = []byte(params[1])
    63  		} else {
    64  			log.Warnln("Wrong protocol-param for %s, only digits are expected before ':'", a.salt)
    65  		}
    66  	}
    67  	if len(a.userKey) == 0 {
    68  		a.userKey = a.Key
    69  		rand.Read(a.userID[:])
    70  	}
    71  }
    72  
    73  func (a *authChainA) StreamConn(c net.Conn, iv []byte) net.Conn {
    74  	p := &authChainA{
    75  		Base:     a.Base,
    76  		authData: a.next(),
    77  		userData: a.userData,
    78  		salt:     a.salt,
    79  		packID:   1,
    80  		recvID:   1,
    81  	}
    82  	p.iv = iv
    83  	p.randDataLength = p.getRandLength
    84  	return &Conn{Conn: c, Protocol: p}
    85  }
    86  
    87  func (a *authChainA) PacketConn(c net.PacketConn) net.PacketConn {
    88  	p := &authChainA{
    89  		Base:     a.Base,
    90  		salt:     a.salt,
    91  		userData: a.userData,
    92  	}
    93  	return &PacketConn{PacketConn: c, Protocol: p}
    94  }
    95  
    96  func (a *authChainA) Decode(dst, src *bytes.Buffer) error {
    97  	if a.rawTrans {
    98  		dst.ReadFrom(src)
    99  		return nil
   100  	}
   101  	for src.Len() > 4 {
   102  		macKey := pool.Get(len(a.userKey) + 4)
   103  		defer pool.Put(macKey)
   104  		copy(macKey, a.userKey)
   105  		binary.LittleEndian.PutUint32(macKey[len(a.userKey):], a.recvID)
   106  
   107  		dataLength := int(binary.LittleEndian.Uint16(src.Bytes()[:2]) ^ binary.LittleEndian.Uint16(a.lastServerHash[14:16]))
   108  		randDataLength := a.randDataLength(dataLength, a.lastServerHash, &a.randomServer)
   109  		length := dataLength + randDataLength
   110  
   111  		if length >= 4096 {
   112  			a.rawTrans = true
   113  			src.Reset()
   114  			return errAuthChainLengthError
   115  		}
   116  
   117  		if 4+length > src.Len() {
   118  			break
   119  		}
   120  
   121  		serverHash := tools.HmacMD5(macKey, src.Bytes()[:length+2])
   122  		if !bytes.Equal(serverHash[:2], src.Bytes()[length+2:length+4]) {
   123  			a.rawTrans = true
   124  			src.Reset()
   125  			return errAuthChainChksumError
   126  		}
   127  		a.lastServerHash = serverHash
   128  
   129  		pos := 2
   130  		if dataLength > 0 && randDataLength > 0 {
   131  			pos += getRandStartPos(randDataLength, &a.randomServer)
   132  		}
   133  		wantedData := src.Bytes()[pos : pos+dataLength]
   134  		a.decrypter.XORKeyStream(wantedData, wantedData)
   135  		if a.recvID == 1 {
   136  			dst.Write(wantedData[2:])
   137  		} else {
   138  			dst.Write(wantedData)
   139  		}
   140  		a.recvID++
   141  		src.Next(length + 4)
   142  	}
   143  	return nil
   144  }
   145  
   146  func (a *authChainA) Encode(buf *bytes.Buffer, b []byte) error {
   147  	if !a.hasSentHeader {
   148  		dataLength := getDataLength(b)
   149  		a.packAuthData(buf, b[:dataLength])
   150  		b = b[dataLength:]
   151  		a.hasSentHeader = true
   152  	}
   153  	for len(b) > 2800 {
   154  		a.packData(buf, b[:2800])
   155  		b = b[2800:]
   156  	}
   157  	if len(b) > 0 {
   158  		a.packData(buf, b)
   159  	}
   160  	return nil
   161  }
   162  
   163  func (a *authChainA) DecodePacket(b []byte) ([]byte, error) {
   164  	if len(b) < 9 {
   165  		return nil, errAuthChainLengthError
   166  	}
   167  	if !bytes.Equal(tools.HmacMD5(a.userKey, b[:len(b)-1])[:1], b[len(b)-1:]) {
   168  		return nil, errAuthChainChksumError
   169  	}
   170  	md5Data := tools.HmacMD5(a.Key, b[len(b)-8:len(b)-1])
   171  
   172  	randDataLength := udpGetRandLength(md5Data, &a.randomServer)
   173  
   174  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
   175  	rc4Cipher, err := rc4.NewCipher(key)
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  	wantedData := b[:len(b)-8-randDataLength]
   180  	rc4Cipher.XORKeyStream(wantedData, wantedData)
   181  	return wantedData, nil
   182  }
   183  
   184  func (a *authChainA) EncodePacket(buf *bytes.Buffer, b []byte) error {
   185  	authData := pool.Get(3)
   186  	defer pool.Put(authData)
   187  	rand.Read(authData)
   188  
   189  	md5Data := tools.HmacMD5(a.Key, authData)
   190  
   191  	randDataLength := udpGetRandLength(md5Data, &a.randomClient)
   192  
   193  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(md5Data), 16)
   194  	rc4Cipher, err := rc4.NewCipher(key)
   195  	if err != nil {
   196  		return err
   197  	}
   198  	rc4Cipher.XORKeyStream(b, b)
   199  
   200  	buf.Write(b)
   201  	tools.AppendRandBytes(buf, randDataLength)
   202  	buf.Write(authData)
   203  	binary.Write(buf, binary.LittleEndian, binary.LittleEndian.Uint32(a.userID[:])^binary.LittleEndian.Uint32(md5Data[:4]))
   204  	buf.Write(tools.HmacMD5(a.userKey, buf.Bytes())[:1])
   205  	return nil
   206  }
   207  
   208  func (a *authChainA) packAuthData(poolBuf *bytes.Buffer, data []byte) {
   209  	/*
   210  		dataLength := len(data)
   211  		12:	checkHead(4) and hmac of checkHead(8)
   212  		4:	uint32 LittleEndian uid (uid = userID ^ last client hash)
   213  		16:	encrypted data of authdata(12), uint16 LittleEndian overhead(2) and uint16 LittleEndian number zero(2)
   214  		4:	last server hash(4)
   215  		packedAuthDataLength := 12 + 4 + 16 + 4 + dataLength
   216  	*/
   217  
   218  	macKey := pool.Get(len(a.iv) + len(a.Key))
   219  	defer pool.Put(macKey)
   220  	copy(macKey, a.iv)
   221  	copy(macKey[len(a.iv):], a.Key)
   222  
   223  	// check head
   224  	tools.AppendRandBytes(poolBuf, 4)
   225  	a.lastClientHash = tools.HmacMD5(macKey, poolBuf.Bytes())
   226  	a.initRC4Cipher()
   227  	poolBuf.Write(a.lastClientHash[:8])
   228  	// uid
   229  	binary.Write(poolBuf, binary.LittleEndian, binary.LittleEndian.Uint32(a.userID[:])^binary.LittleEndian.Uint32(a.lastClientHash[8:12]))
   230  	// encrypted data
   231  	err := a.putEncryptedData(poolBuf, a.userKey, [2]int{a.Overhead, 0}, a.salt)
   232  	if err != nil {
   233  		poolBuf.Reset()
   234  		return
   235  	}
   236  	// last server hash
   237  	a.lastServerHash = tools.HmacMD5(a.userKey, poolBuf.Bytes()[12:])
   238  	poolBuf.Write(a.lastServerHash[:4])
   239  	// packed data
   240  	a.packData(poolBuf, data)
   241  }
   242  
   243  func (a *authChainA) packData(poolBuf *bytes.Buffer, data []byte) {
   244  	a.encrypter.XORKeyStream(data, data)
   245  
   246  	macKey := pool.Get(len(a.userKey) + 4)
   247  	defer pool.Put(macKey)
   248  	copy(macKey, a.userKey)
   249  	binary.LittleEndian.PutUint32(macKey[len(a.userKey):], a.packID)
   250  	a.packID++
   251  
   252  	length := uint16(len(data)) ^ binary.LittleEndian.Uint16(a.lastClientHash[14:16])
   253  
   254  	originalLength := poolBuf.Len()
   255  	binary.Write(poolBuf, binary.LittleEndian, length)
   256  	a.putMixedRandDataAndData(poolBuf, data)
   257  	a.lastClientHash = tools.HmacMD5(macKey, poolBuf.Bytes()[originalLength:])
   258  	poolBuf.Write(a.lastClientHash[:2])
   259  }
   260  
   261  func (a *authChainA) putMixedRandDataAndData(poolBuf *bytes.Buffer, data []byte) {
   262  	randDataLength := a.randDataLength(len(data), a.lastClientHash, &a.randomClient)
   263  	if len(data) == 0 {
   264  		tools.AppendRandBytes(poolBuf, randDataLength)
   265  		return
   266  	}
   267  	if randDataLength > 0 {
   268  		startPos := getRandStartPos(randDataLength, &a.randomClient)
   269  		tools.AppendRandBytes(poolBuf, startPos)
   270  		poolBuf.Write(data)
   271  		tools.AppendRandBytes(poolBuf, randDataLength-startPos)
   272  		return
   273  	}
   274  	poolBuf.Write(data)
   275  }
   276  
   277  func getRandStartPos(length int, random *tools.XorShift128Plus) int {
   278  	if length == 0 {
   279  		return 0
   280  	}
   281  	return int(int64(random.Next()%8589934609) % int64(length))
   282  }
   283  
   284  func (a *authChainA) getRandLength(length int, lastHash []byte, random *tools.XorShift128Plus) int {
   285  	if length > 1440 {
   286  		return 0
   287  	}
   288  	random.InitFromBinAndLength(lastHash, length)
   289  	if length > 1300 {
   290  		return int(random.Next() % 31)
   291  	}
   292  	if length > 900 {
   293  		return int(random.Next() % 127)
   294  	}
   295  	if length > 400 {
   296  		return int(random.Next() % 521)
   297  	}
   298  	return int(random.Next() % 1021)
   299  }
   300  
   301  func (a *authChainA) initRC4Cipher() {
   302  	key := core.Kdf(base64.StdEncoding.EncodeToString(a.userKey)+base64.StdEncoding.EncodeToString(a.lastClientHash), 16)
   303  	a.encrypter, _ = rc4.NewCipher(key)
   304  	a.decrypter, _ = rc4.NewCipher(key)
   305  }
   306  
   307  func udpGetRandLength(lastHash []byte, random *tools.XorShift128Plus) int {
   308  	random.InitFromBin(lastHash)
   309  	return int(random.Next() % 127)
   310  }