Redis实现限流器


目录

  1. 使用场景
  2. 三种实现方式

参考

使用场景

说到限流器,可能会想到Spring Cloudhystrix或者Spring Cloud Alibabasentinel。他们都是在系统架构上,对接口或者方法层面的整体限流组件。

但是如果考虑业务上自定义的限流规则,上面的限流组件可能粒度太大了,那怎么实现业务上的自定义的限流呢?

Redis就是一个很好的实现方法,场景的有三种实现限流的方式:

  • 基于Redis的setnx的操作
  • 基于Redis的数据结构zset
  • 基于Redis的令牌桶算法

三种实现方式

基于Redis的setnx的操作

我们在使用Redis的分布式锁的时候,大家都知道是依靠了setnx的指令,在CAS(Compare and swap)的操作的时候,同时给指定的key设置了过期实践(expire),我们在限流的主要目的就是为了在单位时间内,有且仅有N数量的请求能够访问我的代码程序。所以依靠setnx可以很轻松的做到这方面的功能。比如我们需要在10秒内限定20个请求,那么我们在setnx的时候可以设置过期时间10,当请求的setnx数量达到20时候即达到了限流效果。代码比较简单就不做展示了。当然这种做法的弊端是很多的,比如当统计1-10秒的时候,无法统计2-11秒之内,如果需要统计N秒内的M个请求,那么我们的Redis中需要保持N个key等等问题

基于Redis的数据结构zset

我们可以将请求打造成一个zset数组,当每一次请求进来的时候,value保持唯一,可以用UUID生成,而score可以用当前时间戳表示,因为score我们可以用来计算当前时间戳之内有多少的请求数量。而zset数据结构也提供了range方法让我们可以很轻易的获取到2个时间戳内有多少请求


/**
 * 计数器方法: 在10秒内,只有20个请求能访问我的代码
 * 实现方式: setnx key 10 value  每多一个请求,value就+1,,当value等于10时,不执行接下来的代码
 * 缺点: 能知道1-10s内有多少次请求,但是2-11s的请求数不知道
 * 时间滑动窗口
 * 下面介绍这种
 */
@AllArgsConstructor
public class RedisSimpleRateLimiter {

    private Jedis jedis;

    //无论返回什么,都表示用户已经点击过,通过这个方法,接下来的代码时要去判断该用户点击行为是否有效
    public boolean isActionAllowed(String userId, String actionKey, int period, int maxCount) {
        String key = userId + ":" + actionKey;
        long now = System.currentTimeMillis();
        //大量操作一次性执行
        Pipeline pipeline = jedis.pipelined();
        pipeline.multi();
        pipeline.zadd(key, now, "" + now);//key , score , value
        //删除now-period*1000秒之前的数据(也就是说保留periods秒之内的数据)
        // 看清楚了这是直接删除了
        pipeline.zremrangeByScore(key, 0, now - period * 1000);
        //查看在period秒内用户点击了多少次(无效点击+有效点击都算在period秒之内,所以如果用户一直连续不断点击,会一直提示无效点击)
        //用户必须停下来等待几秒才可以,不能一直连续点击,前端可以做个计数器,一直点,后面的就不发请求了,过几秒再发
        Response<Long> count = pipeline.zcard(key);//返回count
        //这里给key设置过期时间,是防止冷用户占用内存
        pipeline.expire(key, period + 1);
        pipeline.exec();
        pipeline.close();
        //看看redis里面存的消息数量是不是小区我们限流的
        return count.get() <= maxCount;
    }

    public static void main(String[] args) throws InterruptedException {
        Jedis jedis = new RedisUtil().getJedis();
        RedisSimpleRateLimiter redisSimpleRateLimiter = new RedisSimpleRateLimiter(jedis);
        for (int x = 0; x < 15; x++) {
            System.out.println(redisSimpleRateLimiter.isActionAllowed("cixingrui", "reply", 10, 2));
            Thread.sleep(1000l);
        }
    }

}

基于Redis的令牌桶算法

令牌桶算法是网络流量整形(Traffic Shaping)和速率限制(Rate Limiting)中最常使用的一种算法。典型情况下,令牌桶算法用来控制发送到网络上的数据的数目,并允许突发数据的发送。

大小固定的令牌桶可自行以恒定的速率源源不断地产生令牌。如果令牌不被消耗,或者被消耗的速度小于产生的速度,令牌就会不断地增多,直到把桶填满。后面再产生的令牌就会从桶中溢出。最后桶中可以保存的最大令牌数永远不会超过桶的大小。

