If anyone is still interested in the answer, here's how I solved it:
Create a custom type that extends HttpServletRequestWrapper. Make sure you override
- getInputStream()
- getReader()
- getParameter()
- getParameterMap()
- getParameterNames()
- getParameterValues()
This is because when Resteasy tries to bind using @Form, @FormParam, and @QueryParam etc, it calls the getParameter() method on the Resteasy class, which is then delegated to the underlying request, in my case, Apache's Coyote Servlet Request. So overriding getInputStream() and getReader() alone are not enough, you must make sure that getParameter() utilize the new input stream as well.
If you want to store the body for later use, you must then construct the param map yourself by parsing the query string and url-encoded form body. It's quite straight forward to implement but it carries its own risk. I recommend reading Coyote's implementation of the same methods.
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URLDecoder;
import java.util.Collections;
import java.util.Enumeration;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.ws.rs.core.MediaType;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;
/**
* Wrapper class that supports repeated read of the request body and parameters.
*/
public class CustomHttpServletRequest extends HttpServletRequestWrapper {
private static final Logger logger = Logger.getLogger(CustomHttpServletRequest.class);
// A typical url encoded form is "key1=value&key2=some%20value"
public final static Pattern urlStrPattern = Pattern.compile("([^=&]+)=([^&]*)[&]?");
// Cached request body
protected ByteArrayOutputStream cachedBytes;
protected String encoding;
protected String requestBody;
// Cached form parameters
protected Map<String, String[]> paramMap = new LinkedHashMap<String, String[]>();
// Cached header names, including extra headers we injected.
protected Enumeration<?> headerNames = null;
/**
*
* @param request
*/
public CustomHttpServletRequest(HttpServletRequest request) {
super(request);
// Read the body and construct parameters
try{
encoding = (request.getCharacterEncoding()==null)?"UTF-8":request.getCharacterEncoding();
// Parameters in query strings must be added to paramMap
String queryString = request.getQueryString();
logger.debug("Extracted HTTP query string: "+queryString);
if(queryString != null && !queryString.isEmpty()){
addParameters(queryString, encoding);
}
// Parse the request body if this is a form submission. Clients must set content-type to "x-www-form-urlencoded".
requestBody = IOUtils.toString(this.getInputStream(), encoding);
if (StringUtils.isEmpty(requestBody)) {requestBody = null;}
logger.debug("Extracted HTTP request body: "+requestBody);
if(request.getContentType() != null && request.getContentType().toLowerCase().contains(MediaType.APPLICATION_FORM_URLENCODED)){
addParameters(requestBody, encoding);
}
}
catch(IOException ex){
throw new RuntimeException(ex);
}
}
/**
*
* @param requestBody
* @param encoding
* @throws IOException
*/
private void addParameters(String requestBody, String encoding) throws IOException {
if(requestBody == null){
return;
}
Matcher matcher = urlStrPattern.matcher(requestBody);
while(matcher.find()){
String decodedName = URLDecoder.decode(matcher.group(1), encoding);
// If there's no value, matcher.group(2) returns an empty string instead of null
String decodedValue = URLDecoder.decode(matcher.group(2), encoding);
addParameter(decodedName, decodedValue);
logger.debug("Parsed form parameter, name = "+decodedName+", value = "+decodedValue);
}
}
/**
*
* @param name
* @param value
*/
private void addParameter(String name, String value) {
String[] pv = paramMap.get(name);
if (pv == null) {
pv = new String[]{value};
paramMap.put(name, pv);
}
else {
String[] newValue = new String[pv.length+1];
System.arraycopy(pv, 0, newValue, 0, pv.length);
newValue[pv.length] = value;
paramMap.put(name, newValue);
}
}
/*
* (non-Javadoc)
* @see javax.servlet.ServletRequestWrapper#getInputStream()
*/
@Override
public ServletInputStream getInputStream() throws IOException {
if (cachedBytes == null){
cachedBytes = new ByteArrayOutputStream();
IOUtils.copy(super.getInputStream(), cachedBytes);
}
// Return a inner class that references cachedBytes
return new CachedServletInputStream();
}
/*
* (non-Javadoc)
* @see javax.servlet.ServletRequestWrapper#getReader()
*/
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
/**
*
* @return
*/
public String getRequestBody() {
return requestBody;
}
/*
* (non-Javadoc)
* @see javax.servlet.ServletRequestWrapper#getParameter(java.lang.String)
*/
@Override
public String getParameter(String name) {
if(paramMap.containsKey(name)){
String[] value = (String[]) paramMap.get(name);
if(value == null){
return null;
}
else{
return value[0];
}
}
return null;
}
/*
* (non-Javadoc)
* @see javax.servlet.ServletRequestWrapper#getParameterMap()
*/
@Override
public Map<String, String[]> getParameterMap() {
return Collections.unmodifiableMap(paramMap);
}
/*
* (non-Javadoc)
* @see javax.servlet.ServletRequestWrapper#getParameterNames()
*/
@Override
public Enumeration<?> getParameterNames() {
return Collections.enumeration(paramMap.keySet());
}
/*
* (non-Javadoc)
* @see javax.servlet.ServletRequestWrapper#getParameterValues(java.lang.String)
*/
@Override
public String[] getParameterValues(String name) {
if(paramMap.containsKey(name)){
return paramMap.get(name);
}
return null;
}
/**
* Inner class that reads from stored byte array
*/
public class CachedServletInputStream extends ServletInputStream {
private ByteArrayInputStream input;
public CachedServletInputStream() {
input = new ByteArrayInputStream(cachedBytes.toByteArray());
}
@Override
public int read() throws IOException {
return input.read();
}
@Override
public int read(byte[] b) throws IOException {
return input.read(b);
}
@Override
public int read(byte[] b, int off, int len) {
return input.read(b, off, len);
}
}
}
And add a filter to wrap the original request:
public class CustomFilter implements Filter {
private static final Logger logger = Logger.getLogger(CustomFilter.class);
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
if(request!=null && request instanceof HttpServletRequest){
HttpServletRequest httpRequest = (HttpServletRequest) request;
logger.debug("Wrapping HTTP request");
request = new CustomHttpServletRequest(httpRequest);
}
chain.doFilter(request, response);
}
}