2018年8月15日 星期三

How to read request body (ServletInputStream) multiple times

前言

很多時候,我們為了在程式發生錯誤時,能夠有足夠的資訊 debug,所以需要 log body,但偏偏 HttpServletRequest getInputStream 取出一次之後就拿不到了,所以必須想個保留 buffer 的方式。

解決過程

本來是只需要 Override getInputStream() 並保留 byte[] body 當作 buffer,然後回傳 new ServletInputStream 時,Override read() 方法讓它從 byte[] body 拿資料就好了。
但事情並沒有想像的簡單,因為我的專案用的是 Tomcat8,使用的是 servlet-api 3.1.0,所以new ServletInputStream 時必須另外實作 isFinished()isReady()setReadListener(ReadListener readListener) 這些方法。
接著發現,原本 request.getParameter(...) 拿不到東西壞掉了…
好像原本有使用到 getInputStream() 的方法都會拿不到資料,因為我是繼承之後做 wrapper,原本的方法呼叫的是 super.getInputStream,並不是我 override 後的方法,當然拿不到,就像一開始前言說的。
為了解決這個問題,我必須再 override getParameter(String key)getParameterValues(String key)getParameterMap()getReader() 等方法。

完整程式範例


import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

import org.apache.commons.io.IOUtils;
import org.apache.http.NameValuePair;
import org.apache.http.client.utils.URLEncodedUtils;
import org.apache.http.entity.ContentType;

import com.google.common.collect.ObjectArrays;

public class BufferHttpServletRequestWrapper extends HttpServletRequestWrapper {
    public static final Charset UTF8_CHARSET = Charset.forName("UTF-8");
    private Map parameterMap;
    private byte[] body;

    public BufferHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);

        ServletInputStream in = request.getInputStream();

        if (in != null) {
            body = IOUtils.toByteArray(in);
        }
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    @Override
    public ServletInputStream getInputStream() {
        return new ServletInputStream() {

            private int lastIndexRetrieved = -1;
            private ReadListener readListener = null;

            @Override
            public boolean isFinished() {
                return (lastIndexRetrieved == body.length - 1);
            }

            @Override
            public boolean isReady() {
                // This implementation will never block
                // We also never need to call the readListener from this method,
                // as this method will never return false
                return isFinished();
            }

            @Override
            public void setReadListener(ReadListener readListener) {
                this.readListener = readListener;
                if (!isFinished()) {
                    try {
                        readListener.onDataAvailable();
                    } catch (IOException e) {
                        readListener.onError(e);
                    }
                } else {
                    try {
                        readListener.onAllDataRead();
                    } catch (IOException e) {
                        readListener.onError(e);
                    }
                }
            }

            @Override
            public int read() throws IOException {
                int i;
                if (!isFinished()) {
                    i = body[lastIndexRetrieved + 1];
                    lastIndexRetrieved++;
                    if (isFinished() && (readListener != null)) {
                        try {
                            readListener.onAllDataRead();
                        } catch (IOException ex) {
                            readListener.onError(ex);
                            throw ex;
                        }
                    }
                    return i;
                } else {
                    return -1;
                }
            }
        };
    }

    @Override
    public String getParameter(String key) {
        Map parameterMap = getParameterMap();
        String[] values = parameterMap.get(key);
        return values != null && values.length > 0 ? values[0] : null;
    }

    @Override
    public String[] getParameterValues(String key) {
        Map parameterMap = getParameterMap();
        return parameterMap.get(key);
    }

    @Override
    public Map getParameterMap() {
        if (parameterMap == null) {
            Map result = new LinkedHashMap();

            String queryString = getQueryString();
            if (queryString != null) {
                toMap(URLEncodedUtils.parse(queryString, UTF8_CHARSET), result);
            }

            String cts = getContentType();
            if (cts != null) {
                ContentType ct = ContentType.parse(cts);
                if (ct.getMimeType().equals(ContentType.APPLICATION_FORM_URLENCODED.getMimeType())) {
                    try {
                        toMap(URLEncodedUtils.parse(IOUtils.toString(getReader()), UTF8_CHARSET), result);
                    } catch (IOException e) {
                        throw new IllegalStateException(e);
                    }
                }
            }
            parameterMap = Collections.unmodifiableMap(result);
        }
        return parameterMap;
    }

    public static void toMap(Iterable inputParams, Map toMap) {
        for (NameValuePair e : inputParams) {
            String key = e.getName();
            String value = e.getValue();
            if (toMap.containsKey(key)) {
                String[] newValue = ObjectArrays.concat(toMap.get(key), value);
                toMap.remove(key);
                toMap.put(key, newValue);
            } else {
                toMap.put(key, new String[]{value});
            }
        }
    }

}

後記

後來發現另外一個比較簡單的解法,就是直接用 spring-web 的 util(ContentCachingRequestWrapper)
概念一模一樣…原來人家已經做過的事,我在重造輪子… 囧rz

Reference

沒有留言:

張貼留言