github.com/grafana/pyroscope@v1.18.0/pkg/metastore/raftnode/node.go (about)

     1  package raftnode
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"flag"
     7  	"fmt"
     8  	"net"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/go-kit/log"
    15  	"github.com/go-kit/log/level"
    16  	"github.com/google/uuid"
    17  	"github.com/grafana/dskit/flagext"
    18  	"github.com/hashicorp/raft"
    19  	raftwal "github.com/hashicorp/raft-wal"
    20  	"github.com/opentracing/opentracing-go"
    21  	"github.com/opentracing/opentracing-go/ext"
    22  	otlog "github.com/opentracing/opentracing-go/log"
    23  	"github.com/prometheus/client_golang/prometheus"
    24  	"google.golang.org/grpc"
    25  	"google.golang.org/protobuf/proto"
    26  
    27  	"github.com/grafana/pyroscope/pkg/metastore/fsm"
    28  	"github.com/grafana/pyroscope/pkg/metastore/raftnode/raftnodepb"
    29  )
    30  
    31  type ContextRegistry interface {
    32  	Store(id string, ctx context.Context)
    33  }
    34  
    35  type Config struct {
    36  	Dir                string `yaml:"dir"`
    37  	SnapshotsDir       string `yaml:"snapshots_dir" doc:"hidden"`
    38  	SnapshotsImportDir string `yaml:"snapshots_import_dir" doc:"hidden"`
    39  
    40  	BootstrapPeers       []string `yaml:"bootstrap_peers"`
    41  	BootstrapExpectPeers int      `yaml:"bootstrap_expect_peers"`
    42  	AutoJoin             bool     `yaml:"auto_join"`
    43  
    44  	ServerID         string `yaml:"server_id"`
    45  	BindAddress      string `yaml:"bind_address"`
    46  	AdvertiseAddress string `yaml:"advertise_address"`
    47  
    48  	ApplyTimeout          time.Duration `yaml:"apply_timeout" doc:"hidden"`
    49  	LogIndexCheckInterval time.Duration `yaml:"log_index_check_interval" doc:"hidden"`
    50  	ReadIndexMaxDistance  uint64        `yaml:"read_index_max_distance" doc:"hidden"`
    51  
    52  	WALCacheEntries       uint64        `yaml:"wal_cache_entries" doc:"hidden"`
    53  	TrailingLogs          uint64        `yaml:"trailing_logs" doc:"hidden"`
    54  	SnapshotsRetain       uint64        `yaml:"snapshots_retain" doc:"hidden"`
    55  	SnapshotInterval      time.Duration `yaml:"snapshot_interval" doc:"hidden"`
    56  	SnapshotThreshold     uint64        `yaml:"snapshot_threshold" doc:"hidden"`
    57  	TransportConnPoolSize uint64        `yaml:"transport_conn_pool_size" doc:"hidden"`
    58  	TransportTimeout      time.Duration `yaml:"transport_timeout" doc:"hidden"`
    59  }
    60  
    61  const (
    62  	defaultRaftDir      = "./data-metastore/raft"
    63  	defaultSnapshotsDir = defaultRaftDir
    64  
    65  	defaultWALCacheEntries       = 512
    66  	defaultTrailingLogs          = 18 << 10
    67  	defaultSnapshotsRetain       = 3
    68  	defaultSnapshotInterval      = 180 * time.Second
    69  	defaultSnapshotThreshold     = 8 << 10
    70  	defaultTransportConnPoolSize = 10
    71  	defaultTransportTimeout      = 10 * time.Second
    72  )
    73  
    74  func (cfg *Config) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) {
    75  	f.StringVar(&cfg.Dir, prefix+"dir", defaultRaftDir, "Directory to store WAL and raft state. It must be a persistent directory, not a tmpfs or similar.")
    76  	f.StringVar(&cfg.SnapshotsDir, prefix+"snapshots-dir", defaultSnapshotsDir, "Directory to store FSM snapshots. Raft creates 'snapshots' subdirectory in this directory. It must be a persistent directory, not a tmpfs or similar.")
    77  	f.StringVar(&cfg.SnapshotsImportDir, prefix+"snapshots-import-dir", "", "Directory to import snapshots from; the directory must contain 'snapshots' subdirectory. If not set, no import will be done.")
    78  
    79  	f.Var((*flagext.StringSlice)(&cfg.BootstrapPeers), prefix+"bootstrap-peers", "")
    80  	f.IntVar(&cfg.BootstrapExpectPeers, prefix+"bootstrap-expect-peers", 1, "Expected number of peers including the local node.")
    81  	f.BoolVar(&cfg.AutoJoin, prefix+"auto-join", false, "If enabled, new nodes (without a state) will try to join an existing cluster on startup.")
    82  
    83  	f.StringVar(&cfg.ServerID, prefix+"server-id", "localhost:9099", "")
    84  	f.StringVar(&cfg.BindAddress, prefix+"bind-address", "localhost:9099", "")
    85  	f.StringVar(&cfg.AdvertiseAddress, prefix+"advertise-address", "localhost:9099", "")
    86  
    87  	f.DurationVar(&cfg.ApplyTimeout, prefix+"apply-timeout", 5*time.Second, "")
    88  	f.DurationVar(&cfg.LogIndexCheckInterval, prefix+"log-index-check-interval", 14*time.Millisecond, "")
    89  	f.Uint64Var(&cfg.ReadIndexMaxDistance, prefix+"read-index-max-distance", 10<<10, "")
    90  
    91  	f.Uint64Var(&cfg.WALCacheEntries, prefix+"wal-cache-entries", defaultWALCacheEntries, "")
    92  	f.Uint64Var(&cfg.TrailingLogs, prefix+"trailing-logs", defaultTrailingLogs, "")
    93  	f.Uint64Var(&cfg.SnapshotsRetain, prefix+"snapshots-retain", defaultSnapshotsRetain, "")
    94  	f.DurationVar(&cfg.SnapshotInterval, prefix+"snapshot-interval", defaultSnapshotInterval, "")
    95  	f.Uint64Var(&cfg.SnapshotThreshold, prefix+"snapshot-threshold", defaultSnapshotThreshold, "")
    96  	f.Uint64Var(&cfg.TransportConnPoolSize, prefix+"transport-conn-pool-size", defaultTransportConnPoolSize, "")
    97  	f.DurationVar(&cfg.TransportTimeout, prefix+"transport-timeout", defaultTransportTimeout, "")
    98  }
    99  
   100  func (cfg *Config) Validate() error {
   101  	// TODO(kolesnikovae): Check the params.
   102  	return nil
   103  }
   104  
   105  type Node struct {
   106  	logger          log.Logger
   107  	config          Config
   108  	metrics         *metrics
   109  	reg             prometheus.Registerer
   110  	fsm             raft.FSM
   111  	contextRegistry ContextRegistry
   112  
   113  	walDir        string
   114  	wal           *raftwal.WAL
   115  	snapshots     *raft.FileSnapshotStore
   116  	transport     *raft.NetworkTransport
   117  	raft          *raft.Raft
   118  	logStore      raft.LogStore
   119  	stableStore   raft.StableStore
   120  	snapshotStore raft.SnapshotStore
   121  
   122  	observer *Observer
   123  	service  *RaftNodeService
   124  
   125  	raftNodeClient raftnodepb.RaftNodeServiceClient
   126  }
   127  
   128  func NewNode(
   129  	logger log.Logger,
   130  	config Config,
   131  	reg prometheus.Registerer,
   132  	fsm raft.FSM,
   133  	contextRegistry ContextRegistry,
   134  	raftNodeClient raftnodepb.RaftNodeServiceClient,
   135  ) (_ *Node, err error) {
   136  	n := Node{
   137  		logger:          logger,
   138  		config:          config,
   139  		metrics:         newMetrics(reg),
   140  		reg:             reg,
   141  		fsm:             fsm,
   142  		contextRegistry: contextRegistry,
   143  		raftNodeClient:  raftNodeClient,
   144  	}
   145  
   146  	defer func() {
   147  		if err != nil {
   148  			// If the initialization fails, initialized components
   149  			// should be de-initialized gracefully.
   150  			n.Shutdown()
   151  		}
   152  	}()
   153  
   154  	addr, err := net.ResolveTCPAddr("tcp", config.AdvertiseAddress)
   155  	if err != nil {
   156  		return nil, err
   157  	}
   158  	n.transport, err = raft.NewTCPTransport(
   159  		config.BindAddress, addr,
   160  		int(config.TransportConnPoolSize),
   161  		config.TransportTimeout,
   162  		os.Stderr)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  
   167  	if err = n.openStore(); err != nil {
   168  		return nil, err
   169  	}
   170  
   171  	return &n, nil
   172  }
   173  
   174  func (n *Node) Init() (err error) {
   175  	raftConfig := raft.DefaultConfig()
   176  	// TODO: Wrap gokit
   177  	//	config.Logger
   178  	raftConfig.LogLevel = "debug"
   179  
   180  	raftConfig.TrailingLogs = n.config.TrailingLogs
   181  	raftConfig.SnapshotThreshold = n.config.SnapshotThreshold
   182  	raftConfig.SnapshotInterval = n.config.SnapshotInterval
   183  	raftConfig.LocalID = raft.ServerID(n.config.ServerID)
   184  
   185  	n.raft, err = raft.NewRaft(raftConfig, n.fsm, n.logStore, n.stableStore, n.snapshotStore, n.transport)
   186  	if err != nil {
   187  		return fmt.Errorf("starting raft node: %w", err)
   188  	}
   189  	n.observer = NewRaftStateObserver(n.logger, n.raft, n.metrics.state)
   190  	n.service = NewRaftNodeService(n)
   191  
   192  	hasState, err := raft.HasExistingState(n.logStore, n.stableStore, n.snapshotStore)
   193  	if err != nil {
   194  		return fmt.Errorf("failed to check for existing state: %w", err)
   195  	}
   196  	if !hasState {
   197  		if n.config.AutoJoin {
   198  			level.Info(n.logger).Log("msg", "no existing state found and auto-join is enabled, trying to join existing raft cluster...")
   199  			if err = n.tryAutoJoin(); err != nil {
   200  				level.Warn(n.logger).Log("msg", "failed to auto-join raft cluster", "err", err)
   201  			} else {
   202  				level.Info(n.logger).Log("msg", "successfully joined existing raft cluster")
   203  				return nil
   204  			}
   205  		}
   206  
   207  		level.Info(n.logger).Log("msg", "no existing state found and auto-join is disabled, bootstrapping raft cluster...")
   208  		if err = n.bootstrap(); err != nil {
   209  			return fmt.Errorf("failed to bootstrap cluster: %w", err)
   210  		}
   211  	} else {
   212  		level.Debug(n.logger).Log("msg", "restoring existing state, not bootstrapping")
   213  	}
   214  
   215  	return nil
   216  }
   217  
   218  func (n *Node) openStore() (err error) {
   219  	if err = n.createDirs(); err != nil {
   220  		return err
   221  	}
   222  	n.wal, err = raftwal.Open(n.walDir)
   223  	if err != nil {
   224  		return fmt.Errorf("failed to open WAL: %w", err)
   225  	}
   226  	if err = n.importSnapshots(); err != nil {
   227  		return fmt.Errorf("failed to copy snapshots: %w", err)
   228  	}
   229  	n.snapshots, err = raft.NewFileSnapshotStore(n.config.SnapshotsDir, int(n.config.SnapshotsRetain), os.Stderr)
   230  	if err != nil {
   231  		return fmt.Errorf("failed to open shapshot store: %w", err)
   232  	}
   233  	n.logStore = n.wal
   234  	n.logStore, _ = raft.NewLogCache(int(n.config.WALCacheEntries), n.logStore)
   235  	n.stableStore = n.wal
   236  	n.snapshotStore = n.snapshots
   237  	return nil
   238  }
   239  
   240  func (n *Node) createDirs() (err error) {
   241  	n.walDir = filepath.Join(n.config.Dir, "wal")
   242  	if err = os.MkdirAll(n.walDir, 0755); err != nil {
   243  		return fmt.Errorf("WAL dir: %w", err)
   244  	}
   245  	// Raft will create 'snapshots' subdirectory in the SnapshotsDir.
   246  	if err = os.MkdirAll(n.config.SnapshotsDir, 0755); err != nil {
   247  		return fmt.Errorf("snapshot directory: %w", err)
   248  	}
   249  	return nil
   250  }
   251  
   252  func (n *Node) Shutdown() {
   253  	if n.raft != nil {
   254  		if err := n.raft.Shutdown().Error(); err != nil {
   255  			level.Error(n.logger).Log("msg", "failed to shutdown raft", "err", err)
   256  		}
   257  		n.observer.Deregister()
   258  	}
   259  	if n.transport != nil {
   260  		if err := n.transport.Close(); err != nil {
   261  			level.Error(n.logger).Log("msg", "failed to close transport", "err", err)
   262  		}
   263  	}
   264  	if n.wal != nil {
   265  		if err := n.wal.Close(); err != nil {
   266  			level.Error(n.logger).Log("msg", "failed to close WAL", "err", err)
   267  		}
   268  	}
   269  }
   270  
   271  func (n *Node) ListSnapshots() ([]*raft.SnapshotMeta, error) {
   272  	return n.snapshots.List()
   273  }
   274  
   275  func (n *Node) Register(server *grpc.Server) {
   276  	raftnodepb.RegisterRaftNodeServiceServer(server, n.service)
   277  }
   278  
   279  // LeaderActivity is started when the node becomes a leader and stopped
   280  // when it stops being a leader. The implementation MUST be idempotent.
   281  type LeaderActivity interface {
   282  	Start()
   283  	Stop()
   284  }
   285  
   286  type leaderStateHandler struct{ activity LeaderActivity }
   287  
   288  func (h *leaderStateHandler) Observe(state raft.RaftState) {
   289  	if state == raft.Leader {
   290  		h.activity.Start()
   291  	} else {
   292  		h.activity.Stop()
   293  	}
   294  }
   295  
   296  func (n *Node) RunOnLeader(a LeaderActivity) {
   297  	n.observer.RegisterHandler(&leaderStateHandler{activity: a})
   298  }
   299  
   300  func (n *Node) TransferLeadership() (err error) {
   301  	switch err = n.raft.LeadershipTransfer().Error(); {
   302  	case err == nil:
   303  	case errors.Is(err, raft.ErrNotLeader):
   304  		// Not a leader, nothing to do.
   305  	case strings.Contains(err.Error(), "cannot find peer"):
   306  		// No peers, nothing to do.
   307  	default:
   308  		level.Error(n.logger).Log("msg", "failed to transfer leadership", "err", err)
   309  	}
   310  	return err
   311  }
   312  
   313  // Propose makes an attempt to apply the given command to the FSM.
   314  // The function returns an error if node is not the leader.
   315  func (n *Node) Propose(ctx context.Context, t fsm.RaftLogEntryType, m proto.Message) (resp proto.Message, err error) {
   316  	span, ctx := opentracing.StartSpanFromContext(ctx, "node.Propose")
   317  	defer func() {
   318  		if err != nil {
   319  			ext.LogError(span, err)
   320  		}
   321  		span.Finish()
   322  	}()
   323  
   324  	ctxID := uuid.New().String()
   325  	n.contextRegistry.Store(ctxID, ctx)
   326  
   327  	span.LogFields(otlog.String("msg", "marshalling log entry"))
   328  
   329  	raw, err := fsm.MarshalEntry(t, m)
   330  	if err != nil {
   331  		return nil, err
   332  	}
   333  
   334  	span.LogFields(otlog.String("msg", "log entry marshalled"))
   335  	timer := prometheus.NewTimer(n.metrics.apply)
   336  	defer timer.ObserveDuration()
   337  
   338  	span.LogFields(otlog.String("msg", "applying log entry"))
   339  
   340  	future := n.raft.ApplyLog(raft.Log{
   341  		Data:       raw,
   342  		Extensions: []byte(ctxID),
   343  	}, n.config.ApplyTimeout)
   344  
   345  	span.LogFields(otlog.String("msg", "waiting for apply result"))
   346  
   347  	if err = future.Error(); err != nil {
   348  		return nil, WithRaftLeaderStatusDetails(err, n.raft)
   349  	}
   350  	r := future.Response().(fsm.Response)
   351  
   352  	span.LogFields(otlog.String("msg", "apply result received"))
   353  	if r.Data != nil {
   354  		resp = r.Data
   355  	}
   356  	return resp, r.Err
   357  }