Published on

Spring Security OAuth2实现简单的密钥轮换及配置资源服务器JWK缓存

Authors
  • avatar
    Name
    ReLive27
    Twitter

Spring Security OAuth2实现简单的密钥轮换及配置资源服务器JWK缓存

概述

在OAuth2协议中授权服务器或者OIDC中身份提供服务常使用私钥对JWT令牌进行签名,第三方服务客户端或者资源服务使用已知URL上发布的公钥对令牌进行验证。

这些密钥构成了各方之间安全的基础。为了维护安全性,保持私钥免受任何网络攻击是所必需的。

已确定的用于保护密钥不被泄露的最佳做法之一称为密钥滚动更新密钥轮换。在此方法中,我们丢弃当前密钥并生成一对新密钥,用于对令牌进行签名和验证。

为什么我们需要密钥轮换?

为了确保公钥和私钥对的安全性免受黑客的攻击,建议在一段时间后轮换密钥。必须丢弃以前的密钥,并且必须将新生成的密钥用于进一步的加密操作。根据NIST指南,密钥必须至少每两年轮换一次

如何实现密钥轮换?

所有公钥都由授权服务或身份提供服务在 Web 上发布的URL提供,URL返回一个对象,称为 JSON Web Keys或 JWKS,其中包含多个JSON Web Key (是一种 JSON 数据结构,表示一组公钥),通常称为 JWK。在验证由私钥签名的 JWT时,将使用与私钥相对应的JWK。以下是JWKS 示例,其keys包含一个 JWK 数组。

{
"keys": [
  {
    "alg": "RS256",
    "kty": "RSA",
    "use": "sig",
    "x5c": [
      "MIIC+DCCAeCgAwIBAgIJBIGjYW6hFpn2MA0GCSqGSIb3DQEBBQUAMCMxITAfBgNVBAMTGGN1c3RvbWVyLWRlbW9zLmF1dGgwLmNvbTAeFw0xNjExMjIyMjIyMDVaFw0zMDA4MDEyMjIyMDVaMCMxITAfBgNVBAMTGGN1c3RvbWVyLWRlbW9zLmF1dGgwLmNvbTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMnjZc5bm/eGIHq09N9HKHahM7Y31P0ul+A2wwP4lSpIwFrWHzxw88/7Dwk9QMc+orGXX95R6av4GF+Es/nG3uK45ooMVMa/hYCh0Mtx3gnSuoTavQEkLzCvSwTqVwzZ+5noukWVqJuMKNwjL77GNcPLY7Xy2/skMCT5bR8UoWaufooQvYq6SyPcRAU4BtdquZRiBT4U5f+4pwNTxSvey7ki50yc1tG49Per/0zA4O6Tlpv8x7Red6m1bCNHt7+Z5nSl3RX/QYyAEUX1a28VcYmR41Osy+o2OUCXYdUAphDaHo4/8rbKTJhlu8jEcc1KoMXAKjgaVZtG/v5ltx6AXY0CAwEAAaMvMC0wDAYDVR0TBAUwAwEB/zAdBgNVHQ4EFgQUQxFG602h1cG+pnyvJoy9pGJJoCswDQYJKoZIhvcNAQEFBQADggEBAGvtCbzGNBUJPLICth3mLsX0Z4z8T8iu4tyoiuAshP/Ry/ZBnFnXmhD8vwgMZ2lTgUWwlrvlgN+fAtYKnwFO2G3BOCFw96Nm8So9sjTda9CCZ3dhoH57F/hVMBB0K6xhklAc0b5ZxUpCIN92v/w+xZoz1XQBHe8ZbRHaP1HpRM4M7DJk2G5cgUCyu3UBvYS41sHvzrxQ3z7vIePRA4WF4bEkfX12gvny0RsPkrbVMXX1Rj9t6V7QXrbPYBAO+43JvDGYawxYVvLhz+BJ45x50GFQmHszfY3BR9TPK8xmMmQwtIvLu1PMttNCs7niCYkSiUv2sc2mlq1i3IashGkkgmo="
    ],
    "n": "yeNlzlub94YgerT030codqEztjfU_S6X4DbDA_iVKkjAWtYfPHDzz_sPCT1Axz6isZdf3lHpq_gYX4Sz-cbe4rjmigxUxr-FgKHQy3HeCdK6hNq9ASQvMK9LBOpXDNn7mei6RZWom4wo3CMvvsY1w8tjtfLb-yQwJPltHxShZq5-ihC9irpLI9xEBTgG12q5lGIFPhTl_7inA1PFK97LuSLnTJzW0bj096v_TMDg7pOWm_zHtF53qbVsI0e3v5nmdKXdFf9BjIARRfVrbxVxiZHjU6zL6jY5QJdh1QCmENoejj_ytspMmGW7yMRxzUqgxcAqOBpVm0b-_mW3HoBdjQ",
    "e": "AQAB",
    "kid": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg",
    "x5t": "NjVBRjY5MDlCMUIwNzU4RTA2QzZFMDQ4QzQ2MDAyQjVDNjk1RTM2Qg"
  }
]}

