//-----------------------------------------------------------------------
//
// Copyright (c) Outercurve Foundation. All rights reserved.
//
//-----------------------------------------------------------------------
namespace DotNetOpenAuth.OAuth2.ChannelElements {
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Net.Mime;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Web;
using DotNetOpenAuth.Messaging;
using DotNetOpenAuth.Messaging.Reflection;
using DotNetOpenAuth.OAuth2.Messages;
using Validation;
using HttpRequestHeaders = DotNetOpenAuth.Messaging.HttpRequestHeaders;
///
/// The channel for the OAuth protocol.
///
internal class OAuth2ResourceServerChannel : StandardMessageFactoryChannel {
///
/// The messages receivable by this channel.
///
private static readonly Type[] MessageTypes = new Type[] {
typeof(Messages.AccessProtectedResourceRequest),
};
///
/// The protocol versions supported by this channel.
///
private static readonly Version[] Versions = Protocol.AllVersions.Select(v => v.Version).ToArray();
///
/// Initializes a new instance of the class.
///
/// The host factories.
protected internal OAuth2ResourceServerChannel(IHostFactories hostFactories = null)
: base(MessageTypes, Versions, hostFactories ?? new OAuth.DefaultOAuthHostFactories()) {
// TODO: add signing (authenticated request) binding element.
}
///
/// Gets the protocol message that may be embedded in the given HTTP request.
///
/// The request to search for an embedded message.
/// The cancellation token.
///
/// The deserialized message, if one is found. Null otherwise.
///
protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request, CancellationToken cancellationToken) {
var fields = new Dictionary();
string accessToken;
if ((accessToken = SearchForBearerAccessTokenInRequest(request)) != null) {
fields[Protocol.token_type] = Protocol.AccessTokenTypes.Bearer;
fields[Protocol.access_token] = accessToken;
}
if (fields.Count > 0) {
MessageReceivingEndpoint recipient;
try {
recipient = request.GetRecipient();
} catch (ArgumentException ex) {
Logger.OAuth.WarnFormat("Unrecognized HTTP request: " + ex.ToString());
return null;
}
// Deserialize the message using all the data we've collected.
var message = (IDirectedProtocolMessage)this.Receive(fields, recipient);
return message;
}
return null;
}
///
/// Gets the protocol message that may be in the given HTTP response.
///
/// The response that is anticipated to contain an protocol message.
///
/// The deserialized message parts, if found. Null otherwise.
///
/// Thrown when the response is not valid.
protected override Task> ReadFromResponseCoreAsync(HttpResponseMessage response) {
// We never expect resource servers to send out direct requests,
// and therefore won't have direct responses.
throw new NotImplementedException();
}
///
/// Queues a message for sending in the response stream where the fields
/// are sent in the response stream in querystring style.
///
/// The message to send as a response.
///
/// The pending user agent redirect based message to be sent as an HttpResponse.
///
///
/// This method implements spec OAuth V1.0 section 5.3.
///
protected override HttpResponseMessage PrepareDirectResponse(IProtocolMessage response) {
var webResponse = new HttpResponseMessage();
// The only direct response from a resource server is some authorization error (400, 401, 403).
var unauthorizedResponse = response as UnauthorizedResponse;
ErrorUtilities.VerifyInternal(unauthorizedResponse != null, "Only unauthorized responses are expected.");
// First initialize based on the specifics within the message.
ApplyMessageTemplate(response, webResponse);
if (!(response is IHttpDirectResponse)) {
webResponse.StatusCode = HttpStatusCode.Unauthorized;
}
// Now serialize all the message parts into the WWW-Authenticate header.
var fields = this.MessageDescriptions.GetAccessor(response);
webResponse.Headers.WwwAuthenticate.Add(new AuthenticationHeaderValue(unauthorizedResponse.Scheme, MessagingUtilities.AssembleAuthorizationHeader(fields)));
return webResponse;
}
///
/// Searches for a bearer access token in the request.
///
/// The request.
/// The bearer access token, if one exists. Otherwise null.
private static string SearchForBearerAccessTokenInRequest(HttpRequestBase request) {
Requires.NotNull(request, "request");
// First search the authorization header.
string authorizationHeader = request.Headers[HttpRequestHeaders.Authorization];
if (!string.IsNullOrEmpty(authorizationHeader) && authorizationHeader.StartsWith(Protocol.BearerHttpAuthorizationSchemeWithTrailingSpace, StringComparison.OrdinalIgnoreCase)) {
return authorizationHeader.Substring(Protocol.BearerHttpAuthorizationSchemeWithTrailingSpace.Length);
}
// Failing that, scan the entity
if (!string.IsNullOrEmpty(request.Headers[HttpRequestHeaders.ContentType])) {
var contentType = new ContentType(request.Headers[HttpRequestHeaders.ContentType]);
if (string.Equals(contentType.MediaType, HttpFormUrlEncoded, StringComparison.Ordinal)) {
if (request.Form[Protocol.BearerTokenEncodedUrlParameterName] != null) {
return request.Form[Protocol.BearerTokenEncodedUrlParameterName];
}
}
}
// Finally, check the least desirable location: the query string
var unrewrittenQuery = request.GetQueryStringBeforeRewriting();
if (!string.IsNullOrEmpty(unrewrittenQuery[Protocol.BearerTokenEncodedUrlParameterName])) {
return unrewrittenQuery[Protocol.BearerTokenEncodedUrlParameterName];
}
return null;
}
}
}