1
18
19 package io.undertow.websockets.jsr;
20
21 import static io.undertow.websockets.jsr.ServerWebSocketContainer.WebSocketHandshakeHolder;
22
23 import java.io.IOException;
24 import java.security.AccessController;
25 import java.util.ArrayList;
26 import java.util.Collections;
27 import java.util.List;
28 import java.util.Set;
29 import java.util.concurrent.ConcurrentHashMap;
30
31 import javax.servlet.Filter;
32 import javax.servlet.FilterChain;
33 import javax.servlet.FilterConfig;
34 import javax.servlet.ServletException;
35 import javax.servlet.ServletRequest;
36 import javax.servlet.ServletResponse;
37 import javax.servlet.http.HttpServletRequest;
38 import javax.servlet.http.HttpServletResponse;
39 import javax.servlet.http.HttpSessionEvent;
40 import javax.servlet.http.HttpSessionListener;
41 import javax.websocket.CloseReason;
42 import javax.websocket.server.ServerContainer;
43
44 import org.xnio.ChannelListener;
45 import org.xnio.StreamConnection;
46
47 import io.undertow.server.HttpServerExchange;
48 import io.undertow.server.HttpUpgradeListener;
49 import io.undertow.server.session.Session;
50 import io.undertow.servlet.handlers.ServletRequestContext;
51 import io.undertow.servlet.spec.HttpSessionImpl;
52 import io.undertow.servlet.websockets.ServletWebSocketHttpExchange;
53 import io.undertow.util.Headers;
54 import io.undertow.util.PathTemplateMatcher;
55 import io.undertow.util.StatusCodes;
56 import io.undertow.websockets.WebSocketConnectionCallback;
57 import io.undertow.websockets.core.WebSocketChannel;
58 import io.undertow.websockets.core.WebSockets;
59 import io.undertow.websockets.core.protocol.Handshake;
60 import io.undertow.websockets.jsr.handshake.HandshakeUtil;
61
62
71 public class JsrWebSocketFilter implements Filter {
72
73 private WebSocketConnectionCallback callback;
74 private PathTemplateMatcher<WebSocketHandshakeHolder> pathTemplateMatcher;
75 private Set<WebSocketChannel> peerConnections;
76 private ServerWebSocketContainer container;
77
78 private static final String SESSION_ATTRIBUTE = "io.undertow.websocket.current-connections";
79
80
81 @Override
82 public void init(final FilterConfig filterConfig) throws ServletException {
83 peerConnections = Collections.newSetFromMap(new ConcurrentHashMap<WebSocketChannel, Boolean>());
84 container = (ServerWebSocketContainer) filterConfig.getServletContext().getAttribute(ServerContainer.class.getName());
85 container.deploymentComplete();
86 pathTemplateMatcher = new PathTemplateMatcher<>();
87 WebSocketDeploymentInfo info = (WebSocketDeploymentInfo)filterConfig.getServletContext().getAttribute(WebSocketDeploymentInfo.ATTRIBUTE_NAME);
88 for (ConfiguredServerEndpoint endpoint : container.getConfiguredServerEndpoints()) {
89 if (info == null || info.getExtensions().isEmpty()) {
90 pathTemplateMatcher.add(endpoint.getPathTemplate(), ServerWebSocketContainer.handshakes(endpoint));
91 } else {
92 pathTemplateMatcher.add(endpoint.getPathTemplate(), ServerWebSocketContainer.handshakes(endpoint, info.getExtensions()));
93 }
94 }
95 this.callback = new EndpointSessionHandler(container);
96 }
97
98 @Override
99 public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain) throws IOException, ServletException {
100 HttpServletRequest req = (HttpServletRequest) request;
101 HttpServletResponse resp = (HttpServletResponse) response;
102 if (req.getHeader(Headers.UPGRADE_STRING) != null) {
103 final ServletWebSocketHttpExchange facade = new ServletWebSocketHttpExchange(req, resp, peerConnections);
104
105 String path;
106 if (req.getPathInfo() == null) {
107 path = req.getServletPath();
108 } else {
109 path = req.getServletPath() + req.getPathInfo();
110 }
111 if (!path.startsWith("/")) {
112 path = "/" + path;
113 }
114 PathTemplateMatcher.PathMatchResult<WebSocketHandshakeHolder> matchResult = pathTemplateMatcher.match(path);
115 if (matchResult != null) {
116 Handshake handshaker = null;
117 for (Handshake method : matchResult.getValue().handshakes) {
118 if (method.matches(facade)) {
119 handshaker = method;
120 break;
121 }
122 }
123
124 if (handshaker != null) {
125 if(container.isClosed()) {
126 resp.sendError(StatusCodes.SERVICE_UNAVAILABLE);
127 return;
128 }
129 facade.putAttachment(HandshakeUtil.PATH_PARAMS, matchResult.getParameters());
130 facade.putAttachment(HandshakeUtil.PRINCIPAL, req.getUserPrincipal());
131 final Handshake selected = handshaker;
132 ServletRequestContext src = ServletRequestContext.requireCurrent();
133 final HttpSessionImpl session = src.getCurrentServletContext().getSession(src.getExchange(), false);
134 facade.upgradeChannel(new HttpUpgradeListener() {
135 @Override
136 public void handleUpgrade(StreamConnection streamConnection, HttpServerExchange exchange) {
137
138 WebSocketChannel channel = selected.createChannel(facade, streamConnection, facade.getBufferPool());
139 peerConnections.add(channel);
140 if(session != null) {
141 final Session underlying;
142 if (System.getSecurityManager() == null) {
143 underlying = session.getSession();
144 } else {
145 underlying = AccessController.doPrivileged(new HttpSessionImpl.UnwrapSessionAction(session));
146 }
147 List<WebSocketChannel> connections;
148 synchronized (underlying) {
149 connections = (List<WebSocketChannel>) underlying.getAttribute(SESSION_ATTRIBUTE);
150 if(connections == null) {
151 underlying.setAttribute(SESSION_ATTRIBUTE, connections = new ArrayList<>());
152 }
153 connections.add(channel);
154 }
155 final List<WebSocketChannel> finalConnections = connections;
156 channel.addCloseTask(new ChannelListener<WebSocketChannel>() {
157 @Override
158 public void handleEvent(WebSocketChannel channel) {
159 synchronized (underlying) {
160 finalConnections.remove(channel);
161 }
162 }
163 });
164 }
165 callback.onConnect(facade, channel);
166 }
167 });
168 handshaker.handshake(facade);
169 return;
170 }
171 }
172 }
173 chain.doFilter(request, response);
174 }
175
176 @Override
177 public void destroy() {
178
179 }
180
181
182 public static class LogoutListener implements HttpSessionListener {
183
184 @Override
185 public void sessionCreated(HttpSessionEvent se) {
186
187 }
188
189 @Override
190 public void sessionDestroyed(HttpSessionEvent se) {
191 HttpSessionImpl session = (HttpSessionImpl) se.getSession();
192 final Session underlying;
193 if (System.getSecurityManager() == null) {
194 underlying = session.getSession();
195 } else {
196 underlying = AccessController.doPrivileged(new HttpSessionImpl.UnwrapSessionAction(session));
197 }
198 List<WebSocketChannel> connections = (List<WebSocketChannel>) underlying.getAttribute(SESSION_ATTRIBUTE);
199 if(connections != null) {
200 synchronized (underlying) {
201 for(WebSocketChannel c : connections) {
202 WebSockets.sendClose(CloseReason.CloseCodes.VIOLATED_POLICY.getCode(), "", c, null);
203 }
204 }
205 }
206 }
207 }
208
209 }
210