diff --git a/src/main/java/me/zhyd/oauth/request/AuthDefaultRequest.java b/src/main/java/me/zhyd/oauth/request/AuthDefaultRequest.java index 7ca3e81..bd30cfa 100644 --- a/src/main/java/me/zhyd/oauth/request/AuthDefaultRequest.java +++ b/src/main/java/me/zhyd/oauth/request/AuthDefaultRequest.java @@ -72,6 +72,7 @@ public abstract class AuthDefaultRequest implements AuthRequest { throw new AuthException(AuthResponseStatus.ILLEGAL_REQUEST); } AuthChecker.checkCode(source == AuthSource.ALIPAY ? authCallback.getAuth_code() : authCallback.getCode()); + AuthChecker.checkState(authCallback); AuthToken authToken = this.getAccessToken(authCallback); AuthUser user = this.getUserInfo(authToken); diff --git a/src/main/java/me/zhyd/oauth/utils/AuthChecker.java b/src/main/java/me/zhyd/oauth/utils/AuthChecker.java index a62eedc..2d1c8c3 100644 --- a/src/main/java/me/zhyd/oauth/utils/AuthChecker.java +++ b/src/main/java/me/zhyd/oauth/utils/AuthChecker.java @@ -3,7 +3,8 @@ package me.zhyd.oauth.utils; import me.zhyd.oauth.config.AuthConfig; import me.zhyd.oauth.config.AuthSource; import me.zhyd.oauth.exception.AuthException; -import me.zhyd.oauth.enums.AuthResponseStatus; +import me.zhyd.oauth.model.AuthCallback; +import me.zhyd.oauth.model.AuthResponseStatus; /** * 授权配置类的校验器 @@ -65,4 +66,15 @@ public class AuthChecker { throw new AuthException(AuthResponseStatus.ILLEGAL_CODE); } } + + /** + * 校验回调传回的state + * + * @param authCallback 回调 + */ + public static void checkState(AuthCallback authCallback) { + if (!authCallback.checkState()) { + throw new AuthException(AuthResponseStatus.ILLEGAL_REQUEST); + } + } }