github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/pkg/tokens/binaryjwt.go (about) 1 package tokens 2 3 import ( 4 "crypto/ecdsa" 5 "crypto/elliptic" 6 "crypto/rand" 7 "encoding/binary" 8 "fmt" 9 "math/big" 10 "sync" 11 "time" 12 13 "github.com/ugorji/go/codec" 14 enforcerconstants "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/constants" 15 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/ephemeralkeys" 16 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/claimsheader" 17 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pkiverifier" 18 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/secrets" 19 "go.aporeto.io/enforcerd/trireme-lib/utils/cache" 20 localcrypto "go.aporeto.io/enforcerd/trireme-lib/utils/crypto" 21 ) 22 23 // To generate the codecs, 24 // codecgen -o binarycodec.go binaryjwtclaimtypes.go 25 26 // Format of Binary Tokens 27 // 0 1 2 3 4 28 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 29 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 30 // | D |CT|E| Encoding | R (reserved) | 31 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 32 // | Signature Position | nonce | 33 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 34 // | ... | 35 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 36 // | token | 37 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 38 // | ... | 39 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 40 // | Signature | 41 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 42 // | ... | 43 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ 44 // D [0:6] - Datapath version 45 // CT [6:8] - Compressed tag type 46 // E [8:9] - Encryption enabled 47 // C [9:12] - Codec selector 48 // R [12:32] - Reserved 49 // L [32:48] - Token Length 50 // Token bytes (equal to token length) 51 // Signature bytes 52 53 const ( 54 binaryNoncePosition = 6 55 lengthPosition = 4 56 headerLength = 4 57 sharedKeyCacheTimeout = 5 * time.Minute 58 ) 59 60 //ClaimsEncodedBufSize is the size of maximum buffer that is required 61 //for claims to be serialized into 62 const ClaimsEncodedBufSize = 1400 63 64 // AckPattern is added in SYN and ACK tokens. 65 var AckPattern = []byte("PANWIDENTITY") 66 var sha256KeyLength int = 32 67 68 type sharedKeyStruct struct { 69 sharedKeys map[string][]byte 70 sync.RWMutex 71 } 72 73 func (s *sharedKeyStruct) Get(key string) []byte { 74 75 s.RLock() 76 77 if val, ok := s.sharedKeys[key]; ok { 78 s.RUnlock() 79 return val 80 } 81 82 s.RUnlock() 83 return nil 84 } 85 86 func (s *sharedKeyStruct) Put(key string, val []byte) { 87 88 s.Lock() 89 s.sharedKeys[key] = val 90 s.Unlock() 91 92 time.AfterFunc(sharedKeyCacheTimeout, func() { 93 s.Lock() 94 delete(s.sharedKeys, key) 95 s.Unlock() 96 }) 97 } 98 99 // BinaryJWTConfig configures the JWT token generator with the standard parameters. One 100 // configuration is assigned to each server 101 type BinaryJWTConfig struct { 102 // ValidityPeriod period of the JWT 103 ValidityPeriod time.Duration 104 // Issuer is the server that issues the JWT 105 Issuer string 106 // cache test 107 tokenCache cache.DataStore 108 // sharedKey is a cache of pre-shared keys. 109 sharedKeys *sharedKeyStruct 110 } 111 112 // NewBinaryJWT creates a new JWT token processor 113 func NewBinaryJWT(validity time.Duration, issuer string) (*BinaryJWTConfig, error) { 114 115 return &BinaryJWTConfig{ 116 ValidityPeriod: validity, 117 Issuer: issuer, 118 tokenCache: cache.NewCacheWithExpiration("JWTTokenCache", validity), 119 sharedKeys: &sharedKeyStruct{sharedKeys: map[string][]byte{}}, 120 }, nil 121 } 122 123 // DecodeSyn takes as argument the JWT token and the certificate of the issuer. 124 // First it verifies the certificate with the local CA pool, and the decodes 125 // the JWT if the certificate is trusted 126 func (c *BinaryJWTConfig) DecodeSyn(isSynAck bool, data []byte, privateKey *ephemeralkeys.PrivateKey, secrets secrets.Secrets, connClaims *ConnectionClaims) ([]byte, *claimsheader.ClaimsHeader, []byte, *pkiverifier.PKIControllerInfo, bool, error) { 127 header, nonce, token, sig, err := unpackToken(false, data) 128 if err != nil { 129 return nil, nil, nil, nil, false, err 130 } 131 // Parse the claims header. 132 claimsHeader := claimsheader.HeaderBytes(header).ToClaimsHeader() 133 134 // Validate the header version. 135 if err := c.verifyClaimsHeader(claimsHeader); err != nil { 136 return nil, nil, nil, nil, false, err 137 } 138 139 // Decode the claims to a data structure. 140 binaryClaims, err := decode(token) 141 if err != nil { 142 return nil, nil, nil, nil, false, err 143 } 144 145 //Process 314 Protocol 146 if len(binaryClaims.DEK) == 0 { 147 secretKey, controller, err := c.process314Protocol(isSynAck, token, secrets, connClaims, binaryClaims, sig) 148 return secretKey, claimsHeader, nonce, controller, true, err 149 } 150 151 //Process 500 Protocol 152 secretKey, controller, err := c.process500Protocol(isSynAck, token, privateKey, secrets, connClaims, binaryClaims, sig) 153 154 return secretKey, claimsHeader, nonce, controller, false, err 155 } 156 157 // DecodeAck decodes the ack packet token 158 func (c *BinaryJWTConfig) DecodeAck(proto314 bool, secretKey []byte, data []byte, connClaims *ConnectionClaims) error { 159 // Unpack the token first. 160 header, _, token, sig, err := unpackToken(true, data) 161 if err != nil { 162 return err 163 } 164 165 // Parse the claims header. 166 claimsHeader := claimsheader.HeaderBytes(header).ToClaimsHeader() 167 168 // Validate the header. 169 if err := c.verifyClaimsHeader(claimsHeader); err != nil { 170 return err 171 } 172 173 // Decode the claims to a data structure. 174 binaryClaims, err := decode(token) 175 if err != nil { 176 return err 177 } 178 179 if proto314 { 180 // Calculate the signature on the token and compare it with the incoming 181 // signature. Since this is simple symetric hashing this is simple. 182 if err := c.verifyWithSharedKey314(token, secretKey, sig); err != nil { 183 return err 184 } 185 } else { 186 if err := c.verifyWithSharedKey500(token, secretKey, sig[0:sha256KeyLength]); err != nil { 187 return err 188 } 189 } 190 191 CopyToConnectionClaims(binaryClaims, connClaims) 192 return nil 193 } 194 195 //CreateSynToken creates the token which is attached to the tcp syn packet. 196 func (c *BinaryJWTConfig) CreateSynToken(claims *ConnectionClaims, encodedBuf []byte, nonce []byte, header *claimsheader.ClaimsHeader, secrets secrets.Secrets) ([]byte, error) { 197 // Set the appropriate claims header 198 header.SetCompressionType(claimsheader.CompressionTypeV1) 199 header.SetDatapathVersion(claimsheader.DatapathVersion1) 200 201 // Combine the application claims with the standard claims. 202 // In all cases for Syn/SynAck packets we also transmit our 203 // public key. 204 allclaims := ConvertToBinaryClaims(claims, c.ValidityPeriod) 205 allclaims.SignerKey = secrets.TransmittedKey() 206 207 // Encode the claims in a buffer. 208 err := encode(allclaims, &encodedBuf) 209 if err != nil { 210 return nil, logError(ErrTokenEncodeFailed, err.Error()) 211 } 212 213 var sig []byte 214 215 encodedBuf = append(encodedBuf, AckPattern...) 216 217 sig, err = c.sign(encodedBuf, secrets.EncodingKey().(*ecdsa.PrivateKey)) 218 219 if err != nil { 220 return nil, err 221 } 222 223 // Pack and return the token. 224 return packToken(header.ToBytes(), nonce, encodedBuf, sig), nil 225 } 226 227 //CreateSynAckToken creates syn/ack token which is attached to the syn/ack packet. 228 func (c *BinaryJWTConfig) CreateSynAckToken(proto314 bool, claims *ConnectionClaims, encodedBuf []byte, nonce []byte, header *claimsheader.ClaimsHeader, secrets secrets.Secrets, secretKey []byte) ([]byte, error) { 229 230 // Set the appropriate claims header 231 header.SetCompressionType(claimsheader.CompressionTypeV1) 232 header.SetDatapathVersion(claimsheader.DatapathVersion1) 233 234 // Combine the application claims with the standard claims. 235 // In all cases for Syn/SynAck packets we also transmit our 236 // public key. 237 allclaims := ConvertToBinaryClaims(claims, c.ValidityPeriod) 238 allclaims.SignerKey = secrets.TransmittedKey() 239 240 // Encode the claims in a buffer. 241 err := encode(allclaims, &encodedBuf) 242 if err != nil { 243 return nil, logError(ErrTokenEncodeFailed, err.Error()) 244 } 245 246 var sig []byte 247 248 encodedBuf = append(encodedBuf, AckPattern...) 249 250 if proto314 { 251 sig, err = hash314(encodedBuf, secretKey) 252 if err != nil { 253 return nil, err 254 } 255 } else { 256 sig, err = hash500(encodedBuf, secretKey) 257 if err != nil { 258 return nil, err 259 } 260 } 261 262 // Pack and return the token. 263 return packToken(header.ToBytes(), nonce, encodedBuf, sig), nil 264 } 265 266 // Randomize puts the random nonce in the syn token 267 func (c *BinaryJWTConfig) Randomize(token []byte, nonce []byte) error { 268 269 if len(token) < 6+NonceLength { 270 return logError(ErrTokenTooSmall, "token is too small") 271 } 272 273 copy(token[6:], nonce) 274 275 return nil 276 } 277 278 //CreateAckToken creates ack token which is attached to the ack packet. 279 func (c *BinaryJWTConfig) CreateAckToken(proto314 bool, secretKey []byte, claims *ConnectionClaims, encodedBuf []byte, header *claimsheader.ClaimsHeader) ([]byte, error) { 280 281 var pad []byte 282 // Combine the application claims with the standard claims 283 allclaims := ConvertToBinaryClaims(claims, c.ValidityPeriod) 284 285 // Encode the claims in a buffer. 286 err := encode(allclaims, &encodedBuf) 287 if err != nil { 288 return nil, logError(ErrTokenEncodeFailed, err.Error()) 289 } 290 encodedBuf = append(encodedBuf, AckPattern...) 291 292 var sig []byte 293 // Sign the buffer with the pre-shared key. 294 if proto314 { 295 sig, err = hash314(encodedBuf, secretKey) 296 if err != nil { 297 return nil, err 298 } 299 pad = sig 300 } else { 301 pad = make([]byte, 64) 302 sig, err = hash500(encodedBuf, secretKey) 303 if err != nil { 304 return nil, err 305 } 306 copy(pad, sig) 307 } 308 309 // Pack and return the token. 310 return packToken(header.ToBytes(), nil, encodedBuf, pad), nil 311 } 312 313 func (c *BinaryJWTConfig) verifyClaimsHeader(h *claimsheader.ClaimsHeader) error { 314 315 if h.CompressionType() != claimsheader.CompressionTypeV1 { 316 return ErrCompressedTagMismatch 317 318 } 319 320 if h.DatapathVersion() != claimsheader.DatapathVersion1 { 321 return ErrDatapathVersionMismatch 322 } 323 324 return nil 325 } 326 327 // Sign takes in a slice of bytes and a private key, and returns a ecdsa signature. 328 func (c *BinaryJWTConfig) Sign(buf []byte, key *ecdsa.PrivateKey) ([]byte, error) { 329 return c.sign(buf, key) 330 } 331 332 func (c *BinaryJWTConfig) sign(buf []byte, key *ecdsa.PrivateKey) ([]byte, error) { 333 334 // Create the hash and use this for the signature. This is a SHA256 hash 335 // of the token. 336 h, err := hash500(buf, nil) 337 if err != nil { 338 return nil, logError(ErrTokenHashFailed, err.Error()) 339 } 340 341 // Sign the hash with the private key using the ECDSA algorithm 342 // and properly format the resulting signature. 343 r, s, err := ecdsa.Sign(rand.Reader, key, h) 344 if err != nil { 345 return nil, logError(ErrTokenSignFailed, err.Error()) 346 } 347 348 curveBits := key.Curve.Params().BitSize 349 keyBytes := curveBits / 8 350 if curveBits%8 > 0 { 351 keyBytes++ 352 } 353 354 // We serialize the outpus (r and s) into big-endian byte arrays and pad 355 // them with zeros on the left to make sure the sizes work out. Both arrays 356 // must be keyBytes long, and the output must be 2*keyBytes long. 357 tokenBytes := make([]byte, 2*keyBytes) 358 359 rBytes := r.Bytes() 360 copy(tokenBytes[keyBytes-len(rBytes):], rBytes) 361 362 sBytes := s.Bytes() 363 copy(tokenBytes[2*keyBytes-len(sBytes):], sBytes) 364 365 return tokenBytes, nil 366 } 367 368 func (c *BinaryJWTConfig) verify(buf []byte, sig []byte, key *ecdsa.PublicKey) error { 369 370 if len(sig) != 64 { 371 return ErrInvalidSignature 372 } 373 374 r := big.NewInt(0).SetBytes(sig[:32]) 375 s := big.NewInt(0).SetBytes(sig[32:]) 376 377 // Create the hash and use this for the signature. This is a SHA256 hash 378 // of the token. 379 h, err := hash500(buf, nil) 380 if err != nil { 381 return logError(ErrTokenHashFailed, err.Error()) 382 } 383 384 if verifyStatus := ecdsa.Verify(key, h, r, s); verifyStatus { 385 return nil 386 } 387 388 return ErrInvalidSignature 389 } 390 391 func (c *BinaryJWTConfig) getSecretKey(privateKey *ephemeralkeys.PrivateKey, remotePublicKeyString string, isV1Proto bool) ([]byte, error) { 392 393 var remotePublicKey *ecdsa.PublicKey 394 var err error 395 396 hashKey := privateKey.PrivateKeyString + remotePublicKeyString 397 398 secretKey := c.sharedKeys.Get(hashKey) 399 400 if secretKey != nil { 401 return secretKey, nil 402 } 403 404 if isV1Proto { 405 remotePublicKey, err = localcrypto.DecodePublicKeyV1([]byte(remotePublicKeyString)) 406 if err != nil { 407 return nil, err 408 } 409 } else { 410 remotePublicKey, err = localcrypto.DecodePublicKeyV2([]byte(remotePublicKeyString)) 411 if err != nil { 412 return nil, err 413 } 414 } 415 416 if secretKey, err = symmetricKey(privateKey.PrivateKey, remotePublicKey); err != nil { 417 return nil, err 418 } 419 420 c.sharedKeys.Put(hashKey, secretKey) 421 422 return secretKey, nil 423 } 424 425 func encode(c *BinaryJWTClaims, buf *[]byte) error { 426 // Encode and sign the token 427 if cap(*buf) != ClaimsEncodedBufSize { 428 return fmt.Errorf("Not enough space in byte slice") 429 } 430 431 var h codec.Handle = new(codec.CborHandle) 432 enc := codec.NewEncoderBytes(buf, h) 433 if err := enc.Encode(c); err != nil { 434 return fmt.Errorf("unable to encode message: %s", err) 435 } 436 437 return nil 438 } 439 440 func decode(buf []byte) (*BinaryJWTClaims, error) { 441 // Decode the token into a structure. 442 binaryClaims := &BinaryJWTClaims{} 443 var h codec.Handle = new(codec.CborHandle) 444 445 dec := codec.NewDecoderBytes(buf, h) 446 447 if err := dec.Decode(binaryClaims); err != nil { 448 return nil, logError(ErrTokenDecodeFailed, err.Error()) 449 } 450 451 if binaryClaims.ExpiresAt < time.Now().Unix() { 452 return nil, logError(ErrTokenExpired, fmt.Sprintf("token is expired since: %s", time.Unix(binaryClaims.ExpiresAt, 0))) 453 } 454 455 return binaryClaims, nil 456 } 457 458 func packToken(header, nonce, token, sig []byte) []byte { 459 460 binaryTokenPosition := binaryNoncePosition + len(nonce) 461 sigPosition := binaryTokenPosition + len(token) 462 463 // Token is the concatenation of 464 // [Position of Signature] [nonce] [token] [signature] 465 data := make([]byte, sigPosition+len(sig)) 466 467 // Header bytes 468 copy(data[0:headerLength], header) 469 // Length of token 470 binary.BigEndian.PutUint16(data[lengthPosition:], uint16(sigPosition)) 471 472 // nonce not required for ack packets 473 if len(nonce) > 0 { 474 copy(data[binaryNoncePosition:], nonce) 475 } 476 477 // token 478 copy(data[binaryTokenPosition:], token) 479 480 // signature 481 copy(data[sigPosition:], sig) 482 483 return data 484 } 485 486 // unpackToken returns nonce, token, signature or error if something fails 487 func unpackToken(isAck bool, data []byte) ([]byte, []byte, []byte, []byte, error) { 488 489 // We must have enough data to read the length. 490 if len(data) < binaryNoncePosition { 491 return nil, nil, nil, nil, ErrInvalidTokenLength 492 } 493 494 header := make([]byte, headerLength) 495 copy(header, data[:lengthPosition]) 496 497 sigPosition := int(binary.BigEndian.Uint16(data[lengthPosition : lengthPosition+2])) 498 // The token must be long enough to have at least 1 byte of signature. 499 if len(data) < sigPosition+1 || sigPosition == 0 { 500 return nil, nil, nil, nil, ErrMissingSignature 501 } 502 503 var nonce []byte 504 505 if !isAck { 506 nonce = make([]byte, 16) 507 copy(nonce, data[binaryNoncePosition:binaryNoncePosition+NonceLength]) 508 } 509 510 // Only if nonce is found do we need to advance. So, use the 511 // actual length of the nonce and not just a constant here. 512 token := data[binaryNoncePosition+len(nonce) : sigPosition] 513 514 sig := data[sigPosition:] 515 return header, nonce, token, sig, nil 516 } 517 518 // symmetricKey returns a symmetric key for encryption 519 func symmetricKey(privateKey *ecdsa.PrivateKey, remotePublic *ecdsa.PublicKey) ([]byte, error) { 520 521 c := elliptic.P256() 522 523 x, _ := c.ScalarMult(remotePublic.X, remotePublic.Y, privateKey.D.Bytes()) 524 525 return hash500(x.Bytes(), nil) 526 } 527 528 func uncompressTags(binaryClaims *BinaryJWTClaims, publicKeyClaims []string) { 529 530 binaryClaims.T = append(binaryClaims.CT, enforcerconstants.TransmitterLabel+"="+binaryClaims.ID) 531 532 for _, pc := range publicKeyClaims { 533 534 if len(pc) <= claimsheader.CompressedTagLengthV1 { 535 binaryClaims.T = append(binaryClaims.T, pc) 536 continue 537 } 538 539 binaryClaims.T = append(binaryClaims.T, pc[:claimsheader.CompressedTagLengthV1]) 540 } 541 }