1
18
19 package io.undertow.servlet.spec;
20
21 import java.io.IOException;
22 import java.io.PrintWriter;
23 import java.security.AccessController;
24 import java.security.PrivilegedActionException;
25 import java.security.PrivilegedExceptionAction;
26 import java.util.Deque;
27 import java.util.Map;
28
29 import javax.servlet.DispatcherType;
30 import javax.servlet.RequestDispatcher;
31 import javax.servlet.ServletException;
32 import javax.servlet.ServletOutputStream;
33 import javax.servlet.ServletRequest;
34 import javax.servlet.ServletRequestWrapper;
35 import javax.servlet.ServletResponse;
36 import javax.servlet.ServletResponseWrapper;
37 import javax.servlet.http.HttpServletRequest;
38 import javax.servlet.http.HttpServletResponse;
39
40 import io.undertow.UndertowLogger;
41 import io.undertow.server.HttpServerExchange;
42 import io.undertow.servlet.UndertowServletLogger;
43 import io.undertow.servlet.UndertowServletMessages;
44 import io.undertow.servlet.api.ThreadSetupAction;
45 import io.undertow.servlet.api.ThreadSetupHandler;
46 import io.undertow.servlet.handlers.ServletRequestContext;
47 import io.undertow.servlet.handlers.ServletChain;
48 import io.undertow.servlet.handlers.ServletPathMatch;
49 import io.undertow.util.QueryParameterUtils;
50
51
54 public class RequestDispatcherImpl implements RequestDispatcher {
55
56 private final String path;
57 private final ServletContextImpl servletContext;
58 private final ServletChain chain;
59 private final ServletPathMatch pathMatch;
60 private final boolean named;
61
62 public RequestDispatcherImpl(final String path, final ServletContextImpl servletContext) {
63 this.path = path;
64 this.servletContext = servletContext;
65 String basePath = path;
66 int qPos = basePath.indexOf("?");
67 if (qPos != -1) {
68 basePath = basePath.substring(0, qPos);
69 }
70 int mPos = basePath.indexOf(";");
71 if(mPos != -1) {
72 basePath = basePath.substring(0, mPos);
73 }
74 this.pathMatch = servletContext.getDeployment().getServletPaths().getServletHandlerByPath(basePath);
75 this.chain = pathMatch.getServletChain();
76 this.named = false;
77 }
78
79 public RequestDispatcherImpl(final ServletChain chain, final ServletContextImpl servletContext) {
80 this.chain = chain;
81 this.named = true;
82 this.servletContext = servletContext;
83 this.path = null;
84 this.pathMatch = null;
85 }
86
87
88 @Override
89 public void forward(final ServletRequest request, final ServletResponse response) throws ServletException, IOException {
90 if(System.getSecurityManager() != null) {
91 try {
92 AccessController.doPrivileged(new PrivilegedExceptionAction<Object>() {
93 @Override
94 public Object run() throws Exception {
95 forwardImplSetup(request, response);
96 return null;
97 }
98 });
99 } catch (PrivilegedActionException e) {
100 if(e.getCause() instanceof ServletException) {
101 throw (ServletException)e.getCause();
102 } else if(e.getCause() instanceof IOException) {
103 throw (IOException)e.getCause();
104 } else if(e.getCause() instanceof RuntimeException) {
105 throw (RuntimeException)e.getCause();
106 } else {
107 throw new RuntimeException(e.getCause());
108 }
109 }
110 } else {
111 forwardImplSetup(request, response);
112 }
113 }
114
115 private void forwardImplSetup(final ServletRequest request, final ServletResponse response) throws ServletException, IOException {
116 final ServletRequestContext servletRequestContext = SecurityActions.currentServletRequestContext();
117 if(servletRequestContext == null) {
118 UndertowLogger.REQUEST_LOGGER.debugf("No servlet request context for %s, dispatching mock request", request);
119 mock(request, response);
120 return;
121 }
122
123 ThreadSetupAction.Handle handle = null;
124 ServletContextImpl oldServletContext = null;
125 HttpSessionImpl oldSession = null;
126 if (servletRequestContext.getCurrentServletContext() != this.servletContext) {
127
128 try {
129
130 oldServletContext = servletRequestContext.getCurrentServletContext();
131 oldSession = servletRequestContext.getSession();
132 servletRequestContext.setSession(null);
133 servletRequestContext.setCurrentServletContext(this.servletContext);
134 this.servletContext.invokeAction(servletRequestContext.getExchange(), new ThreadSetupHandler.Action<Void, Object>() {
135 @Override
136 public Void call(HttpServerExchange exchange, Object context) throws Exception {
137 forwardImpl(request, response, servletRequestContext);
138 return null;
139 }
140 });
141
142 } finally {
143 servletRequestContext.setSession(oldSession);
144 servletRequestContext.setCurrentServletContext(oldServletContext);
145
146 servletRequestContext.getCurrentServletContext().updateSessionAccessTime(servletRequestContext.getExchange());
147 }
148 } else {
149 forwardImpl(request, response, servletRequestContext);
150 }
151
152 }
153
154 private void forwardImpl(ServletRequest request, ServletResponse response, ServletRequestContext servletRequestContext) throws ServletException, IOException {
155 final HttpServletRequestImpl requestImpl = servletRequestContext.getOriginalRequest();
156 final HttpServletResponseImpl responseImpl = servletRequestContext.getOriginalResponse();
157 if (!servletContext.getDeployment().getDeploymentInfo().isAllowNonStandardWrappers()) {
158 if (servletRequestContext.getOriginalRequest() != request) {
159 if (!(request instanceof ServletRequestWrapper)) {
160 throw UndertowServletMessages.MESSAGES.requestWasNotOriginalOrWrapper(request);
161 }
162 }
163 if (servletRequestContext.getOriginalResponse() != response) {
164 if (!(response instanceof ServletResponseWrapper)) {
165 throw UndertowServletMessages.MESSAGES.responseWasNotOriginalOrWrapper(response);
166 }
167 }
168 }
169 response.resetBuffer();
170
171 final ServletRequest oldRequest = servletRequestContext.getServletRequest();
172 final ServletResponse oldResponse = servletRequestContext.getServletResponse();
173
174 Map<String, Deque<String>> queryParameters = requestImpl.getQueryParameters();
175
176 request.removeAttribute(INCLUDE_REQUEST_URI);
177 request.removeAttribute(INCLUDE_CONTEXT_PATH);
178 request.removeAttribute(INCLUDE_SERVLET_PATH);
179 request.removeAttribute(INCLUDE_PATH_INFO);
180 request.removeAttribute(INCLUDE_QUERY_STRING);
181
182 final String oldURI = requestImpl.getExchange().getRequestURI();
183 final String oldRequestPath = requestImpl.getExchange().getRequestPath();
184 final String oldPath = requestImpl.getExchange().getRelativePath();
185 final ServletPathMatch oldServletPathMatch = requestImpl.getExchange().getAttachment(ServletRequestContext.ATTACHMENT_KEY).getServletPathMatch();
186 if (!named) {
187
188
189 if (request.getAttribute(FORWARD_REQUEST_URI) == null) {
190 requestImpl.setAttribute(FORWARD_REQUEST_URI, requestImpl.getRequestURI());
191 requestImpl.setAttribute(FORWARD_CONTEXT_PATH, requestImpl.getContextPath());
192 requestImpl.setAttribute(FORWARD_SERVLET_PATH, requestImpl.getServletPath());
193 requestImpl.setAttribute(FORWARD_PATH_INFO, requestImpl.getPathInfo());
194 requestImpl.setAttribute(FORWARD_QUERY_STRING, requestImpl.getQueryString());
195 }
196
197 int qsPos = path.indexOf("?");
198 String newServletPath = path;
199 if (qsPos != -1) {
200 String newQueryString = newServletPath.substring(qsPos + 1);
201 newServletPath = newServletPath.substring(0, qsPos);
202
203 String encoding = QueryParameterUtils.getQueryParamEncoding(servletRequestContext.getExchange());
204 Map<String, Deque<String>> newQueryParameters = QueryParameterUtils.mergeQueryParametersWithNewQueryString(queryParameters, newQueryString, encoding);
205 requestImpl.getExchange().setQueryString(newQueryString);
206 requestImpl.setQueryParameters(newQueryParameters);
207 }
208 String newRequestUri = servletContext.getContextPath() + newServletPath;
209
210
211
212 requestImpl.getExchange().setRelativePath(newServletPath);
213 requestImpl.getExchange().setRequestPath(newRequestUri);
214 requestImpl.getExchange().setRequestURI(newRequestUri);
215 requestImpl.getExchange().getAttachment(ServletRequestContext.ATTACHMENT_KEY).setServletPathMatch(pathMatch);
216 requestImpl.setServletContext(servletContext);
217 responseImpl.setServletContext(servletContext);
218 }
219
220 try {
221 try {
222 servletRequestContext.setServletRequest(request);
223 servletRequestContext.setServletResponse(response);
224 if (named) {
225 servletContext.getDeployment().getServletDispatcher().dispatchToServlet(requestImpl.getExchange(), chain, DispatcherType.FORWARD);
226 } else {
227 servletContext.getDeployment().getServletDispatcher().dispatchToPath(requestImpl.getExchange(), pathMatch, DispatcherType.FORWARD);
228 }
229
230
231 if (!request.isAsyncStarted()) {
232 if (response instanceof HttpServletResponseImpl) {
233 responseImpl.closeStreamAndWriter();
234 } else {
235 try {
236 final PrintWriter writer = response.getWriter();
237 writer.flush();
238 writer.close();
239 } catch (IllegalStateException e) {
240 final ServletOutputStream outputStream = response.getOutputStream();
241 outputStream.flush();
242 outputStream.close();
243 }
244 }
245 }
246 } catch (ServletException e) {
247 throw e;
248 } catch (IOException e) {
249 throw e;
250 } catch (Exception e) {
251 throw new RuntimeException(e);
252 }
253 } finally {
254 servletRequestContext.setServletRequest(oldRequest);
255 servletRequestContext.setServletResponse(oldResponse);
256 final boolean preservePath = servletRequestContext.getDeployment().getDeploymentInfo().isPreservePathOnForward();
257 if (preservePath) {
258 requestImpl.getExchange().setRelativePath(oldPath);
259 requestImpl.getExchange().getAttachment(ServletRequestContext.ATTACHMENT_KEY).setServletPathMatch(oldServletPathMatch);
260 requestImpl.getExchange().setRequestPath(oldRequestPath);
261 requestImpl.getExchange().setRequestURI(oldURI);
262 }
263 }
264 }
265
266
267 @Override
268 public void include(final ServletRequest request, final ServletResponse response) throws ServletException, IOException {
269 if(System.getSecurityManager() != null) {
270 try {
271 AccessController.doPrivileged(new PrivilegedExceptionAction<Object>() {
272 @Override
273 public Object run() throws Exception {
274 setupIncludeImpl(request, response);
275 return null;
276 }
277 });
278 } catch (PrivilegedActionException e) {
279 if(e.getCause() instanceof ServletException) {
280 throw (ServletException)e.getCause();
281 } else if(e.getCause() instanceof IOException) {
282 throw (IOException)e.getCause();
283 } else if(e.getCause() instanceof RuntimeException) {
284 throw (RuntimeException)e.getCause();
285 } else {
286 throw new RuntimeException(e.getCause());
287 }
288 }
289 } else {
290 setupIncludeImpl(request, response);
291 }
292 }
293
294 private void setupIncludeImpl(final ServletRequest request, final ServletResponse response) throws ServletException, IOException {
295 final ServletRequestContext servletRequestContext = SecurityActions.currentServletRequestContext();
296 if(servletRequestContext == null) {
297 UndertowLogger.REQUEST_LOGGER.debugf("No servlet request context for %s, dispatching mock request", request);
298 mock(request, response);
299 return;
300 }
301 final HttpServletRequestImpl requestImpl = servletRequestContext.getOriginalRequest();
302 final HttpServletResponseImpl responseImpl = servletRequestContext.getOriginalResponse();
303 ServletContextImpl oldServletContext = null;
304 HttpSessionImpl oldSession = null;
305 if (servletRequestContext.getCurrentServletContext() != this.servletContext) {
306
307 oldServletContext = servletRequestContext.getCurrentServletContext();
308 oldSession = servletRequestContext.getSession();
309 servletRequestContext.setSession(null);
310 servletRequestContext.setCurrentServletContext(this.servletContext);
311 try {
312 servletRequestContext.getCurrentServletContext().invokeAction(servletRequestContext.getExchange(), new ThreadSetupHandler.Action<Void, Object>() {
313 @Override
314 public Void call(HttpServerExchange exchange, Object context) throws Exception {
315 includeImpl(request, response, servletRequestContext, requestImpl, responseImpl);
316 return null;
317 }
318 });
319 } finally {
320
321 servletRequestContext.getCurrentServletContext().updateSessionAccessTime(servletRequestContext.getExchange());
322 servletRequestContext.setSession(oldSession);
323 servletRequestContext.setCurrentServletContext(oldServletContext);
324 }
325 } else {
326 includeImpl(request, response, servletRequestContext, requestImpl, responseImpl);
327 }
328 }
329
330 private void includeImpl(ServletRequest request, ServletResponse response, ServletRequestContext servletRequestContext, HttpServletRequestImpl requestImpl, HttpServletResponseImpl responseImpl) throws ServletException, IOException {
331 if (!servletContext.getDeployment().getDeploymentInfo().isAllowNonStandardWrappers()) {
332 if (servletRequestContext.getOriginalRequest() != request) {
333 if (!(request instanceof ServletRequestWrapper)) {
334 throw UndertowServletMessages.MESSAGES.requestWasNotOriginalOrWrapper(request);
335 }
336 }
337 if (servletRequestContext.getOriginalResponse() != response) {
338 if (!(response instanceof ServletResponseWrapper)) {
339 throw UndertowServletMessages.MESSAGES.responseWasNotOriginalOrWrapper(response);
340 }
341 }
342 }
343 final ServletRequest oldRequest = servletRequestContext.getServletRequest();
344 final ServletResponse oldResponse = servletRequestContext.getServletResponse();
345
346 Object requestUri = null;
347 Object contextPath = null;
348 Object servletPath = null;
349 Object pathInfo = null;
350 Object queryString = null;
351 Map<String, Deque<String>> queryParameters = requestImpl.getQueryParameters();
352
353 if (!named) {
354 requestUri = request.getAttribute(INCLUDE_REQUEST_URI);
355 contextPath = request.getAttribute(INCLUDE_CONTEXT_PATH);
356 servletPath = request.getAttribute(INCLUDE_SERVLET_PATH);
357 pathInfo = request.getAttribute(INCLUDE_PATH_INFO);
358 queryString = request.getAttribute(INCLUDE_QUERY_STRING);
359
360 int qsPos = path.indexOf("?");
361 String newServletPath = path;
362 if (qsPos != -1) {
363 String newQueryString = newServletPath.substring(qsPos + 1);
364 newServletPath = newServletPath.substring(0, qsPos);
365
366 String encoding = QueryParameterUtils.getQueryParamEncoding(servletRequestContext.getExchange());
367 Map<String, Deque<String>> newQueryParameters = QueryParameterUtils.mergeQueryParametersWithNewQueryString(queryParameters, newQueryString, encoding);
368 requestImpl.setQueryParameters(newQueryParameters);
369 requestImpl.setAttribute(INCLUDE_QUERY_STRING, newQueryString);
370 } else {
371 requestImpl.setAttribute(INCLUDE_QUERY_STRING, "");
372 }
373 String newRequestUri = servletContext.getContextPath() + newServletPath;
374
375 requestImpl.setAttribute(INCLUDE_REQUEST_URI, newRequestUri);
376 requestImpl.setAttribute(INCLUDE_CONTEXT_PATH, servletContext.getContextPath());
377 requestImpl.setAttribute(INCLUDE_SERVLET_PATH, pathMatch.getMatched());
378 requestImpl.setAttribute(INCLUDE_PATH_INFO, pathMatch.getRemaining());
379 }
380 boolean inInclude = responseImpl.isInsideInclude();
381 responseImpl.setInsideInclude(true);
382 DispatcherType oldDispatcherType = servletRequestContext.getDispatcherType();
383
384 ServletContextImpl oldContext = requestImpl.getServletContext();
385 try {
386 requestImpl.setServletContext(servletContext);
387 responseImpl.setServletContext(servletContext);
388 try {
389 servletRequestContext.setServletRequest(request);
390 servletRequestContext.setServletResponse(response);
391 servletContext.getDeployment().getServletDispatcher().dispatchToServlet(requestImpl.getExchange(), chain, DispatcherType.INCLUDE);
392 } catch (ServletException e) {
393 throw e;
394 } catch (IOException e) {
395 throw e;
396 } catch (Exception e) {
397 throw new RuntimeException(e);
398 }
399 } finally {
400 responseImpl.setInsideInclude(inInclude);
401 requestImpl.setServletContext(oldContext);
402 responseImpl.setServletContext(oldContext);
403
404 servletRequestContext.setServletRequest(oldRequest);
405 servletRequestContext.setServletResponse(oldResponse);
406 servletRequestContext.setDispatcherType(oldDispatcherType);
407 if (!named) {
408 requestImpl.setAttribute(INCLUDE_REQUEST_URI, requestUri);
409 requestImpl.setAttribute(INCLUDE_CONTEXT_PATH, contextPath);
410 requestImpl.setAttribute(INCLUDE_SERVLET_PATH, servletPath);
411 requestImpl.setAttribute(INCLUDE_PATH_INFO, pathInfo);
412 requestImpl.setAttribute(INCLUDE_QUERY_STRING, queryString);
413 requestImpl.setQueryParameters(queryParameters);
414 }
415 }
416 }
417
418 public void error(ServletRequestContext servletRequestContext, final ServletRequest request, final ServletResponse response, final String servletName, final String message) throws ServletException, IOException {
419 error(servletRequestContext, request, response, servletName, null, message);
420 }
421
422 public void error(ServletRequestContext servletRequestContext, final ServletRequest request, final ServletResponse response, final String servletName) throws ServletException, IOException {
423 error(servletRequestContext, request, response, servletName, null, null);
424 }
425
426 public void error(ServletRequestContext servletRequestContext, final ServletRequest request, final ServletResponse response, final String servletName, final Throwable exception) throws ServletException, IOException {
427 error(servletRequestContext, request, response, servletName, exception, exception.getMessage());
428 }
429
430 private void error(ServletRequestContext servletRequestContext, final ServletRequest request, final ServletResponse response, final String servletName, final Throwable exception, final String message) throws ServletException, IOException {
431 if(request.getDispatcherType() == DispatcherType.ERROR) {
432
433
434
435 UndertowServletLogger.REQUEST_LOGGER.errorGeneratingErrorPage(servletRequestContext.getExchange().getRequestPath(), request.getAttribute(ERROR_EXCEPTION), servletRequestContext.getExchange().getStatusCode(), exception);
436 servletRequestContext.getExchange().endExchange();
437 return;
438 }
439
440 final HttpServletRequestImpl requestImpl = servletRequestContext.getOriginalRequest();
441 final HttpServletResponseImpl responseImpl = servletRequestContext.getOriginalResponse();
442 if (!servletContext.getDeployment().getDeploymentInfo().isAllowNonStandardWrappers()) {
443 if (servletRequestContext.getOriginalRequest() != request) {
444 if (!(request instanceof ServletRequestWrapper)) {
445 throw UndertowServletMessages.MESSAGES.requestWasNotOriginalOrWrapper(request);
446 }
447 }
448 if (servletRequestContext.getOriginalResponse() != response) {
449 if (!(response instanceof ServletResponseWrapper)) {
450 throw UndertowServletMessages.MESSAGES.responseWasNotOriginalOrWrapper(response);
451 }
452 }
453 }
454
455 final ServletRequest oldRequest = servletRequestContext.getServletRequest();
456 final ServletResponse oldResponse = servletRequestContext.getServletResponse();
457 servletRequestContext.setDispatcherType(DispatcherType.ERROR);
458
459
460 if (request.getAttribute(FORWARD_REQUEST_URI) == null) {
461 requestImpl.setAttribute(FORWARD_REQUEST_URI, requestImpl.getRequestURI());
462 requestImpl.setAttribute(FORWARD_CONTEXT_PATH, requestImpl.getContextPath());
463 requestImpl.setAttribute(FORWARD_SERVLET_PATH, requestImpl.getServletPath());
464 requestImpl.setAttribute(FORWARD_PATH_INFO, requestImpl.getPathInfo());
465 requestImpl.setAttribute(FORWARD_QUERY_STRING, requestImpl.getQueryString());
466 }
467 requestImpl.setAttribute(ERROR_REQUEST_URI, requestImpl.getRequestURI());
468 requestImpl.setAttribute(ERROR_SERVLET_NAME, servletName);
469 if (exception != null) {
470 requestImpl.setAttribute(ERROR_EXCEPTION, exception);
471 requestImpl.setAttribute(ERROR_EXCEPTION_TYPE, exception.getClass());
472 }
473 requestImpl.setAttribute(ERROR_MESSAGE, message);
474 requestImpl.setAttribute(ERROR_STATUS_CODE, responseImpl.getStatus());
475
476 int qsPos = path.indexOf("?");
477 String newServletPath = path;
478 if (qsPos != -1) {
479 Map<String, Deque<String>> queryParameters = requestImpl.getQueryParameters();
480 String newQueryString = newServletPath.substring(qsPos + 1);
481 newServletPath = newServletPath.substring(0, qsPos);
482
483 String encoding = QueryParameterUtils.getQueryParamEncoding(servletRequestContext.getExchange());
484 Map<String, Deque<String>> newQueryParameters = QueryParameterUtils.mergeQueryParametersWithNewQueryString(queryParameters, newQueryString, encoding);
485 requestImpl.getExchange().setQueryString(newQueryString);
486 requestImpl.setQueryParameters(newQueryParameters);
487 }
488 String newRequestUri = servletContext.getContextPath() + newServletPath;
489
490 requestImpl.getExchange().setRelativePath(newServletPath);
491 requestImpl.getExchange().setRequestPath(newRequestUri);
492 requestImpl.getExchange().setRequestURI(newRequestUri);
493 requestImpl.getExchange().getAttachment(ServletRequestContext.ATTACHMENT_KEY).setServletPathMatch(pathMatch);
494 requestImpl.setServletContext(servletContext);
495 responseImpl.setServletContext(servletContext);
496
497 try {
498 try {
499 servletRequestContext.setServletRequest(request);
500 servletRequestContext.setServletResponse(response);
501 servletContext.getDeployment().getServletDispatcher().dispatchToPath(requestImpl.getExchange(), pathMatch, DispatcherType.ERROR);
502 } catch (ServletException e) {
503 throw e;
504 } catch (IOException e) {
505 throw e;
506 } catch (Exception e) {
507 throw new RuntimeException(e);
508 }
509 } finally {
510 AsyncContextImpl ac = servletRequestContext.getOriginalRequest().getAsyncContextInternal();
511 if(ac != null) {
512 ac.complete();
513 }
514 servletRequestContext.setServletRequest(oldRequest);
515 servletRequestContext.setServletResponse(oldResponse);
516 }
517 }
518
519 public void mock(ServletRequest request, ServletResponse response) throws ServletException, IOException {
520 if (request instanceof HttpServletRequest && response instanceof HttpServletResponse) {
521 HttpServletRequest req = (HttpServletRequest) request;
522 HttpServletResponse resp = (HttpServletResponse) response;
523 servletContext.getDeployment().getServletDispatcher().dispatchMockRequest(req, resp);
524 } else {
525 throw UndertowServletMessages.MESSAGES.invalidRequestResponseType(request, response);
526 }
527 }
528 }
529