store/engine/raft/node.go (448 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. * */ package raft import ( "context" "encoding/json" "errors" "fmt" "net/http" "net/url" "os" "sync" "sync/atomic" "time" "github.com/apache/kvrocks-controller/logger" "github.com/apache/kvrocks-controller/store/engine" "go.etcd.io/etcd/client/pkg/v3/types" "go.etcd.io/etcd/raft/v3" "go.etcd.io/etcd/raft/v3/raftpb" "go.etcd.io/etcd/server/v3/etcdserver/api/rafthttp" stats "go.etcd.io/etcd/server/v3/etcdserver/api/v2stats" "go.uber.org/zap" ) const ( defaultSnapshotThreshold = 10000 defaultCompactThreshold = 1024 ) const ( opGet = iota + 1 opSet opDelete ) type Event struct { Op int `json:"op"` Key string `json:"key"` Value []byte `json:"value"` } type Node struct { config *Config addr string raftNode raft.Node transport *rafthttp.Transport httpServer *http.Server dataStore *DataStore leaderChanged chan bool logger *zap.Logger peers sync.Map leader uint64 appliedIndex uint64 snapshotIndex uint64 confState raftpb.ConfState snapshotThreshold atomic.Uint64 compactThreshold atomic.Uint64 wg sync.WaitGroup shutdown chan struct{} isRunning atomic.Bool } var _ engine.Engine = (*Node)(nil) func New(config *Config) (*Node, error) { config.init() if err := config.validate(); err != nil { return nil, err } logger := logger.Get().With(zap.Uint64("node_id", config.ID)) n := &Node{ config: config, leader: raft.None, dataStore: NewDataStore(config.DataDir), leaderChanged: make(chan bool), logger: logger, } n.snapshotThreshold.Store(defaultSnapshotThreshold) n.compactThreshold.Store(defaultCompactThreshold) if err := n.run(); err != nil { return nil, err } return n, nil } func (n *Node) Addr() string { return n.addr } func (n *Node) ListPeers() map[uint64]string { peers := make(map[uint64]string) n.peers.Range(func(key, value interface{}) bool { id, _ := key.(uint64) peer, _ := value.(string) peers[id] = peer return true }) return peers } func (n *Node) SetSnapshotThreshold(threshold uint64) { n.snapshotThreshold.Store(threshold) } func (n *Node) run() error { // The node is already running if !n.isRunning.CompareAndSwap(false, true) { return nil } n.shutdown = make(chan struct{}) peers := make([]raft.Peer, len(n.config.Peers)) for i, peer := range n.config.Peers { peers[i] = raft.Peer{ ID: uint64(i + 1), Context: []byte(peer), } } raftConfig := &raft.Config{ ID: n.config.ID, HeartbeatTick: n.config.HeartbeatSeconds, ElectionTick: n.config.ElectionSeconds, MaxInflightMsgs: 128, MaxSizePerMsg: 10 * 1024 * 1024, // 10 MiB Storage: n.dataStore.raftStorage, Logger: Logger{SugaredLogger: n.logger.Sugar()}, } // WAL existing check must be done before replayWAL since it will create a new WAL if not exists walExists := n.dataStore.walExists() snapshot, err := n.dataStore.replayWAL() if err != nil { return err } n.appliedIndex = snapshot.Metadata.Index n.snapshotIndex = snapshot.Metadata.Index n.confState = snapshot.Metadata.ConfState if n.config.ClusterState == ClusterStateExisting || walExists { n.raftNode = raft.RestartNode(raftConfig) } else { n.raftNode = raft.StartNode(raftConfig, peers) } if err := n.runTransport(); err != nil { return err } n.watchLeaderChange() return n.runRaftMessages() } func (n *Node) runTransport() error { logger := logger.Get() idString := fmt.Sprintf("%d", n.config.ID) transport := &rafthttp.Transport{ ID: types.ID(n.config.ID), Logger: logger, ClusterID: 0x6666, Raft: n, LeaderStats: stats.NewLeaderStats(logger, idString), ServerStats: stats.NewServerStats("raft", idString), ErrorC: make(chan error), } if err := transport.Start(); err != nil { return fmt.Errorf("unable to start transport: %w", err) } for i, peer := range n.config.Peers { // Don't add self to transport if uint64(i+1) != n.config.ID { transport.AddPeer(types.ID(i+1), []string{peer}) } n.peers.Store(uint64(i+1), peer) } n.addr = n.config.Peers[n.config.ID-1] url, err := url.Parse(n.addr) if err != nil { return err } httpServer := &http.Server{ Addr: url.Host, Handler: transport.Handler(), } n.wg.Add(1) go func() { defer n.wg.Done() if err := httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { n.logger.Fatal("Unable to start http server", zap.Error(err)) os.Exit(1) } }() n.transport = transport n.httpServer = httpServer return nil } func (n *Node) watchLeaderChange() { n.wg.Add(1) go func() { defer n.wg.Done() ticker := time.NewTicker(time.Second) defer ticker.Stop() for { select { case <-n.shutdown: return case <-ticker.C: lead := n.GetRaftLead() if lead != n.leader { n.leader = lead select { case <-n.shutdown: return case n.leaderChanged <- true: } n.logger.Info("Found leader changed", zap.Uint64("leader", lead)) } } } }() } func (n *Node) runRaftMessages() error { n.wg.Add(1) go func() { ticker := time.NewTicker(time.Second) defer func() { ticker.Stop() n.wg.Done() }() for { select { case <-ticker.C: n.raftNode.Tick() case rd := <-n.raftNode.Ready(): // Save to wal and storage first if !raft.IsEmptySnap(rd.Snapshot) { if err := n.dataStore.saveSnapshot(rd.Snapshot); err != nil { n.logger.Error("Failed to save snapshot", zap.Error(err)) } } if err := n.dataStore.wal.Save(rd.HardState, rd.Entries); err != nil { n.logger.Error("Failed to save to wal", zap.Error(err)) } // Replay the entries into the raft storage if err := n.applySnapshot(rd.Snapshot); err != nil { n.logger.Error("Failed to apply snapshot", zap.Error(err)) } _ = n.dataStore.raftStorage.Append(rd.Entries) for _, msg := range rd.Messages { if msg.Type == raftpb.MsgApp { msg.Snapshot.Metadata.ConfState = n.confState } } n.transport.Send(rd.Messages) // Apply the committed entries to the state machine n.applyEntries(rd.CommittedEntries) if err := n.triggerSnapshotIfNeed(); err != nil { n.logger.Error("Failed to trigger snapshot", zap.Error(err)) } n.raftNode.Advance() case err := <-n.transport.ErrorC: n.logger.Fatal("Found transport error", zap.Error(err)) return case <-n.shutdown: n.logger.Info("Shutting down raft node") return } } }() return nil } func (n *Node) triggerSnapshotIfNeed() error { if n.appliedIndex-n.snapshotIndex <= n.snapshotThreshold.Load() { return nil } snapshotBytes, err := n.dataStore.GetDataStoreSnapshot() if err != nil { return err } snap, err := n.dataStore.raftStorage.CreateSnapshot(n.appliedIndex, &n.confState, snapshotBytes) if err != nil { return err } if err := n.dataStore.saveSnapshot(snap); err != nil { return err } compactIndex := uint64(1) if n.appliedIndex > n.compactThreshold.Load() { compactIndex = n.appliedIndex - n.compactThreshold.Load() } if err := n.dataStore.raftStorage.Compact(compactIndex); err != nil && !errors.Is(err, raft.ErrCompacted) { return err } n.snapshotIndex = n.appliedIndex return nil } func (n *Node) Set(ctx context.Context, key string, value []byte) error { bytes, err := json.Marshal(&Event{ Op: opSet, Key: key, Value: value, }) if err != nil { return err } return n.raftNode.Propose(ctx, bytes) } func (n *Node) AddPeer(ctx context.Context, nodeID uint64, peer string) error { cc := raftpb.ConfChange{ Type: raftpb.ConfChangeAddNode, NodeID: nodeID, Context: []byte(peer), } return n.raftNode.ProposeConfChange(ctx, cc) } func (n *Node) RemovePeer(ctx context.Context, nodeID uint64) error { cc := raftpb.ConfChange{ Type: raftpb.ConfChangeRemoveNode, NodeID: nodeID, } return n.raftNode.ProposeConfChange(ctx, cc) } func (n *Node) ID() string { return fmt.Sprintf("%d", n.config.ID) } func (n *Node) Leader() string { return fmt.Sprintf("%d", n.GetRaftLead()) } func (n *Node) GetRaftLead() uint64 { return n.raftNode.Status().Lead } func (n *Node) IsReady(ctx context.Context) bool { tries := 0 for { select { case <-n.shutdown: return false case <-time.After(200 * time.Millisecond): // wait for the leader to be elected if n.GetRaftLead() != raft.None { return true } tries++ if tries >= 10 { // waiting too long, just return the running status n.logger.Warn("Leader not elected, return the running status") return n.isRunning.Load() } case <-ctx.Done(): return false } } } func (n *Node) LeaderChange() <-chan bool { return n.leaderChanged } func (n *Node) Get(_ context.Context, key string) ([]byte, error) { return n.dataStore.Get(key) } func (n *Node) Exists(_ context.Context, key string) (bool, error) { _, err := n.dataStore.Get(key) if err != nil { if errors.Is(err, ErrKeyNotFound) { return false, nil } return false, err } return true, nil } func (n *Node) Delete(ctx context.Context, key string) error { bytes, err := json.Marshal(&Event{ Op: opDelete, Key: key, }) if err != nil { return err } return n.raftNode.Propose(ctx, bytes) } func (n *Node) List(_ context.Context, prefix string) ([]engine.Entry, error) { return n.dataStore.List(prefix), nil } func (n *Node) applySnapshot(snapshot raftpb.Snapshot) error { if raft.IsEmptySnap(snapshot) { return nil } _ = n.dataStore.raftStorage.ApplySnapshot(snapshot) if n.appliedIndex >= snapshot.Metadata.Index { return fmt.Errorf("snapshot index [%d] should be greater than applied index [%d]", snapshot.Metadata.Index, n.appliedIndex) } // Load the snapshot into the key-value store. if err := n.dataStore.reloadSnapshot(); err != nil { return err } n.confState = snapshot.Metadata.ConfState n.appliedIndex = snapshot.Metadata.Index n.snapshotIndex = snapshot.Metadata.Index return nil } func (n *Node) applyEntries(entries []raftpb.Entry) { if len(entries) == 0 || entries[0].Index > n.appliedIndex+1 { return } firstEntryIndex := entries[0].Index // remove entries that have been applied if n.appliedIndex-firstEntryIndex+1 < uint64(len(entries)) { entries = entries[n.appliedIndex-firstEntryIndex+1:] } for _, entry := range entries { if err := n.applyEntry(entry); err != nil { n.logger.Error("failed to apply entry", zap.Error(err)) } } n.appliedIndex = entries[len(entries)-1].Index } func (n *Node) applyEntry(entry raftpb.Entry) error { switch entry.Type { case raftpb.EntryNormal: return n.dataStore.applyDataEntry(entry) case raftpb.EntryConfChangeV2, raftpb.EntryConfChange: // apply config change to the state machine var cc raftpb.ConfChange if err := cc.Unmarshal(entry.Data); err != nil { return err } n.confState = *n.raftNode.ApplyConfChange(cc) switch cc.Type { case raftpb.ConfChangeAddNode: if cc.NodeID != n.config.ID && len(cc.Context) > 0 { n.logger.Info("Add the new peer", zap.String("context", string(cc.Context))) n.transport.AddPeer(types.ID(cc.NodeID), []string{string(cc.Context)}) n.peers.Store(cc.NodeID, string(cc.Context)) } case raftpb.ConfChangeRemoveNode: if cc.NodeID == n.config.ID { n.logger.Info("I have been removed from the cluster, will shutdown") n.Close() return nil } n.transport.RemovePeer(types.ID(cc.NodeID)) n.peers.Delete(cc.NodeID) n.logger.Info("Remove the peer", zap.Uint64("node_id", cc.NodeID)) case raftpb.ConfChangeUpdateNode: n.transport.UpdatePeer(types.ID(cc.NodeID), []string{string(cc.Context)}) if _, ok := n.peers.Load(cc.NodeID); ok { n.peers.Store(cc.NodeID, string(cc.Context)) } case raftpb.ConfChangeAddLearnerNode: // TODO: add the learner node } } return nil } func (n *Node) Process(ctx context.Context, m raftpb.Message) error { return n.raftNode.Step(ctx, m) } func (n *Node) IsIDRemoved(_ uint64) bool { return false } func (n *Node) ReportUnreachable(id uint64) { n.raftNode.ReportUnreachable(id) } func (n *Node) ReportSnapshot(id uint64, status raft.SnapshotStatus) { n.raftNode.ReportSnapshot(id, status) } func (n *Node) Close() error { if !n.isRunning.CompareAndSwap(true, false) { return nil } close(n.shutdown) n.raftNode.Stop() n.transport.Stop() if err := n.httpServer.Close(); err != nil { return err } n.dataStore.Close() n.wg.Wait() return nil }