传送到令牌桶的数据包需要消耗令牌。不同大小的数据包,消耗的令牌数量不一样。令牌桶这种控制机制基于令牌桶中是否存在令牌来指示什么时候可以发送流量。令牌桶中的每一个令牌都代表一个字节。如果令牌桶中存在令牌,则允许发送流量;而如果令牌桶中不存在令牌,则不允许发送流量。因此,如果突发门限被合理地配置并且令牌桶中有足够的令牌,那么流量就可以以峰值速率发送。

算法过程

算法描述:

  • 假如用户配置的平均发送速率为r,则每隔1/r秒一个令牌被加入到桶中(每秒会有r个令牌放入桶中);
  • 假设桶中最多可以存放b个令牌。如果令牌到达时令牌桶已经满了,那么这个令牌会被丢弃;
  • 当一个n个字节的数据包到达时,就从令牌桶中删除n个令牌(不同大小的数据包,消耗的令牌数量不一样),并且数据包被发送到网络;
  • 如果令牌桶中少于n个令牌,那么不会删除令牌,并且认为这个数据包在流量限制之外(n个字节,需要n个令牌。该数据包将被缓存或丢弃);
  • 算法允许最长b个字节的突发,但从长期运行结果看,数据包的速率被限制成常量r。对于在流量限制外的数据包可以以不同的方式处理:(1)它们可以被丢弃;(2)它们可以排放在队列中以便当令牌桶中累积了足够多的令牌时再传输;(3)它们可以继续发送,但需要做特殊标记,网络过载的时候将这些特殊标记的包丢弃。

实现

本限流器主要是基于令牌桶思想,并将令牌数存储到Redis中,实现在集群模式下对接口的精准限流,实现思想如下:

  • 对于每个限流接口,记录最大存储令牌数maxPermits, 当前存储令牌数storedPermits, 添加令牌时间间隔intervalMillis, 下次请求可以获取令牌的起始时间nextFreeTicketMillis,这些信息都记录在Redis中
  • 响应本次请求之后,动态计算下一次可以服务的时间,如果下一次请求在这个时间之前则需要进行等待。 nextFreeTicketMicros 记录下一次可以响应的时间。例如,如果我们设置QPS为1,本次请求处理完之后,那么下一次最早的能够响应请求的时间一秒钟之后。
  • 限流器支持处理突发流量请求,突发请求允许个数就是最大存储令牌数maxPermits。例如,我们设置QPS为1,在十秒钟之内没有请求,那么令牌桶中会有10个(假设设置的maxPermits为10)空闲令牌,如果下一次请求是 10个令牌,则可以一次性获取10个令牌,因为令牌桶中已经有10个空闲的令牌。 storedPermits 就是用来表示当前令牌桶中的空闲令牌数。
  • 对于令牌的产生有两种方式,一种是通过后台定时任务来不断产生令牌,一种是延迟生成,在每次获取令牌之前先计算在nextFreeTicketMillis到目前这个时间段内应该产生多少令牌,并更新令牌桶。本限流器采用的是后者。

令牌桶

/**
 * Redis令牌桶
 */
@Data
public class RedisPermits implements Serializable {

    private static final long serialVersionUID = 1L;
    /**
     * maxPermits 最大存储令牌数
     */
    private Long maxPermits;
    /**
     * storedPermits 当前存储令牌数
     */
    private Long storedPermits;
    /**
     * intervalMillis 添加令牌时间间隔
     */
    private Long intervalMillis;
    /**
     * nextFreeTicketMillis 下次请求可以获取令牌的起始时间,默认当前系统时间
     */
    private Long nextFreeTicketMillis;

    /**
     * @param permitsPerSecond 每秒放入的令牌数
     * @param maxBurstSeconds  maxPermits由此字段计算,最大存储maxBurstSeconds秒生成的令牌
     */
    public RedisPermits(Double permitsPerSecond, Integer maxBurstSeconds) {
        if (null == maxBurstSeconds) {
            maxBurstSeconds = 60;
        }
        this.maxPermits = (long) (permitsPerSecond * maxBurstSeconds);
        this.storedPermits = permitsPerSecond.longValue();
        this.intervalMillis = (long) (TimeUnit.SECONDS.toMillis(1) / permitsPerSecond);
        this.nextFreeTicketMillis = System.currentTimeMillis();
    }

