1 /*
2  * JBoss, Home of Professional Open Source.
3  * Copyright 2014 Red Hat, Inc., and individual contributors
4  * as indicated by the @author tags.
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  *  Unless required by applicable law or agreed to in writing, software
13  *  distributed under the License is distributed on an "AS IS" BASIS,
14  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  *  See the License for the specific language governing permissions and
16  *  limitations under the License.
17  */

18
19 package io.undertow.servlet.handlers;
20
21 import java.io.IOException;
22 import java.util.EnumMap;
23 import java.util.List;
24 import java.util.Map;
25
26 import javax.servlet.DispatcherType;
27 import javax.servlet.FilterChain;
28 import javax.servlet.ServletException;
29 import javax.servlet.ServletRequest;
30 import javax.servlet.ServletRequestWrapper;
31 import javax.servlet.ServletResponse;
32 import javax.servlet.ServletResponseWrapper;
33
34 import io.undertow.server.HttpHandler;
35 import io.undertow.server.HttpServerExchange;
36 import io.undertow.servlet.UndertowServletMessages;
37 import io.undertow.servlet.core.ManagedFilter;
38
39 /**
40  * @author Stuart Douglas
41  */

42 public class FilterHandler implements HttpHandler {
43
44     private final Map<DispatcherType, List<ManagedFilter>> filters;
45     private final Map<DispatcherType, Boolean> asyncSupported;
46     private final boolean allowNonStandardWrappers;
47
48     private final HttpHandler next;
49
50     public FilterHandler(final Map<DispatcherType, List<ManagedFilter>> filters, final boolean allowNonStandardWrappers, final HttpHandler next) {
51         this.allowNonStandardWrappers = allowNonStandardWrappers;
52         this.next = next;
53         this.filters = new EnumMap<>(filters);
54         Map<DispatcherType, Boolean> asyncSupported = new EnumMap<>(DispatcherType.class);
55         for(Map.Entry<DispatcherType, List<ManagedFilter>> entry : filters.entrySet()) {
56             boolean supported = true;
57             for(ManagedFilter i : entry.getValue()) {
58                 if(!i.getFilterInfo().isAsyncSupported()) {
59                     supported = false;
60                     break;
61                 }
62             }
63             asyncSupported.put(entry.getKey(), supported);
64         }
65         this.asyncSupported = asyncSupported;
66     }
67
68     @Override
69     public void handleRequest(final HttpServerExchange exchange) throws Exception {
70         final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY);
71         ServletRequest request = servletRequestContext.getServletRequest();
72         ServletResponse response = servletRequestContext.getServletResponse();
73         DispatcherType dispatcher = servletRequestContext.getDispatcherType();
74         Boolean supported = asyncSupported.get(dispatcher);
75         if(supported != null && ! supported) {
76             servletRequestContext.setAsyncSupported(false);
77         }
78
79         final List<ManagedFilter> filters = this.filters.get(dispatcher);
80         if(filters == null) {
81             next.handleRequest(exchange);
82         } else {
83             final FilterChainImpl filterChain = new FilterChainImpl(exchange, filters, next, allowNonStandardWrappers);
84             filterChain.doFilter(request, response);
85         }
86     }
87
88     private static class FilterChainImpl implements FilterChain {
89
90         int location = 0;
91         final HttpServerExchange exchange;
92         final List<ManagedFilter> filters;
93         final HttpHandler next;
94         final boolean allowNonStandardWrappers;
95
96         private FilterChainImpl(final HttpServerExchange exchange, final List<ManagedFilter> filters, final HttpHandler next, final boolean allowNonStandardWrappers) {
97             this.exchange = exchange;
98             this.filters = filters;
99             this.next = next;
100             this.allowNonStandardWrappers = allowNonStandardWrappers;
101         }
102
103         @Override
104         public void doFilter(final ServletRequest request, final ServletResponse response) throws IOException, ServletException {
105
106
107
108             final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY);
109             final ServletRequest oldReq = servletRequestContext.getServletRequest();
110             final ServletResponse oldResp = servletRequestContext.getServletResponse();
111             try {
112
113                 if(!allowNonStandardWrappers) {
114                     if(oldReq != request) {
115                         if(!(request instanceof ServletRequestWrapper)) {
116                             throw UndertowServletMessages.MESSAGES.requestWasNotOriginalOrWrapper(request);
117                         }
118                     }
119                     if(oldResp != response) {
120                         if(!(response instanceof ServletResponseWrapper)) {
121                             throw UndertowServletMessages.MESSAGES.responseWasNotOriginalOrWrapper(response);
122                         }
123                     }
124                 }
125                 servletRequestContext.setServletRequest(request);
126                 servletRequestContext.setServletResponse(response);
127                 int index = location++;
128                 if (index >= filters.size()) {
129                     next.handleRequest(exchange);
130                 } else {
131                     filters.get(index).doFilter(request, response, this);
132                 }
133             } catch (IOException e) {
134                 throw e;
135             } catch (ServletException e) {
136                 throw e;
137             } catch (RuntimeException e) {
138                 throw e;
139             } catch (Exception e) {
140                 throw new RuntimeException(e);
141             } finally {
142                 location--;
143                 servletRequestContext.setServletRequest(oldReq);
144                 servletRequestContext.setServletResponse(oldResp);
145             }
146         }
147     }
148 }
149