典型的密钥轮换策略可避免客户端发送使用以前颁发的密钥签名的 JWT 的验证失败潜在问题,因此,在令牌完全过期之前,我们需要在一段时间内保持两个密钥(先前和当前)有效 - 刚好足以为客户端提供更新其本地缓存的空间。

先决条件

  • java8+
  • Redis
  • JWT

在阅读文章前,首先说明下文密钥JWK表示相同含义。虽然JWK从规范定义层面表示一组公钥,但是在代码层面JWK所指定的是一组密钥。例如RSAKey,ECKey等等。

授权服务器实现密钥轮换

本节中我们将使用Spring Authorization Server 搭建一个简单的授权服务器,并实现JWKSource自定义密钥轮换逻辑,密钥缓存策略提供本地内存,caffeine,redis三种实现方式。

Maven依赖

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-security</artifactId>
  <version>2.6.7</version>
</dependency>

<dependency>
  <groupId>org.springframework.security</groupId>
  <artifactId>spring-security-oauth2-authorization-server</artifactId>
  <version>0.3.1</version>
</dependency>

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-web</artifactId>
  <version>2.6.7</version>
</dependency>

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-cache</artifactId>
  <version>2.6.7</version>
</dependency>

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-data-redis</artifactId>
  <version>2.6.7</version>
</dependency>

<dependency>
  <groupId>com.github.ben-manes.caffeine</groupId>
  <artifactId>caffeine</artifactId>
  <version>2.9.3</version>
</dependency>

配置

首先我们从application.yml配置开始,这里我们指定授权服务器端口为8080,并添加redis连接配置信息:

server:
  port: 8080

spring:
  redis:
    host: localhost
    database: 0
    port: 6379
    password: 123456
    timeout: 1800
    lettuce:
      pool:
        max-active: 20
        max-wait: 60
        max-idle: 5
        min-idle: 0
      shutdown-timeout: 100


接下来我们将创建AuthorizationServerConfig配置类,用于配置OAuth2及OIDC所需Bean,首先我们将新增OAuth2客户端信息:

    @Bean
    public RegisteredClientRepository registeredClientRepository() {
        RegisteredClient registeredClient = RegisteredClient.withId(UUID.randomUUID().toString())
                .clientId("relive-client")
                .clientSecret("{noop}relive-client")
                .clientAuthenticationMethods(s -> {
                    s.add(ClientAuthenticationMethod.CLIENT_SECRET_POST);
                    s.add(ClientAuthenticationMethod.CLIENT_SECRET_BASIC);
                })
                .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
                .authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
                .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
                .authorizationGrantType(AuthorizationGrantType.PASSWORD)
                .redirectUri("http://127.0.0.1:8070/login/oauth2/code/messaging-client-authorization-code")
                .scope("message.read")
                .clientSettings(ClientSettings.builder()
                        .requireAuthorizationConsent(true)
                        .requireProofKey(false)
                        .build())
                .tokenSettings(TokenSettings.builder()
                        .accessTokenFormat(OAuth2TokenFormat.SELF_CONTAINED)
                        .idTokenSignatureAlgorithm(SignatureAlgorithm.RS256)
                        .accessTokenTimeToLive(Duration.ofSeconds(30 * 60))
                        .refreshTokenTimeToLive(Duration.ofSeconds(60 * 60))
                        .reuseRefreshTokens(true)
                        .build())
                .build();


        return new InMemoryRegisteredClientRepository(registeredClient);
    }

和以往文章一样,指定OAuth2客户端信息,并将OAuth2客户端信息存储在内存中,如果你需要配置数据库存储,请参考文章将JWT与Spring Security OAuth2结合使用


简化其他高级配置,使用OAuth2授权服务默认配置,并将未认证的授权请求重定向到登录页面:

    @Bean
    @Order(Ordered.HIGHEST_PRECEDENCE)
    public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception {
        OAuth2AuthorizationServerConfiguration.applyDefaultSecurity(http);
        return http.exceptionHandling(exceptions -> exceptions.
                        authenticationEntryPoint(new LoginUrlAuthenticationEntryPoint("/login"))).build();
    }

自定义JWKSource实现密钥轮换

之前文章中授权服务启动时随机生成一个2048字节的RSA密钥,用于令牌的签名密钥。本示例中我们将自定义JWKSource并实现密钥轮换策略:

public final class RotateJwkSource<C extends SecurityContext> implements JWKSource<C> {
    private final JWKSource<C> failoverJWKSource;
    private final JWKSetCache jwkSetCache;
    private final JWKGenerator<? extends JWK> jwkGenerator;
    private KeyIDStrategy keyIDStrategy;

    public RotateJwkSource() {
        this(new InMemoryJWKSetCache(), null, null, null);
    }

