1 /*
2  * JBoss, Home of Professional Open Source.
3  * Copyright 2014 Red Hat, Inc., and individual contributors
4  * as indicated by the @author tags.
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  *  Unless required by applicable law or agreed to in writing, software
13  *  distributed under the License is distributed on an "AS IS" BASIS,
14  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  *  See the License for the specific language governing permissions and
16  *  limitations under the License.
17  */

18
19 package io.undertow.websockets.jsr;
20
21 import static io.undertow.websockets.jsr.ServerWebSocketContainer.WebSocketHandshakeHolder;
22
23 import java.io.IOException;
24 import java.security.AccessController;
25 import java.util.ArrayList;
26 import java.util.Collections;
27 import java.util.List;
28 import java.util.Set;
29 import java.util.concurrent.ConcurrentHashMap;
30
31 import javax.servlet.Filter;
32 import javax.servlet.FilterChain;
33 import javax.servlet.FilterConfig;
34 import javax.servlet.ServletException;
35 import javax.servlet.ServletRequest;
36 import javax.servlet.ServletResponse;
37 import javax.servlet.http.HttpServletRequest;
38 import javax.servlet.http.HttpServletResponse;
39 import javax.servlet.http.HttpSessionEvent;
40 import javax.servlet.http.HttpSessionListener;
41 import javax.websocket.CloseReason;
42 import javax.websocket.server.ServerContainer;
43
44 import org.xnio.ChannelListener;
45 import org.xnio.StreamConnection;
46
47 import io.undertow.server.HttpServerExchange;
48 import io.undertow.server.HttpUpgradeListener;
49 import io.undertow.server.session.Session;
50 import io.undertow.servlet.handlers.ServletRequestContext;
51 import io.undertow.servlet.spec.HttpSessionImpl;
52 import io.undertow.servlet.websockets.ServletWebSocketHttpExchange;
53 import io.undertow.util.Headers;
54 import io.undertow.util.PathTemplateMatcher;
55 import io.undertow.util.StatusCodes;
56 import io.undertow.websockets.WebSocketConnectionCallback;
57 import io.undertow.websockets.core.WebSocketChannel;
58 import io.undertow.websockets.core.WebSockets;
59 import io.undertow.websockets.core.protocol.Handshake;
60 import io.undertow.websockets.jsr.handshake.HandshakeUtil;
61
62 /**
63  * Filter that provides HTTP upgrade functionality. This should be run after all user filters, but before any servlets.
64  * <p>
65  * The use of a filter rather than a servlet allows for normal HTTP requests to be served from the same location
66  * as a web socket endpoint if no upgrade header is found.
67  * <p>
68  *
69  * @author Stuart Douglas
70  */

