diff --git a/throttle.go b/throttle.go index 0e31938..86f2434 100644 --- a/throttle.go +++ b/throttle.go @@ -46,22 +46,31 @@ func (t *Throttle) run() { } case r:=<-t.queueChan: { for { - if t.config.NumTokens > 0 || t.config.RefillRate == 0 { - r.resolveChan <- rxgo.Of(true) - close(r.resolveChan) + var hasEnoughTokens bool + if t.config.Capacity != 0 { + hasEnoughTokens = t.config.NumTokens > 0 + } else { + hasEnoughTokens = t.config.NumTokens >= 0 + } + if hasEnoughTokens { var cost float64 = 0 if r.cost != -1 { cost = r.cost } else { cost = t.config.DefaultCost } - t.config.NumTokens -= cost - break + if t.config.NumTokens >= math.Min(cost, t.config.Capacity) { + t.config.NumTokens -= cost + r.resolveChan <- rxgo.Of(true) + close(r.resolveChan) + break + } } + if t.lastTimestamp == -1 { - t.lastTimestamp = time.Now().Unix() + t.lastTimestamp = time.Now().UnixNano() } - now := time.Now().Unix() + now := time.Now().UnixNano() elapsed := now - t.lastTimestamp t.lastTimestamp = now t.config.NumTokens = math.Min(t.config.Capacity, t.config.NumTokens+float64(elapsed)*t.config.RefillRate) @@ -74,7 +83,7 @@ func (t *Throttle) run() { } -func (t *Throttle) Take(rateLimit float64, cost float64) rxgo.Observable { +func (t *Throttle) Take(rateLimit time.Duration, cost float64) rxgo.Observable { t.cfgMut.Lock() defer t.cfgMut.Unlock() ch := make(chan rxgo.Item) @@ -85,7 +94,7 @@ func (t *Throttle) Take(rateLimit float64, cost float64) rxgo.Observable { if rateLimit == 0 { t.config.RefillRate = 0 } else { - t.config.RefillRate = 1/rateLimit + t.config.RefillRate = float64(1 / rateLimit.Nanoseconds()) } return rxgo.FromChannel(ch) }