    /**
     * redis的过期时长
     * @return
     */
    public Long expires() {
        long now = System.currentTimeMillis();
        return 2 * TimeUnit.MINUTES.toSeconds(1)
                + TimeUnit.MILLISECONDS.toSeconds(Math.max(nextFreeTicketMillis, now) - now);
    }

    public Map<String, String> toMap() {
        Map<String, String> resultMap = new HashMap<>();
        resultMap.put("maxPermits", maxPermits.toString());
        resultMap.put("storedPermits", storedPermits.toString());
        resultMap.put("intervalMillis", intervalMillis.toString());
        resultMap.put("nextFreeTicketMillis", nextFreeTicketMillis.toString());
        return resultMap;
    }

该类主要存储了令牌桶核心的四个参数

限流器

主要方法:

@Slf4j
@Data
public class RateLimiter {
    /**
     * 在超时时间内尝试获取{tokenCount}个令牌
     * @param tokenCount
     * @param timeout
     * @param timeUnit
     * @return
     * @throws InterruptedException
     */
    public boolean tryAcquire(Long tokenCount, Long timeout, TimeUnit timeUnit) throws InterruptedException{
        if(checkTokens(tokenCount)) {
            Long timeoutMillis = Math.max(timeUnit.toMillis(timeout), 0);
            Long millisToWait = tryAndGetWaitTime(tokenCount, timeoutMillis);
            if(millisToWait <= timeoutMillis) {
                log.info("tryAcquire for {}ms {}", millisToWait, Thread.currentThread().getName());
                Thread.sleep(millisToWait);
                return true;
            }
        }
        return false;
    }

    /**
     * 等待直到获取指定数量的令牌
     * @param tokenCount
     * @return
     * @throws InterruptedException
     */
    public Long acquire(Long tokenCount) throws InterruptedException {
        long milliToWait = this.reserve(tokenCount);
        log.info("acquire for {}ms {}", milliToWait, Thread.currentThread().getName());
        Thread.sleep(milliToWait);
        return milliToWait;
    }

    /**
     * 获取令牌n个需要等待的时间
     * @param tokenCount
     * @return
     */
    private long reserve(Long tokenCount) {
        if (checkTokens(tokenCount)) {
            return reserveAndGetWaitTime(tokenCount);
        } else {
            return -1;
        }
    }

    /**
     * 预定@{tokenCount}个令牌并返回所需要等待的时间
     * @param tokenCount
     * @return
     */
    private Long reserveAndGetWaitTime(Long tokenCount){
        putDefaultPermits();
        String script = "redis.replicate_commands() " +
                "local redisKey = KEYS[1] " +
                "local timeStrArray = redis.call('time') " +
                "local seconds = tonumber(timeStrArray[1]) " +
                "local microseconds = tonumber(timeStrArray[2]) " +
                "local nowMilliseconds = seconds * 1000 + math.modf(microseconds/1000) " +
                "local redisPermitsValues = redis.call('hmget', redisKey, 'nextFreeTicketMillis', 'maxPermits', 'storedPermits', 'intervalMillis') " +
                "local nextFreeTicketMillis = tonumber(redisPermitsValues[1]) " +
                "local maxPermits = tonumber(redisPermitsValues[2]) " +
                "local storedPermits = tonumber(redisPermitsValues[3]) " +
                "local intervalMillis = tonumber(redisPermitsValues[4]) " +
                "if(nowMilliseconds > nextFreeTicketMillis) " +
                "then " +
                "storedPermits = math.min(maxPermits, storedPermits + math.modf((nowMilliseconds - nextFreeTicketMillis) / intervalMillis)) " +
                "nextFreeTicketMillis = nowMilliseconds " +
                "end " +
                "local tokenCount = tonumber(ARGV[1]) " +
                "local storedPermitsToSpend = math.min(tokenCount, storedPermits) " +
                "local freshPermits = tokenCount - storedPermitsToSpend " +
                "local waitMillis = freshPermits * intervalMillis " +
                "nextFreeTicketMillis = nextFreeTicketMillis + waitMillis " +
                "storedPermits = storedPermits - storedPermitsToSpend " +
                "redis.call('hmset', redisKey, 'nextFreeTicketMillis', nextFreeTicketMillis, 'storedPermits', storedPermits) " +
                "redis.call('expire', redisKey, 120) " +
                "return nextFreeTicketMillis - nowMilliseconds";
        List<String> keys = Collections.singletonList(key);
        List<String> args = Collections.singletonList(tokenCount.toString());
        Object obj = redisUtil.eval(script, keys, args);
        Long result = null;
        if(obj != null) {
            result = (Long) obj;
        }
        return result;
    }

