Add missing checks, fix rateLimit parameter format to time.Duration

This commit is contained in:
ChronosX88 2021-01-25 17:26:51 +03:00
parent dee1df2117
commit 7cc2c6e929
Signed by: ChronosXYZ
GPG Key ID: 085A69A82C8C511A

View File

@ -46,22 +46,31 @@ func (t *Throttle) run() {
} }
case r:=<-t.queueChan: { case r:=<-t.queueChan: {
for { for {
if t.config.NumTokens > 0 || t.config.RefillRate == 0 { var hasEnoughTokens bool
r.resolveChan <- rxgo.Of(true) if t.config.Capacity != 0 {
close(r.resolveChan) hasEnoughTokens = t.config.NumTokens > 0
} else {
hasEnoughTokens = t.config.NumTokens >= 0
}
if hasEnoughTokens {
var cost float64 = 0 var cost float64 = 0
if r.cost != -1 { if r.cost != -1 {
cost = r.cost cost = r.cost
} else { } else {
cost = t.config.DefaultCost cost = t.config.DefaultCost
} }
t.config.NumTokens -= cost if t.config.NumTokens >= math.Min(cost, t.config.Capacity) {
break t.config.NumTokens -= cost
r.resolveChan <- rxgo.Of(true)
close(r.resolveChan)
break
}
} }
if t.lastTimestamp == -1 { if t.lastTimestamp == -1 {
t.lastTimestamp = time.Now().Unix() t.lastTimestamp = time.Now().UnixNano()
} }
now := time.Now().Unix() now := time.Now().UnixNano()
elapsed := now - t.lastTimestamp elapsed := now - t.lastTimestamp
t.lastTimestamp = now t.lastTimestamp = now
t.config.NumTokens = math.Min(t.config.Capacity, t.config.NumTokens+float64(elapsed)*t.config.RefillRate) t.config.NumTokens = math.Min(t.config.Capacity, t.config.NumTokens+float64(elapsed)*t.config.RefillRate)
@ -74,7 +83,7 @@ func (t *Throttle) run() {
} }
func (t *Throttle) Take(rateLimit float64, cost float64) rxgo.Observable { func (t *Throttle) Take(rateLimit time.Duration, cost float64) rxgo.Observable {
t.cfgMut.Lock() t.cfgMut.Lock()
defer t.cfgMut.Unlock() defer t.cfgMut.Unlock()
ch := make(chan rxgo.Item) ch := make(chan rxgo.Item)
@ -85,7 +94,7 @@ func (t *Throttle) Take(rateLimit float64, cost float64) rxgo.Observable {
if rateLimit == 0 { if rateLimit == 0 {
t.config.RefillRate = 0 t.config.RefillRate = 0
} else { } else {
t.config.RefillRate = 1/rateLimit t.config.RefillRate = float64(1 / rateLimit.Nanoseconds())
} }
return rxgo.FromChannel(ch) return rxgo.FromChannel(ch)
} }