    public RotateJwkSource(JWKSetCache jwkSetCache) {
        this(jwkSetCache, null, null, null);
    }

    public RotateJwkSource(JWKSetCache jwkSetCache, JWKSource<C> failoverJWKSource) {
        this(jwkSetCache, failoverJWKSource, null, null);
    }

    public RotateJwkSource(JWKSetCache jwkSetCache, JWKGenerator<? extends JWK> jwkGenerator) {
        this(jwkSetCache, null, jwkGenerator, null);
    }

    public RotateJwkSource(JWKSetCache jwkSetCache, JWKSource<C> failoverJWKSource, JWKGenerator<? extends JWK> jwkGenerator, KeyIDStrategy keyIDStrategy) {
        Assert.notNull(jwkSetCache, "jwkSetCache cannot be null");
        this.jwkSetCache = jwkSetCache;
        this.failoverJWKSource = failoverJWKSource;
        if (jwkGenerator == null) {
            this.jwkGenerator = new RSAKeyGenerator(RSAKeyGenerator.MIN_KEY_SIZE_BITS);
        } else {
            this.jwkGenerator = jwkGenerator;
        }
        if (keyIDStrategy == null) {
            this.keyIDStrategy = new TimestampKeyIDStrategy();
        } else {
            this.keyIDStrategy = keyIDStrategy;
        }

    }

    @Override
    public List<JWK> get(JWKSelector jwkSelector, C context) throws RotateKeySourceException {
        JWKSet jwkSet = this.jwkSetCache.get();
        if (this.jwkSetCache.requiresRefresh() || jwkSet == null) {
            try {
                synchronized (this) {
                    jwkSet = this.jwkSetCache.get();
                    if (this.jwkSetCache.requiresRefresh() || jwkSet == null) {
                        jwkSet = this.updateJWKSet(jwkSet);
                    }
                }
            } catch (Exception e) {
                List<JWK> failoverMatches = this.failover(e, jwkSelector, context);
                if (failoverMatches != null) {
                    return failoverMatches;
                }

                if (jwkSet == null) {
                    throw e;
                }
            }
        }
        List<JWK> jwks = jwkSelector.select(jwkSet);
        if (!jwks.isEmpty()) {
            return jwks;
        } else {
            return Collections.emptyList();
        }
    }

    private JWKSet updateJWKSet(JWKSet jwkSet) throws RotateKeySourceException {
        JWK jwk;
        try {
            jwkGenerator.keyID(keyIDStrategy.generateKeyID());
            jwk = jwkGenerator.generate();
        } catch (JOSEException e) {
            throw new RotateKeySourceException("Couldn't generate JWK:" + e.getMessage(), e);
        }
        JWKSet updateJWKSet = new JWKSet(jwk);
        this.jwkSetCache.put(updateJWKSet);
        if (jwkSet != null) {
            List<JWK> keys = jwkSet.getKeys();
            List<JWK> updateJwks = new ArrayList<>(keys);
            updateJwks.add(jwk);
            updateJWKSet = new JWKSet(updateJwks);
        }
        return updateJWKSet;
    }

    private List<JWK> failover(Exception exception, JWKSelector jwkSelector, C context) throws RotateKeySourceException {
        if (this.getFailoverJWKSource() == null) {
            return null;
        } else {
            try {
                return this.getFailoverJWKSource().get(jwkSelector, context);
            } catch (KeySourceException e) {
                throw new RotateKeySourceException(exception.getMessage() + "; Failover JWK source retrieval failed with: " + e.getMessage(), e);
            }
        }
    }

    public void setKeyIDStrategy(KeyIDStrategy keyIDStrategy) {
        this.keyIDStrategy = keyIDStrategy;
    }
}

RotateJwkSource为包含密钥轮换的JWKSource的实现类,遵循以下步骤:

  • 首先从JWKSetCache缓存中获取JWKSet(JWKSet仅包含未过期JWK)。本示例中自定义JWKSetCache实现类有InMemoryJWKSetCacheCaffeineJWKSetCacheRedisJWKSetCache

  • 如果JWKSet不为空或不需要刷新密钥,则通过JWKSelector从指定的 JWK 集中选择与配置的条件匹配的JWK。

  • 否则,执行updateJWKSet(JWKSet jwkSet)生成新的密钥对添加进缓存,并返回新的JWKSet(JWKSet仅包含未过期JWK)。

JWKSetCache定义密钥刷新周期及密钥过期时间。

RotateJwkSource属性介绍:

  • failoverJWKSource:故障转移 JWKSource。
  • jwkSetCache:JWKSet缓存接口类,定义密钥刷新周期,密钥过期时间。本示例中提供三种实现类,InMemoryJWKSetCacheCaffeineJWKSetCacheRedisJWKSetCache
  • jwkGenerator:密钥生成器,RotateJwkSource默认使用RSAKeyGenerator
  • KeyIDStrategykid生成策略,本示例中使用时间戳表示kid

