diff options
Diffstat (limited to 'src/DotNetOpenAuth.OAuth/OAuth/ChannelElements/OAuthChannel.cs')
-rw-r--r-- | src/DotNetOpenAuth.OAuth/OAuth/ChannelElements/OAuthChannel.cs | 96 |
1 files changed, 46 insertions, 50 deletions
diff --git a/src/DotNetOpenAuth.OAuth/OAuth/ChannelElements/OAuthChannel.cs b/src/DotNetOpenAuth.OAuth/OAuth/ChannelElements/OAuthChannel.cs index 1362ca9..f483ade 100644 --- a/src/DotNetOpenAuth.OAuth/OAuth/ChannelElements/OAuthChannel.cs +++ b/src/DotNetOpenAuth.OAuth/OAuth/ChannelElements/OAuthChannel.cs @@ -13,8 +13,12 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { using System.IO; using System.Linq; using System.Net; + using System.Net.Http; + using System.Net.Http.Headers; using System.Net.Mime; using System.Text; + using System.Threading; + using System.Threading.Tasks; using System.Web; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; @@ -22,12 +26,14 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { using DotNetOpenAuth.OAuth.Messages; using Validation; + using HttpRequestHeaders = DotNetOpenAuth.Messaging.HttpRequestHeaders; + /// <summary> /// An OAuth-specific implementation of the <see cref="Channel"/> class. /// </summary> internal abstract class OAuthChannel : Channel { /// <summary> - /// Initializes a new instance of the <see cref="OAuthChannel"/> class. + /// Initializes a new instance of the <see cref="OAuthChannel" /> class. /// </summary> /// <param name="signingBindingElement">The binding element to use for signing.</param> /// <param name="tokenManager">The ITokenManager instance to use.</param> @@ -36,9 +42,10 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// Except for mock testing, this should always be one of /// OAuthConsumerMessageFactory or OAuthServiceProviderMessageFactory.</param> /// <param name="bindingElements">The binding elements.</param> + /// <param name="hostFactories">The host factories.</param> [SuppressMessage("Microsoft.Globalization", "CA1303:Do not pass literals as localized parameters", MessageId = "System.Diagnostics.Contracts.__ContractsRuntime.Requires<System.ArgumentNullException>(System.Boolean,System.String,System.String)", Justification = "Code contracts"), SuppressMessage("Microsoft.Naming", "CA2204:Literals should be spelled correctly", MessageId = "securitySettings", Justification = "Code contracts")] - protected OAuthChannel(ITamperProtectionChannelBindingElement signingBindingElement, ITokenManager tokenManager, SecuritySettings securitySettings, IMessageFactory messageTypeProvider, IChannelBindingElement[] bindingElements) - : base(messageTypeProvider, bindingElements) { + protected OAuthChannel(ITamperProtectionChannelBindingElement signingBindingElement, ITokenManager tokenManager, SecuritySettings securitySettings, IMessageFactory messageTypeProvider, IChannelBindingElement[] bindingElements, IHostFactories hostFactories = null) + : base(messageTypeProvider, bindingElements, hostFactories ?? new DefaultOAuthHostFactories()) { Requires.NotNull(tokenManager, "tokenManager"); Requires.NotNull(securitySettings, "securitySettings"); Requires.NotNull(signingBindingElement, "signingBindingElement"); @@ -76,11 +83,12 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// expect an OAuth message response to. /// </summary> /// <param name="request">The message to attach.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns>The initialized web request.</returns> - internal HttpWebRequest InitializeRequest(IDirectedProtocolMessage request) { + internal async Task<HttpRequestMessage> InitializeRequestAsync(IDirectedProtocolMessage request, CancellationToken cancellationToken) { Requires.NotNull(request, "request"); - ProcessOutgoingMessage(request); + await this.ProcessOutgoingMessageAsync(request, cancellationToken); return this.CreateHttpRequest(request); } @@ -108,29 +116,25 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// a protocol request message. /// </summary> /// <param name="request">The HTTP request to search.</param> + /// <param name="cancellationToken">The cancellation token.</param> /// <returns>The deserialized message, if one is found. Null otherwise.</returns> - protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request) { + protected override async Task<IDirectedProtocolMessage> ReadFromRequestCoreAsync(HttpRequestMessage request, CancellationToken cancellationToken) { // First search the Authorization header. - string authorization = request.Headers[HttpRequestHeaders.Authorization]; + var authorization = request.Headers.Authorization; var fields = MessagingUtilities.ParseAuthorizationHeader(Protocol.AuthorizationHeaderScheme, authorization).ToDictionary(); fields.Remove("realm"); // ignore the realm parameter, since we don't use it, and it must be omitted from signature base string. // Scrape the entity - if (!string.IsNullOrEmpty(request.Headers[HttpRequestHeaders.ContentType])) { - var contentType = new ContentType(request.Headers[HttpRequestHeaders.ContentType]); - if (string.Equals(contentType.MediaType, HttpFormUrlEncoded, StringComparison.Ordinal)) { - foreach (string key in request.Form) { - if (key != null) { - fields.Add(key, request.Form[key]); - } else { - Logger.OAuth.WarnFormat("Ignoring query string parameter '{0}' since it isn't a standard name=value parameter.", request.Form[key]); - } - } + foreach (var pair in await ParseUrlEncodedFormContentAsync(request, cancellationToken)) { + if (pair.Key != null) { + fields.Add(pair.Key, pair.Value); + } else { + Logger.OAuth.WarnFormat("Ignoring query string parameter '{0}' since it isn't a standard name=value parameter.", pair.Value); } } // Scrape the query string - var qs = request.GetQueryStringBeforeRewriting(); + var qs = HttpUtility.ParseQueryString(request.RequestUri.Query); foreach (string key in qs) { if (key != null) { fields.Add(key, qs[key]); @@ -153,8 +157,8 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { // Add receiving HTTP transport information required for signature generation. var signedMessage = message as ITamperResistantOAuthMessage; if (signedMessage != null) { - signedMessage.Recipient = request.GetPublicFacingUrl(); - signedMessage.HttpMethod = request.HttpMethod; + signedMessage.Recipient = request.RequestUri; + signedMessage.HttpMethod = request.Method; } return message; @@ -167,8 +171,8 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// <returns> /// The deserialized message parts, if found. Null otherwise. /// </returns> - protected override IDictionary<string, string> ReadFromResponseCore(IncomingWebResponse response) { - string body = response.GetResponseReader().ReadToEnd(); + protected override async Task<IDictionary<string, string>> ReadFromResponseCoreAsync(HttpResponseMessage response, CancellationToken cancellationToken) { + string body = await response.Content.ReadAsStringAsync(); return HttpUtility.ParseQueryString(body).ToDictionary(); } @@ -179,8 +183,8 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// <returns> /// The <see cref="HttpRequest"/> prepared to send the request. /// </returns> - protected override HttpWebRequest CreateHttpRequest(IDirectedProtocolMessage request) { - HttpWebRequest httpRequest; + protected override HttpRequestMessage CreateHttpRequest(IDirectedProtocolMessage request) { + HttpRequestMessage httpRequest; HttpDeliveryMethods transmissionMethod = request.HttpMethods; if ((transmissionMethod & HttpDeliveryMethods.AuthorizationHeaderRequest) != 0) { @@ -200,6 +204,7 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { } else { throw new NotSupportedException(); } + return httpRequest; } @@ -212,16 +217,12 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// <remarks> /// This method implements spec V1.0 section 5.3. /// </remarks> - protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) { + protected override HttpResponseMessage PrepareDirectResponse(IProtocolMessage response) { var messageAccessor = this.MessageDescriptions.GetAccessor(response); var fields = messageAccessor.Serialize(); - string responseBody = MessagingUtilities.CreateQueryString(fields); - OutgoingWebResponse encodedResponse = new OutgoingWebResponse { - Body = responseBody, - OriginalMessage = response, - Status = HttpStatusCode.OK, - Headers = new System.Net.WebHeaderCollection(), + var encodedResponse = new HttpResponseMessage { + Content = new FormUrlEncodedContent(fields), }; ApplyMessageTemplate(response, encodedResponse); @@ -256,7 +257,7 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// </summary> /// <param name="message">The message.</param> /// <returns>"POST", "GET" or some other similar http verb.</returns> - private static string GetHttpMethod(IDirectedProtocolMessage message) { + private static HttpMethod GetHttpMethod(IDirectedProtocolMessage message) { Requires.NotNull(message, "message"); var signedMessage = message as ITamperResistantOAuthMessage; @@ -273,11 +274,9 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// <param name="requestMessage">The message to be transmitted to the ServiceProvider.</param> /// <returns>The web request ready to send.</returns> /// <remarks> - /// <para>If the message has non-empty ExtraData in it, the request stream is sent to - /// the server automatically. If it is empty, the request stream must be sent by the caller.</para> - /// <para>This method implements OAuth 1.0 section 5.2, item #1 (described in section 5.4).</para> + /// This method implements OAuth 1.0 section 5.2, item #1 (described in section 5.4). /// </remarks> - private HttpWebRequest InitializeRequestAsAuthHeader(IDirectedProtocolMessage requestMessage) { + private HttpRequestMessage InitializeRequestAsAuthHeader(IDirectedProtocolMessage requestMessage) { var dictionary = this.MessageDescriptions.GetAccessor(requestMessage); // copy so as to not modify original @@ -285,41 +284,38 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { foreach (string key in dictionary.DeclaredKeys) { fields.Add(key, dictionary[key]); } + if (this.Realm != null) { fields.Add("realm", this.Realm.AbsoluteUri); } - HttpWebRequest httpRequest; UriBuilder recipientBuilder = new UriBuilder(requestMessage.Recipient); bool hasEntity = HttpMethodHasEntity(GetHttpMethod(requestMessage)); if (!hasEntity) { MessagingUtilities.AppendQueryArgs(recipientBuilder, requestMessage.ExtraData); } - httpRequest = (HttpWebRequest)WebRequest.Create(recipientBuilder.Uri); + + var httpRequest = new HttpRequestMessage(GetHttpMethod(requestMessage), recipientBuilder.Uri); this.PrepareHttpWebRequest(httpRequest); - httpRequest.Method = GetHttpMethod(requestMessage); - httpRequest.Headers.Add(HttpRequestHeader.Authorization, MessagingUtilities.AssembleAuthorizationHeader(Protocol.AuthorizationHeaderScheme, fields)); + httpRequest.Headers.Authorization = new AuthenticationHeaderValue(Protocol.AuthorizationHeaderScheme, MessagingUtilities.AssembleAuthorizationHeader(fields)); if (hasEntity) { - // WARNING: We only set up the request stream for the caller if there is - // extra data. If there isn't any extra data, the caller must do this themselves. var requestMessageWithBinaryData = requestMessage as IMessageWithBinaryData; if (requestMessageWithBinaryData != null && requestMessageWithBinaryData.SendAsMultipart) { // Include the binary data in the multipart entity, and any standard text extra message data. // The standard declared message parts are included in the authorization header. - var multiPartFields = new List<MultipartPostPart>(requestMessageWithBinaryData.BinaryData); - multiPartFields.AddRange(requestMessage.ExtraData.Select(field => MultipartPostPart.CreateFormPart(field.Key, field.Value))); - this.SendParametersInEntityAsMultipart(httpRequest, multiPartFields); + var content = InitializeMultipartFormDataContent(requestMessageWithBinaryData); + httpRequest.Content = content; + + foreach (var extraData in requestMessage.ExtraData) { + content.Add(new StringContent(extraData.Value), extraData.Key); + } } else { ErrorUtilities.VerifyProtocol(requestMessageWithBinaryData == null || requestMessageWithBinaryData.BinaryData.Count == 0, MessagingStrings.BinaryDataRequiresMultipart); if (requestMessage.ExtraData.Count > 0) { - this.SendParametersInEntity(httpRequest, requestMessage.ExtraData); - } else { - // We'll assume the content length is zero since the caller may not have - // anything. They're responsible to change it when the add the payload if they have one. - httpRequest.ContentLength = 0; + httpRequest.Content = new FormUrlEncodedContent(requestMessage.ExtraData); } } } |