Published on

Spring Security OAuth 2.0 授权服务器结合Redis实现获取accessToken速率限制

Authors
  • avatar
    Name
    ReLive27
    Twitter

Spring Security OAuth 2.0授权服务器结合Redis实现获取accessToken速率限制

概述

在生产环境中,我们通常颁发给OAuth2客户端有效期较长的token,但是授权服务无从知晓OAuth2客户端服务是否频繁获取token,便于我们主动控制token的颁发,减少数据库操作,本文我们将结合Redis实现滑动窗口算法限制速率解决此问题。

先决条件

  • java 8+
  • Redis
  • Lua

授权服务器

本节中我们将使用Spring Authorization Server 搭建一个简单的授权服务器,并通过扩展OAuth2TokenCustomizer实现access_token的速率限制。

Maven依赖

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-security</artifactId>
            <version>2.6.7</version>
        </dependency>

        <dependency>
            <groupId>org.springframework.security</groupId>
            <artifactId>spring-security-oauth2-authorization-server</artifactId>
            <version>0.3.1</version>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
            <version>2.6.7</version>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
            <version>2.6.7</version>
        </dependency>

配置

首先添加spring.redis配置连接本地Redis服务:

server:
  port: 8080

spring:
  redis:
    host: localhost
    database: 0
    port: 6379
    password: 123456
    timeout: 1800
    lettuce:
      pool:
        max-active: 20
        max-wait: 60
        max-idle: 5
        min-idle: 0
      shutdown-timeout: 100

接下来我们需要注册一个OAuth2客户端,声明客户端如下:

    @Bean
    public RegisteredClientRepository registeredClientRepository() {
        RegisteredClient registeredClient = RegisteredClient.withId(UUID.randomUUID().toString())
                .clientId("relive-client")
                .clientSecret("{noop}relive-client")
                .clientAuthenticationMethods(s -> {
                    s.add(ClientAuthenticationMethod.CLIENT_SECRET_POST);
                    s.add(ClientAuthenticationMethod.CLIENT_SECRET_BASIC);
                })
                .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
                .redirectUri("http://127.0.0.1:8070/login/oauth2/code/messaging-client-model")
                .scope("message.read")
                .clientSettings(ClientSettings.builder()
                        .requireAuthorizationConsent(false)
                        .requireProofKey(false)
                        .build())
                .tokenSettings(TokenSettings.builder()
                        .accessTokenFormat(OAuth2TokenFormat.SELF_CONTAINED)
                        .idTokenSignatureAlgorithm(SignatureAlgorithm.RS256)
                        .accessTokenTimeToLive(Duration.ofSeconds(30 * 60))
                        .refreshTokenTimeToLive(Duration.ofSeconds(60 * 60))
                        .reuseRefreshTokens(true)
                        .setting("accessTokenLimitTimeSeconds", 5 * 60)
                        .setting("accessTokenLimitRate", 3)
                        .build())
                .build();

        return new InMemoryRegisteredClientRepository(registeredClient);
    }

上述OAuth2客户端信息如下:

特别注意:我们额外添加了两个参数用于控制AccessToken的速率限制,accessTokenLimitTimeSeconds访问限制时间, accessTokenLimitRate访问限制次数。
此外,我们为单个客户端添加限制参数,由此可以针对不同OAuth2客户端设置不同的速率限制或者取消。

使用Spring Authorization Server提供的授权服务默认配置,并将未认证的授权请求重定向到登录页面:

    @Bean
    @Order(Ordered.HIGHEST_PRECEDENCE)
    public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception {
        OAuth2AuthorizationServerConfiguration.applyDefaultSecurity(http);
        return http.exceptionHandling(exceptions -> exceptions.
                authenticationEntryPoint(new LoginUrlAuthenticationEntryPoint("/login"))).build();
    }

其余常规配置本文将不再赘述,您可以参考以往文章或从文末链接中获取源码。


接下来我们将利用Redis sorted set数据结构实现滑动窗口算法用于access_token速率限制,我们将利用Lua脚本保证Redis操作的原子性,节省网络开销。

redis.replicate_commands()

local key = KEYS[1]

local windowSize = tonumber(ARGV[1])
local rate = tonumber(ARGV[2])
local now = tonumber(redis.call("TIME")[1])

redis.call("zadd", key, now, now)
local start = math.max(0, now - windowSize)

local requestRate = tonumber(redis.call("zcount", key, start, now))

local result = true
if requestRate > rate then
  result = false
end

redis.call("zremrangebyscore", key, "-inf", "("..start)

return result

上述Lua脚本遵循以下步骤:

  • 将当前时间(秒)作为value和score 添加进有序集合(sorted set)中
  • 计算窗口长度,统计窗口中成员总数,该总数表示该窗口长度中已请求次数
  • 判断请求次数是否超过阈值
  • 移除已失效成员

RedisAccessTokenLimiterTokenSettings获取参数accessTokenLimitTimeSeconds,accessTokenLimitRate,由RedisTemplate执行Lua脚本, 并传递参数信息。

@Slf4j
public class RedisAccessTokenLimiter implements AccessTokenLimiter {
    private static final String ACCESS_TOKEN_LIMIT_TIME_SECONDS = "accessTokenLimitTimeSeconds";
    private static final String ACCESS_TOKEN_LIMIT_RATE = "accessTokenLimitRate";
    private final RedisTemplate<String, Object> redisTemplate;
    private final RedisScript<Boolean> script;