基于本地内存,caffeine,redis的JWKSetCache

本示例用于测试需要,密钥刷新周期定为5分钟,密钥过期时间定为15分钟,实际应用中请根据需要修改。

InMemoryJWKSetCache实现方式相对简单。由JWKWithTimestamp存储密钥对,lifespan为密钥过期时间,refreshTime为密钥刷新周期。为确保密钥轮换正常使用,建议 lifespan >= refreshTime + exp

public class InMemoryJWKSetCache implements JWKSetCache {
    private final long lifespan;
    private final long refreshTime;
    private final TimeUnit timeUnit;
    private volatile Set<JWKWithTimestamp> jwkWithTimestamps;


    public InMemoryJWKSetCache() {
        this(15L, 5L, TimeUnit.MINUTES);
    }

    public InMemoryJWKSetCache(long lifespan, long refreshTime, TimeUnit timeUnit) {
        this.lifespan = lifespan;
        this.refreshTime = refreshTime;
        if ((lifespan > -1L || refreshTime > -1L) && timeUnit == null) {
            throw new IllegalArgumentException("A time unit must be specified for non-negative lifespans or refresh times");
        } else {
            this.timeUnit = timeUnit;
        }
        this.jwkWithTimestamps = new LinkedHashSet<>();
    }

    @Override
    public void put(JWKSet jwkSet) {
        if (jwkSet != null) {
            if (!CollectionUtils.isEmpty(jwkSet.getKeys())) {
                List<JWKWithTimestamp> updateJWKWithTs = jwkSet.getKeys().stream().map(JWKWithTimestamp::new)
                        .collect(Collectors.toList());
                this.jwkWithTimestamps.addAll(updateJWKWithTs);
            }
        }
    }

    @Override
    public JWKSet get() {
        return !CollectionUtils.isEmpty(this.jwkWithTimestamps) && !this.isExpired() ? new JWKSet(this.jwkWithTimestamps.stream()
                .filter(t -> t.getDate().getTime() + TimeUnit.MILLISECONDS.convert(this.lifespan, this.timeUnit) > (new Date()).getTime())
                .map(JWKWithTimestamp::getJwk).collect(Collectors.toList())) : null;
    }

    @Override
    public boolean requiresRefresh() {
        return !CollectionUtils.isEmpty(this.jwkWithTimestamps) && this.refreshTime > -1L && this.jwkWithTimestamps.stream().map(jwkWithTimestamp -> jwkWithTimestamp.getDate().getTime())
                .max(Long::compareTo)
                .filter(time -> (new Date()).getTime() > time + TimeUnit.MILLISECONDS.convert(this.refreshTime, this.timeUnit))
                .isPresent();
    }

    public boolean isExpired() {
        return !CollectionUtils.isEmpty(this.jwkWithTimestamps) && this.lifespan > -1L && this.jwkWithTimestamps.stream().map(jwkWithTimestamp -> jwkWithTimestamp.getDate().getTime())
                .max(Long::compareTo)
                .filter(time -> (new Date()).getTime() > time + TimeUnit.MILLISECONDS.convert(this.lifespan, this.timeUnit))
                .isPresent();
    }

    public long getLifespan(TimeUnit timeUnit) {
        return this.lifespan < 0L ? this.lifespan : timeUnit.convert(this.lifespan, this.timeUnit);
    }

    public long getRefreshTime(TimeUnit timeUnit) {
        return this.refreshTime < 0L ? this.refreshTime : timeUnit.convert(this.refreshTime, this.timeUnit);
    }
}

InMemoryJWKSetCache中put方法将密钥对及当前时间封装为JWKWithTimestamp并添加到LinkedHashSet。get方法从LinkedHashSet过滤获取未过期JWK并返回。

此方式中使用ScheduledFuture开启单独任务清除过期密钥。


Caffeine — 一个用于 Java 的高性能缓存库CaffeineJWKSetCache基于Caffeine实现密钥存储。lifespan为密钥过期时间,refreshTime为密钥刷新周期。建议 lifespan >= refreshTime + exp

Caffeine三种缓存填充策略:手动、同步加载和异步加载。其中我们选用手动填充将密钥放入缓存中,并在get()中检索它们。

public class CaffeineJWKSetCache implements JWKSetCache {
    private final long lifespan;
    private final long refreshTime;
    private final TimeUnit timeUnit;
    private final Cache<Long, JWK> cache;

    public CaffeineJWKSetCache() {
        this(15L, 5L, TimeUnit.MINUTES);
    }

