summaryrefslogtreecommitdiffstats
path: root/src/DotNetOpenAuth.Web/OAuthWebSecurity.cs
blob: 7a3f8641a200a60b893002666646fbeebf00d8f2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Threading;
using System.Web;
using DotNetOpenAuth.Messaging;
using DotNetOpenAuth.Web.Clients;
using DotNetOpenAuth.Web.Resources;

namespace DotNetOpenAuth.Web
{
    /// <summary>
    /// Contains APIs to manage authentication against OAuth & OpenID service providers
    /// </summary>
    public static class OAuthWebSecurity
    {
        private const string ProviderQueryStringName = "__provider__";

        private static IOAuthDataProvider _oAuthDataProvider;
        private static IOAuthDataProvider OAuthDataProvider
        {
            get
            {
                return _oAuthDataProvider;
            }
        }

        // contains all registered authentication clients
        private static readonly AuthenticationClientCollection _authenticationClients = new AuthenticationClientCollection();

        public static void RegisterDataProvider(IOAuthDataProvider dataProvider)
        {
            if (dataProvider == null)
            {
                throw new ArgumentNullException("dataProvider");
            }

            var originalValue = Interlocked.CompareExchange(ref _oAuthDataProvider, dataProvider, null);
            if (originalValue != null)
            {
                throw new InvalidOperationException(WebResources.OAuthDataProviderRegistered);
            }
        }

        public static bool IsOAuthDataProviderRegistered
        {
            get
            {
                return OAuthDataProvider != null;
            }
        }

        private static void EnsureDataProvider()
        {
            if (!IsOAuthDataProviderRegistered)
            {
                throw new InvalidOperationException(WebResources.OAuthDataProviderNotRegistered);
            }
        }

        /// <summary>
        /// Registers a supported OAuth client with the specified consumer key and consumer secret.
        /// </summary>
        /// <param name="client">One of the supported OAuth clients.</param>
        /// <param name="consumerKey">The consumer key.</param>
        /// <param name="consumerSecret">The consumer secret.</param>
        public static void RegisterOAuthClient(BuiltInOAuthClient client, string consumerKey, string consumerSecret)
        {
            IAuthenticationClient authenticationClient;
            switch (client)
            {
                case BuiltInOAuthClient.LinkedIn:
                    authenticationClient = new LinkedInClient(consumerKey, consumerSecret);
                    break;

                case BuiltInOAuthClient.Twitter:
                    authenticationClient = new TwitterClient(consumerKey, consumerSecret);
                    break;

                case BuiltInOAuthClient.Facebook:
                    authenticationClient = new FacebookClient(consumerKey, consumerSecret);
                    break;

                case BuiltInOAuthClient.WindowsLive:
                    authenticationClient = new WindowsLiveClient(consumerKey, consumerSecret);
                    break;

                default:
                    throw new ArgumentOutOfRangeException("client");
            }
            RegisterClient(authenticationClient);
        }

        /// <summary>
        /// Registers a supported OpenID client
        /// </summary>
        public static void RegisterOpenIDClient(BuiltInOpenIDClient openIDClient)
        {
            IAuthenticationClient client;
            switch (openIDClient)
            {
                case BuiltInOpenIDClient.Google:
                    client = new GoogleOpenIdClient();
                    break;

                case BuiltInOpenIDClient.Yahoo:
                    client = new YahooOpenIdClient();
                    break;

                default:
                    throw new ArgumentOutOfRangeException("openIDClient");
            }

            RegisterClient(client);
        }

        /// <summary>
        /// Registers an authentication client.
        /// </summary>
        public static void RegisterClient(IAuthenticationClient client)
        {
            if (client == null)
            {
                throw new ArgumentNullException("client");
            }

            if (String.IsNullOrEmpty(client.ProviderName))
            {
                throw new ArgumentException(WebResources.InvalidServiceProviderName, "client");
            }

            if (_authenticationClients.Contains(client))
            {
                throw new ArgumentException(WebResources.ServiceProviderNameExists, "client");
            }

            _authenticationClients.Add(client);
        }

