1
18
19 package io.undertow.servlet.handlers.security;
20
21 import static io.undertow.util.StatusCodes.OK;
22
23 import io.undertow.security.api.AuthenticationMechanism;
24 import io.undertow.security.api.AuthenticationMechanismFactory;
25 import io.undertow.security.idm.IdentityManager;
26 import io.undertow.security.impl.FormAuthenticationMechanism;
27 import io.undertow.server.HttpServerExchange;
28 import io.undertow.server.handlers.form.FormParserFactory;
29 import io.undertow.server.session.Session;
30 import io.undertow.server.session.SessionListener;
31 import io.undertow.server.session.SessionManager;
32 import io.undertow.servlet.handlers.ServletRequestContext;
33 import io.undertow.servlet.spec.HttpSessionImpl;
34 import io.undertow.servlet.util.SavedRequest;
35 import io.undertow.util.Headers;
36 import io.undertow.util.RedirectBuilder;
37
38 import javax.servlet.RequestDispatcher;
39 import javax.servlet.ServletException;
40 import javax.servlet.ServletRequest;
41 import javax.servlet.ServletResponse;
42 import javax.servlet.http.HttpServletResponse;
43 import javax.servlet.http.HttpServletResponseWrapper;
44
45 import java.io.IOException;
46 import java.security.AccessController;
47 import java.util.Collections;
48 import java.util.Map;
49 import java.util.Set;
50 import java.util.WeakHashMap;
51
52
58 public class ServletFormAuthenticationMechanism extends FormAuthenticationMechanism {
59
60 public static final AuthenticationMechanismFactory FACTORY = new Factory();
61
62 private static final String SESSION_KEY = "io.undertow.servlet.form.auth.redirect.location";
63
64 public static final String SAVE_ORIGINAL_REQUEST = "save-original-request";
65
66 private final boolean saveOriginalRequest;
67
68 private final Set<SessionManager> seenSessionManagers = Collections.synchronizedSet(Collections.newSetFromMap(new WeakHashMap<SessionManager, Boolean>()));
69
70 private final String defaultPage;
71
72 private final boolean overrideInitial;
73
74 private static final SessionListener LISTENER = new SessionListener() {
75 @Override
76 public void sessionCreated(Session session, HttpServerExchange exchange) { }
77
78 @Override
79 public void sessionDestroyed(Session session, HttpServerExchange exchange, SessionDestroyedReason reason) { }
80
81 @Override
82 public void attributeAdded(Session session, String name, Object value) { }
83
84 @Override
85 public void attributeUpdated(Session session, String name, Object newValue, Object oldValue) { }
86
87 @Override
88 public void attributeRemoved(Session session, String name, Object oldValue) { }
89
90 @Override
91 public void sessionIdChanged(Session session, String oldSessionId) {
92 String oldLocation = (String)session.getAttribute(SESSION_KEY);
93 if(oldLocation != null) {
94
95
96 String oldPart = ";jsessionid=" + oldSessionId;
97 if (oldLocation.contains(oldPart)) {
98 session.setAttribute(ServletFormAuthenticationMechanism.SESSION_KEY, oldLocation.replace(oldPart, ";jsessionid=" + session.getId()));
99 }
100 }
101 }
102 };
103
104 @Deprecated
105 public ServletFormAuthenticationMechanism(final String name, final String loginPage, final String errorPage) {
106 super(name, loginPage, errorPage);
107 this.saveOriginalRequest = true;
108 this.defaultPage = null;
109 this.overrideInitial = false;
110 }
111
112 @Deprecated
113 public ServletFormAuthenticationMechanism(final String name, final String loginPage, final String errorPage, final String postLocation) {
114 super(name, loginPage, errorPage, postLocation);
115 this.saveOriginalRequest = true;
116 this.defaultPage = null;
117 this.overrideInitial = false;
118 }
119
120 public ServletFormAuthenticationMechanism(FormParserFactory formParserFactory, String name, String loginPage, String errorPage, String postLocation) {
121 super(formParserFactory, name, loginPage, errorPage, postLocation);
122 this.saveOriginalRequest = true;
123 this.defaultPage = null;
124 this.overrideInitial = false;
125 }
126
127 public ServletFormAuthenticationMechanism(FormParserFactory formParserFactory, String name, String loginPage, String errorPage) {
128 super(formParserFactory, name, loginPage, errorPage);
129 this.saveOriginalRequest = true;
130 this.defaultPage = null;
131 this.overrideInitial = false;
132 }
133
134 public ServletFormAuthenticationMechanism(FormParserFactory formParserFactory, String name, String loginPage, String errorPage, IdentityManager identityManager) {
135 super(formParserFactory, name, loginPage, errorPage, identityManager);
136 this.saveOriginalRequest = true;
137 this.defaultPage = null;
138 this.overrideInitial = false;
139 }
140
141 public ServletFormAuthenticationMechanism(FormParserFactory formParserFactory, String name, String loginPage, String errorPage, IdentityManager identityManager, boolean saveOriginalRequest) {
142 super(formParserFactory, name, loginPage, errorPage, identityManager);
143 this.saveOriginalRequest = true;
144 this.defaultPage = null;
145 this.overrideInitial = false;
146 }
147
148 public ServletFormAuthenticationMechanism(FormParserFactory formParserFactory, String name, String loginPage, String errorPage, String defaultPage, boolean overrideInitial, IdentityManager identityManager, boolean saveOriginalRequest) {
149 super(formParserFactory, name, loginPage, errorPage, identityManager);
150 this.saveOriginalRequest = saveOriginalRequest;
151 this.defaultPage = defaultPage;
152 this.overrideInitial = overrideInitial;
153 }
154
155 @Override
156 protected Integer servePage(final HttpServerExchange exchange, final String location) {
157 final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY);
158 ServletRequest req = servletRequestContext.getServletRequest();
159 ServletResponse resp = servletRequestContext.getServletResponse();
160 RequestDispatcher disp = req.getRequestDispatcher(location);
161
162 exchange.getResponseHeaders().add(Headers.CACHE_CONTROL, "no-cache, no-store, must-revalidate");
163 exchange.getResponseHeaders().add(Headers.PRAGMA, "no-cache");
164 exchange.getResponseHeaders().add(Headers.EXPIRES, "0");
165
166 final FormResponseWrapper respWrapper = exchange.getStatusCode() != OK && resp instanceof HttpServletResponse
167 ? new FormResponseWrapper((HttpServletResponse) resp) : null;
168
169 try {
170 disp.forward(req, respWrapper != null ? respWrapper : resp);
171 } catch (ServletException e) {
172 throw new RuntimeException(e);
173 } catch (IOException e) {
174 throw new RuntimeException(e);
175 }
176
177 return respWrapper != null ? respWrapper.getStatus() : null;
178 }
179
180 @Override
181 protected void storeInitialLocation(final HttpServerExchange exchange) {
182 storeInitialLocation(exchange, null, 0);
183 }
184
185
193 protected void storeInitialLocation(final HttpServerExchange exchange, byte[] bytes, int contentLength) {
194 if(!saveOriginalRequest) {
195 return;
196 }
197 final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY);
198 HttpSessionImpl httpSession = servletRequestContext.getCurrentServletContext().getSession(exchange, true);
199 Session session;
200 if (System.getSecurityManager() == null) {
201 session = httpSession.getSession();
202 } else {
203 session = AccessController.doPrivileged(new HttpSessionImpl.UnwrapSessionAction(httpSession));
204 }
205 SessionManager manager = session.getSessionManager();
206 if (seenSessionManagers.add(manager)) {
207 manager.registerSessionListener(LISTENER);
208 }
209 session.setAttribute(SESSION_KEY, RedirectBuilder.redirect(exchange, exchange.getRelativePath()));
210 if(bytes == null) {
211 SavedRequest.trySaveRequest(exchange);
212 } else {
213 SavedRequest.trySaveRequest(exchange, bytes, contentLength);
214 }
215 }
216
217 @Override
218 protected void handleRedirectBack(final HttpServerExchange exchange) {
219 final ServletRequestContext servletRequestContext = exchange.getAttachment(ServletRequestContext.ATTACHMENT_KEY);
220 HttpServletResponse resp = (HttpServletResponse) servletRequestContext.getServletResponse();
221 HttpSessionImpl httpSession = servletRequestContext.getCurrentServletContext().getSession(exchange, false);
222 if (httpSession != null) {
223 Session session;
224 if (System.getSecurityManager() == null) {
225 session = httpSession.getSession();
226 } else {
227 session = AccessController.doPrivileged(new HttpSessionImpl.UnwrapSessionAction(httpSession));
228 }
229 String path = (String) session.getAttribute(SESSION_KEY);
230 if ((path == null || overrideInitial) && defaultPage != null) {
231 path = defaultPage;
232 }
233 if (path != null) {
234 try {
235 resp.sendRedirect(path);
236 } catch (IOException e) {
237 throw new RuntimeException(e);
238 }
239 }
240 }
241
242 }
243
244 private static class FormResponseWrapper extends HttpServletResponseWrapper {
245
246 private int status = OK;
247
248 private FormResponseWrapper(final HttpServletResponse wrapped) {
249 super(wrapped);
250 }
251
252 @Override
253 public void setStatus(int sc, String sm) {
254 status = sc;
255 }
256
257 @Override
258 public void setStatus(int sc) {
259 status = sc;
260 }
261
262 @Override
263 public int getStatus() {
264 return status;
265 }
266
267 }
268
269 public static class Factory implements AuthenticationMechanismFactory {
270
271 @Deprecated
272 public Factory(IdentityManager identityManager) {}
273
274 public Factory() {}
275
276 @Override
277 public AuthenticationMechanism create(String mechanismName, IdentityManager identityManager, FormParserFactory formParserFactory, Map<String, String> properties) {
278 final String loginPage = properties.get(LOGIN_PAGE);
279 final String errorPage = properties.get(ERROR_PAGE);
280 final String defaultPage = properties.get(DEFAULT_PAGE);
281 final boolean overrideInitial = properties.containsKey(OVERRIDE_INITIAL) ?
282 Boolean.parseBoolean(properties.get(OVERRIDE_INITIAL)): false;
283 boolean saveOriginal = true;
284 if (properties.containsKey(SAVE_ORIGINAL_REQUEST)) {
285 saveOriginal = Boolean.parseBoolean(properties.get(SAVE_ORIGINAL_REQUEST));
286 }
287 return new ServletFormAuthenticationMechanism(formParserFactory, mechanismName, loginPage, errorPage, defaultPage, overrideInitial, identityManager, saveOriginal);
288 }
289 }
290
291 }
292