    public CaffeineJWKSetCache(long lifespan, long refreshTime, TimeUnit timeUnit) {
        this.lifespan = lifespan;
        this.refreshTime = refreshTime;
        if ((lifespan > -1L || refreshTime > -1L) && timeUnit == null) {
            throw new IllegalArgumentException("A time unit must be specified for non-negative lifespans or refresh times");
        } else {
            this.timeUnit = timeUnit;
        }
        Caffeine<Object, Object> caffeine = Caffeine.newBuilder().maximumSize(10);
        if (lifespan > -1L) {
            caffeine.expireAfterWrite(this.lifespan, this.timeUnit);
        }
        this.cache = caffeine.build();

    }

    @Override
    public void put(JWKSet jwkSet) {
        if (jwkSet != null) {
            if (!CollectionUtils.isEmpty(jwkSet.getKeys())) {
                jwkSet.getKeys().forEach(jwk -> cache.put(new Date().getTime(), jwk));
            }
        }
    }

    @Override
    public JWKSet get() {
        List<@NonNull JWK> jwks = new ArrayList<>(cache.asMap().values());
        return CollectionUtils.isEmpty(jwks) ? null : new JWKSet(jwks);
    }

    @Override
    public boolean requiresRefresh() {
        return this.refreshTime > -1L && cache.asMap().keySet().stream()
                .max(Long::compareTo)
                .filter(time -> (new Date()).getTime() > time + TimeUnit.MILLISECONDS.convert(this.refreshTime, this.timeUnit))
                .isPresent();
    }
}

Redis — 流行的内存数据结构存储。RedisJWKSetCache使用Redis有序集合(sorted set)存储密钥,scope为密钥放进缓存的时间。

Redis有序集合不支持对单个元素设置过期时间,所以我们将通过使用scope存储密钥缓存时间,并在每次更新缓存时计算已过期密钥,使用zRemRangeByScore命令移除已过期密钥。建议 lifespan >= refreshTime + exp

public class RedisJWKSetCache implements JWKSetCache {
    private static final boolean springDataRedis_2_0 = ClassUtils.isPresent("org.springframework.data.redis.connection.RedisStandaloneConfiguration", RedisJWKSetCache.class.getClassLoader());
    private final RedisConnectionFactory connectionFactory;
    private final String JWK_KEY = "jwks";
    private String prefix = "";
    private RedisSerializer<String> redisSerializeKey = new StringRedisSerializer();
    private RedisSerializer<String> redisSerializerValue = new Jackson2JsonRedisSerializer<>(String.class);
    private Method redisConnectionSet_2_0;
    private final long lifespan;
    private final long refreshTime;
    private final TimeUnit timeUnit;

    public RedisJWKSetCache(RedisConnectionFactory connectionFactory) {
        this(15L, 5L, TimeUnit.MINUTES, connectionFactory);
    }

    public RedisJWKSetCache(long lifespan, long refreshTime, TimeUnit timeUnit, RedisConnectionFactory connectionFactory) {
        this.lifespan = lifespan;
        this.refreshTime = refreshTime;
        if ((lifespan > -1L || refreshTime > -1L) && timeUnit == null) {
            throw new IllegalArgumentException("A time unit must be specified for non-negative lifespans or refresh times");
        } else {
            this.timeUnit = timeUnit;
        }
        Assert.notNull(connectionFactory, "redisConnectionFactory cannot be null");
        this.connectionFactory = connectionFactory;
        if (springDataRedis_2_0) {
            this.loadRedisConnectionMethods_2_0();
        }

    }


    @Override
    public void put(JWKSet jwkSet) {
        if (jwkSet != null) {
            if (!CollectionUtils.isEmpty(jwkSet.getKeys())) {
                RedisConnection connection = this.getConnection();
                byte[] key = this.serializeKey(JWK_KEY);

                connection.openPipeline();

                if (this.lifespan > -1) {
                    long max = new Date().getTime() - TimeUnit.MILLISECONDS.convert(this.lifespan, this.timeUnit);
                    connection.zRemRangeByScore(key, Range.range().lte(max));
                }

                List<JWK> keys = jwkSet.getKeys();
                try {
                    for (JWK jwk : keys) {
                        byte[] value = this.serialize(jwk.toJSONString());

                        if (springDataRedis_2_0) {
                            try {
                                this.redisConnectionSet_2_0.invoke(connection, key, new Date().getTime(), value);
                            } catch (Exception e) {
                                throw new RuntimeException(e);
                            }
                        } else {
                            connection.zAdd(key, new Date().getTime(), value);
                        }
                    }
                    connection.closePipeline();
                } finally {
                    connection.close();
                }
            }
        }
    }

    @Override
    public JWKSet get() {
        RedisConnection connection = this.getConnection();
        byte[] key = this.serializeKey(JWK_KEY);
        try {
            Long efficientCount = Optional.ofNullable(connection.zCard(key)).orElse(0L);
            if (efficientCount > 0) {
                Set<byte[]> jwkBytes = connection.zRevRangeByScore(key, Range.range());
                List<JWK> jwks = jwkBytes.stream().map(this::deserialize).map(this::parse).collect(Collectors.toList());
                return new JWKSet(jwks);
            }

            return null;
        } finally {
            connection.close();
        }
    }

