diff options
7 files changed, 154 insertions, 81 deletions
diff --git a/projecttemplates/RelyingPartyLogic/OAuthAuthenticationModule.cs b/projecttemplates/RelyingPartyLogic/OAuthAuthenticationModule.cs index 13e725d..148af91 100644 --- a/projecttemplates/RelyingPartyLogic/OAuthAuthenticationModule.cs +++ b/projecttemplates/RelyingPartyLogic/OAuthAuthenticationModule.cs @@ -53,10 +53,11 @@ namespace RelyingPartyLogic { var tokenAnalyzer = new SpecialAccessTokenAnalyzer(crypto, crypto); var resourceServer = new ResourceServer(tokenAnalyzer); - IPrincipal principal; - var errorMessage = resourceServer.VerifyAccess(new HttpRequestWrapper(this.application.Context.Request), out principal); - if (errorMessage == null) { + try { + IPrincipal principal = resourceServer.GetPrincipal(new HttpRequestWrapper(this.application.Context.Request)); this.application.Context.User = principal; + } catch (ProtocolFaultResponseException ex) { + ex.ErrorResponse.Send(); } } } diff --git a/projecttemplates/RelyingPartyLogic/OAuthAuthorizationManager.cs b/projecttemplates/RelyingPartyLogic/OAuthAuthorizationManager.cs index 1a3a0f0..e38d955 100644 --- a/projecttemplates/RelyingPartyLogic/OAuthAuthorizationManager.cs +++ b/projecttemplates/RelyingPartyLogic/OAuthAuthorizationManager.cs @@ -14,6 +14,7 @@ namespace RelyingPartyLogic { using System.ServiceModel.Channels; using System.ServiceModel.Security; using DotNetOpenAuth; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OAuth; using DotNetOpenAuth.OAuth2; @@ -37,33 +38,33 @@ namespace RelyingPartyLogic { var resourceServer = new ResourceServer(tokenAnalyzer); try { - IPrincipal principal; - var errorResponse = resourceServer.VerifyAccess(httpDetails, requestUri, out principal); - if (errorResponse == null) { - var policy = new OAuthPrincipalAuthorizationPolicy(principal); - var policies = new List<IAuthorizationPolicy> { + IPrincipal principal = resourceServer.GetPrincipal(httpDetails, requestUri); + var policy = new OAuthPrincipalAuthorizationPolicy(principal); + var policies = new List<IAuthorizationPolicy> { policy, }; - var securityContext = new ServiceSecurityContext(policies.AsReadOnly()); - if (operationContext.IncomingMessageProperties.Security != null) { - operationContext.IncomingMessageProperties.Security.ServiceSecurityContext = securityContext; - } else { - operationContext.IncomingMessageProperties.Security = new SecurityMessageProperty { - ServiceSecurityContext = securityContext, - }; - } + var securityContext = new ServiceSecurityContext(policies.AsReadOnly()); + if (operationContext.IncomingMessageProperties.Security != null) { + operationContext.IncomingMessageProperties.Security.ServiceSecurityContext = securityContext; + } else { + operationContext.IncomingMessageProperties.Security = new SecurityMessageProperty { + ServiceSecurityContext = securityContext, + }; + } - securityContext.AuthorizationContext.Properties["Identities"] = new List<IIdentity> { + securityContext.AuthorizationContext.Properties["Identities"] = new List<IIdentity> { principal.Identity, }; - // Only allow this method call if the access token scope permits it. - if (principal.IsInRole(operationContext.IncomingMessageHeaders.Action)) { - return true; - } + // Only allow this method call if the access token scope permits it. + if (principal.IsInRole(operationContext.IncomingMessageHeaders.Action)) { + return true; } - } catch (ProtocolException /*ex*/) { + } catch (ProtocolFaultResponseException ex) { + // Return the appropriate unauthorized response to the client. + ex.ErrorResponse.Send(); + } catch (DotNetOpenAuth.Messaging.ProtocolException/* ex*/) { ////Logger.Error("Error processing OAuth messages.", ex); } } diff --git a/samples/OAuthResourceServer/Code/OAuthAuthorizationManager.cs b/samples/OAuthResourceServer/Code/OAuthAuthorizationManager.cs index 8d0c13d..353e838 100644 --- a/samples/OAuthResourceServer/Code/OAuthAuthorizationManager.cs +++ b/samples/OAuthResourceServer/Code/OAuthAuthorizationManager.cs @@ -54,6 +54,11 @@ } else { return false; } + } catch (ProtocolFaultResponseException ex) { + Global.Logger.Error("Error processing OAuth messages.", ex); + + // Return the appropriate unauthorized response to the client. + ex.ErrorResponse.Send(); } catch (ProtocolException ex) { Global.Logger.Error("Error processing OAuth messages.", ex); } @@ -67,12 +72,7 @@ using (var signing = Global.CreateAuthorizationServerSigningServiceProvider()) { using (var encrypting = Global.CreateResourceServerEncryptionServiceProvider()) { var resourceServer = new ResourceServer(new StandardAccessTokenAnalyzer(signing, encrypting)); - - IPrincipal result; - var error = resourceServer.VerifyAccess(HttpRequestInfo.Create(httpDetails, requestUri), out result); - - // TODO: return the prepared error code. - return error != null ? null : result; + return resourceServer.GetPrincipal(httpDetails, requestUri); } } } diff --git a/src/DotNetOpenAuth.Core/DotNetOpenAuth.Core.csproj b/src/DotNetOpenAuth.Core/DotNetOpenAuth.Core.csproj index 65dee44..c0aef30 100644 --- a/src/DotNetOpenAuth.Core/DotNetOpenAuth.Core.csproj +++ b/src/DotNetOpenAuth.Core/DotNetOpenAuth.Core.csproj @@ -55,6 +55,7 @@ <Compile Include="Messaging\MultipartPostPart.cs" /> <Compile Include="Messaging\NetworkDirectWebResponse.cs" /> <Compile Include="Messaging\OutgoingWebResponseActionResult.cs" /> + <Compile Include="Messaging\ProtocolFaultResponseException.cs" /> <Compile Include="Messaging\Reflection\IMessagePartEncoder.cs" /> <Compile Include="Messaging\Reflection\IMessagePartNullEncoder.cs" /> <Compile Include="Messaging\Reflection\IMessagePartOriginalEncoder.cs" /> diff --git a/src/DotNetOpenAuth.Core/Messaging/ProtocolException.cs b/src/DotNetOpenAuth.Core/Messaging/ProtocolException.cs index e26d15e..982e1c0 100644 --- a/src/DotNetOpenAuth.Core/Messaging/ProtocolException.cs +++ b/src/DotNetOpenAuth.Core/Messaging/ProtocolException.cs @@ -42,10 +42,10 @@ namespace DotNetOpenAuth.Messaging { /// such that it can be sent as a protocol message response to a remote caller. /// </summary> /// <param name="message">The human-readable exception message.</param> - /// <param name="faultedMessage">The message that was the cause of the exception. Must not be null.</param> - protected internal ProtocolException(string message, IProtocolMessage faultedMessage) - : base(message) { - Requires.NotNull(faultedMessage, "faultedMessage"); + /// <param name="faultedMessage">The message that was the cause of the exception. May be null.</param> + /// <param name="innerException">The inner exception to include.</param> + protected internal ProtocolException(string message, IProtocolMessage faultedMessage, Exception innerException = null) + : base(message, innerException) { this.FaultedMessage = faultedMessage; } diff --git a/src/DotNetOpenAuth.Core/Messaging/ProtocolFaultResponseException.cs b/src/DotNetOpenAuth.Core/Messaging/ProtocolFaultResponseException.cs new file mode 100644 index 0000000..515414b --- /dev/null +++ b/src/DotNetOpenAuth.Core/Messaging/ProtocolFaultResponseException.cs @@ -0,0 +1,78 @@ +//----------------------------------------------------------------------- +// <copyright file="ProtocolFaultResponseException.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging { + using System; + using System.Collections.Generic; + using System.Linq; + using System.Text; + + /// <summary> + /// An exception to represent errors in the local or remote implementation of the protocol + /// that includes the response message that should be returned to the HTTP client to comply + /// with the protocol specification. + /// </summary> + public class ProtocolFaultResponseException : ProtocolException { + /// <summary> + /// The channel that produced the error response message, to be used in constructing the actual HTTP response. + /// </summary> + private readonly Channel channel; + + /// <summary> + /// A cached value for the <see cref="ErrorResponse"/> property. + /// </summary> + private OutgoingWebResponse response; + + /// <summary> + /// Initializes a new instance of the <see cref="ProtocolFaultResponseException"/> class + /// such that it can be sent as a protocol message response to a remote caller. + /// </summary> + /// <param name="channel">The channel to use when encoding the response message.</param> + /// <param name="errorResponse">The message to send back to the HTTP client.</param> + /// <param name="faultedMessage">The message that was the cause of the exception. May be null.</param> + /// <param name="innerException">The inner exception.</param> + /// <param name="message">The message for the exception.</param> + protected internal ProtocolFaultResponseException(Channel channel, IDirectResponseProtocolMessage errorResponse, IProtocolMessage faultedMessage = null, Exception innerException = null, string message = null) + : base(message ?? (innerException != null ? innerException.Message : null), faultedMessage, innerException) { + Requires.NotNull(channel, "channel"); + Requires.NotNull(errorResponse, "errorResponse"); + this.channel = channel; + this.ErrorResponseMessage = errorResponse; + } + + /// <summary> + /// Initializes a new instance of the <see cref="ProtocolFaultResponseException"/> class. + /// </summary> + /// <param name="info">The <see cref="System.Runtime.Serialization.SerializationInfo"/> + /// that holds the serialized object data about the exception being thrown.</param> + /// <param name="context">The System.Runtime.Serialization.StreamingContext + /// that contains contextual information about the source or destination.</param> + protected ProtocolFaultResponseException( + System.Runtime.Serialization.SerializationInfo info, + System.Runtime.Serialization.StreamingContext context) + : base(info, context) { + throw new NotImplementedException(); + } + + /// <summary> + /// Gets the protocol message to send back to the client to report the error. + /// </summary> + public IDirectResponseProtocolMessage ErrorResponseMessage { get; private set; } + + /// <summary> + /// Gets the HTTP response to forward to the client to report the error. + /// </summary> + public OutgoingWebResponse ErrorResponse { + get { + if (this.response == null) { + this.response = this.channel.PrepareResponse(this.ErrorResponseMessage); + } + + return this.response; + } + } + } +} diff --git a/src/DotNetOpenAuth.OAuth2.ResourceServer/OAuth2/ResourceServer.cs b/src/DotNetOpenAuth.OAuth2.ResourceServer/OAuth2/ResourceServer.cs index 9540d10..ba332fe 100644 --- a/src/DotNetOpenAuth.OAuth2.ResourceServer/OAuth2/ResourceServer.cs +++ b/src/DotNetOpenAuth.OAuth2.ResourceServer/OAuth2/ResourceServer.cs @@ -65,27 +65,20 @@ namespace DotNetOpenAuth.OAuth2 { /// <summary> /// Discovers what access the client should have considering the access token in the current request. /// </summary> - /// <param name="accessToken">Receives the access token describing the authorization the client has.</param> - /// <returns>An error to return to the client if access is not authorized; <c>null</c> if access is granted.</returns> - [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "0#", Justification = "Try pattern")] - [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "1#", Justification = "Try pattern")] - public OutgoingWebResponse VerifyAccess(out AccessToken accessToken) { - return this.VerifyAccess(this.Channel.GetRequestFromContext(), out accessToken); - } - - /// <summary> - /// Discovers what access the client should have considering the access token in the current request. - /// </summary> /// <param name="httpRequestInfo">The HTTP request info.</param> - /// <param name="accessToken">Receives the access token describing the authorization the client has.</param> /// <returns> - /// An error to return to the client if access is not authorized; <c>null</c> if access is granted. + /// The access token describing the authorization the client has. Never <c>null</c>. /// </returns> - [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "1#", Justification = "Try pattern")] - [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "2#", Justification = "Try pattern")] - public virtual OutgoingWebResponse VerifyAccess(HttpRequestBase httpRequestInfo, out AccessToken accessToken) { - Requires.NotNull(httpRequestInfo, "httpRequestInfo"); + /// <exception cref="ProtocolFaultResponseException"> + /// Thrown when the client is not authorized. This exception should be caught and the + /// <see cref="ProtocolFaultResponseException.ErrorResponse"/> message should be returned to the client. + /// </exception> + public virtual AccessToken GetAccessToken(HttpRequestBase httpRequestInfo = null) { + if (httpRequestInfo == null) { + httpRequestInfo = this.Channel.GetRequestFromContext(); + } + AccessToken accessToken; AccessProtectedResourceRequest request = null; try { if (this.Channel.TryReadFromRequest<AccessProtectedResourceRequest>(httpRequestInfo, out request)) { @@ -96,18 +89,15 @@ namespace DotNetOpenAuth.OAuth2 { ErrorUtilities.ThrowProtocol(ResourceServerStrings.InvalidAccessToken); } - return null; + return accessToken; } else { - var response = new UnauthorizedResponse(new ProtocolException(ResourceServerStrings.MissingAccessToken)); - - accessToken = null; - return this.Channel.PrepareResponse(response); + var ex = new ProtocolException(ResourceServerStrings.MissingAccessToken); + var response = new UnauthorizedResponse(ex); + throw new ProtocolFaultResponseException(this.Channel, response, innerException: ex); } } catch (ProtocolException ex) { var response = request != null ? new UnauthorizedResponse(request, ex) : new UnauthorizedResponse(ex); - - accessToken = null; - return this.Channel.PrepareResponse(response); + throw new ProtocolFaultResponseException(this.Channel, response, innerException: ex); } } @@ -115,30 +105,29 @@ namespace DotNetOpenAuth.OAuth2 { /// Discovers what access the client should have considering the access token in the current request. /// </summary> /// <param name="httpRequestInfo">The HTTP request info.</param> - /// <param name="principal">The principal that contains the user and roles that the access token is authorized for.</param> /// <returns> - /// An error to return to the client if access is not authorized; <c>null</c> if access is granted. + /// The principal that contains the user and roles that the access token is authorized for. Never <c>null</c>. /// </returns> + /// <exception cref="ProtocolFaultResponseException"> + /// Thrown when the client is not authorized. This exception should be caught and the + /// <see cref="ProtocolFaultResponseException.ErrorResponse"/> message should be returned to the client. + /// </exception> [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "1#", Justification = "Try pattern")] - public virtual OutgoingWebResponse VerifyAccess(HttpRequestBase httpRequestInfo, out IPrincipal principal) { - AccessToken accessToken; - var result = this.VerifyAccess(httpRequestInfo, out accessToken); - if (result == null) { - // Mitigates attacks on this approach of differentiating clients from resource owners - // by checking that a username doesn't look suspiciously engineered to appear like the other type. - ErrorUtilities.VerifyProtocol(accessToken.User == null || string.IsNullOrEmpty(this.ClientPrincipalPrefix) || !accessToken.User.StartsWith(this.ClientPrincipalPrefix, StringComparison.OrdinalIgnoreCase), ResourceServerStrings.ResourceOwnerNameLooksLikeClientIdentifier); - ErrorUtilities.VerifyProtocol(accessToken.ClientIdentifier == null || string.IsNullOrEmpty(this.ResourceOwnerPrincipalPrefix) || !accessToken.ClientIdentifier.StartsWith(this.ResourceOwnerPrincipalPrefix, StringComparison.OrdinalIgnoreCase), ResourceServerStrings.ClientIdentifierLooksLikeResourceOwnerName); - - string principalUserName = !string.IsNullOrEmpty(accessToken.User) - ? this.ResourceOwnerPrincipalPrefix + accessToken.User - : this.ClientPrincipalPrefix + accessToken.ClientIdentifier; - string[] principalScope = accessToken.Scope != null ? accessToken.Scope.ToArray() : new string[0]; - principal = new OAuthPrincipal(principalUserName, principalScope); - } else { - principal = null; - } + public virtual IPrincipal GetPrincipal(HttpRequestBase httpRequestInfo = null) { + AccessToken accessToken = this.GetAccessToken(httpRequestInfo); + + // Mitigates attacks on this approach of differentiating clients from resource owners + // by checking that a username doesn't look suspiciously engineered to appear like the other type. + ErrorUtilities.VerifyProtocol(accessToken.User == null || string.IsNullOrEmpty(this.ClientPrincipalPrefix) || !accessToken.User.StartsWith(this.ClientPrincipalPrefix, StringComparison.OrdinalIgnoreCase), ResourceServerStrings.ResourceOwnerNameLooksLikeClientIdentifier); + ErrorUtilities.VerifyProtocol(accessToken.ClientIdentifier == null || string.IsNullOrEmpty(this.ResourceOwnerPrincipalPrefix) || !accessToken.ClientIdentifier.StartsWith(this.ResourceOwnerPrincipalPrefix, StringComparison.OrdinalIgnoreCase), ResourceServerStrings.ClientIdentifierLooksLikeResourceOwnerName); + + string principalUserName = !string.IsNullOrEmpty(accessToken.User) + ? this.ResourceOwnerPrincipalPrefix + accessToken.User + : this.ClientPrincipalPrefix + accessToken.ClientIdentifier; + string[] principalScope = accessToken.Scope != null ? accessToken.Scope.ToArray() : new string[0]; + var principal = new OAuthPrincipal(principalUserName, principalScope); - return result; + return principal; } /// <summary> @@ -146,17 +135,20 @@ namespace DotNetOpenAuth.OAuth2 { /// </summary> /// <param name="request">HTTP details from an incoming WCF message.</param> /// <param name="requestUri">The URI of the WCF service endpoint.</param> - /// <param name="principal">The principal that contains the user and roles that the access token is authorized for.</param> /// <returns> - /// An error to return to the client if access is not authorized; <c>null</c> if access is granted. + /// The principal that contains the user and roles that the access token is authorized for. Never <c>null</c>. /// </returns> + /// <exception cref="ProtocolFaultResponseException"> + /// Thrown when the client is not authorized. This exception should be caught and the + /// <see cref="ProtocolFaultResponseException.ErrorResponse"/> message should be returned to the client. + /// </exception> [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "1#", Justification = "Try pattern")] [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "2#", Justification = "Try pattern")] - public virtual OutgoingWebResponse VerifyAccess(HttpRequestMessageProperty request, Uri requestUri, out IPrincipal principal) { + public virtual IPrincipal GetPrincipal(HttpRequestMessageProperty request, Uri requestUri) { Requires.NotNull(request, "request"); Requires.NotNull(requestUri, "requestUri"); - return this.VerifyAccess(new HttpRequestInfo(request, requestUri), out principal); + return this.GetPrincipal(new HttpRequestInfo(request, requestUri)); } } } |