diff --git a/rate.go b/rate.go index 1cfc36e..5bb2029 100644 --- a/rate.go +++ b/rate.go @@ -3,6 +3,7 @@ package redis_rate import ( "context" "fmt" + "reflect" "strconv" "time" @@ -116,10 +117,20 @@ func (l Limiter) AllowN( return nil, err } + allowed, err := convertToInt(values[0]) + if err != nil { + return nil, err + } + + remaining, err := convertToInt(values[1]) + if err != nil { + return nil, err + } + res := &Result{ Limit: limit, - Allowed: int(values[0].(int64)), - Remaining: int(values[1].(int64)), + Allowed: allowed, + Remaining: remaining, RetryAfter: dur(retryAfter), ResetAfter: dur(resetAfter), } @@ -152,10 +163,20 @@ func (l Limiter) AllowAtMost( return nil, err } + allowed, err := convertToInt(values[0]) + if err != nil { + return nil, err + } + + remaining, err := convertToInt(values[1]) + if err != nil { + return nil, err + } + res := &Result{ Limit: limit, - Allowed: int(values[0].(int64)), - Remaining: int(values[1].(int64)), + Allowed: allowed, + Remaining: remaining, RetryAfter: dur(retryAfter), ResetAfter: dur(resetAfter), } @@ -174,6 +195,17 @@ func dur(f float64) time.Duration { return time.Duration(f * float64(time.Second)) } +func convertToInt(value interface{}) (int, error) { + switch v := value.(type) { + case int64: + return int(v), nil + case float64: + return int(v), nil + default: + return 0, fmt.Errorf("value type is not match. Type: %s", reflect.TypeOf(value).Name()) + } +} + type Result struct { // Limit is the limit that was used to obtain this result. Limit Limit