    private JWK parse(String jwkJsonString) {
        try {
            return JWK.parse(jwkJsonString);
        } catch (ParseException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public boolean requiresRefresh() {
        RedisConnection connection = this.getConnection();
        byte[] key = this.serializeKey("jwks");
        try {
            Long efficientCount = Optional.ofNullable(connection.zCard(key)).orElse(0L);
            Set<Tuple> maximumScoreTuple = connection.zRevRangeByScoreWithScores(key, Range.range(), Limit.limit().count(1));

            long lastRefreshTime = 0L;
            if (!CollectionUtils.isEmpty(maximumScoreTuple)) {
                lastRefreshTime = maximumScoreTuple.stream().findFirst().orElse(new DefaultTuple(null, 0.0)).getScore().longValue();
            }
            return efficientCount > 0 && this.refreshTime > -1L && (new Date()).getTime() > lastRefreshTime + TimeUnit.MILLISECONDS.convert(this.refreshTime, this.timeUnit);
        } finally {
            connection.close();
        }

    }

    private byte[] serializeKey(String key) {
        return this.redisSerializeKey.serialize(this.prefix + key);
    }

    private byte[] serialize(String value) {
        return this.redisSerializerValue.serialize(value);
    }

    private String deserialize(byte[] bytes) {
        return this.redisSerializerValue.deserialize(bytes);
    }

    private void loadRedisConnectionMethods_2_0() {
        this.redisConnectionSet_2_0 = ReflectionUtils.findMethod(RedisConnection.class, "zAdd", new Class[]{byte[].class, double.class, byte[].class});
    }

    private RedisConnection getConnection() {
        return this.connectionFactory.getConnection();
    }

    public void setPrefix(String prefix) {
        this.prefix = prefix;
    }

    public void setRedisSerializerKey(RedisSerializer<String> redisSerializer) {
        this.redisSerializeKey = redisSerializer;
    }

    public void setRedisSerializerValue(RedisSerializer<String> redisSerializer) {
        this.redisSerializerValue = redisSerializer;
    }
}


介绍完本示例中密钥轮换实现逻辑,接下来让我们配置RotateJwkSource应用于授权服务:

    @Bean
    public JWKSource<SecurityContext> jwkSource(RedisConnectionFactory connectionFactory) {
        RedisJWKSetCache redisJWKSetCache = new RedisJWKSetCache(connectionFactory);
        redisJWKSetCache.setPrefix("auth-server");

        return new RotateJwkSource<>(redisJWKSetCache);
    }

是否记得前面提到的避免客户端发送使用以前颁发的密钥签名的 JWT 的验证失败潜在问题,在令牌完全过期之前,我们需要在一段时间内保持两个密钥。所以授权服务在签发JWT令牌时,由于某一段时间存在多个密钥,因此在JwtEncoder生成JWT时将提示以下错误信息:

org.springframework.security.oauth2.jwt.JwtEncodingException: An error occurred while attempting to encode the Jwt: Found multiple JWK signing keys for algorithm 'RS256'

所以我们需要生成JWT前指定kid属性, JWKSelector将从指定的 JWKS 中选择与kid相对应的JWK用于生成JWT。Spring Authorization Server中OAuth2TokenCustomizer提供了自定义属性的能力,根据密钥轮换策略,我们需要使用最新密钥生成JWT,RotateJwkSource中kid生成策略由时间戳定义,所以JWKS中最新的密钥将会是最大值kid对应的密钥,我们将获取最大值kid放入JWT的Header中。

    @Bean
    public OAuth2TokenCustomizer<JwtEncodingContext> tokenCustomizer(JWKSource<SecurityContext> jwkSource) {
        return (context) -> {
            if (OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType()) ||
                    OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) {

                JWKSelector jwkSelector = new JWKSelector(new JWKMatcher.Builder().build());
                List<JWK> jwks;
                try {
                    jwks = jwkSource.get(jwkSelector, null);
                } catch (KeySourceException e) {
                    throw new IllegalStateException("Failed to select the JWK(s) -> " + e.getMessage(), e);
                }
                String kid = jwks.stream().map(JWK::getKeyID)
                        .max(String::compareTo)
                        .orElseThrow(() -> new IllegalArgumentException("kid not found"));
                context.getHeaders().keyId(kid);
            }
        };
    }

本示例中kid由时间戳定义,所以确保密钥轮换后使用最新密钥,我们将获取最大值kid所对应的密钥进行签名。但是kid若不使用类似于时间戳的递增值,将建议按照FIFO(先进先出)结构,其格式是将新生成的密钥推送到末尾。


最后让我们配置Form表单认证方式,并设置用户名和密码:

    @Bean
    SecurityFilterChain defaultSecurityFilterChain(HttpSecurity http) throws Exception {
        http.authorizeHttpRequests((authorize) -> authorize.anyRequest().authenticated())
                .formLogin(Customizer.withDefaults());

        return http.build();
    }

