github.com/glide-im/glide@v1.6.0/pkg/gate/authenticator.go (about)

     1  package gate
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/aes"
     6  	"crypto/cipher"
     7  	"crypto/rand"
     8  	"crypto/sha512"
     9  	"encoding/base64"
    10  	"encoding/json"
    11  	"errors"
    12  	"github.com/glide-im/glide/pkg/hash"
    13  	"github.com/glide-im/glide/pkg/logger"
    14  	"github.com/glide-im/glide/pkg/messages"
    15  	"strings"
    16  	"time"
    17  )
    18  
    19  type CredentialCrypto interface {
    20  	EncryptCredentials(c *ClientAuthCredentials) ([]byte, error)
    21  
    22  	DecryptCredentials(src []byte) (*ClientAuthCredentials, error)
    23  }
    24  
    25  // AesCBCCrypto cbc mode PKCS7 padding
    26  type AesCBCCrypto struct {
    27  	Key []byte
    28  }
    29  
    30  func NewAesCBCCrypto(key []byte) *AesCBCCrypto {
    31  	keyLen := len(key)
    32  	count := 0
    33  	switch true {
    34  	case keyLen <= 16:
    35  		count = 16 - keyLen
    36  	case keyLen <= 24:
    37  		count = 24 - keyLen
    38  	case keyLen <= 32:
    39  		count = 32 - keyLen
    40  	default:
    41  		key = key[:32]
    42  	}
    43  	if count != 0 {
    44  		key = append(key, bytes.Repeat([]byte{0}, count)...)
    45  	}
    46  	return &AesCBCCrypto{Key: key}
    47  }
    48  
    49  func (a *AesCBCCrypto) EncryptCredentials(c *ClientAuthCredentials) ([]byte, error) {
    50  	jsonBytes, err := json.Marshal(c)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	// generate random iv
    56  	iv := make([]byte, aes.BlockSize)
    57  	_, err = rand.Read(iv)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	encryptBody, err := a.Encrypt(jsonBytes, iv)
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  
    67  	// NOTE: append iv
    68  	var encrypt []byte
    69  	encrypt = append(encrypt, iv...)
    70  	encrypt = append(encrypt, encryptBody...)
    71  
    72  	// base64 encoding encrypted json credentials
    73  	b64Bytes := make([]byte, base64.RawStdEncoding.EncodedLen(len(encrypt)))
    74  	base64.RawStdEncoding.Encode(b64Bytes, encrypt)
    75  	return b64Bytes, nil
    76  }
    77  
    78  func (a *AesCBCCrypto) DecryptCredentials(src []byte) (*ClientAuthCredentials, error) {
    79  
    80  	encrypt := make([]byte, base64.RawStdEncoding.DecodedLen(len(src)))
    81  	_, err := base64.RawStdEncoding.Decode(encrypt, src)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	var iv []byte
    86  	iv = append(iv, encrypt[:aes.BlockSize]...)
    87  	var encryptBody []byte
    88  	encryptBody = append(encryptBody, encrypt[aes.BlockSize:]...)
    89  
    90  	jsonBytes, err := a.Decrypt(encryptBody, iv)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	credentials := ClientAuthCredentials{}
    96  	err = json.Unmarshal(jsonBytes, &credentials)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	return &credentials, nil
   101  }
   102  
   103  func (a *AesCBCCrypto) Encrypt(src, iv []byte) ([]byte, error) {
   104  
   105  	block, err := aes.NewCipher(a.Key)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	// padding
   110  	blockSize := block.BlockSize()
   111  	padding := blockSize - len(src)%blockSize
   112  	padtext := bytes.Repeat([]byte{byte(padding)}, padding)
   113  	src = append(src, padtext...)
   114  
   115  	encryptData := make([]byte, len(src))
   116  
   117  	if len(iv) != block.BlockSize() {
   118  		iv = a.cbcIVPending(iv, blockSize)
   119  	}
   120  
   121  	mode := cipher.NewCBCEncrypter(block, iv)
   122  	mode.CryptBlocks(encryptData, src)
   123  
   124  	return encryptData, nil
   125  }
   126  
   127  func (a *AesCBCCrypto) Decrypt(src, iv []byte) ([]byte, error) {
   128  
   129  	block, err := aes.NewCipher(a.Key)
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  
   134  	dst := make([]byte, len(src))
   135  	blockSize := block.BlockSize()
   136  	if len(iv) != blockSize {
   137  		iv = a.cbcIVPending(iv, blockSize)
   138  	}
   139  
   140  	mode := cipher.NewCBCDecrypter(block, iv)
   141  	mode.CryptBlocks(dst, src)
   142  
   143  	length := len(dst)
   144  	if length == 0 {
   145  		return nil, errors.New("unpadding")
   146  	}
   147  	unpadding := int(dst[length-1])
   148  	if length < unpadding {
   149  		return nil, errors.New("unpadding")
   150  	}
   151  	res := dst[:(length - unpadding)]
   152  
   153  	return res, nil
   154  }
   155  
   156  func (a *AesCBCCrypto) cbcIVPending(iv []byte, blockSize int) []byte {
   157  	k := len(iv)
   158  	if k < blockSize {
   159  		return append(iv, bytes.Repeat([]byte{0}, blockSize-k)...)
   160  	} else if k > blockSize {
   161  		return iv[0:blockSize]
   162  	}
   163  	return iv
   164  }
   165  
   166  // Authenticator handle client authentication message
   167  type Authenticator struct {
   168  	credentialCrypto CredentialCrypto
   169  	gateway          DefaultGateway
   170  }
   171  
   172  func NewAuthenticator(gateway DefaultGateway, key string) *Authenticator {
   173  	k := sha512.New().Sum([]byte(key))
   174  	return &Authenticator{
   175  		credentialCrypto: NewAesCBCCrypto(k),
   176  		gateway:          gateway,
   177  	}
   178  }
   179  
   180  func (a *Authenticator) MessageInterceptor(dc DefaultClient, msg *messages.GlideMessage) bool {
   181  
   182  	switch msg.Action {
   183  	case messages.ActionGroupMessage, messages.ActionChatMessage, messages.ActionChatMessageResend:
   184  		break
   185  	default:
   186  		return false
   187  	}
   188  
   189  	if dc.GetCredentials() == nil || dc.GetCredentials().Secrets == nil {
   190  		_ = a.gateway.EnqueueMessage(dc.GetInfo().ID, messages.NewMessage(msg.GetSeq(), messages.ActionNotifyForbidden, "no credentials"))
   191  		return true
   192  	}
   193  
   194  	secret := dc.GetCredentials().Secrets.MessageDeliverSecret
   195  	if secret == "" {
   196  		_ = a.gateway.EnqueueMessage(dc.GetInfo().ID, messages.NewMessage(msg.GetSeq(), messages.ActionNotifyForbidden, "no message deliver secret"))
   197  		return true
   198  	}
   199  
   200  	var ticket = msg.Ticket
   201  	// sha1 hash
   202  	if len(ticket) != 40 {
   203  		_ = a.gateway.EnqueueMessage(dc.GetInfo().ID, messages.NewMessage(msg.GetSeq(), messages.ActionNotifyForbidden, "invalid ticket"))
   204  		return true
   205  	}
   206  	sum1 := hash.SHA1(secret + msg.To)
   207  	id := dc.GetInfo().ID
   208  	expectTicket := hash.SHA1(secret + id.UID() + sum1)
   209  
   210  	if strings.ToUpper(ticket) != strings.ToUpper(expectTicket) {
   211  		logger.I("invalid ticket, expected=%s, actually=%s, secret=%s, to=%s, from=%s", expectTicket, ticket, secret, msg.To, id.UID())
   212  		// invalid ticket
   213  		_ = a.gateway.EnqueueMessage(dc.GetInfo().ID, messages.NewMessage(msg.GetSeq(), messages.ActionNotifyForbidden, "ticket expired"))
   214  		return true
   215  	}
   216  	return false
   217  }
   218  
   219  func (a *Authenticator) ClientAuthMessageInterceptor(dc DefaultClient, msg *messages.GlideMessage) (intercept bool) {
   220  	if msg.Action != messages.ActionAuthenticate {
   221  		return false
   222  	}
   223  
   224  	intercept = true
   225  
   226  	var err error
   227  	var errMsg string
   228  	var newId ID
   229  	var span int64
   230  	var authCredentials *ClientAuthCredentials
   231  
   232  	credential := EncryptedCredential{}
   233  	err = msg.Data.Deserialize(&credential)
   234  	if err != nil {
   235  		errMsg = "invalid authenticate message"
   236  		goto DONE
   237  	}
   238  
   239  	if len(credential.Credential) < 5 {
   240  		errMsg = "invalid authenticate message"
   241  		goto DONE
   242  	}
   243  
   244  	authCredentials, err = a.credentialCrypto.DecryptCredentials([]byte(credential.Credential))
   245  	if err != nil {
   246  		errMsg = "invalid authenticate message"
   247  		goto DONE
   248  	}
   249  
   250  	span = time.Now().UnixMilli() - authCredentials.Timestamp
   251  	if span > 1500*1000 {
   252  		errMsg = "credential expired"
   253  		goto DONE
   254  	}
   255  
   256  	newId, err = a.updateClient(dc, authCredentials)
   257  
   258  DONE:
   259  
   260  	logger.D("client auth message intercepted %s, %v", dc.GetInfo().ID, err)
   261  
   262  	if err != nil || errMsg != "" {
   263  		_ = a.gateway.EnqueueMessage(dc.GetInfo().ID, messages.NewMessage(msg.GetSeq(), messages.ActionNotifyError, errMsg))
   264  	} else {
   265  		_ = a.gateway.EnqueueMessage(newId, messages.NewMessage(msg.GetSeq(), messages.ActionNotifySuccess, nil))
   266  	}
   267  	return
   268  }
   269  
   270  func (a *Authenticator) updateClient(dc DefaultClient, authCredentials *ClientAuthCredentials) (ID, error) {
   271  
   272  	dc.SetCredentials(authCredentials)
   273  
   274  	oldID := dc.GetInfo().ID
   275  	newID := NewID2(authCredentials.UserID)
   276  	err := a.gateway.SetClientID(oldID, newID)
   277  	if IsIDAlreadyExist(err) {
   278  		if newID.Equals(oldID) {
   279  			// already authenticated
   280  			return newID, nil
   281  		}
   282  		tempID, _ := GenTempID("")
   283  		err = a.gateway.SetClientID(newID, tempID)
   284  		if err != nil {
   285  			return "", err
   286  		}
   287  		kickOut := messages.NewMessage(0, messages.ActionNotifyKickOut, &messages.KickOutNotify{
   288  			DeviceName: authCredentials.DeviceName,
   289  			DeviceId:   authCredentials.DeviceID,
   290  		})
   291  		_ = a.gateway.EnqueueMessage(tempID, kickOut)
   292  		err = a.gateway.SetClientID(oldID, newID)
   293  		if err != nil {
   294  			return "", err
   295  		}
   296  	}
   297  	return newID, err
   298  }