[Java] 基于 SpringBoot 增加 @BodyParam 注解接收Http请求Body中的Json参数

今天我们也来自己实现一个读取Http请求Body中的Json参数的注解。

import java.lang.annotation.*;

@Target({ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface BodyParam {
    String value() default "";

// 是否必须? boolean required() default true; Class<? extends Annotation> annotation() default Annotation.class; }

2. 编写注解的方法参数处理程序

import com.fasterxml.jackson.databind.ObjectMapper;
import graspyun.weshop.manage.filter.BufferedServletRequestWrapper;
import graspyun.weshop.manage.support.annotation.BodyParam;
import graspyun.weshop.manage.support.exceptions.SysException;
import org.springframework.core.MethodParameter;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.MissingServletRequestParameterException;
import org.springframework.web.bind.support.WebDataBinderFactory;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.method.support.ModelAndViewContainer;

import java.lang.reflect.Array;
import java.util.List;
import java.util.Map;

/**
 * 处理 @BodyParam 注解
 * @author yangyxd
 * @date 2020.08.18 16:03
 */
@Component
public class BodyParamResolver implements HandlerMethodArgumentResolver {

    @Override
    public boolean supportsParameter(MethodParameter methodParameter) {
        return methodParameter.hasParameterAnnotation(BodyParam.class);
    }

    @Override
    public Object resolveArgument(MethodParameter methodParameter, ModelAndViewContainer modelAndViewContainer, NativeWebRequest nativeWebRequest, WebDataBinderFactory webDataBinderFactory) throws Exception {
        BufferedServletRequestWrapper request=(BufferedServletRequestWrapper) nativeWebRequest.getNativeRequest();
        Map json = request.getContentBody();

        BodyParam p = methodParameter.getParameterAnnotation(BodyParam.class);
        Object v = json == null ? null : json.get(p.value().toLowerCase());

        if (v == null) {
            if (p.required()) {
                if (p.message().isEmpty())
                    throw new MissingServletRequestParameterException(p.value(), methodParameter.getParameterType().getTypeName());
                throw new SysException(p.message());
            }
            return null;
        }

        final Class<?> pType = methodParameter.getParameterType();
        final Class<?> vType = v.getClass();
        if (pType == vType)
            return v;

        if (pType == String.class) {
            if (v instanceof Map || v instanceof List || v instanceof Array) {
                ObjectMapper mapper = new ObjectMapper();
                return mapper.writeValueAsString(v);
            }
            if (v instanceof Number) {
                return v.toString();
            }
        } else if (v instanceof Number) {
            if (pType == Long.class)
                return ((Number) v).longValue();
            if (pType == Short.class)
                return ((Number) v).shortValue();
            if (pType == Byte.class)
                return ((Number) v).byteValue();
            if (pType == Double.class)
                return ((Number) v).doubleValue();
            if (pType == Float.class)
                return ((Number) v).floatValue();
        }
        return v;
    }
}

3. 使用带缓冲的 RequestWrapper 替换 SpringBoot 默认的 HttpServletRequestWrapper

由于 SpringBoot 默认的Request中的getInputStream只能读一次,我们的BodyParam肯定会读一次,那么如果再有其它地方还要读,就会出错了,反之其它地方先读过了,我们读也会报错。所以需要使用自定义的带缓存的RequestWrapper替换掉HttpServletRequestWrapper.

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;

import javax.servlet.ServletInputStream;
import javax.servlet.ServletRequest;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;

/**
 * @author yangyxd
 * @date 2020.08.18 14:56
 */
public class BufferedServletRequestWrapper extends HttpServletRequestWrapper {

    private BufferedServletInputStream stream;
    private Map<String,Object> contentBody;

    public BufferedServletRequestWrapper(HttpServletRequest request) {
        super(request);
    }



    @Override
    public ServletInputStream getInputStream() throws IOException {
        if (stream == null) {
            ServletRequest request = this.getRequest();
            int len = this.getRequest().getContentLength();
            byte[] buffer = new byte[len];
            recv(len, request.getInputStream(), buffer);
            this.stream = new BufferedServletInputStream(buffer);
        }
        this.stream.reset();
        return this.stream;
    }

    public Map<String,Object> getContentBody() throws IOException {
        if (contentBody == null) {
            String contentString = recvString(this.getRequest());
            if (contentString.isEmpty()) return null;
            contentBody = new ObjectMapper().readValue(contentString, new TypeReference<Map<String, Object>>() {});
        }
        return contentBody;
    }

    public static void recv(int len, ServletInputStream stream, byte[] buffer) throws IOException {
        int offset = 0;
        while (offset < len) {
            int _len = 4096;
            if (len - offset < _len)
                _len = len - offset;
            int i = stream.read(buffer, offset, _len);
            if (i < 0) throw new IOException("数据接则不完整");
            offset = offset + i;
        }
    }

    public static String recvString(ServletRequest request) throws IOException {
        return recvString(request.getContentLength(), request.getInputStream());
    }

    public static String recvString(int len, ServletInputStream stream) throws IOException {
        if (len > 0) {
            byte[] buffer = new byte[len];
            recv(len, stream, buffer);
            return new String(buffer, StandardCharsets.UTF_8);
        } else
            return "";
    }

}

BufferedServletInputStream

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import java.io.ByteArrayInputStream;
import java.io.IOException;

class BufferedServletInputStream extends ServletInputStream {
    private final ByteArrayInputStream inputStream;

    public BufferedServletInputStream(byte[] buffer) {
        this.inputStream = new ByteArrayInputStream(buffer);
    }

    @Override
    public int available() throws IOException {
        return inputStream.available();
    }

    @Override
    public int read() throws IOException {
        return inputStream.read();
    }

    @Override
    public int read(byte[] b, int off, int len) throws IOException {
        return inputStream.read(b, off, len);
    }

    @Override
    public boolean isFinished() {
        return true;
    }

    @Override
    public boolean isReady() {
        return true;
    }

    @Override
    public void setReadListener(ReadListener readListener) {}

    @Override
    public synchronized void reset() throws IOException {
        this.inputStream.reset();
    }
}

增加过滤器,在里面替换掉 httpServletRequest

import org.springframework.stereotype.Component;
import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.UUID;

@Component
@WebFilter(urlPatterns = "/**",filterName = "RequestFilter")
public class RequestFilter implements Filter {

    private final Logger logger = LoggerFactory.getLogger(RequestFilter.class);
    @Override
    public void init(FilterConfig filterConfig) {
        logger.debug("初始化请求过滤器");
    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        logger.debug("执行过滤器前");

        // 替换成我们的 BufferedServletRequestWrapper
        BufferedServletRequestWrapper httpRequest = new BufferedServletRequestWrapper((HttpServletRequest) servletRequest);
        // 进入下一个过滤器
        filterChain.doFilter(httpRequest, servletResponse);
        logger.debug("执行过滤器后");

    }


    @Override
    public void destroy() {
        logger.debug("销毁过滤器");
    }
}

4. 使用示例

@Controller
@RequestMapping("/manage/xxx")
@Validated
public class XXXController {

    final
    XXXService service;

    public XXXController(XXXService service) {
        this.service = service;
    }

    /**
     * 获取商品列表
     */
    @PostMapping("/list")
    @ResponseBody
    ResponseDTO<PageDataDTO<XXXItemDTO>> getList(
            @BodyParam(value = "from",required = false) Long from,
            @BodyParam(value = "to",required = false) Long to,
            @BodyParam(value = "name", required = false) String name,
            @BodyParam(value = "pagesize",required = false) Integer pageSize,
            @BodyParam(value = "pageindex",required = false) Integer pageIndex,
            @BodyParam(value = "categoryid",required = false) String categoryId,
            @BodyParam("deleted") Integer deleted
    ) throws SysException {        
        return ResponseDTO.success();
    }
    
}