    @Bean
    UserDetailsService userDetailsService() {
        UserDetails userDetails = User.withUsername("admin")
                .password("{noop}password")
                .roles("ADMIN")
                .build();
        return new InMemoryUserDetailsManager(userDetails);
    }

配置资源服务

本节中我们将使用Spring Security搭建OAuth2资源服务,并且我们将为JwtDecoder配置redis缓存。

Maven依赖

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-web</artifactId>
  <version>2.6.7</version>
</dependency>
<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-security</artifactId>
  <version>2.6.7</version>
</dependency>
<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-oauth2-resource-server</artifactId>
  <version>2.6.7</version>
</dependency>

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-cache</artifactId>
  <version>2.6.7</version>
</dependency>

<dependency>
  <groupId>org.springframework.boot</groupId>
  <artifactId>spring-boot-starter-data-redis</artifactId>
  <version>2.6.7</version>
</dependency>

配置

首先我们从application.yml文件配置开始,指定端口8090,并添加redis配置和OAuth2配置信息:

server:
  port: 8090

spring:
  redis:
    host: localhost
    database: 0
    port: 6379
    password: 123456
    timeout: 1800
    lettuce:
      pool:
        max-active: 20
        max-wait: 60
        max-idle: 5
        min-idle: 0
      shutdown-timeout: 100
  security:
    oauth2:
      resourceserver:
        jwt:
          jwk-set-uri: http://127.0.0.1:8080/oauth2/jwks


Spring Boot 自动配置一个具有默认缓存配置的RedisCacheManager 。但是,我们可以在缓存管理器初始化之前修改此配置,将缓存过期时间设置为5分钟