        /// <summary>
        /// Requests the specified provider to start the authentication by directing users to an external website
        /// </summary>
        /// <param name="provider">The provider.</param>
        public static void RequestAuthentication(string provider)
        {
            RequestAuthentication(provider, returnUrl: null);
        }

        /// <summary>
        /// Requests the specified provider to start the authentication by directing users to an external website
        /// </summary>
        /// <param name="provider">The provider.</param>
        /// <param name="returnUrl">The return url after user is authenticated.</param>
        [SuppressMessage(
            "Microsoft.Design",
            "CA1054:UriParametersShouldNotBeStrings",
            MessageId = "1#",
            Justification = "We want to allow relative app path, and support ~/")]
        public static void RequestAuthentication(string provider, string returnUrl)
        {
            if (HttpContext.Current == null)
            {
                throw new InvalidOperationException(WebResources.HttpContextNotAvailable);
            }

            RequestAuthenticationCore(new HttpContextWrapper(HttpContext.Current), provider, returnUrl);
        }

        internal static void RequestAuthenticationCore(HttpContextBase context, string provider, string returnUrl)
        {
            if (String.IsNullOrEmpty(provider))
            {
                throw new ArgumentException(
                    String.Format(CultureInfo.CurrentCulture, WebResources.Argument_Cannot_Be_Null_Or_Empty, "provider"),
                    "provider");
            }

            IAuthenticationClient client = GetOAuthClient(provider);

            // convert returnUrl to an absolute path
            Uri uri;
            if (!String.IsNullOrEmpty(returnUrl))
            {
                uri = UriHelper.ConvertToAbsoluteUri(returnUrl);
            }
            else
            {
                uri = UriHelper.GetPublicFacingUrl(context.Request);
            }
            // attach the provider parameter so that we know which provider initiated 
            // the login when user is redirected back to this page
            uri = uri.AttachQueryStringParameter(ProviderQueryStringName, provider);
            client.RequestAuthentication(context, uri);
        }

        /// <summary>
        /// Checks if user is successfully authenticated when user is redirected back to this user.
        /// </summary>
        /// <returns></returns>
        public static AuthenticationResult VerifyAuthentication()
        {
            if (HttpContext.Current == null)
            {
                throw new InvalidOperationException(WebResources.HttpContextNotAvailable);
            }

            return VerifyAuthenticationCore(new HttpContextWrapper(HttpContext.Current));
        }

        internal static AuthenticationResult VerifyAuthenticationCore(HttpContextBase context)
        {
            string providerName = context.Request.QueryString[ProviderQueryStringName];
            if (String.IsNullOrEmpty(providerName))
            {
                return AuthenticationResult.Failed;
            }

            IAuthenticationClient client;
            if (TryGetOAuthClient(providerName, out client))
            {
                AuthenticationResult result = client.VerifyAuthentication(context);
                if (!result.IsSuccessful)
                {
                    // if the result is a Failed result, creates a new Failed response which has providerName info.
                    result = new AuthenticationResult(isSuccessful: false, provider: providerName, providerUserId: null,
                                                      userName: null, extraData: null);
                }

                return result;
            }
            else
            {
                throw new InvalidOperationException(WebResources.InvalidServiceProviderName);
            }
        }

        /// <summary>
        /// Checks if the specified provider user id represents a valid account.
        /// If it does, log user in.
        /// </summary>
        /// <param name="providerName">Name of the provider.</param>
        /// <param name="providerUserId">The provider user id.</param>
        /// <returns><c>true</c> if the login is successful.</returns>
        [SuppressMessage("Microsoft.Naming", "CA1726:UsePreferredTerms", MessageId = "Login", Justification = "Login is used more consistently in ASP.Net")]
        public static bool Login(string providerName, string providerUserId, bool createPersistentCookie)
        {
            if (HttpContext.Current == null)
            {
                throw new InvalidOperationException(WebResources.HttpContextNotAvailable);
            }

            return LoginCore(new HttpContextWrapper(HttpContext.Current), providerName, providerUserId, createPersistentCookie);
        }

