拦截器从spring mvc 改为servlet

This commit is contained in:
dark 2021-03-18 22:46:22 +08:00
parent fe0018b3a8
commit fe956b71f6
3 changed files with 64 additions and 91 deletions

View File

@ -0,0 +1,64 @@
package cn.iocoder.dashboard.framework.tracer.filter;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.dashboard.framework.tracer.core.util.TracerUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.HandlerInterceptor;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
/**
* 对Spring Mvc 的请求拦截, 添加traceId.
*
* @author mashu
*/
@Slf4j
@Component
public class ServletTraceFilter implements Filter {
@Value("${cn.iocoder.tracer.name:global-trace-id}")
private String traceIdName;
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
HttpServletResponse httpServletResponse = (HttpServletResponse) servletResponse;
try {
// 请求中traceId
String reqTraceId = (String)httpServletRequest.getHeader(traceIdName);
// skywalking中的traceId
String skywalkingTraceId = TracerUtils.getSkywalkingTraceId();
String traceId ;
if (null == reqTraceId && StrUtil.isBlank(skywalkingTraceId)) {
// 两者皆空,添加默认的.
traceId = TracerUtils.getTraceId();
httpServletResponse.setHeader(traceIdName, traceId);
} else if (null == reqTraceId && StrUtil.isNotBlank(skywalkingTraceId)){
// 若请求空,则添加,为没有skywalking的系统添加一个TraceId
traceId = skywalkingTraceId;
httpServletResponse.setHeader(traceIdName, traceId);
} else if (null != reqTraceId && StrUtil.isBlank(skywalkingTraceId)) {
// 请求非空, skywalking为空
traceId = reqTraceId;
} else {
// 两者皆非空,不动请求头
traceId = skywalkingTraceId;
}
TracerUtils.saveThreadTraceId(traceId);
log.debug("请求进入,添加traceId[{}]", traceId);
filterChain.doFilter(httpServletRequest, httpServletResponse);
} finally {
// 请求结束,删除本地的链路流水号
log.debug("请求结束,删除traceId[{}]", TracerUtils.getTraceId());
TracerUtils.deleteThreadTraceId();
}
}
}

View File

@ -1,68 +0,0 @@
package cn.iocoder.dashboard.framework.tracer.filter;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.dashboard.framework.tracer.core.util.TracerUtils;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
import org.springframework.lang.Nullable;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.AsyncHandlerInterceptor;
import org.springframework.web.servlet.HandlerInterceptor;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
/**
* 对Spring Mvc 的请求拦截, 添加traceId.
*
* @author mashu
*/
@Slf4j
@Component
public class SpringMvcTraceFilter implements HandlerInterceptor {
@Value("${cn.iocoder.tracer.name:global-trace-id}")
private String traceIdName;
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
// 请求中traceId
String reqTraceId = (String)request.getAttribute(traceIdName);
// skywalking中的traceId
String skywalkingTraceId = TracerUtils.getSkywalkingTraceId();
String traceId ;
if (null == reqTraceId && StrUtil.isBlank(skywalkingTraceId)) {
// 两者皆空,添加默认的.
traceId = TracerUtils.getTraceId();
request.setAttribute(traceIdName, traceId);
} else if (null == reqTraceId && StrUtil.isNotBlank(skywalkingTraceId)){
// 若请求空,则添加,为没有skywalking的系统添加一个TraceId
traceId = skywalkingTraceId;
request.setAttribute(traceIdName, traceId);
} else if (null != reqTraceId && StrUtil.isBlank(skywalkingTraceId)) {
// 请求非空, skywalking为空
traceId = reqTraceId;
} else {
// 两者皆非空,不动请求头
traceId = skywalkingTraceId;
}
TracerUtils.saveThreadTraceId(traceId);
log.debug("请求进入,添加traceId[{}]", traceId);
return true;
}
public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, @Nullable Exception ex) throws Exception {
// 请求结束,删除本地的链路流水号
log.debug("请求结束,删除traceId[{}]", TracerUtils.getTraceId());
TracerUtils.deleteThreadTraceId();
}
}

View File

@ -1,23 +0,0 @@
package cn.iocoder.dashboard.framework.tracer.filter;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.stereotype.Component;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
import javax.annotation.Resource;
@Configuration
@Component
public class WebConfig implements WebMvcConfigurer {
@Resource
private SpringMvcTraceFilter springMvcTraceFilter;
@Override
public void addInterceptors(InterceptorRegistry registry) {
registry.addInterceptor(this.springMvcTraceFilter).addPathPatterns("/**");
}
}