    /**
     * 判断{timeout}时间内能否获取{tokenCount}令牌,如果能获取到则预定令牌
     * @param tokenCount
     * @return 需要等待时长
     */
    private Long tryAndGetWaitTime(Long tokenCount, Long timeoutMillis) {
        putDefaultPermits();
        String script = "redis.replicate_commands() " +
                "local redisKey = KEYS[1] " +
                "local timeStrArray = redis.call('time') " +
                "local seconds = tonumber(timeStrArray[1]) " +
                "local microseconds = tonumber(timeStrArray[2]) " +
                "local nowMilliseconds = seconds * 1000 + math.modf(microseconds/1000) " +
                "local redisPermitsValues = redis.call('hmget', redisKey, 'nextFreeTicketMillis', 'maxPermits', 'storedPermits', 'intervalMillis') " +
                "local nextFreeTicketMillis = tonumber(redisPermitsValues[1]) " +
                "local maxPermits = tonumber(redisPermitsValues[2]) " +
                "local storedPermits = tonumber(redisPermitsValues[3]) " +
                "local intervalMillis = tonumber(redisPermitsValues[4]) " +
                "if(nowMilliseconds > nextFreeTicketMillis) " +
                "then " +
                "storedPermits = math.min(maxPermits, storedPermits + math.modf((nowMilliseconds - nextFreeTicketMillis) / intervalMillis)) " +
                "nextFreeTicketMillis = nowMilliseconds " +
                "end " +
                "local tokenCount = tonumber(ARGV[1]) " +
                "local timeoutMillis = tonumber(ARGV[2]) " +
                "local storedPermitsToSpend = math.min(tokenCount, storedPermits) " +
                "local freshPermits = tokenCount - storedPermitsToSpend " +
                "local waitMillis = freshPermits * intervalMillis " +
                "local actualWaitMillis = nextFreeTicketMillis + waitMillis - nowMilliseconds " +
                "if(actualWaitMillis <= timeoutMillis) " +
                "then " +
                "nextFreeTicketMillis = nextFreeTicketMillis + waitMillis " +
                "storedPermits = storedPermits - storedPermitsToSpend " +
                "redis.call('hmset', redisKey, 'nextFreeTicketMillis', nextFreeTicketMillis, 'storedPermits', storedPermits) " +
                "redis.call('expire', redisKey, 120) " +
                "end " +
                "return actualWaitMillis";
        List<String> keys = Collections.singletonList(key);
        List<String> args = Arrays.asList(tokenCount.toString(), timeoutMillis.toString());
        Object obj = redisUtil.eval(script, keys, args);
        Long result = null;
        if(obj != null) {
            result = (Long) obj;
        }
        return result;
    }
}

可以看到,限流器的主要方法是acquire和tryAcquire,前者是进行线程阻塞以等待令牌桶中达到所需令牌,后者是设定超时时间,并判断在超时时间内能否获取所需令牌,可以的话再进行线程阻塞等待令牌。获取由于存储在Redis中的令牌桶信息在集群环境下会有线程不同步问题,虽然采用Redis分布锁可以解决该问题,但是会造成线程阻塞,降低并发效率。而Redis运行lua脚本是原子性操作,因此本文采用lua脚本执行对令牌桶的计算和更新操作。可以看到核心方法reserveAndGetWaitTime和tryAndGetWaitTime方法都使用了lua脚本,下面简单讲解一下这两个方法的实现逻辑。

reserveAndGetWaitTime

  • 更新令牌桶,这一步操作就是上文讲到的延迟更新令牌
  • 计算所需令牌数与令牌桶中令牌数的插值,确定补全所需令牌数需要等待的时间
  • 取令牌并将令牌桶数据更新到Redis

tryAndGetWaitTime

  • 同样是先更新令牌桶
  • 计算所需令牌数与令牌桶中令牌数的插值,确定补全所需令牌数需要等待的时间
  • 判断等待的时间是否在超时时间内,如果是的话再取令牌将令牌桶数据更新到Redis

文章作者: 小小千千
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 小小千千 !
评论
  目录