github.com/cilium/cilium@v1.16.2/pkg/auth/spire/delegate.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package spire
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"crypto/x509"
    10  	"errors"
    11  	"fmt"
    12  	"os"
    13  
    14  	"github.com/cilium/hive/cell"
    15  	"github.com/sirupsen/logrus"
    16  	"github.com/spf13/pflag"
    17  	delegatedidentityv1 "github.com/spiffe/spire-api-sdk/proto/spire/api/agent/delegatedidentity/v1"
    18  	spiffeTypes "github.com/spiffe/spire-api-sdk/proto/spire/api/types"
    19  	"google.golang.org/grpc"
    20  	"google.golang.org/grpc/credentials/insecure"
    21  
    22  	"github.com/cilium/cilium/pkg/auth/certs"
    23  	"github.com/cilium/cilium/pkg/backoff"
    24  	"github.com/cilium/cilium/pkg/lock"
    25  	"github.com/cilium/cilium/pkg/logging"
    26  	"github.com/cilium/cilium/pkg/time"
    27  )
    28  
    29  type SpireDelegateClient struct {
    30  	cfg SpireDelegateConfig
    31  	log logrus.FieldLogger
    32  
    33  	connectionAttempts int
    34  
    35  	stream      delegatedidentityv1.DelegatedIdentity_SubscribeToX509SVIDsClient
    36  	trustStream delegatedidentityv1.DelegatedIdentity_SubscribeToX509BundlesClient
    37  
    38  	svidStore      map[string]*delegatedidentityv1.X509SVIDWithKey
    39  	svidStoreMutex lock.RWMutex
    40  	trustBundle    *x509.CertPool
    41  
    42  	cancelListenForUpdates context.CancelFunc
    43  
    44  	rotatedIdentitiesChan chan certs.CertificateRotationEvent
    45  
    46  	logLimiter logging.Limiter
    47  
    48  	connected        bool
    49  	lastConnectError error
    50  	connectedMutex   lock.RWMutex
    51  }
    52  
    53  type SpireDelegateConfig struct {
    54  	SpireAdminSocketPath string `mapstructure:"mesh-auth-spire-admin-socket"`
    55  	SpiffeTrustDomain    string `mapstructure:"mesh-auth-spiffe-trust-domain"`
    56  	RotatedQueueSize     int    `mapstructure:"mesh-auth-rotated-identities-queue-size"`
    57  }
    58  
    59  var Cell = cell.Module(
    60  	"spire-delegate",
    61  	"Spire Delegate API Client",
    62  	cell.Provide(newSpireDelegateClient),
    63  	cell.Config(SpireDelegateConfig{}),
    64  )
    65  
    66  func newSpireDelegateClient(lc cell.Lifecycle, cfg SpireDelegateConfig, log logrus.FieldLogger) certs.CertificateProvider {
    67  	if cfg.SpireAdminSocketPath == "" {
    68  		log.Info("Spire Delegate API Client is disabled as no socket path is configured")
    69  		return nil
    70  	}
    71  	client := &SpireDelegateClient{
    72  		cfg:                   cfg,
    73  		log:                   log,
    74  		svidStore:             map[string]*delegatedidentityv1.X509SVIDWithKey{},
    75  		rotatedIdentitiesChan: make(chan certs.CertificateRotationEvent, cfg.RotatedQueueSize),
    76  		logLimiter:            logging.NewLimiter(10*time.Second, 3),
    77  	}
    78  
    79  	lc.Append(cell.Hook{OnStart: client.onStart, OnStop: client.onStop})
    80  
    81  	return client
    82  }
    83  
    84  func (cfg SpireDelegateConfig) Flags(flags *pflag.FlagSet) {
    85  	flags.StringVar(&cfg.SpireAdminSocketPath, "mesh-auth-spire-admin-socket", "", "The path for the SPIRE admin agent Unix socket.") // default is /run/spire/sockets/admin.sock
    86  	flags.StringVar(&cfg.SpiffeTrustDomain, "mesh-auth-spiffe-trust-domain", "spiffe.cilium", "The trust domain for the SPIFFE identity.")
    87  	flags.IntVar(&cfg.RotatedQueueSize, "mesh-auth-rotated-identities-queue-size", 1024, "The size of the queue for signaling rotated identities.")
    88  }
    89  
    90  func (s *SpireDelegateClient) onStart(ctx cell.HookContext) error {
    91  	s.log.Info("Spire Delegate API Client is running")
    92  
    93  	listenCtx, cancel := context.WithCancel(context.Background())
    94  	go s.listenForUpdates(listenCtx)
    95  
    96  	s.cancelListenForUpdates = cancel
    97  
    98  	return nil
    99  }
   100  
   101  func (s *SpireDelegateClient) onStop(ctx cell.HookContext) error {
   102  	s.log.Info("SPIFFE Delegate API Client is stopping")
   103  
   104  	s.cancelListenForUpdates()
   105  
   106  	if s.stream != nil {
   107  		s.stream.CloseSend()
   108  	}
   109  
   110  	return nil
   111  }
   112  
   113  func (s *SpireDelegateClient) listenForUpdates(ctx context.Context) {
   114  	s.openStream(ctx)
   115  
   116  	listenCtx, cancel := context.WithCancel(ctx)
   117  	err := make(chan error)
   118  
   119  	go s.listenForSVIDUpdates(listenCtx, err)
   120  	go s.listenForBundleUpdates(listenCtx, err)
   121  
   122  	backoffTime := backoff.Exponential{Min: 100 * time.Millisecond, Max: 10 * time.Second}
   123  	for {
   124  		select {
   125  		case <-ctx.Done():
   126  			cancel()
   127  			return
   128  		case e := <-err:
   129  			s.log.WithError(e).Error("Error in delegate stream, restarting")
   130  			time.Sleep(backoffTime.Duration(s.connectionAttempts))
   131  			cancel()
   132  			s.connectionAttempts++
   133  			s.listenForUpdates(ctx)
   134  			return
   135  		}
   136  	}
   137  }
   138  
   139  func (s *SpireDelegateClient) listenForSVIDUpdates(ctx context.Context, errorChan chan error) {
   140  	for {
   141  		select {
   142  		case <-ctx.Done():
   143  			return
   144  		default:
   145  			resp, err := s.stream.Recv()
   146  			if err != nil {
   147  				errorChan <- err
   148  				return
   149  			}
   150  
   151  			s.log.
   152  				WithField("nr_of_svids", len(resp.X509Svids)).
   153  				Debug("Received X509-SVID update")
   154  			s.handleX509SVIDUpdate(resp.X509Svids)
   155  		}
   156  	}
   157  }
   158  
   159  func (s *SpireDelegateClient) listenForBundleUpdates(ctx context.Context, errorChan chan error) {
   160  	for {
   161  		select {
   162  		case <-ctx.Done():
   163  			return
   164  		default:
   165  			resp, err := s.trustStream.Recv()
   166  			if err != nil {
   167  				errorChan <- err
   168  				return
   169  			}
   170  
   171  			s.log.
   172  				WithField("nr_of_bundles", len(resp.CaCertificates)).
   173  				Debug("Received X509-Bundle update", len(resp.CaCertificates))
   174  			s.handleX509BundleUpdate(resp.CaCertificates)
   175  		}
   176  	}
   177  }
   178  
   179  func (s *SpireDelegateClient) handleX509SVIDUpdate(svids []*delegatedidentityv1.X509SVIDWithKey) {
   180  	newSvidStore := map[string]*delegatedidentityv1.X509SVIDWithKey{}
   181  
   182  	s.svidStoreMutex.RLock()
   183  	updatedKeys := []string{}
   184  	deletedKeys := []string{}
   185  
   186  	for _, svid := range svids {
   187  
   188  		if svid.X509Svid.Id.TrustDomain != s.cfg.SpiffeTrustDomain {
   189  			s.log.
   190  				WithField("trust_domain", svid.X509Svid.Id.TrustDomain).
   191  				Debug("Skipping X509-SVID update as it does not match ours")
   192  			s.svidStoreMutex.RUnlock()
   193  			return
   194  		}
   195  
   196  		key := fmt.Sprintf("spiffe://%s%s", svid.X509Svid.Id.TrustDomain, svid.X509Svid.Id.Path)
   197  
   198  		if _, exists := s.svidStore[key]; exists {
   199  			old := s.svidStore[key]
   200  			if old.X509Svid.ExpiresAt != svid.X509Svid.ExpiresAt || !equalCertChains(old.X509Svid.CertChain, svid.X509Svid.CertChain) {
   201  				updatedKeys = append(updatedKeys, key)
   202  			}
   203  		} else {
   204  			s.log.
   205  				WithField("spiffe_id", key).
   206  				Debug("Adding newly discovered X509-SVID")
   207  		}
   208  		newSvidStore[key] = svid
   209  
   210  	}
   211  
   212  	// check for deleted keys
   213  	for key := range s.svidStore {
   214  		if _, exists := newSvidStore[key]; !exists {
   215  			deletedKeys = append(deletedKeys, key)
   216  		}
   217  	}
   218  
   219  	s.svidStoreMutex.RUnlock()
   220  
   221  	s.svidStoreMutex.Lock()
   222  	s.svidStore = newSvidStore
   223  	s.svidStoreMutex.Unlock()
   224  
   225  	for _, key := range deletedKeys {
   226  		// we send an update event to re-trigger a handshake if needed
   227  		id, err := s.spiffeIDToNumericIdentity(key)
   228  		if err != nil {
   229  			s.log.
   230  				WithError(err).
   231  				WithField("spiffe_id", key).
   232  				Error("Failed to convert SPIFFE ID to numeric identity")
   233  			continue
   234  		}
   235  		select {
   236  		case s.rotatedIdentitiesChan <- certs.CertificateRotationEvent{Identity: id, Deleted: true}:
   237  			s.log.
   238  				WithField("spiffe_id", key).
   239  				Debug("X509-SVID has been deleted, signaling this")
   240  		default:
   241  			if s.logLimiter.Allow() {
   242  				s.log.
   243  					WithField("identity", id).
   244  					Warn("Skip sending deleted identity as channel is full")
   245  			}
   246  		}
   247  	}
   248  
   249  	for _, key := range updatedKeys {
   250  		// we send an update event to re-trigger a handshake if needed
   251  		id, err := s.spiffeIDToNumericIdentity(key)
   252  		if err != nil {
   253  			s.log.
   254  				WithError(err).
   255  				WithField("spiffe_id", key).
   256  				Error("Failed to convert SPIFFE ID to numeric identity")
   257  			continue
   258  		}
   259  		select {
   260  		case s.rotatedIdentitiesChan <- certs.CertificateRotationEvent{Identity: id}:
   261  			s.log.
   262  				WithField("spiffe_id", key).
   263  				Debug("X509-SVID has changed, signaling this")
   264  		default:
   265  			if s.logLimiter.Allow() {
   266  				s.log.
   267  					WithField("identity", id).
   268  					Warn("Skip sending rotated identity as channel is full")
   269  			}
   270  		}
   271  	}
   272  }
   273  
   274  func (s *SpireDelegateClient) handleX509BundleUpdate(bundles map[string][]byte) {
   275  	pool := x509.NewCertPool()
   276  
   277  	for trustDomain, bundle := range bundles {
   278  		s.log.
   279  			WithField("trust_domain", trustDomain).
   280  			Debug("Processing trust domain cert bundle", trustDomain)
   281  
   282  		certs, err := x509.ParseCertificates(bundle)
   283  		if err != nil {
   284  			s.log.
   285  				WithError(err).
   286  				WithField("trust_domain", trustDomain).
   287  				Error("Failed to parse X.509 DER bundle")
   288  			continue
   289  		}
   290  
   291  		for _, cert := range certs {
   292  			pool.AddCert(cert)
   293  		}
   294  	}
   295  
   296  	s.trustBundle = pool
   297  }
   298  
   299  func (s *SpireDelegateClient) openStream(ctx context.Context) {
   300  	// try to init the watcher with a backoff
   301  	backoffTime := backoff.Exponential{Min: 100 * time.Millisecond, Max: 10 * time.Second}
   302  
   303  	// a retry might have happened, signal that we are disconnected
   304  	s.connectedMutex.Lock()
   305  	s.connected = false
   306  	s.connectedMutex.Unlock()
   307  
   308  	for {
   309  		s.log.Info("Connecting to SPIRE Delegate API Client")
   310  
   311  		var err error
   312  		s.stream, s.trustStream, err = s.initWatcher(ctx)
   313  		if err != nil {
   314  			s.log.WithError(err).Warn("SPIRE Delegate API Client failed to init watcher, retrying")
   315  
   316  			s.connectedMutex.Lock()
   317  			s.connected = false
   318  			s.lastConnectError = err
   319  			s.connectedMutex.Unlock()
   320  
   321  			time.Sleep(backoffTime.Duration(s.connectionAttempts))
   322  			s.connectionAttempts++
   323  			continue
   324  		}
   325  
   326  		s.connectedMutex.Lock()
   327  		s.connected = true
   328  		s.lastConnectError = nil
   329  		s.connectedMutex.Unlock()
   330  		break
   331  	}
   332  }
   333  
   334  func (s *SpireDelegateClient) initWatcher(ctx context.Context) (delegatedidentityv1.DelegatedIdentity_SubscribeToX509SVIDsClient, delegatedidentityv1.DelegatedIdentity_SubscribeToX509BundlesClient, error) {
   335  	if _, err := os.Stat(s.cfg.SpireAdminSocketPath); errors.Is(err, os.ErrNotExist) {
   336  		return nil, nil, fmt.Errorf("SPIRE admin socket (%s) does not exist: %w", s.cfg.SpireAdminSocketPath, err)
   337  	}
   338  
   339  	unixPath := fmt.Sprintf("unix://%s", s.cfg.SpireAdminSocketPath)
   340  
   341  	conn, err := grpc.Dial(unixPath, grpc.WithTransportCredentials(insecure.NewCredentials()),
   342  		grpc.WithDefaultCallOptions(
   343  			grpc.MaxCallRecvMsgSize(20*1024*1024),
   344  			grpc.MaxCallSendMsgSize(20*1024*1024))) // setting this to 20MB to handle large bundles TODO: improve this once fixed upstream (https://github.com/cilium/cilium/issues/24297)
   345  	if err != nil {
   346  		return nil, nil, fmt.Errorf("grpc.Dial() failed on %s: %w", unixPath, err)
   347  	}
   348  
   349  	client := delegatedidentityv1.NewDelegatedIdentityClient(conn)
   350  
   351  	stream, err := client.SubscribeToX509SVIDs(ctx, &delegatedidentityv1.SubscribeToX509SVIDsRequest{
   352  		Selectors: []*spiffeTypes.Selector{
   353  			{
   354  				Type:  "cilium",
   355  				Value: "mutual-auth",
   356  			},
   357  		},
   358  	})
   359  
   360  	if err != nil {
   361  		conn.Close()
   362  		return nil, nil, fmt.Errorf("stream failed on %s: %w", unixPath, err)
   363  	}
   364  
   365  	trustStream, err := client.SubscribeToX509Bundles(ctx, &delegatedidentityv1.SubscribeToX509BundlesRequest{})
   366  	if err != nil {
   367  		conn.Close()
   368  		return nil, nil, fmt.Errorf("stream for x509 bundle failed on %s: %w", unixPath, err)
   369  	}
   370  
   371  	return stream, trustStream, nil
   372  }
   373  
   374  func equalCertChains(a, b [][]byte) bool {
   375  	if len(a) != len(b) {
   376  		return false
   377  	}
   378  	for i := range a {
   379  		if !bytes.Equal(a[i], b[i]) {
   380  			return false
   381  		}
   382  	}
   383  	return true
   384  }