1
18
19 package io.undertow.servlet.handlers;
20
21 import java.io.IOException;
22 import java.util.EnumMap;
23 import java.util.List;
24 import java.util.Map;
25
26 import javax.servlet.DispatcherType;
27 import javax.servlet.FilterChain;
28 import javax.servlet.ServletException;
29 import javax.servlet.ServletRequest;
30 import javax.servlet.ServletRequestWrapper;
31 import javax.servlet.ServletResponse;
32 import javax.servlet.ServletResponseWrapper;
33
34 import io.undertow.server.HttpHandler;
35 import io.undertow.server.HttpServerExchange;
36 import io.undertow.servlet.UndertowServletMessages;
37 import io.undertow.servlet.core.ManagedFilter;
38
39
42 public class FilterHandler implements HttpHandler {
43
44 private final Map<DispatcherType, List<ManagedFilter>> filters;
45 private final Map<DispatcherType, Boolean> asyncSupported;
46 private final boolean allowNonStandardWrappers;
47
48 private final HttpHandler next;
49
50 public FilterHandler(final Map<DispatcherType, List<ManagedFilter>> filters, final boolean allowNonStandardWrappers, final HttpHandler next) {
51 this.allowNonStandardWrappers = allowNonStandardWrappers;
52 this.next = next;
53 this.filters = new EnumMap<>(filters);
54 Map<DispatcherType, Boolean> asyncSupported = new EnumMap<>(DispatcherType.class);
55 for(Map.Entry<DispatcherType, List<ManagedFilter>> entry : filters.entrySet()) {
56 boolean supported = true;
57 for(ManagedFilter i : entry.getValue()) {
58 if(!i.getFilterInfo().isAsyncSupported()) {
59 supported = false;
60 break;
61 }
62 }
63 asyncSupported.put(entry.getKey(), supported);
64 }
65 this.asyncSupported = asyncSupported;
66 }
67
68 @Override
69 public void handleRequest(final HttpServerExchange exchange) throws Exception {
70 final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY);
71 ServletRequest request = servletRequestContext.getServletRequest();
72 ServletResponse response = servletRequestContext.getServletResponse();
73 DispatcherType dispatcher = servletRequestContext.getDispatcherType();
74 Boolean supported = asyncSupported.get(dispatcher);
75 if(supported != null && ! supported) {
76 servletRequestContext.setAsyncSupported(false);
77 }
78
79 final List<ManagedFilter> filters = this.filters.get(dispatcher);
80 if(filters == null) {
81 next.handleRequest(exchange);
82 } else {
83 final FilterChainImpl filterChain = new FilterChainImpl(exchange, filters, next, allowNonStandardWrappers);
84 filterChain.doFilter(request, response);
85 }
86 }
87
88 private static class FilterChainImpl implements FilterChain {
89
90 int location = 0;
91 final HttpServerExchange exchange;
92 final List<ManagedFilter> filters;
93 final HttpHandler next;
94 final boolean allowNonStandardWrappers;
95
96 private FilterChainImpl(final HttpServerExchange exchange, final List<ManagedFilter> filters, final HttpHandler next, final boolean allowNonStandardWrappers) {
97 this.exchange = exchange;
98 this.filters = filters;
99 this.next = next;
100 this.allowNonStandardWrappers = allowNonStandardWrappers;
101 }
102
103 @Override
104 public void doFilter(final ServletRequest request, final ServletResponse response) throws IOException, ServletException {
105
106
107
108 final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY);
109 final ServletRequest oldReq = servletRequestContext.getServletRequest();
110 final ServletResponse oldResp = servletRequestContext.getServletResponse();
111 try {
112
113 if(!allowNonStandardWrappers) {
114 if(oldReq != request) {
115 if(!(request instanceof ServletRequestWrapper)) {
116 throw UndertowServletMessages.MESSAGES.requestWasNotOriginalOrWrapper(request);
117 }
118 }
119 if(oldResp != response) {
120 if(!(response instanceof ServletResponseWrapper)) {
121 throw UndertowServletMessages.MESSAGES.responseWasNotOriginalOrWrapper(response);
122 }
123 }
124 }
125 servletRequestContext.setServletRequest(request);
126 servletRequestContext.setServletResponse(response);
127 int index = location++;
128 if (index >= filters.size()) {
129 next.handleRequest(exchange);
130 } else {
131 filters.get(index).doFilter(request, response, this);
132 }
133 } catch (IOException e) {
134 throw e;
135 } catch (ServletException e) {
136 throw e;
137 } catch (RuntimeException e) {
138 throw e;
139 } catch (Exception e) {
140 throw new RuntimeException(e);
141 } finally {
142 location--;
143 servletRequestContext.setServletRequest(oldReq);
144 servletRequestContext.setServletResponse(oldResp);
145 }
146 }
147 }
148 }
149