github.com/glide-im/glide@v1.6.0/pkg/gate/authenticator.go (about) 1 package gate 2 3 import ( 4 "bytes" 5 "crypto/aes" 6 "crypto/cipher" 7 "crypto/rand" 8 "crypto/sha512" 9 "encoding/base64" 10 "encoding/json" 11 "errors" 12 "github.com/glide-im/glide/pkg/hash" 13 "github.com/glide-im/glide/pkg/logger" 14 "github.com/glide-im/glide/pkg/messages" 15 "strings" 16 "time" 17 ) 18 19 type CredentialCrypto interface { 20 EncryptCredentials(c *ClientAuthCredentials) ([]byte, error) 21 22 DecryptCredentials(src []byte) (*ClientAuthCredentials, error) 23 } 24 25 // AesCBCCrypto cbc mode PKCS7 padding 26 type AesCBCCrypto struct { 27 Key []byte 28 } 29 30 func NewAesCBCCrypto(key []byte) *AesCBCCrypto { 31 keyLen := len(key) 32 count := 0 33 switch true { 34 case keyLen <= 16: 35 count = 16 - keyLen 36 case keyLen <= 24: 37 count = 24 - keyLen 38 case keyLen <= 32: 39 count = 32 - keyLen 40 default: 41 key = key[:32] 42 } 43 if count != 0 { 44 key = append(key, bytes.Repeat([]byte{0}, count)...) 45 } 46 return &AesCBCCrypto{Key: key} 47 } 48 49 func (a *AesCBCCrypto) EncryptCredentials(c *ClientAuthCredentials) ([]byte, error) { 50 jsonBytes, err := json.Marshal(c) 51 if err != nil { 52 return nil, err 53 } 54 55 // generate random iv 56 iv := make([]byte, aes.BlockSize) 57 _, err = rand.Read(iv) 58 if err != nil { 59 return nil, err 60 } 61 62 encryptBody, err := a.Encrypt(jsonBytes, iv) 63 if err != nil { 64 return nil, err 65 } 66 67 // NOTE: append iv 68 var encrypt []byte 69 encrypt = append(encrypt, iv...) 70 encrypt = append(encrypt, encryptBody...) 71 72 // base64 encoding encrypted json credentials 73 b64Bytes := make([]byte, base64.RawStdEncoding.EncodedLen(len(encrypt))) 74 base64.RawStdEncoding.Encode(b64Bytes, encrypt) 75 return b64Bytes, nil 76 } 77 78 func (a *AesCBCCrypto) DecryptCredentials(src []byte) (*ClientAuthCredentials, error) { 79 80 encrypt := make([]byte, base64.RawStdEncoding.DecodedLen(len(src))) 81 _, err := base64.RawStdEncoding.Decode(encrypt, src) 82 if err != nil { 83 return nil, err 84 } 85 var iv []byte 86 iv = append(iv, encrypt[:aes.BlockSize]...) 87 var encryptBody []byte 88 encryptBody = append(encryptBody, encrypt[aes.BlockSize:]...) 89 90 jsonBytes, err := a.Decrypt(encryptBody, iv) 91 if err != nil { 92 return nil, err 93 } 94 95 credentials := ClientAuthCredentials{} 96 err = json.Unmarshal(jsonBytes, &credentials) 97 if err != nil { 98 return nil, err 99 } 100 return &credentials, nil 101 } 102 103 func (a *AesCBCCrypto) Encrypt(src, iv []byte) ([]byte, error) { 104 105 block, err := aes.NewCipher(a.Key) 106 if err != nil { 107 return nil, err 108 } 109 // padding 110 blockSize := block.BlockSize() 111 padding := blockSize - len(src)%blockSize 112 padtext := bytes.Repeat([]byte{byte(padding)}, padding) 113 src = append(src, padtext...) 114 115 encryptData := make([]byte, len(src)) 116 117 if len(iv) != block.BlockSize() { 118 iv = a.cbcIVPending(iv, blockSize) 119 } 120 121 mode := cipher.NewCBCEncrypter(block, iv) 122 mode.CryptBlocks(encryptData, src) 123 124 return encryptData, nil 125 } 126 127 func (a *AesCBCCrypto) Decrypt(src, iv []byte) ([]byte, error) { 128 129 block, err := aes.NewCipher(a.Key) 130 if err != nil { 131 return nil, err 132 } 133 134 dst := make([]byte, len(src)) 135 blockSize := block.BlockSize() 136 if len(iv) != blockSize { 137 iv = a.cbcIVPending(iv, blockSize) 138 } 139 140 mode := cipher.NewCBCDecrypter(block, iv) 141 mode.CryptBlocks(dst, src) 142 143 length := len(dst) 144 if length == 0 { 145 return nil, errors.New("unpadding") 146 } 147 unpadding := int(dst[length-1]) 148 if length < unpadding { 149 return nil, errors.New("unpadding") 150 } 151 res := dst[:(length - unpadding)] 152 153 return res, nil 154 } 155 156 func (a *AesCBCCrypto) cbcIVPending(iv []byte, blockSize int) []byte { 157 k := len(iv) 158 if k < blockSize { 159 return append(iv, bytes.Repeat([]byte{0}, blockSize-k)...) 160 } else if k > blockSize { 161 return iv[0:blockSize] 162 } 163 return iv 164 } 165 166 // Authenticator handle client authentication message 167 type Authenticator struct { 168 credentialCrypto CredentialCrypto 169 gateway DefaultGateway 170 } 171 172 func NewAuthenticator(gateway DefaultGateway, key string) *Authenticator { 173 k := sha512.New().Sum([]byte(key)) 174 return &Authenticator{ 175 credentialCrypto: NewAesCBCCrypto(k), 176 gateway: gateway, 177 } 178 } 179 180 func (a *Authenticator) MessageInterceptor(dc DefaultClient, msg *messages.GlideMessage) bool { 181 182 switch msg.Action { 183 case messages.ActionGroupMessage, messages.ActionChatMessage, messages.ActionChatMessageResend: 184 break 185 default: 186 return false 187 } 188 189 if dc.GetCredentials() == nil || dc.GetCredentials().Secrets == nil { 190 _ = a.gateway.EnqueueMessage(dc.GetInfo().ID, messages.NewMessage(msg.GetSeq(), messages.ActionNotifyForbidden, "no credentials")) 191 return true 192 } 193 194 secret := dc.GetCredentials().Secrets.MessageDeliverSecret 195 if secret == "" { 196 _ = a.gateway.EnqueueMessage(dc.GetInfo().ID, messages.NewMessage(msg.GetSeq(), messages.ActionNotifyForbidden, "no message deliver secret")) 197 return true 198 } 199 200 var ticket = msg.Ticket 201 // sha1 hash 202 if len(ticket) != 40 { 203 _ = a.gateway.EnqueueMessage(dc.GetInfo().ID, messages.NewMessage(msg.GetSeq(), messages.ActionNotifyForbidden, "invalid ticket")) 204 return true 205 } 206 sum1 := hash.SHA1(secret + msg.To) 207 id := dc.GetInfo().ID 208 expectTicket := hash.SHA1(secret + id.UID() + sum1) 209 210 if strings.ToUpper(ticket) != strings.ToUpper(expectTicket) { 211 logger.I("invalid ticket, expected=%s, actually=%s, secret=%s, to=%s, from=%s", expectTicket, ticket, secret, msg.To, id.UID()) 212 // invalid ticket 213 _ = a.gateway.EnqueueMessage(dc.GetInfo().ID, messages.NewMessage(msg.GetSeq(), messages.ActionNotifyForbidden, "ticket expired")) 214 return true 215 } 216 return false 217 } 218 219 func (a *Authenticator) ClientAuthMessageInterceptor(dc DefaultClient, msg *messages.GlideMessage) (intercept bool) { 220 if msg.Action != messages.ActionAuthenticate { 221 return false 222 } 223 224 intercept = true 225 226 var err error 227 var errMsg string 228 var newId ID 229 var span int64 230 var authCredentials *ClientAuthCredentials 231 232 credential := EncryptedCredential{} 233 err = msg.Data.Deserialize(&credential) 234 if err != nil { 235 errMsg = "invalid authenticate message" 236 goto DONE 237 } 238 239 if len(credential.Credential) < 5 { 240 errMsg = "invalid authenticate message" 241 goto DONE 242 } 243 244 authCredentials, err = a.credentialCrypto.DecryptCredentials([]byte(credential.Credential)) 245 if err != nil { 246 errMsg = "invalid authenticate message" 247 goto DONE 248 } 249 250 span = time.Now().UnixMilli() - authCredentials.Timestamp 251 if span > 1500*1000 { 252 errMsg = "credential expired" 253 goto DONE 254 } 255 256 newId, err = a.updateClient(dc, authCredentials) 257 258 DONE: 259 260 logger.D("client auth message intercepted %s, %v", dc.GetInfo().ID, err) 261 262 if err != nil || errMsg != "" { 263 _ = a.gateway.EnqueueMessage(dc.GetInfo().ID, messages.NewMessage(msg.GetSeq(), messages.ActionNotifyError, errMsg)) 264 } else { 265 _ = a.gateway.EnqueueMessage(newId, messages.NewMessage(msg.GetSeq(), messages.ActionNotifySuccess, nil)) 266 } 267 return 268 } 269 270 func (a *Authenticator) updateClient(dc DefaultClient, authCredentials *ClientAuthCredentials) (ID, error) { 271 272 dc.SetCredentials(authCredentials) 273 274 oldID := dc.GetInfo().ID 275 newID := NewID2(authCredentials.UserID) 276 err := a.gateway.SetClientID(oldID, newID) 277 if IsIDAlreadyExist(err) { 278 if newID.Equals(oldID) { 279 // already authenticated 280 return newID, nil 281 } 282 tempID, _ := GenTempID("") 283 err = a.gateway.SetClientID(newID, tempID) 284 if err != nil { 285 return "", err 286 } 287 kickOut := messages.NewMessage(0, messages.ActionNotifyKickOut, &messages.KickOutNotify{ 288 DeviceName: authCredentials.DeviceName, 289 DeviceId: authCredentials.DeviceID, 290 }) 291 _ = a.gateway.EnqueueMessage(tempID, kickOut) 292 err = a.gateway.SetClientID(oldID, newID) 293 if err != nil { 294 return "", err 295 } 296 } 297 return newID, err 298 }