diff --git a/beacon/drand/drand.go b/beacon/drand/drand.go index c1181ca..4dec731 100644 --- a/beacon/drand/drand.go +++ b/beacon/drand/drand.go @@ -29,25 +29,14 @@ var log = logrus.WithFields(logrus.Fields{ "subsystem": "drand", }) -// DrandResponse structure representing response from drand network -type DrandResponse struct { - // PreviousSig is the previous signature generated - PreviousSig []byte - // Round is the round number this beacon is tied to - Round uint64 - // Signature is the BLS deterministic signature over Round || PreviousRand - Signature []byte - // Randomness for specific round generated by Drand - Randomness []byte -} - type DrandBeacon struct { DrandClient client.Client PublicKey kyber.Point drandResultChannel <-chan client.Result - cacheLock sync.Mutex - localCache map[uint64]types.BeaconEntry + cacheLock sync.Mutex + localCache map[uint64]types.BeaconEntry + latestDrandRound uint64 } func NewDrandBeacon(ps *pubsub.PubSub) (*DrandBeacon, error) { @@ -96,12 +85,26 @@ func NewDrandBeacon(ps *pubsub.PubSub) (*DrandBeacon, error) { db.PublicKey = drandChain.PublicKey db.drandResultChannel = db.DrandClient.Watch(context.TODO()) - + err = db.getLatestDrandResult() + if err != nil { + return nil, err + } go db.loop(context.TODO()) return db, nil } +func (db *DrandBeacon) getLatestDrandResult() error { + latestDround, err := db.DrandClient.Get(context.TODO(), 0) + if err != nil { + log.Errorf("failed to get latest drand round: %v", err) + return err + } + db.cacheValue(newBeaconResultFromDrandResult(latestDround)) + db.updateLatestDrandRound(latestDround.Round()) + return nil +} + func (db *DrandBeacon) loop(ctx context.Context) { for { select { @@ -111,7 +114,8 @@ func (db *DrandBeacon) loop(ctx context.Context) { } case res := <-db.drandResultChannel: { - db.cacheValue(types.NewBeaconEntry(res.Round(), res.Randomness(), map[string]interface{}{"signature": res.Signature()})) + db.cacheValue(newBeaconResultFromDrandResult(res)) + db.updateLatestDrandRound(res.Round()) } } } @@ -163,6 +167,12 @@ func (db *DrandBeacon) getCachedValue(round uint64) *types.BeaconEntry { return &v } +func (db *DrandBeacon) updateLatestDrandRound(round uint64) { + db.cacheLock.Lock() + defer db.cacheLock.Unlock() + db.latestDrandRound = round +} + func (db *DrandBeacon) VerifyEntry(curr, prev types.BeaconEntry) error { if prev.Round == 0 { return nil @@ -179,11 +189,13 @@ func (db *DrandBeacon) VerifyEntry(curr, prev types.BeaconEntry) error { } func (db *DrandBeacon) LatestBeaconRound() uint64 { - latestDround, err := db.DrandClient.Get(context.TODO(), 0) - if err != nil { - log.Errorf("failed to get latest drand round: %w", err) - } - return latestDround.Round() + db.cacheLock.Lock() + defer db.cacheLock.Unlock() + return db.latestDrandRound +} + +func newBeaconResultFromDrandResult(res client.Result) types.BeaconEntry { + return types.NewBeaconEntry(res.Round(), res.Randomness(), map[string]interface{}{"signature": res.Signature()}) } var _ beacon.BeaconAPI = (*DrandBeacon)(nil)