github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/nfqdatapath/tokenaccessor/tokenaccessor.go (about)

     1  package tokenaccessor
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/ecdsa"
     6  	"errors"
     7  	"fmt"
     8  	"time"
     9  
    10  	enforcerconstants "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/constants"
    11  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/ephemeralkeys"
    12  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/claimsheader"
    13  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pkiverifier"
    14  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/secrets"
    15  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/tokens"
    16  )
    17  
    18  // tokenAccessor is a wrapper around tokenEngine to provide locks for accessing
    19  type tokenAccessor struct {
    20  	tokens   tokens.TokenEngine
    21  	serverID string
    22  	validity time.Duration
    23  }
    24  
    25  // New creates a new instance of TokenAccessor interface
    26  func New(serverID string, validity time.Duration, secret secrets.Secrets) (TokenAccessor, error) {
    27  
    28  	var tokenEngine tokens.TokenEngine
    29  	var err error
    30  
    31  	tokenEngine, err = tokens.NewBinaryJWT(validity, serverID)
    32  	if err != nil {
    33  		return nil, err
    34  	}
    35  
    36  	return &tokenAccessor{
    37  		tokens:   tokenEngine,
    38  		serverID: serverID,
    39  		validity: validity,
    40  	}, nil
    41  }
    42  
    43  // GetTokenValidity returns the duration the token is valid for
    44  func (t *tokenAccessor) GetTokenValidity() time.Duration {
    45  	return t.validity
    46  }
    47  
    48  // GetTokenServerID returns the server ID which is used the generate the token.
    49  func (t *tokenAccessor) GetTokenServerID() string {
    50  	return t.serverID
    51  }
    52  
    53  // CreateAckPacketToken creates the authentication token
    54  func (t *tokenAccessor) CreateAckPacketToken(proto314 bool, secretKey []byte, claims *tokens.ConnectionClaims, encodedBuf []byte) ([]byte, error) {
    55  
    56  	token, err := t.tokens.CreateAckToken(proto314, secretKey, claims, encodedBuf, claimsheader.NewClaimsHeader())
    57  	if err != nil {
    58  		return nil, fmt.Errorf("unable to create ack token: %v", err)
    59  	}
    60  
    61  	return token, nil
    62  }
    63  
    64  func (t *tokenAccessor) Randomize(token []byte, nonce []byte) error {
    65  	return t.tokens.Randomize(token, nonce)
    66  }
    67  
    68  func (t *tokenAccessor) Sign(buf []byte, key *ecdsa.PrivateKey) ([]byte, error) {
    69  	return t.tokens.Sign(buf, key)
    70  }
    71  
    72  // createSynPacketToken creates the authentication token
    73  func (t *tokenAccessor) CreateSynPacketToken(claims *tokens.ConnectionClaims, encodedBuf []byte, nonce []byte, claimsHeader *claimsheader.ClaimsHeader, secrets secrets.Secrets) ([]byte, error) {
    74  	token, err := t.tokens.CreateSynToken(claims, encodedBuf, nonce, claimsHeader, secrets)
    75  	if err != nil {
    76  		return nil, fmt.Errorf("unable to create syn token: %v", err)
    77  	}
    78  
    79  	return token, nil
    80  }
    81  
    82  // createSynAckPacketToken  creates the authentication token for SynAck packets
    83  // We need to sign the received token. No caching possible here
    84  func (t *tokenAccessor) CreateSynAckPacketToken(proto314 bool, claims *tokens.ConnectionClaims, encodedBuf []byte, nonce []byte, claimsHeader *claimsheader.ClaimsHeader, secrets secrets.Secrets, secretKey []byte) ([]byte, error) {
    85  	token, err := t.tokens.CreateSynAckToken(proto314, claims, encodedBuf, nonce, claimsHeader, secrets, secretKey)
    86  	if err != nil {
    87  		return nil, fmt.Errorf("unable to create synack token: %v", err)
    88  	}
    89  
    90  	return token, nil
    91  }
    92  
    93  // parsePacketToken parses the packet token and populates the right state.
    94  // Returns an error if the token cannot be parsed or the signature fails
    95  func (t *tokenAccessor) ParsePacketToken(privateKey *ephemeralkeys.PrivateKey, data []byte, secrets secrets.Secrets, claims *tokens.ConnectionClaims, isSynAck bool) ([]byte, *claimsheader.ClaimsHeader, *pkiverifier.PKIControllerInfo, []byte, string, bool, error) {
    96  
    97  	// Validate the certificate and parse the token
    98  	secretKey, header, nonce, controller, proto314, err := t.tokens.DecodeSyn(isSynAck, data, privateKey, secrets, claims)
    99  	if err != nil {
   100  		return nil, nil, nil, nil, "", false, err
   101  	}
   102  
   103  	// We always a need a valid remote context ID
   104  	if claims.T == nil {
   105  		return nil, nil, nil, nil, "", false, errors.New("no claims found")
   106  	}
   107  
   108  	remoteContextID, ok := claims.T.Get(enforcerconstants.TransmitterLabel)
   109  	if !ok {
   110  		return nil, nil, nil, nil, "", false, errors.New("no transmitter label")
   111  	}
   112  
   113  	return secretKey, header, controller, nonce, remoteContextID, proto314, nil
   114  }
   115  
   116  // parseAckToken parses the tokens in Ack packets. They don't carry all the state context
   117  // and it needs to be recovered
   118  func (t *tokenAccessor) ParseAckToken(proto314 bool, secretKey []byte, nonce []byte, data []byte, connClaims *tokens.ConnectionClaims) error {
   119  
   120  	// Validate the certificate and parse the token
   121  	if err := t.tokens.DecodeAck(proto314, secretKey, data, connClaims); err != nil {
   122  		return err
   123  	}
   124  
   125  	if !bytes.Equal(connClaims.RMT, nonce) {
   126  		return errors.New("failed to match context in ack packet")
   127  	}
   128  
   129  	return nil
   130  }