        internal static bool LoginCore(HttpContextBase context, string providerName, string providerUserId, bool createPersistentCookie)
        {
            EnsureDataProvider();

            string userName = OAuthDataProvider.GetUserNameFromOAuth(providerName, providerUserId);
            if (String.IsNullOrEmpty(userName))
            {
                return false;
            }

            OAuthAuthenticationTicketHelper.SetAuthenticationTicket(
                   context,
                   userName,
                   createPersistentCookie);
            return true;
        }

        /// <summary>
        /// Gets a value indicating whether the current user is authenticated by an OAuth provider.
        /// </summary>
        public static bool IsAuthenticatedViaOAuth
        {
            get
            {
                if (HttpContext.Current == null)
                {
                    throw new InvalidOperationException(WebResources.HttpContextNotAvailable);
                }

                return GetIsAuthenticatedViaOAuthCore(new HttpContextWrapper(HttpContext.Current));
            }
        }

        internal static bool GetIsAuthenticatedViaOAuthCore(HttpContextBase context)
        {
            if (!context.Request.IsAuthenticated)
            {
                return false;
            }
            return OAuthAuthenticationTicketHelper.IsOAuthAuthenticationTicket(context);
        }

        /// <summary>
        /// Creates or update the account with the specified provider, provider user id and associate it with the specified user name.
        /// </summary>
        /// <param name="providerName">Name of the provider.</param>
        /// <param name="providerUserId">The provider user id.</param>
        /// <param name="userName">The user name.</param>
        public static void CreateOrUpdateAccount(string providerName, string providerUserId, string userName)
        {
            EnsureDataProvider();
            OAuthDataProvider.CreateOrUpdateOAuthAccount(providerName, providerUserId, userName);
        }

        public static string GetUsername(string providerName, string providerUserId)
        {
            EnsureDataProvider();
            return OAuthDataProvider.GetUserNameFromOAuth(providerName, providerUserId);
        }

        /// <summary>
        /// Gets all OAuth & OpenID accounts which are associted with the specified user name.
        /// </summary>
        /// <param name="userName">The user name.</param>
        public static ICollection<OAuthAccount> GetAccountsFromUserName(string userName)
        {
            if (String.IsNullOrEmpty(userName))
            {
                throw new ArgumentException(
                    String.Format(CultureInfo.CurrentCulture, WebResources.Argument_Cannot_Be_Null_Or_Empty, "userName"),
                    "userName");
            }

            EnsureDataProvider();

            return OAuthDataProvider.GetOAuthAccountsFromUserName(userName);
        }

        /// <summary>
        /// Delete the specified OAuth & OpenID account
        /// </summary>
        /// <param name="providerName">Name of the provider.</param>
        /// <param name="providerUserId">The provider user id.</param>
        public static void DeleteAccount(string providerName, string providerUserId)
        {
            EnsureDataProvider();

            OAuthDataProvider.DeleteOAuthAccount(providerName, providerUserId);
        }

        internal static IAuthenticationClient GetOAuthClient(string providerName)
        {
            if (!_authenticationClients.Contains(providerName))
            {
                throw new ArgumentException(WebResources.ServiceProviderNotFound, "providerName");
            }

            return _authenticationClients[providerName];
        }

        internal static bool TryGetOAuthClient(string provider, out IAuthenticationClient client)
        {
            if (_authenticationClients.Contains(provider))
            {
                client = _authenticationClients[provider];
                return true;
            }
            else
            {
                client = null;
                return false;
            }
        }

        /// <summary>
        /// for unit tests
        /// </summary>
        internal static void ClearProviders()
        {
            _authenticationClients.Clear();
        }

        /// <summary>
        /// for unit tests
        /// </summary>
        internal static void ClearDataProvider()
        {
            _oAuthDataProvider = null;
        }
    }
}