    public RedisAccessTokenLimiter(RedisTemplate<String, Object> redisTemplate, RedisScript<Boolean> script) {
        Assert.notNull(redisTemplate, "redisTemplate can not be null");
        Assert.notNull(script, "script can not be null");
        this.redisTemplate = redisTemplate;
        this.script = script;
    }


    @Override
    public boolean isAllowed(RegisteredClient registeredClient) {

        TokenSettings tokenSettings = registeredClient.getTokenSettings();
        if (tokenSettings == null || tokenSettings.getSetting(ACCESS_TOKEN_LIMIT_TIME_SECONDS) == null ||
                tokenSettings.getSetting(ACCESS_TOKEN_LIMIT_RATE) == null) {
            return true;
        }
        int accessTokenLimitTimeSeconds = tokenSettings.getSetting(ACCESS_TOKEN_LIMIT_TIME_SECONDS);

        int accessTokenLimitRate = tokenSettings.getSetting(ACCESS_TOKEN_LIMIT_RATE);

        String clientId = registeredClient.getClientId();

        try {
            List<String> keys = getKeys(clientId);

            return redisTemplate.execute(this.script, keys, accessTokenLimitTimeSeconds, accessTokenLimitRate);
        } catch (Exception e) {
            /*
             * 我们不希望硬依赖 Redis 来允许访问。 确保设置
             * 一个警报,知道发生了许多次。
             */
            log.error("Error determining if user allowed from redis", e);
        }
        return true;
    }

    static List<String> getKeys(String id) {
        // 在key周围使用 `{}` 以使用 Redis Key hash tag
        // 这允许使用 redis 集群
        String prefix = "access_token_rate_limiter.{" + id;

        String key = prefix + "}.client";
        return Arrays.asList(key);
    }

}

已知OAuth2TokenCustomizer提供了自定义OAuth2Token的属性的能力,但是在本示例中我们将使用OAuth2TokenCustomizer作为扩展点, 使用AccessTokenLimiter提供了速率限制,当请求超过阈值时,将抛出OAuth2AuthenticationException异常。


public class AccessTokenRestrictionCustomizer implements OAuth2TokenCustomizer<JwtEncodingContext> {
    private static final String DEFAULT_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1";
    private final AccessTokenLimiter tokenLimiter;

    public AccessTokenRestrictionCustomizer(AccessTokenLimiter tokenLimiter) {
        Assert.notNull(tokenLimiter, "accessTokenLimiter can not be null");
        this.tokenLimiter = tokenLimiter;
    }

    /**
     * 通过{@link AccessTokenLimiter} 为OAuth2 客户端模式访问令牌添加访问限制
     *
     * @param context
     */
    @Override
    public void customize(JwtEncodingContext context) {
        if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(context.getAuthorizationGrantType())) {
            RegisteredClient registeredClient = context.getRegisteredClient();
            if (registeredClient == null) {
                OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST, "OAuth 2.0 Parameter: " + OAuth2ParameterNames.CLIENT_ID, DEFAULT_ERROR_URI);
                throw new OAuth2AuthenticationException(error);
            }


            boolean requiresGenerateToken = this.tokenLimiter.isAllowed(registeredClient);
            if (!requiresGenerateToken) {
                OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.ACCESS_DENIED,
                        "The token generation fails, and the same client is prohibited from repeatedly obtaining the token within a short period of time.", null);
                throw new OAuth2AuthenticationException(error);
            }
        }

    }
}


注意:上述示例中我们使用OAuth 2.0 客户端模式。

测试

本示例中我们限制access_token请求5分钟响应3次,我们将使用以下单元测试简单测试。

    @Test
    public void authorizationWhenObtainingTheAccessTokenSucceeds() throws Exception {
        MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
        parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue());
        parameters.set(OAuth2ParameterNames.CLIENT_ID, "relive-client");
        parameters.set(OAuth2ParameterNames.CLIENT_SECRET, "relive-client");
        this.mockMvc.perform(post("/oauth2/token")
                .params(parameters))
                .andExpect(status().is2xxSuccessful());


    }

    @Test
    public void authorizationWhenTokenAccessRestrictionIsTriggeredThrowOAuth2AuthenticationException() throws Exception {
        MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
        parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue());
        parameters.set(OAuth2ParameterNames.CLIENT_ID, "relive-client");
        parameters.set(OAuth2ParameterNames.CLIENT_SECRET, "relive-client");
        this.mockMvc.perform(post("/oauth2/token")
                .params(parameters))
                .andExpect(status().isBadRequest())
                .andExpect(result -> assertEquals("{\"error_description\":\"The token generation fails, and the same client is prohibited from repeatedly obtaining the token within a short period of time.\",\"error\":\"access_denied\"}", result.getResponse().getContentAsString()));
    }

结论

可能有人会有疑问,一般服务都会由网关限流,为什么使用本示例中方式。当然,从实现上并不妨碍我们在网关中进行限制,这只是一个选择问题。后续文章中我将会介绍如何通过Spring Cloud Gateway结合授权服务 对OAuth2客户端进行速率限制。

与往常一样,本文中使用的源代码可在 GitHub 上获得。