github.com/webmeshproj/webmesh-cni@v0.0.27/internal/host/hostnode.go (about)

     1  /*
     2  Copyright 2023 Avi Zimmerman <avi.zimmerman@gmail.com>.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package host
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"log/slog"
    23  	"net/netip"
    24  	"strings"
    25  	"sync"
    26  	"sync/atomic"
    27  
    28  	v1 "github.com/webmeshproj/api/go/v1"
    29  	meshcontext "github.com/webmeshproj/webmesh/pkg/context"
    30  	"github.com/webmeshproj/webmesh/pkg/logging"
    31  	"github.com/webmeshproj/webmesh/pkg/meshnet"
    32  	endpoints "github.com/webmeshproj/webmesh/pkg/meshnet/endpoints"
    33  	netutil "github.com/webmeshproj/webmesh/pkg/meshnet/netutil"
    34  	meshtransport "github.com/webmeshproj/webmesh/pkg/meshnet/transport"
    35  	meshnode "github.com/webmeshproj/webmesh/pkg/meshnode"
    36  	meshplugins "github.com/webmeshproj/webmesh/pkg/plugins"
    37  	meshbuiltins "github.com/webmeshproj/webmesh/pkg/plugins/builtins"
    38  	meshstorage "github.com/webmeshproj/webmesh/pkg/storage"
    39  	mesherrors "github.com/webmeshproj/webmesh/pkg/storage/errors"
    40  	meshtypes "github.com/webmeshproj/webmesh/pkg/storage/types"
    41  	"k8s.io/client-go/rest"
    42  	"sigs.k8s.io/controller-runtime/pkg/log"
    43  
    44  	"github.com/webmeshproj/webmesh-cni/internal/ipam"
    45  	cnitypes "github.com/webmeshproj/webmesh-cni/internal/types"
    46  )
    47  
    48  // Node is a representation of the host node running the CNI plugin
    49  // and allocating addresses for containers. This is the node that all
    50  // containers on the system peer with for access to the rest of the
    51  // cluster and/or the internet.
    52  type Node interface {
    53  	// ID returns the ID of the host node.
    54  	ID() meshtypes.NodeID
    55  	// Start starts the host node.
    56  	Start(ctx context.Context, cfg *rest.Config) error
    57  	// Started returns true if the host node has been started.
    58  	Started() bool
    59  	// Stop stops the host node. This is also closes the underlying
    60  	// storage provider.
    61  	Stop(ctx context.Context) error
    62  	// IPAM returns the IPv4 address allocator. This will be nil until
    63  	// Start is called.
    64  	IPAM() ipam.Allocator
    65  	// Node returns the underlying mesh node. This will be nil until
    66  	// Start is called.
    67  	Node() meshnode.Node
    68  	// NodeLogger returns the node's logger.
    69  	NodeLogger() *slog.Logger
    70  	// NodeContext returns a context with the node's logger.
    71  	NodeContext(context.Context) context.Context
    72  }
    73  
    74  // NewNode is the function for creating a new mesh node. Declared as a variable for testing purposes.
    75  var NewMeshNode = meshnode.NewWithLogger
    76  
    77  // NewNode creates a new host node.
    78  func NewNode(storage meshstorage.Provider, opts Config) Node {
    79  	node := &hostNode{
    80  		nodeID:  meshtypes.NodeID(opts.NodeID),
    81  		storage: storage,
    82  		config:  opts,
    83  	}
    84  	return node
    85  }
    86  
    87  // hostNode implements the Host interface.
    88  type hostNode struct {
    89  	nodeID     meshtypes.NodeID
    90  	storage    meshstorage.Provider
    91  	config     Config
    92  	started    atomic.Bool
    93  	networkV4  netip.Prefix
    94  	networkV6  netip.Prefix
    95  	meshDomain string
    96  	node       meshnode.Node
    97  	nodeLog    *slog.Logger
    98  	ipam       ipam.Allocator
    99  	mu         sync.Mutex
   100  }
   101  
   102  // ID returns the ID of the host node.
   103  func (h *hostNode) ID() meshtypes.NodeID {
   104  	return h.nodeID
   105  }
   106  
   107  // Started returns true if the host node has been started.
   108  func (h *hostNode) Started() bool {
   109  	return h.started.Load()
   110  }
   111  
   112  // Node returns the underlying mesh node.
   113  func (h *hostNode) Node() meshnode.Node {
   114  	return h.node
   115  }
   116  
   117  // IPAM returns the IPv4 address allocator.
   118  func (h *hostNode) IPAM() ipam.Allocator {
   119  	return h.ipam
   120  }
   121  
   122  // NodeLogger returns the node's logger.
   123  func (h *hostNode) NodeLogger() *slog.Logger {
   124  	return h.nodeLog
   125  }
   126  
   127  // NodeContext returns a context with the node's logger. If context is
   128  // nil, the background context is used.
   129  func (h *hostNode) NodeContext(ctx context.Context) context.Context {
   130  	if ctx != nil {
   131  		return meshcontext.WithLogger(ctx, h.nodeLog)
   132  	}
   133  	return meshcontext.WithLogger(context.Background(), h.nodeLog)
   134  }
   135  
   136  // Start starts the host node.
   137  func (h *hostNode) Start(ctx context.Context, cfg *rest.Config) error {
   138  	h.mu.Lock()
   139  	defer h.mu.Unlock()
   140  	if h.started.Load() {
   141  		return fmt.Errorf("host node already started")
   142  	}
   143  	log := log.FromContext(ctx).WithName("host-node")
   144  	err := h.bootstrap(ctx)
   145  	if err != nil {
   146  		return fmt.Errorf("failed to bootstrap host node: %w", err)
   147  	}
   148  	log.Info("Setting up host node")
   149  	log.V(1).Info("Starting IPAM allocator")
   150  	h.ipam, err = ipam.NewAllocator(cfg, ipam.Config{
   151  		IPAM: meshplugins.IPAMConfig{
   152  			Storage: h.storage.MeshDB(),
   153  		},
   154  		Lock: ipam.LockConfig{
   155  			ID:                 h.config.NodeID,
   156  			Namespace:          h.config.Namespace,
   157  			LockDuration:       h.config.LockDuration,
   158  			LockAcquireTimeout: h.config.LockAcquireTimeout,
   159  		},
   160  		Network: h.networkV4,
   161  	})
   162  	if err != nil {
   163  		return fmt.Errorf("failed to create IPAM allocator: %w", err)
   164  	}
   165  	// Detect the current endpoints on the machine.
   166  	log.Info("Detecting host endpoints")
   167  	eps, err := endpoints.Detect(ctx, endpoints.DetectOpts{
   168  		DetectPrivate:        true, // Required for finding endpoints for other containers on the local node.
   169  		DetectIPv6:           !h.config.Network.DisableIPv6,
   170  		AllowRemoteDetection: h.config.Network.RemoteEndpointDetection,
   171  		// Make configurable? It will at least need to account for any CNI interfaces
   172  		// from a previous run.
   173  		SkipInterfaces: []string{},
   174  	})
   175  	if err != nil {
   176  		return fmt.Errorf("failed to detect endpoints: %w", err)
   177  	}
   178  	key, err := h.config.WireGuard.LoadKey(ctx)
   179  	if err != nil {
   180  		return fmt.Errorf("failed to generate key: %w", err)
   181  	}
   182  	encodedPubKey, err := key.PublicKey().Encode()
   183  	if err != nil {
   184  		return fmt.Errorf("failed to encode public key: %w", err)
   185  	}
   186  	// We always allocate addresses for ourselves, even if we won't use them.
   187  	log.Info("Allocating a mesh IPv4 address")
   188  	err = h.ipam.Locker().Acquire(ctx)
   189  	if err != nil {
   190  		return fmt.Errorf("failed to acquire IPAM lock: %w", err)
   191  	}
   192  	defer h.ipam.Locker().Release(ctx)
   193  	var ipv4Addr, ipv6Addr string
   194  	alloc, err := h.ipam.Allocate(ctx, h.nodeID)
   195  	if err != nil {
   196  		return fmt.Errorf("failed to allocate IPv4 address: %w", err)
   197  	}
   198  	ipv4Addr = alloc.String()
   199  	log.Info("Allocating a mesh IPv6 address")
   200  	ipv6Addr = netutil.AssignToPrefix(h.networkV6, key.PublicKey()).String()
   201  	log.Info("Connecting to the webmesh network")
   202  	h.nodeLog = logging.NewLogger(h.config.LogLevel, "json")
   203  	hostNode := NewMeshNode(h.nodeLog, meshnode.Config{
   204  		Key:             key,
   205  		NodeID:          h.nodeID.String(),
   206  		ZoneAwarenessID: h.config.NodeID,
   207  		UseMeshDNS:      h.config.Network.WriteResolvConf,
   208  		LocalMeshDNSAddr: func() string {
   209  			if h.config.Services.MeshDNS.Enabled {
   210  				return fmt.Sprintf("%s:%d", alloc.Addr(), h.config.Services.MeshDNS.ListenPort())
   211  			}
   212  			return ""
   213  		}(),
   214  		LocalDNSOnly: true,
   215  		DisableIPv4:  h.config.Network.DisableIPv4,
   216  		DisableIPv6:  h.config.Network.DisableIPv6,
   217  	})
   218  	connectCtx, cancel := context.WithTimeout(ctx, h.config.ConnectTimeout)
   219  	defer cancel()
   220  	plugins, err := h.config.Plugins.NewPluginSet(connectCtx)
   221  	if err != nil {
   222  		return fmt.Errorf("failed to create plugin set: %w", err)
   223  	}
   224  	connectOpts := meshnode.ConnectOptions{
   225  		StorageProvider: h.storage,
   226  		MaxJoinRetries:  10,
   227  		Plugins:         plugins,
   228  		JoinRoundTripper: meshtransport.JoinRoundTripperFunc(func(ctx context.Context, req *v1.JoinRequest) (*v1.JoinResponse, error) {
   229  			// TODO: Check for pre-existing peers and return them.
   230  			return &v1.JoinResponse{
   231  				AddressIPv4: ipv4Addr,
   232  				AddressIPv6: ipv6Addr,
   233  				NetworkIPv4: h.networkV4.String(),
   234  				NetworkIPv6: h.networkV6.String(),
   235  				MeshDomain:  h.meshDomain,
   236  			}, nil
   237  		}),
   238  		LeaveRoundTripper: meshtransport.LeaveRoundTripperFunc(func(ctx context.Context, req *v1.LeaveRequest) (*v1.LeaveResponse, error) {
   239  			// No-op, we clean up on shutdown
   240  			return &v1.LeaveResponse{}, nil
   241  		}),
   242  		NetworkOptions: meshnet.Options{
   243  			ListenPort:            h.config.WireGuard.ListenPort,
   244  			InterfaceName:         cnitypes.IfNameFromID(h.nodeID.String()),
   245  			ForceReplace:          true,
   246  			MTU:                   h.config.WireGuard.MTU,
   247  			ZoneAwarenessID:       h.config.NodeID,
   248  			DisableIPv4:           h.config.Network.DisableIPv4,
   249  			DisableIPv6:           h.config.Network.DisableIPv6,
   250  			RecordMetrics:         h.config.WireGuard.RecordMetrics,
   251  			RecordMetricsInterval: h.config.WireGuard.RecordMetricsInterval,
   252  		},
   253  	}
   254  	if h.config.Services.API.MTLS {
   255  		// Add the MTLS plugin.
   256  		mtlsPlug, _ := meshbuiltins.NewClient("mtls")
   257  		if connectOpts.Plugins == nil {
   258  			connectOpts.Plugins = make(map[string]meshplugins.Plugin)
   259  		}
   260  		connectOpts.Plugins["mtls"] = meshplugins.Plugin{
   261  			Client: mtlsPlug,
   262  			Config: map[string]any{
   263  				"ca-file": h.config.Services.API.MTLSClientCAFile,
   264  			},
   265  		}
   266  	}
   267  	err = hostNode.Connect(connectCtx, connectOpts)
   268  	if err != nil {
   269  		return fmt.Errorf("failed to connect to webmesh network: %w", err)
   270  	}
   271  	select {
   272  	case <-connectCtx.Done():
   273  		return fmt.Errorf("timeout while connecting to webmesh network: %w", connectCtx.Err())
   274  	case <-hostNode.Ready():
   275  	}
   276  	// Register ourselves with the mesh.
   277  	log.Info("Host node is connected, registering endpoints with network")
   278  	wireguardPort, err := hostNode.Network().WireGuard().ListenPort()
   279  	if err != nil {
   280  		defer hostNode.Close(ctx)
   281  		return fmt.Errorf("failed to get wireguard listen port: %w", err)
   282  	}
   283  	var wgeps []string
   284  	for _, ep := range eps.AddrPorts(uint16(wireguardPort)) {
   285  		wgeps = append(wgeps, ep.String())
   286  	}
   287  	features := h.config.Services.NewFeatureSet(h.storage, h.config.Services.API.ListenPort())
   288  	peer := meshtypes.MeshNode{
   289  		MeshNode: &v1.MeshNode{
   290  			Id:        h.nodeID.String(),
   291  			PublicKey: encodedPubKey,
   292  			PrimaryEndpoint: func() string {
   293  				if eps.FirstPublicAddr().IsValid() {
   294  					return eps.FirstPublicAddr().String()
   295  				}
   296  				return eps.PrivateAddrs()[0].String()
   297  			}(),
   298  			WireguardEndpoints: wgeps,
   299  			ZoneAwarenessID:    h.nodeID.String(),
   300  			PrivateIPv4:        ipv4Addr,
   301  			PrivateIPv6:        ipv6Addr,
   302  			Features:           features,
   303  		},
   304  	}
   305  	err = h.storage.MeshDB().Peers().Put(ctx, peer)
   306  	if err != nil {
   307  		defer hostNode.Close(ctx)
   308  		return fmt.Errorf("failed to register with mesh: %w", err)
   309  	}
   310  	// Update our consensus record with our public key.
   311  	err = h.storage.Consensus().AddVoter(ctx, meshtypes.StoragePeer{StoragePeer: &v1.StoragePeer{
   312  		Id:            h.nodeID.String(),
   313  		PublicKey:     encodedPubKey,
   314  		Address:       fmt.Sprintf("%s:%d", h.nodeID, h.config.Services.API.ListenPort()),
   315  		ClusterStatus: v1.ClusterStatus_CLUSTER_VOTER,
   316  	}})
   317  	if err != nil {
   318  		defer hostNode.Close(ctx)
   319  		return fmt.Errorf("failed to register with consensus: %w", err)
   320  	}
   321  	// Put a default gateway route for ourselves.
   322  	err = h.storage.MeshDB().Networking().PutRoute(ctx, meshtypes.Route{
   323  		Route: &v1.Route{
   324  			Name: fmt.Sprintf("%s-node-gw", h.nodeID.String()),
   325  			Node: h.nodeID.String(),
   326  			DestinationCIDRs: func() []string {
   327  				out := h.config.Network.Routes
   328  				for _, ep := range append(eps, h.config.Network.CIDRs()...) {
   329  					out = append(out, ep.String())
   330  				}
   331  				return out
   332  			}(),
   333  		},
   334  	})
   335  	if err != nil {
   336  		defer hostNode.Close(ctx)
   337  		return fmt.Errorf("failed to register default gateway route: %w", err)
   338  	}
   339  	h.node = hostNode
   340  	h.started.Store(true)
   341  	return nil
   342  }
   343  
   344  // Stop stops the host node.
   345  func (h *hostNode) Stop(ctx context.Context) error {
   346  	h.mu.Lock()
   347  	defer h.mu.Unlock()
   348  	log := log.FromContext(ctx).WithName("host-node")
   349  	if !h.started.Load() {
   350  		return fmt.Errorf("host node must be started before it can be stopped")
   351  	}
   352  	// Try to remove ourself from the consensus group
   353  	err := h.storage.Consensus().RemovePeer(ctx, meshtypes.StoragePeer{
   354  		StoragePeer: &v1.StoragePeer{
   355  			Id: h.nodeID.String(),
   356  		},
   357  	}, false)
   358  	if err != nil {
   359  		log.Error(err, "Failed to remove host webmesh node from consensus group")
   360  	}
   361  	// Try to remove our peer from the mesh.
   362  	err = h.storage.MeshDB().Peers().Delete(ctx, h.nodeID)
   363  	if err != nil {
   364  		log.Error(err, "Failed to remove host webmesh node from network")
   365  	}
   366  	err = h.storage.MeshDB().Networking().DeleteRoute(ctx, fmt.Sprintf("%s-node-gw", h.nodeID.String()))
   367  	if err != nil {
   368  		log.Error(err, "Failed to remove default gateway route")
   369  	}
   370  	err = h.node.Close(ctx)
   371  	if err != nil {
   372  		log.Error(err, "Failed to close host webmesh node")
   373  	}
   374  	h.started.Store(false)
   375  	return nil
   376  }
   377  
   378  // bootstrap attempts to bootstrap the underlying storage provider and network state.
   379  // If the storage is already bootstrapped, it will read in the pre-existing state.
   380  func (h *hostNode) bootstrap(ctx context.Context) error {
   381  	log := log.FromContext(ctx).WithName("network-bootstrap")
   382  	log.Info("Checking that the webmesh network is bootstrapped")
   383  	log.V(1).Info("Attempting to bootstrap storage provider")
   384  	err := h.storage.Bootstrap(ctx)
   385  	if err != nil {
   386  		if !mesherrors.Is(err, mesherrors.ErrAlreadyBootstrapped) {
   387  			log.Error(err, "Unable to bootstrap storage provider")
   388  			return fmt.Errorf("failed to bootstrap storage provider: %w", err)
   389  		}
   390  		log.V(1).Info("Storage provider already bootstrapped, making sure network state is boostrapped")
   391  	}
   392  	var ipv4Cidr, ipv6Cidr string
   393  	for _, addr := range strings.Split(h.config.Network.PodCIDR, ",") {
   394  		prefix, err := netip.ParsePrefix(addr)
   395  		if err != nil {
   396  			return fmt.Errorf("invalid pod-cidr: %w", err)
   397  		}
   398  		if prefix.Addr().Is6() {
   399  			ipv6Cidr = prefix.String()
   400  		} else if prefix.Addr().Is4() {
   401  			ipv4Cidr = prefix.String()
   402  		}
   403  	}
   404  	if ipv4Cidr == "" {
   405  		ipv4Cidr = meshstorage.DefaultIPv4Network
   406  	}
   407  	// Make sure the network state is boostrapped.
   408  	bootstrapOpts := meshstorage.BootstrapOptions{
   409  		MeshDomain:           h.config.Network.ClusterDomain,
   410  		IPv4Network:          ipv4Cidr,
   411  		IPv6Network:          ipv6Cidr,
   412  		Admin:                meshstorage.DefaultMeshAdmin,
   413  		DefaultNetworkPolicy: meshstorage.DefaultNetworkPolicy,
   414  		DisableRBAC:          h.config.Network.DisableRBAC,
   415  	}
   416  	log.V(1).Info("Attempting to bootstrap network state", "options", bootstrapOpts)
   417  	networkState, err := meshstorage.Bootstrap(ctx, h.storage.MeshDB(), &bootstrapOpts)
   418  	if err != nil && !mesherrors.Is(err, mesherrors.ErrAlreadyBootstrapped) {
   419  		log.Error(err, "Unable to bootstrap network state")
   420  		return fmt.Errorf("failed to bootstrap network state: %w", err)
   421  	} else if mesherrors.Is(err, mesherrors.ErrAlreadyBootstrapped) {
   422  		log.V(1).Info("Network already bootstrapped")
   423  	} else {
   424  		log.Info("Network state bootstrapped for the first time")
   425  	}
   426  	h.networkV4 = networkState.NetworkV4
   427  	h.networkV6 = networkState.NetworkV6
   428  	h.meshDomain = networkState.MeshDomain
   429  	return nil
   430  }