    @Bean
    public CacheManager cacheManager(RedisConnectionFactory factory) {
        RedisSerializer<String> redisSerializer = new StringRedisSerializer();
        Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
        ObjectMapper om = new ObjectMapper();
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        om.activateDefaultTyping(LaissezFaireSubTypeValidator.instance, ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(om);
        // 配置序列化(解决乱码的问题),过期时间5分钟
        RedisCacheConfiguration config = RedisCacheConfiguration.defaultCacheConfig()
                .entryTtl(Duration.ofSeconds(5*60))
                .serializeKeysWith(RedisSerializationContext.SerializationPair.fromSerializer(redisSerializer))
                .serializeValuesWith(RedisSerializationContext.SerializationPair.fromSerializer(jackson2JsonRedisSerializer))
                .disableCachingNullValues();
        RedisCacheManager cacheManager = RedisCacheManager.builder(factory)
                .cacheDefaults(config)
                .build();
        return cacheManager;
    }

阅读到这里,你是否有疑问授权服务轮换密钥后资源服务如何获取最新密钥验证JWT。

在此之前让我们先了解下JwtDecoder工作原理,以下是简单声明JwtDecoder的示例:

@Bean
public JwtDecoder jwtDecoder() {
    return NimbusJwtDecoder.withJwkSetUri(jwkSetUri).build();
}

为了清楚起见,部分源码细节已被省略。

当我们查看源码NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder构建器中,可以看到内部创建了JWKSource的实现类RemoteJWKSet,注意我们没有配置Cache,所以最终执行return new RemoteJWKSet(toURL(this.jwkSetUri), jwkSetRetriever);。这里ResourceRetriever实现类为RestOperationsResourceRetriever

        JWKSource<SecurityContext> jwkSource(ResourceRetriever jwkSetRetriever) {
            if (this.cache == null) {
                return new RemoteJWKSet(toURL(this.jwkSetUri), jwkSetRetriever);
            } else {
                ResourceRetriever cachingJwkSetRetriever = new NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder.CachingResourceRetriever(this.cache, jwkSetRetriever);
                return new RemoteJWKSet(toURL(this.jwkSetUri), cachingJwkSetRetriever, new NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder.NoOpJwkSetCache());
            }
        }

JwtDecoder验证JWT需要通过RemoteJWKSet获取JWK, RemoteJWKSet由 JWKS URL 指定的远程 JSON Web KEY (JWK) 端点。检索到的 JWKS将被缓存以最小化网络调用。每当JWKSelector尝试获取具有未知 kid时,都会更新缓存。以下为RemoteJWKSet核心方法:

    public List<JWK> get(JWKSelector jwkSelector, C context) throws RemoteKeySourceException {
        JWKSet jwkSet = this.jwkSetCache.get();
        if (this.jwkSetCache.requiresRefresh() || jwkSet == null) {
            try {
                jwkSet = this.updateJWKSetFromURL();
            } catch (Exception var6) {
                if (jwkSet == null) {
                    throw var6;
                }
            }
        }

        List<JWK> matches = jwkSelector.select(jwkSet);
        if (!matches.isEmpty()) {
            return matches;
        } else {
            String soughtKeyID = getFirstSpecifiedKeyID(jwkSelector.getMatcher());
            if (soughtKeyID == null) {
                return Collections.emptyList();
            } else if (jwkSet.getKeyByKeyId(soughtKeyID) != null) {
                return Collections.emptyList();
            } else {
                jwkSet = this.updateJWKSetFromURL();
                return jwkSet == null ? Collections.emptyList() : jwkSelector.select(jwkSet);
            }
        }
    }

RemoteJWKSet遵循以下步骤:

  • JWKSetCache获取JWKSet,RemoteJWKSet中默认实现为DefaultJWKSetCache,默认情况DefaultJWKSetCache将授权服务器的 JWKS 缓存 5 分钟。
  • JWKSetCache中JWKSet为空或者需要刷新JWK更新缓存时,RestOperationsResourceRetriever将发起HTTP请求向授权服务获取JWKS。
  • JWKSelector从指定的 JWKS中选择与配置的条件匹配的JWK。若匹配为空则将重新通过RestOperationsResourceRetrieve向授权服务请求获取JWKS,再次匹配结果为空则返回空值。

通过简要了解RemoteJWKSet执行过程,我相信对于之前授权服务器轮换密钥后资源服务如何获取最新密钥已经有了答案。

在授权服务密钥轮换后生成JWT的Header中kid使用的是当前最新密钥所对应的kid,此时资源服务收到JWT,通过RemoteJWKSet获取JWK用于验证JWT时,JWKSelectorJWKSetCache返回的JWKS中并没有匹配到条件相符的JWK,所以将会使用RestOperationsResourceRetrieve重新向授权服务获取最新JWKS,JWKSelector将再次选择与条件相符的JWK。

但是分布式系统中协作服务器数量的增加,授权服务密钥轮换后,涉及资源服务都要重新请求授权服务获取最新JWK,当然这并不会对授权服务造成太大压力。但是为了最小化网络调用,本示例使用共享缓存解决此问题。


接下来我们将为JwtDecoder配置Redis缓存,Redis将使用 JWKS Uri 作为键,并使用 JWKS JSON 作为值:

    @Bean
    JwtDecoder jwtDecoder(OAuth2ResourceServerProperties properties, RestOperations restOperations, CacheManager cacheManager) {
        NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withJwkSetUri(properties.getJwt().getJwkSetUri())
                .restOperations(restOperations)
                .cache(cacheManager.getCache("jwks"))
                .jwsAlgorithms(algorithms -> {
                    algorithms.add(RS256);
                }).build();

        //自定义时间戳验证
        OAuth2TokenValidator<Jwt> withClockSkew = new DelegatingOAuth2TokenValidator<>(
                new JwtTimestampValidator(Duration.ofSeconds(60)));

        jwtDecoder.setJwtValidator(withClockSkew);

        return jwtDecoder;
    }

此时RemoteJWKSet的ResourceRetriever属性实际赋值为CachingResourceRetriever, 我们使用的是Redis缓存,CachingResourceRetriever中更新JWKS会先从Redis缓存中获取,若Redis缓存为空则将请求授权服务,部分源码如下:

public Resource retrieveResource(URL url) throws IOException {
    String jwkSet = (String)this.cache.get(url.toString(), () -> {
      return this.resourceRetriever.retrieveResource(url).getContent();
    });
    return new Resource(jwkSet, "UTF-8");
}

这里引出一个新的问题,Redis缓存为空才会重新请求授权服务JWKS Uri,如果某个时刻授权服务密钥轮换后,资源服务Redis缓存此时存在值,则不会重新向授权服务发起请求来更新资源服务JWKS缓存,此时资源服务验证轮换后的密钥生成的JWT将会失败。

解决此问题我们可以在授权服务密钥轮换后清除Redis中资源服务JWKS缓存信息。


最后我们将使用Spring Security保护资源服务端点,指定受保护端点访问权限:

    @Bean
    SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
        http.authorizeHttpRequests((authorize) -> authorize
                .antMatchers("/resource/article").hasAuthority("SCOPE_message.read")
                .anyRequest().authenticated())
                .oauth2ResourceServer(OAuth2ResourceServerConfigurer::jwt);

        return http.build();
    }
@RestController
public class ArticleController {


    @GetMapping("/resource/article")
    public Map<String, Object> getArticle(@AuthenticationPrincipal Jwt jwt) {
        Map<String, Object> result = new HashMap<>();
        result.put("principal", jwt.getClaims());
        result.put("article", Arrays.asList("article1", "article2", "article3"));
        return result;
    }
}

测试我们的应用程序

本文中并没有说明客户端服务创建,客户端服务并不是本文介绍重点,若有疑问可以参考之前文章或者通过文末链接获取源码。

接下来我们将启动所有服务并访问http://127.0.0.1:8070/client/article 。在等待授权服务轮换密钥后,访问依旧正常。

结论

与往常一样,本文中使用的源代码可在 GitHub 上获得。