github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/pkg/tokens/binaryjwt.go (about)

     1  package tokens
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"crypto/elliptic"
     6  	"crypto/rand"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"math/big"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/ugorji/go/codec"
    14  	enforcerconstants "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/constants"
    15  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/ephemeralkeys"
    16  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/claimsheader"
    17  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pkiverifier"
    18  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/secrets"
    19  	"go.aporeto.io/enforcerd/trireme-lib/utils/cache"
    20  	localcrypto "go.aporeto.io/enforcerd/trireme-lib/utils/crypto"
    21  )
    22  
    23  // To generate the codecs,
    24  // codecgen -o binarycodec.go binaryjwtclaimtypes.go
    25  
    26  // Format of Binary Tokens
    27  //    0             1              2               3               4
    28  //  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
    29  //  +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    30  //  |     D     |CT|E| Encoding |    R (reserved)                   |
    31  //  +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    32  //  | Signature Position           |    nonce                       |
    33  //  +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    34  //  |   ...                                                         |
    35  //  +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    36  //  |   token                                                       |
    37  //  +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    38  //  |   ...                                                         |
    39  //  +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    40  //  | Signature                                                     |
    41  //  +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    42  //  |   ...                                                         |
    43  //  +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
    44  //  D  [0:6]   - Datapath version
    45  //  CT [6:8]   - Compressed tag type
    46  //  E  [8:9]   - Encryption enabled
    47  //  C  [9:12]  - Codec selector
    48  //  R  [12:32] - Reserved
    49  //  L  [32:48] - Token Length
    50  //  Token bytes (equal to token length)
    51  //  Signature bytes
    52  
    53  const (
    54  	binaryNoncePosition   = 6
    55  	lengthPosition        = 4
    56  	headerLength          = 4
    57  	sharedKeyCacheTimeout = 5 * time.Minute
    58  )
    59  
    60  //ClaimsEncodedBufSize is the size of maximum buffer that is required
    61  //for claims to be serialized into
    62  const ClaimsEncodedBufSize = 1400
    63  
    64  // AckPattern is added in SYN and ACK tokens.
    65  var AckPattern = []byte("PANWIDENTITY")
    66  var sha256KeyLength int = 32
    67  
    68  type sharedKeyStruct struct {
    69  	sharedKeys map[string][]byte
    70  	sync.RWMutex
    71  }
    72  
    73  func (s *sharedKeyStruct) Get(key string) []byte {
    74  
    75  	s.RLock()
    76  
    77  	if val, ok := s.sharedKeys[key]; ok {
    78  		s.RUnlock()
    79  		return val
    80  	}
    81  
    82  	s.RUnlock()
    83  	return nil
    84  }
    85  
    86  func (s *sharedKeyStruct) Put(key string, val []byte) {
    87  
    88  	s.Lock()
    89  	s.sharedKeys[key] = val
    90  	s.Unlock()
    91  
    92  	time.AfterFunc(sharedKeyCacheTimeout, func() {
    93  		s.Lock()
    94  		delete(s.sharedKeys, key)
    95  		s.Unlock()
    96  	})
    97  }
    98  
    99  // BinaryJWTConfig configures the JWT token generator with the standard parameters. One
   100  // configuration is assigned to each server
   101  type BinaryJWTConfig struct {
   102  	// ValidityPeriod  period of the JWT
   103  	ValidityPeriod time.Duration
   104  	// Issuer is the server that issues the JWT
   105  	Issuer string
   106  	// cache test
   107  	tokenCache cache.DataStore
   108  	// sharedKey is a cache of pre-shared keys.
   109  	sharedKeys *sharedKeyStruct
   110  }
   111  
   112  // NewBinaryJWT creates a new JWT token processor
   113  func NewBinaryJWT(validity time.Duration, issuer string) (*BinaryJWTConfig, error) {
   114  
   115  	return &BinaryJWTConfig{
   116  		ValidityPeriod: validity,
   117  		Issuer:         issuer,
   118  		tokenCache:     cache.NewCacheWithExpiration("JWTTokenCache", validity),
   119  		sharedKeys:     &sharedKeyStruct{sharedKeys: map[string][]byte{}},
   120  	}, nil
   121  }
   122  
   123  // DecodeSyn takes as argument the JWT token and the certificate of the issuer.
   124  // First it verifies the certificate with the local CA pool, and the decodes
   125  // the JWT if the certificate is trusted
   126  func (c *BinaryJWTConfig) DecodeSyn(isSynAck bool, data []byte, privateKey *ephemeralkeys.PrivateKey, secrets secrets.Secrets, connClaims *ConnectionClaims) ([]byte, *claimsheader.ClaimsHeader, []byte, *pkiverifier.PKIControllerInfo, bool, error) {
   127  	header, nonce, token, sig, err := unpackToken(false, data)
   128  	if err != nil {
   129  		return nil, nil, nil, nil, false, err
   130  	}
   131  	// Parse the claims header.
   132  	claimsHeader := claimsheader.HeaderBytes(header).ToClaimsHeader()
   133  
   134  	// Validate the header version.
   135  	if err := c.verifyClaimsHeader(claimsHeader); err != nil {
   136  		return nil, nil, nil, nil, false, err
   137  	}
   138  
   139  	// Decode the claims to a data structure.
   140  	binaryClaims, err := decode(token)
   141  	if err != nil {
   142  		return nil, nil, nil, nil, false, err
   143  	}
   144  
   145  	//Process 314 Protocol
   146  	if len(binaryClaims.DEK) == 0 {
   147  		secretKey, controller, err := c.process314Protocol(isSynAck, token, secrets, connClaims, binaryClaims, sig)
   148  		return secretKey, claimsHeader, nonce, controller, true, err
   149  	}
   150  
   151  	//Process 500 Protocol
   152  	secretKey, controller, err := c.process500Protocol(isSynAck, token, privateKey, secrets, connClaims, binaryClaims, sig)
   153  
   154  	return secretKey, claimsHeader, nonce, controller, false, err
   155  }
   156  
   157  // DecodeAck decodes the ack packet token
   158  func (c *BinaryJWTConfig) DecodeAck(proto314 bool, secretKey []byte, data []byte, connClaims *ConnectionClaims) error {
   159  	// Unpack the token first.
   160  	header, _, token, sig, err := unpackToken(true, data)
   161  	if err != nil {
   162  		return err
   163  	}
   164  
   165  	// Parse the claims header.
   166  	claimsHeader := claimsheader.HeaderBytes(header).ToClaimsHeader()
   167  
   168  	// Validate the header.
   169  	if err := c.verifyClaimsHeader(claimsHeader); err != nil {
   170  		return err
   171  	}
   172  
   173  	// Decode the claims to a data structure.
   174  	binaryClaims, err := decode(token)
   175  	if err != nil {
   176  		return err
   177  	}
   178  
   179  	if proto314 {
   180  		// Calculate the signature on the token and compare it with the incoming
   181  		// signature. Since this is simple symetric hashing this is simple.
   182  		if err := c.verifyWithSharedKey314(token, secretKey, sig); err != nil {
   183  			return err
   184  		}
   185  	} else {
   186  		if err := c.verifyWithSharedKey500(token, secretKey, sig[0:sha256KeyLength]); err != nil {
   187  			return err
   188  		}
   189  	}
   190  
   191  	CopyToConnectionClaims(binaryClaims, connClaims)
   192  	return nil
   193  }
   194  
   195  //CreateSynToken creates the token which is attached to the tcp syn packet.
   196  func (c *BinaryJWTConfig) CreateSynToken(claims *ConnectionClaims, encodedBuf []byte, nonce []byte, header *claimsheader.ClaimsHeader, secrets secrets.Secrets) ([]byte, error) {
   197  	// Set the appropriate claims header
   198  	header.SetCompressionType(claimsheader.CompressionTypeV1)
   199  	header.SetDatapathVersion(claimsheader.DatapathVersion1)
   200  
   201  	// Combine the application claims with the standard claims.
   202  	// In all cases for Syn/SynAck packets we also transmit our
   203  	// public key.
   204  	allclaims := ConvertToBinaryClaims(claims, c.ValidityPeriod)
   205  	allclaims.SignerKey = secrets.TransmittedKey()
   206  
   207  	// Encode the claims in a buffer.
   208  	err := encode(allclaims, &encodedBuf)
   209  	if err != nil {
   210  		return nil, logError(ErrTokenEncodeFailed, err.Error())
   211  	}
   212  
   213  	var sig []byte
   214  
   215  	encodedBuf = append(encodedBuf, AckPattern...)
   216  
   217  	sig, err = c.sign(encodedBuf, secrets.EncodingKey().(*ecdsa.PrivateKey))
   218  
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  
   223  	// Pack and return the token.
   224  	return packToken(header.ToBytes(), nonce, encodedBuf, sig), nil
   225  }
   226  
   227  //CreateSynAckToken creates syn/ack token which is attached to the syn/ack packet.
   228  func (c *BinaryJWTConfig) CreateSynAckToken(proto314 bool, claims *ConnectionClaims, encodedBuf []byte, nonce []byte, header *claimsheader.ClaimsHeader, secrets secrets.Secrets, secretKey []byte) ([]byte, error) {
   229  
   230  	// Set the appropriate claims header
   231  	header.SetCompressionType(claimsheader.CompressionTypeV1)
   232  	header.SetDatapathVersion(claimsheader.DatapathVersion1)
   233  
   234  	// Combine the application claims with the standard claims.
   235  	// In all cases for Syn/SynAck packets we also transmit our
   236  	// public key.
   237  	allclaims := ConvertToBinaryClaims(claims, c.ValidityPeriod)
   238  	allclaims.SignerKey = secrets.TransmittedKey()
   239  
   240  	// Encode the claims in a buffer.
   241  	err := encode(allclaims, &encodedBuf)
   242  	if err != nil {
   243  		return nil, logError(ErrTokenEncodeFailed, err.Error())
   244  	}
   245  
   246  	var sig []byte
   247  
   248  	encodedBuf = append(encodedBuf, AckPattern...)
   249  
   250  	if proto314 {
   251  		sig, err = hash314(encodedBuf, secretKey)
   252  		if err != nil {
   253  			return nil, err
   254  		}
   255  	} else {
   256  		sig, err = hash500(encodedBuf, secretKey)
   257  		if err != nil {
   258  			return nil, err
   259  		}
   260  	}
   261  
   262  	// Pack and return the token.
   263  	return packToken(header.ToBytes(), nonce, encodedBuf, sig), nil
   264  }
   265  
   266  // Randomize puts the random nonce in the syn token
   267  func (c *BinaryJWTConfig) Randomize(token []byte, nonce []byte) error {
   268  
   269  	if len(token) < 6+NonceLength {
   270  		return logError(ErrTokenTooSmall, "token is too small")
   271  	}
   272  
   273  	copy(token[6:], nonce)
   274  
   275  	return nil
   276  }
   277  
   278  //CreateAckToken creates ack token which is attached to the ack packet.
   279  func (c *BinaryJWTConfig) CreateAckToken(proto314 bool, secretKey []byte, claims *ConnectionClaims, encodedBuf []byte, header *claimsheader.ClaimsHeader) ([]byte, error) {
   280  
   281  	var pad []byte
   282  	// Combine the application claims with the standard claims
   283  	allclaims := ConvertToBinaryClaims(claims, c.ValidityPeriod)
   284  
   285  	// Encode the claims in a buffer.
   286  	err := encode(allclaims, &encodedBuf)
   287  	if err != nil {
   288  		return nil, logError(ErrTokenEncodeFailed, err.Error())
   289  	}
   290  	encodedBuf = append(encodedBuf, AckPattern...)
   291  
   292  	var sig []byte
   293  	// Sign the buffer with the pre-shared key.
   294  	if proto314 {
   295  		sig, err = hash314(encodedBuf, secretKey)
   296  		if err != nil {
   297  			return nil, err
   298  		}
   299  		pad = sig
   300  	} else {
   301  		pad = make([]byte, 64)
   302  		sig, err = hash500(encodedBuf, secretKey)
   303  		if err != nil {
   304  			return nil, err
   305  		}
   306  		copy(pad, sig)
   307  	}
   308  
   309  	// Pack and return the token.
   310  	return packToken(header.ToBytes(), nil, encodedBuf, pad), nil
   311  }
   312  
   313  func (c *BinaryJWTConfig) verifyClaimsHeader(h *claimsheader.ClaimsHeader) error {
   314  
   315  	if h.CompressionType() != claimsheader.CompressionTypeV1 {
   316  		return ErrCompressedTagMismatch
   317  
   318  	}
   319  
   320  	if h.DatapathVersion() != claimsheader.DatapathVersion1 {
   321  		return ErrDatapathVersionMismatch
   322  	}
   323  
   324  	return nil
   325  }
   326  
   327  // Sign takes in a slice of bytes and a private key, and returns a ecdsa signature.
   328  func (c *BinaryJWTConfig) Sign(buf []byte, key *ecdsa.PrivateKey) ([]byte, error) {
   329  	return c.sign(buf, key)
   330  }
   331  
   332  func (c *BinaryJWTConfig) sign(buf []byte, key *ecdsa.PrivateKey) ([]byte, error) {
   333  
   334  	// Create the hash and use this for the signature. This is a SHA256 hash
   335  	// of the token.
   336  	h, err := hash500(buf, nil)
   337  	if err != nil {
   338  		return nil, logError(ErrTokenHashFailed, err.Error())
   339  	}
   340  
   341  	// Sign the hash with the private key using the ECDSA algorithm
   342  	// and properly format the resulting signature.
   343  	r, s, err := ecdsa.Sign(rand.Reader, key, h)
   344  	if err != nil {
   345  		return nil, logError(ErrTokenSignFailed, err.Error())
   346  	}
   347  
   348  	curveBits := key.Curve.Params().BitSize
   349  	keyBytes := curveBits / 8
   350  	if curveBits%8 > 0 {
   351  		keyBytes++
   352  	}
   353  
   354  	// We serialize the outpus (r and s) into big-endian byte arrays and pad
   355  	// them with zeros on the left to make sure the sizes work out. Both arrays
   356  	// must be keyBytes long, and the output must be 2*keyBytes long.
   357  	tokenBytes := make([]byte, 2*keyBytes)
   358  
   359  	rBytes := r.Bytes()
   360  	copy(tokenBytes[keyBytes-len(rBytes):], rBytes)
   361  
   362  	sBytes := s.Bytes()
   363  	copy(tokenBytes[2*keyBytes-len(sBytes):], sBytes)
   364  
   365  	return tokenBytes, nil
   366  }
   367  
   368  func (c *BinaryJWTConfig) verify(buf []byte, sig []byte, key *ecdsa.PublicKey) error {
   369  
   370  	if len(sig) != 64 {
   371  		return ErrInvalidSignature
   372  	}
   373  
   374  	r := big.NewInt(0).SetBytes(sig[:32])
   375  	s := big.NewInt(0).SetBytes(sig[32:])
   376  
   377  	// Create the hash and use this for the signature. This is a SHA256 hash
   378  	// of the token.
   379  	h, err := hash500(buf, nil)
   380  	if err != nil {
   381  		return logError(ErrTokenHashFailed, err.Error())
   382  	}
   383  
   384  	if verifyStatus := ecdsa.Verify(key, h, r, s); verifyStatus {
   385  		return nil
   386  	}
   387  
   388  	return ErrInvalidSignature
   389  }
   390  
   391  func (c *BinaryJWTConfig) getSecretKey(privateKey *ephemeralkeys.PrivateKey, remotePublicKeyString string, isV1Proto bool) ([]byte, error) {
   392  
   393  	var remotePublicKey *ecdsa.PublicKey
   394  	var err error
   395  
   396  	hashKey := privateKey.PrivateKeyString + remotePublicKeyString
   397  
   398  	secretKey := c.sharedKeys.Get(hashKey)
   399  
   400  	if secretKey != nil {
   401  		return secretKey, nil
   402  	}
   403  
   404  	if isV1Proto {
   405  		remotePublicKey, err = localcrypto.DecodePublicKeyV1([]byte(remotePublicKeyString))
   406  		if err != nil {
   407  			return nil, err
   408  		}
   409  	} else {
   410  		remotePublicKey, err = localcrypto.DecodePublicKeyV2([]byte(remotePublicKeyString))
   411  		if err != nil {
   412  			return nil, err
   413  		}
   414  	}
   415  
   416  	if secretKey, err = symmetricKey(privateKey.PrivateKey, remotePublicKey); err != nil {
   417  		return nil, err
   418  	}
   419  
   420  	c.sharedKeys.Put(hashKey, secretKey)
   421  
   422  	return secretKey, nil
   423  }
   424  
   425  func encode(c *BinaryJWTClaims, buf *[]byte) error {
   426  	// Encode and sign the token
   427  	if cap(*buf) != ClaimsEncodedBufSize {
   428  		return fmt.Errorf("Not enough space in byte slice")
   429  	}
   430  
   431  	var h codec.Handle = new(codec.CborHandle)
   432  	enc := codec.NewEncoderBytes(buf, h)
   433  	if err := enc.Encode(c); err != nil {
   434  		return fmt.Errorf("unable to encode message: %s", err)
   435  	}
   436  
   437  	return nil
   438  }
   439  
   440  func decode(buf []byte) (*BinaryJWTClaims, error) {
   441  	// Decode the token into a structure.
   442  	binaryClaims := &BinaryJWTClaims{}
   443  	var h codec.Handle = new(codec.CborHandle)
   444  
   445  	dec := codec.NewDecoderBytes(buf, h)
   446  
   447  	if err := dec.Decode(binaryClaims); err != nil {
   448  		return nil, logError(ErrTokenDecodeFailed, err.Error())
   449  	}
   450  
   451  	if binaryClaims.ExpiresAt < time.Now().Unix() {
   452  		return nil, logError(ErrTokenExpired, fmt.Sprintf("token is expired since: %s", time.Unix(binaryClaims.ExpiresAt, 0)))
   453  	}
   454  
   455  	return binaryClaims, nil
   456  }
   457  
   458  func packToken(header, nonce, token, sig []byte) []byte {
   459  
   460  	binaryTokenPosition := binaryNoncePosition + len(nonce)
   461  	sigPosition := binaryTokenPosition + len(token)
   462  
   463  	// Token is the concatenation of
   464  	// [Position of Signature] [nonce] [token] [signature]
   465  	data := make([]byte, sigPosition+len(sig))
   466  
   467  	// Header bytes
   468  	copy(data[0:headerLength], header)
   469  	// Length of token
   470  	binary.BigEndian.PutUint16(data[lengthPosition:], uint16(sigPosition))
   471  
   472  	// nonce not required for ack packets
   473  	if len(nonce) > 0 {
   474  		copy(data[binaryNoncePosition:], nonce)
   475  	}
   476  
   477  	// token
   478  	copy(data[binaryTokenPosition:], token)
   479  
   480  	// signature
   481  	copy(data[sigPosition:], sig)
   482  
   483  	return data
   484  }
   485  
   486  // unpackToken returns nonce, token, signature or error if something fails
   487  func unpackToken(isAck bool, data []byte) ([]byte, []byte, []byte, []byte, error) {
   488  
   489  	// We must have enough data to read the length.
   490  	if len(data) < binaryNoncePosition {
   491  		return nil, nil, nil, nil, ErrInvalidTokenLength
   492  	}
   493  
   494  	header := make([]byte, headerLength)
   495  	copy(header, data[:lengthPosition])
   496  
   497  	sigPosition := int(binary.BigEndian.Uint16(data[lengthPosition : lengthPosition+2]))
   498  	// The token must be long enough to have at least 1 byte of signature.
   499  	if len(data) < sigPosition+1 || sigPosition == 0 {
   500  		return nil, nil, nil, nil, ErrMissingSignature
   501  	}
   502  
   503  	var nonce []byte
   504  
   505  	if !isAck {
   506  		nonce = make([]byte, 16)
   507  		copy(nonce, data[binaryNoncePosition:binaryNoncePosition+NonceLength])
   508  	}
   509  
   510  	// Only if nonce is found do we need to advance. So, use the
   511  	// actual length of the nonce and not just a constant here.
   512  	token := data[binaryNoncePosition+len(nonce) : sigPosition]
   513  
   514  	sig := data[sigPosition:]
   515  	return header, nonce, token, sig, nil
   516  }
   517  
   518  // symmetricKey returns a symmetric key for encryption
   519  func symmetricKey(privateKey *ecdsa.PrivateKey, remotePublic *ecdsa.PublicKey) ([]byte, error) {
   520  
   521  	c := elliptic.P256()
   522  
   523  	x, _ := c.ScalarMult(remotePublic.X, remotePublic.Y, privateKey.D.Bytes())
   524  
   525  	return hash500(x.Bytes(), nil)
   526  }
   527  
   528  func uncompressTags(binaryClaims *BinaryJWTClaims, publicKeyClaims []string) {
   529  
   530  	binaryClaims.T = append(binaryClaims.CT, enforcerconstants.TransmitterLabel+"="+binaryClaims.ID)
   531  
   532  	for _, pc := range publicKeyClaims {
   533  
   534  		if len(pc) <= claimsheader.CompressedTagLengthV1 {
   535  			binaryClaims.T = append(binaryClaims.T, pc)
   536  			continue
   537  		}
   538  
   539  		binaryClaims.T = append(binaryClaims.T, pc[:claimsheader.CompressedTagLengthV1])
   540  	}
   541  }