From 0695d2386645f5710547ef5a06c33fc202d1d479 Mon Sep 17 00:00:00 2001 From: ChronosX88 Date: Thu, 3 Jun 2021 00:19:52 +0300 Subject: [PATCH] Refactor pubsub package --- consensus/consensus.go | 101 +++++++++++++++++++++++-------- consensus/consensus_validator.go | 57 +++++++++-------- consensus/msg_log.go | 16 ++--- consensus/types/message.go | 8 +-- node/node.go | 2 +- node/node_dep_providers.go | 4 +- pubsub/handler.go | 7 --- pubsub/message.go | 20 ++++++ pubsub/pubsub_router.go | 18 +++--- 9 files changed, 145 insertions(+), 88 deletions(-) delete mode 100644 pubsub/handler.go create mode 100644 pubsub/message.go diff --git a/consensus/consensus.go b/consensus/consensus.go index d63e0f6..8079a5f 100644 --- a/consensus/consensus.go +++ b/consensus/consensus.go @@ -4,6 +4,8 @@ import ( "math/big" "sync" + "github.com/fxamacker/cbor/v2" + "github.com/Secured-Finance/dione/cache" "github.com/Secured-Finance/dione/consensus/types" @@ -45,9 +47,9 @@ func NewPBFTConsensusManager(psb *pubsub.PubSubRouter, minApprovals int, privKey pcm.ethereumClient = ethereumClient pcm.cache = evc pcm.consensusMap = map[string]*Consensus{} - pcm.psb.Hook(types.MessageTypePrePrepare, pcm.handlePrePrepare) - pcm.psb.Hook(types.MessageTypePrepare, pcm.handlePrepare) - pcm.psb.Hook(types.MessageTypeCommit, pcm.handleCommit) + pcm.psb.Hook(pubsub.PrePrepareMessageType, pcm.handlePrePrepare) + pcm.psb.Hook(pubsub.PrepareMessageType, pcm.handlePrepare) + pcm.psb.Hook(pubsub.CommitMessageType, pcm.handleCommit) return pcm } @@ -62,44 +64,54 @@ func (pcm *PBFTConsensusManager) Propose(task types2.DioneTask) error { return nil } -func (pcm *PBFTConsensusManager) handlePrePrepare(message *types.Message) { - if message.Payload.Task.Miner == pcm.miner.address { +func (pcm *PBFTConsensusManager) handlePrePrepare(message *pubsub.PubSubMessage) { + cmsg, err := unmarshalPayload(message) + if err != nil { return } - if pcm.msgLog.Exists(*message) { + + if cmsg.Task.Miner == pcm.miner.address { + return + } + if pcm.msgLog.Exists(cmsg) { logrus.Debugf("received existing pre_prepare msg, dropping...") return } - if !pcm.validator.Valid(*message) { + if !pcm.validator.Valid(cmsg) { logrus.Warn("received invalid pre_prepare msg, dropping...") return } - pcm.msgLog.AddMessage(*message) + pcm.msgLog.AddMessage(cmsg) - prepareMsg, err := NewMessage(message, types.MessageTypePrepare) + prepareMsg, err := NewMessage(message, pubsub.PrepareMessageType) if err != nil { logrus.Errorf("failed to create prepare message: %v", err) } - pcm.createConsensusInfo(&message.Payload.Task, false) + pcm.createConsensusInfo(&cmsg.Task, false) pcm.psb.BroadcastToServiceTopic(&prepareMsg) } -func (pcm *PBFTConsensusManager) handlePrepare(message *types.Message) { - if pcm.msgLog.Exists(*message) { +func (pcm *PBFTConsensusManager) handlePrepare(message *pubsub.PubSubMessage) { + cmsg, err := unmarshalPayload(message) + if err != nil { + return + } + + if pcm.msgLog.Exists(cmsg) { logrus.Debugf("received existing prepare msg, dropping...") return } - if !pcm.validator.Valid(*message) { + if !pcm.validator.Valid(cmsg) { logrus.Warn("received invalid prepare msg, dropping...") return } - pcm.msgLog.AddMessage(*message) + pcm.msgLog.AddMessage(cmsg) - if len(pcm.msgLog.GetMessagesByTypeAndConsensusID(types.MessageTypePrepare, message.Payload.Task.ConsensusID)) >= pcm.minApprovals { + if len(pcm.msgLog.Get(types.MessageTypePrepare, cmsg.Task.ConsensusID)) >= pcm.minApprovals { commitMsg, err := NewMessage(message, types.MessageTypeCommit) if err != nil { logrus.Errorf("failed to create commit message: %w", err) @@ -108,21 +120,25 @@ func (pcm *PBFTConsensusManager) handlePrepare(message *types.Message) { } } -func (pcm *PBFTConsensusManager) handleCommit(message *types.Message) { - if pcm.msgLog.Exists(*message) { +func (pcm *PBFTConsensusManager) handleCommit(message *pubsub.PubSubMessage) { + cmsg, err := unmarshalPayload(message) + if err != nil { + return + } + + if pcm.msgLog.Exists(cmsg) { logrus.Debugf("received existing commit msg, dropping...") return } - if !pcm.validator.Valid(*message) { + if !pcm.validator.Valid(cmsg) { logrus.Warn("received invalid commit msg, dropping...") return } - pcm.msgLog.AddMessage(*message) + pcm.msgLog.AddMessage(cmsg) - consensusMsg := message.Payload - if len(pcm.msgLog.GetMessagesByTypeAndConsensusID(types.MessageTypeCommit, message.Payload.Task.ConsensusID)) >= pcm.minApprovals { - info := pcm.GetConsensusInfo(consensusMsg.Task.ConsensusID) + if len(pcm.msgLog.Get(types.MessageTypeCommit, cmsg.Task.ConsensusID)) >= pcm.minApprovals { + info := pcm.GetConsensusInfo(cmsg.Task.ConsensusID) if info == nil { logrus.Debugf("consensus doesn't exist in our consensus map - skipping...") return @@ -133,13 +149,13 @@ func (pcm *PBFTConsensusManager) handleCommit(message *types.Message) { return } if info.IsCurrentMinerLeader { - logrus.Infof("Submitting on-chain result for consensus ID: %s", consensusMsg.Task.ConsensusID) - reqID, ok := new(big.Int).SetString(consensusMsg.Task.RequestID, 10) + logrus.Infof("Submitting on-chain result for consensus ID: %s", cmsg.Task.ConsensusID) + reqID, ok := new(big.Int).SetString(cmsg.Task.RequestID, 10) if !ok { - logrus.Errorf("Failed to parse request ID: %v", consensusMsg.Task.RequestID) + logrus.Errorf("Failed to parse request ID: %v", cmsg.Task.RequestID) } - err := pcm.ethereumClient.SubmitRequestAnswer(reqID, consensusMsg.Task.Payload) + err := pcm.ethereumClient.SubmitRequestAnswer(reqID, cmsg.Task.Payload) if err != nil { logrus.Errorf("Failed to submit on-chain result: %v", err) } @@ -167,3 +183,36 @@ func (pcm *PBFTConsensusManager) GetConsensusInfo(consensusID string) *Consensus return c } + +func unmarshalPayload(msg *pubsub.PubSubMessage) (types.ConsensusMessage, error) { + var task types2.DioneTask + err := cbor.Unmarshal(msg.Payload, &task) + if err != nil { + logrus.Debug(err) + return types.ConsensusMessage{}, err + } + var consensusMessageType types.MessageType + switch msg.Type { + case pubsub.PrePrepareMessageType: + { + consensusMessageType = types.MessageTypePrePrepare + break + } + case pubsub.PrepareMessageType: + { + consensusMessageType = types.MessageTypePrepare + break + } + case pubsub.CommitMessageType: + { + consensusMessageType = types.MessageTypeCommit + break + } + } + cmsg := types.ConsensusMessage{ + Type: consensusMessageType, + From: msg.From, + Task: task, + } + return cmsg, nil +} diff --git a/consensus/consensus_validator.go b/consensus/consensus_validator.go index 4c37be8..b6d16d5 100644 --- a/consensus/consensus_validator.go +++ b/consensus/consensus_validator.go @@ -12,7 +12,7 @@ import ( ) type ConsensusValidator struct { - validationFuncMap map[types2.MessageType]func(msg types2.Message) bool + validationFuncMap map[types2.MessageType]func(msg types2.ConsensusMessage) bool cache cache.Cache miner *Miner } @@ -23,13 +23,12 @@ func NewConsensusValidator(ec cache.Cache, miner *Miner) *ConsensusValidator { miner: miner, } - cv.validationFuncMap = map[types2.MessageType]func(msg types2.Message) bool{ - types2.MessageTypePrePrepare: func(msg types2.Message) bool { + cv.validationFuncMap = map[types2.MessageType]func(msg types2.ConsensusMessage) bool{ + types2.MessageTypePrePrepare: func(msg types2.ConsensusMessage) bool { // TODO here we need to do validation of tx itself - consensusMsg := msg.Payload // === verify task signature === - err := VerifyTaskSignature(consensusMsg.Task) + err := VerifyTaskSignature(msg.Task) if err != nil { logrus.Errorf("unable to verify signature: %v", err) return false @@ -38,15 +37,15 @@ func NewConsensusValidator(ec cache.Cache, miner *Miner) *ConsensusValidator { // === verify if request exists in cache === var requestEvent *dioneOracle.DioneOracleNewOracleRequest - err = cv.cache.Get("request_"+consensusMsg.Task.RequestID, &requestEvent) + err = cv.cache.Get("request_"+msg.Task.RequestID, &requestEvent) if err != nil { logrus.Errorf("the request doesn't exist in the cache or has been failed to decode: %v", err) return false } - if requestEvent.OriginChain != consensusMsg.Task.OriginChain || - requestEvent.RequestType != consensusMsg.Task.RequestType || - requestEvent.RequestParams != consensusMsg.Task.RequestParams { + if requestEvent.OriginChain != msg.Task.OriginChain || + requestEvent.RequestType != msg.Task.RequestType || + requestEvent.RequestParams != msg.Task.RequestParams { logrus.Errorf("the incoming task and cached request requestEvent don't match!") return false @@ -54,14 +53,14 @@ func NewConsensusValidator(ec cache.Cache, miner *Miner) *ConsensusValidator { ///////////////////////////////// // === verify election proof wincount preliminarily === - if consensusMsg.Task.ElectionProof.WinCount < 1 { + if msg.Task.ElectionProof.WinCount < 1 { logrus.Error("miner isn't a winner!") return false } ///////////////////////////////// // === verify miner's eligibility to propose this task === - err = cv.miner.IsMinerEligibleToProposeTask(common.HexToAddress(consensusMsg.Task.MinerEth)) + err = cv.miner.IsMinerEligibleToProposeTask(common.HexToAddress(msg.Task.MinerEth)) if err != nil { logrus.Errorf("miner is not eligible to propose task: %v", err) return false @@ -69,22 +68,22 @@ func NewConsensusValidator(ec cache.Cache, miner *Miner) *ConsensusValidator { ///////////////////////////////// // === verify election proof vrf === - minerAddressMarshalled, err := consensusMsg.Task.Miner.MarshalBinary() + minerAddressMarshalled, err := msg.Task.Miner.MarshalBinary() if err != nil { logrus.Errorf("failed to marshal miner address: %v", err) return false } electionProofRandomness, err := DrawRandomness( - consensusMsg.Task.BeaconEntries[1].Data, + msg.Task.BeaconEntries[1].Data, crypto.DomainSeparationTag_ElectionProofProduction, - consensusMsg.Task.DrandRound, + msg.Task.DrandRound, minerAddressMarshalled, ) if err != nil { logrus.Errorf("failed to draw electionProofRandomness: %v", err) return false } - err = VerifyVRF(consensusMsg.Task.Miner, electionProofRandomness, consensusMsg.Task.ElectionProof.VRFProof) + err = VerifyVRF(msg.Task.Miner, electionProofRandomness, msg.Task.ElectionProof.VRFProof) if err != nil { logrus.Errorf("failed to verify election proof vrf: %v", err) } @@ -92,9 +91,9 @@ func NewConsensusValidator(ec cache.Cache, miner *Miner) *ConsensusValidator { // === verify ticket vrf === ticketRandomness, err := DrawRandomness( - consensusMsg.Task.BeaconEntries[1].Data, + msg.Task.BeaconEntries[1].Data, crypto.DomainSeparationTag_TicketProduction, - consensusMsg.Task.DrandRound-types.TicketRandomnessLookback, + msg.Task.DrandRound-types.TicketRandomnessLookback, minerAddressMarshalled, ) if err != nil { @@ -102,48 +101,48 @@ func NewConsensusValidator(ec cache.Cache, miner *Miner) *ConsensusValidator { return false } - err = VerifyVRF(consensusMsg.Task.Miner, ticketRandomness, consensusMsg.Task.Ticket.VRFProof) + err = VerifyVRF(msg.Task.Miner, ticketRandomness, msg.Task.Ticket.VRFProof) if err != nil { logrus.Errorf("failed to verify ticket vrf: %v", err) } ////////////////////////////////////// // === compute wincount locally and verify values === - mStake, nStake, err := cv.miner.GetStakeInfo(common.HexToAddress(consensusMsg.Task.MinerEth)) + mStake, nStake, err := cv.miner.GetStakeInfo(common.HexToAddress(msg.Task.MinerEth)) if err != nil { logrus.Errorf("failed to get miner stake: %v", err) return false } - actualWinCount := consensusMsg.Task.ElectionProof.ComputeWinCount(*mStake, *nStake) - if consensusMsg.Task.ElectionProof.WinCount != actualWinCount { + actualWinCount := msg.Task.ElectionProof.ComputeWinCount(*mStake, *nStake) + if msg.Task.ElectionProof.WinCount != actualWinCount { logrus.Errorf("locally computed wincount isn't matching received value!", err) return false } ////////////////////////////////////// // === validate payload by specific-chain checks === - if validationFunc := validation.GetValidationMethod(consensusMsg.Task.OriginChain, consensusMsg.Task.RequestType); validationFunc != nil { - err := validationFunc(consensusMsg.Task.Payload) + if validationFunc := validation.GetValidationMethod(msg.Task.OriginChain, msg.Task.RequestType); validationFunc != nil { + err := validationFunc(msg.Task.Payload) if err != nil { logrus.Errorf("payload validation has failed: %v", err) return false } } else { - logrus.Debugf("Origin chain [%v]/request type[%v] doesn't have any payload validation!", consensusMsg.Task.OriginChain, consensusMsg.Task.RequestType) + logrus.Debugf("Origin chain [%v]/request type[%v] doesn't have any payload validation!", msg.Task.OriginChain, msg.Task.RequestType) } ///////////////////////////////// return true }, - types2.MessageTypePrepare: func(msg types2.Message) bool { - err := VerifyTaskSignature(msg.Payload.Task) + types2.MessageTypePrepare: func(msg types2.ConsensusMessage) bool { + err := VerifyTaskSignature(msg.Task) if err != nil { return false } return true }, - types2.MessageTypeCommit: func(msg types2.Message) bool { - err := VerifyTaskSignature(msg.Payload.Task) + types2.MessageTypeCommit: func(msg types2.ConsensusMessage) bool { + err := VerifyTaskSignature(msg.Task) if err != nil { return false } @@ -154,6 +153,6 @@ func NewConsensusValidator(ec cache.Cache, miner *Miner) *ConsensusValidator { return cv } -func (cv *ConsensusValidator) Valid(msg types2.Message) bool { +func (cv *ConsensusValidator) Valid(msg types2.ConsensusMessage) bool { return cv.validationFuncMap[msg.Type](msg) } diff --git a/consensus/msg_log.go b/consensus/msg_log.go index 0bd8011..3f0bac7 100644 --- a/consensus/msg_log.go +++ b/consensus/msg_log.go @@ -8,7 +8,7 @@ import ( type MessageLog struct { messages mapset.Set maxLogSize int - validationFuncMap map[types2.MessageType]func(message types2.Message) + validationFuncMap map[types2.MessageType]func(message types2.ConsensusMessage) } func NewMessageLog() *MessageLog { @@ -20,21 +20,21 @@ func NewMessageLog() *MessageLog { return msgLog } -func (ml *MessageLog) AddMessage(msg types2.Message) { +func (ml *MessageLog) AddMessage(msg types2.ConsensusMessage) { ml.messages.Add(msg) } -func (ml *MessageLog) Exists(msg types2.Message) bool { +func (ml *MessageLog) Exists(msg types2.ConsensusMessage) bool { return ml.messages.Contains(msg) } -func (ml *MessageLog) GetMessagesByTypeAndConsensusID(typ types2.MessageType, consensusID string) []types2.Message { - var result []types2.Message +func (ml *MessageLog) Get(typ types2.MessageType, consensusID string) []*types2.ConsensusMessage { + var result []*types2.ConsensusMessage for v := range ml.messages.Iter() { - msg := v.(types2.Message) - if msg.Type == typ && msg.Payload.Task.ConsensusID == consensusID { - result = append(result, msg) + msg := v.(types2.ConsensusMessage) + if msg.Type == typ && msg.Task.ConsensusID == consensusID { + result = append(result, &msg) } } diff --git a/consensus/types/message.go b/consensus/types/message.go index b2995e9..f603839 100644 --- a/consensus/types/message.go +++ b/consensus/types/message.go @@ -17,10 +17,6 @@ const ( type ConsensusMessage struct { Task types.DioneTask -} - -type Message struct { - Type MessageType - Payload ConsensusMessage - From peer.ID `cbor:"-"` + From peer.ID + Type MessageType } diff --git a/node/node.go b/node/node.go index 8f1ae21..9675405 100644 --- a/node/node.go +++ b/node/node.go @@ -167,7 +167,7 @@ func NewNode(config *config.Config, prvKey crypto.PrivKey, pexDiscoveryUpdateTim r := provideP2PRPCClient(lhost) // initialize sync manager - sm, err := provideSyncManager(bp, mp, r, baddrs[0]) // FIXME here we just pick up first bootstrap in list + sm, err := provideSyncManager(bp, mp, r, baddrs[0], psb) // FIXME here we just pick up first bootstrap in list if err != nil { logrus.Fatal(err) } diff --git a/node/node_dep_providers.go b/node/node_dep_providers.go index f52b55c..6107361 100644 --- a/node/node_dep_providers.go +++ b/node/node_dep_providers.go @@ -159,12 +159,12 @@ func provideMemPool() (*pool.Mempool, error) { return pool.NewMempool() } -func provideSyncManager(bp *pool.BlockPool, mp *pool.Mempool, r *gorpc.Client, bootstrap multiaddr.Multiaddr) (sync.SyncManager, error) { +func provideSyncManager(bp *pool.BlockPool, mp *pool.Mempool, r *gorpc.Client, bootstrap multiaddr.Multiaddr, psb *pubsub.PubSubRouter) (sync.SyncManager, error) { addr, err := peer.AddrInfoFromP2pAddr(bootstrap) if err != nil { return nil, err } - return sync.NewSyncManager(bp, mp, r, addr.ID), nil + return sync.NewSyncManager(bp, mp, r, addr.ID, psb), nil } func provideP2PRPCClient(h host.Host) *gorpc.Client { diff --git a/pubsub/handler.go b/pubsub/handler.go deleted file mode 100644 index d53ef51..0000000 --- a/pubsub/handler.go +++ /dev/null @@ -1,7 +0,0 @@ -package pubsub - -import ( - "github.com/Secured-Finance/dione/consensus/types" -) - -type Handler func(message *types.Message) diff --git a/pubsub/message.go b/pubsub/message.go new file mode 100644 index 0000000..5537d51 --- /dev/null +++ b/pubsub/message.go @@ -0,0 +1,20 @@ +package pubsub + +import "github.com/libp2p/go-libp2p-core/peer" + +type PubSubMessageType int + +const ( + UnknownMessageType = iota + PrePrepareMessageType + PrepareMessageType + CommitMessageType + NewTxMessageType + NewBlockMessageType +) + +type PubSubMessage struct { + Type PubSubMessageType + From peer.ID `cbor:"-"` + Payload []byte +} diff --git a/pubsub/pubsub_router.go b/pubsub/pubsub_router.go index 3d1e7f3..bd39edb 100644 --- a/pubsub/pubsub_router.go +++ b/pubsub/pubsub_router.go @@ -6,10 +6,8 @@ import ( "github.com/fxamacker/cbor/v2" - "github.com/Secured-Finance/dione/consensus/types" - - host "github.com/libp2p/go-libp2p-core/host" - peer "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/peer" pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/sirupsen/logrus" ) @@ -20,11 +18,13 @@ type PubSubRouter struct { context context.Context contextCancel context.CancelFunc serviceSubscription *pubsub.Subscription - handlers map[types.MessageType][]Handler + handlers map[PubSubMessageType][]Handler oracleTopicName string oracleTopic *pubsub.Topic } +type Handler func(message *PubSubMessage) + func NewPubSubRouter(h host.Host, oracleTopic string, isBootstrap bool) *PubSubRouter { ctx, ctxCancel := context.WithCancel(context.Background()) @@ -32,7 +32,7 @@ func NewPubSubRouter(h host.Host, oracleTopic string, isBootstrap bool) *PubSubR node: h, context: ctx, contextCancel: ctxCancel, - handlers: make(map[types.MessageType][]Handler), + handlers: make(map[PubSubMessageType][]Handler), } var pbOptions []pubsub.Option @@ -102,7 +102,7 @@ func (psr *PubSubRouter) handleMessage(p *pubsub.Message) { if senderPeerID == psr.node.ID() { return } - var message types.Message + var message PubSubMessage err = cbor.Unmarshal(p.Data, &message) if err != nil { logrus.Warn("Unable to decode message data! " + err.Error()) @@ -119,7 +119,7 @@ func (psr *PubSubRouter) handleMessage(p *pubsub.Message) { } } -func (psr *PubSubRouter) Hook(messageType types.MessageType, handler Handler) { +func (psr *PubSubRouter) Hook(messageType PubSubMessageType, handler Handler) { _, ok := psr.handlers[messageType] if !ok { psr.handlers[messageType] = []Handler{} @@ -127,7 +127,7 @@ func (psr *PubSubRouter) Hook(messageType types.MessageType, handler Handler) { psr.handlers[messageType] = append(psr.handlers[messageType], handler) } -func (psr *PubSubRouter) BroadcastToServiceTopic(msg *types.Message) error { +func (psr *PubSubRouter) BroadcastToServiceTopic(msg *PubSubMessage) error { data, err := cbor.Marshal(msg) if err != nil { return err