71 public class JsrWebSocketFilter implements Filter {
72
73     private WebSocketConnectionCallback callback;
74     private PathTemplateMatcher<WebSocketHandshakeHolder> pathTemplateMatcher;
75     private Set<WebSocketChannel> peerConnections;
76     private ServerWebSocketContainer container;
77
78     private static final String SESSION_ATTRIBUTE = "io.undertow.websocket.current-connections";
79
80
81     @Override
82     public void init(final FilterConfig filterConfig) throws ServletException {
83         peerConnections = Collections.newSetFromMap(new ConcurrentHashMap<WebSocketChannel, Boolean>());
84         container = (ServerWebSocketContainer) filterConfig.getServletContext().getAttribute(ServerContainer.class.getName());
85         container.deploymentComplete();
86         pathTemplateMatcher = new PathTemplateMatcher<>();
87         WebSocketDeploymentInfo info = (WebSocketDeploymentInfo)filterConfig.getServletContext().getAttribute(WebSocketDeploymentInfo.ATTRIBUTE_NAME);
88         for (ConfiguredServerEndpoint endpoint : container.getConfiguredServerEndpoints()) {
89             if (info == null || info.getExtensions().isEmpty()) {
90                 pathTemplateMatcher.add(endpoint.getPathTemplate(), ServerWebSocketContainer.handshakes(endpoint));
91             } else {
92                 pathTemplateMatcher.add(endpoint.getPathTemplate(), ServerWebSocketContainer.handshakes(endpoint, info.getExtensions()));
93             }
94         }
95         this.callback = new EndpointSessionHandler(container);
96     }
97
98     @Override
99     public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain) throws IOException, ServletException {
100         HttpServletRequest req = (HttpServletRequest) request;
101         HttpServletResponse resp = (HttpServletResponse) response;
102         if (req.getHeader(Headers.UPGRADE_STRING) != null) {
103             final ServletWebSocketHttpExchange facade = new ServletWebSocketHttpExchange(req, resp, peerConnections);
104
105             String path;
106             if (req.getPathInfo() == null) {
107                 path = req.getServletPath();
108             } else {
109                 path = req.getServletPath() + req.getPathInfo();
110             }
111             if (!path.startsWith("/")) {
112                 path = "/" + path;
113             }
114             PathTemplateMatcher.PathMatchResult<WebSocketHandshakeHolder> matchResult = pathTemplateMatcher.match(path);
115             if (matchResult != null) {
116                 Handshake handshaker = null;
117                 for (Handshake method : matchResult.getValue().handshakes) {
118                     if (method.matches(facade)) {
119                         handshaker = method;
120                         break;
121                     }
122                 }
123
124                 if (handshaker != null) {
125                     if(container.isClosed()) {
126                         resp.sendError(StatusCodes.SERVICE_UNAVAILABLE);
127                         return;
128                     }
129                     facade.putAttachment(HandshakeUtil.PATH_PARAMS, matchResult.getParameters());
130                     facade.putAttachment(HandshakeUtil.PRINCIPAL, req.getUserPrincipal());
131                     final Handshake selected = handshaker;
132                     ServletRequestContext src = ServletRequestContext.requireCurrent();
133                     final HttpSessionImpl session = src.getCurrentServletContext().getSession(src.getExchange(), false);
134                     facade.upgradeChannel(new HttpUpgradeListener() {
135                         @Override
136                         public void handleUpgrade(StreamConnection streamConnection, HttpServerExchange exchange) {
137
138                             WebSocketChannel channel = selected.createChannel(facade, streamConnection, facade.getBufferPool());
139                             peerConnections.add(channel);
140                             if(session != null) {
141                                 final Session underlying;
142                                 if (System.getSecurityManager() == null) {
143                                     underlying = session.getSession();
144                                 } else {
145                                     underlying = AccessController.doPrivileged(new HttpSessionImpl.UnwrapSessionAction(session));
146                                 }
147                                 List<WebSocketChannel> connections;
148                                 synchronized (underlying) {
149                                     connections = (List<WebSocketChannel>) underlying.getAttribute(SESSION_ATTRIBUTE);
150                                     if(connections == null) {
151                                         underlying.setAttribute(SESSION_ATTRIBUTE, connections = new ArrayList<>());
152                                     }
153                                     connections.add(channel);
154                                 }
155                                 final List<WebSocketChannel> finalConnections = connections;
156                                 channel.addCloseTask(new ChannelListener<WebSocketChannel>() {
157                                     @Override
158                                     public void handleEvent(WebSocketChannel channel) {
159                                         synchronized (underlying) {
160                                             finalConnections.remove(channel);
161                                         }
162                                     }
163                                 });
164                             }
165                             callback.onConnect(facade, channel);
166                         }
167                     });
168                     handshaker.handshake(facade);
169                     return;
170                 }
171             }
172         }
173         chain.doFilter(request, response);
174     }
175
176     @Override
177     public void destroy() {
178
179     }
180
181
182     public static class LogoutListener implements HttpSessionListener {
183
184         @Override
185         public void sessionCreated(HttpSessionEvent se) {
186
187         }
188
189         @Override
190         public void sessionDestroyed(HttpSessionEvent se) {
191             HttpSessionImpl session = (HttpSessionImpl) se.getSession();
192             final Session underlying;
193             if (System.getSecurityManager() == null) {
194                 underlying = session.getSession();
195             } else {
196                 underlying = AccessController.doPrivileged(new HttpSessionImpl.UnwrapSessionAction(session));
197             }
198             List<WebSocketChannel> connections = (List<WebSocketChannel>) underlying.getAttribute(SESSION_ATTRIBUTE);
199             if(connections != null) {
200                 synchronized (underlying) {
201                     for(WebSocketChannel c : connections) {
202                         WebSockets.sendClose(CloseReason.CloseCodes.VIOLATED_POLICY.getCode(), "", c, null);
203                     }
204                 }
205             }
206         }
207     }
208
209 }
210