diff options
author | Andrew Arnott <andrewarnott@gmail.com> | 2012-03-05 17:38:00 -0800 |
---|---|---|
committer | Andrew Arnott <andrewarnott@gmail.com> | 2012-03-05 17:38:00 -0800 |
commit | 9a3885e6992462122057f532b7cbcda3695ca6bd (patch) | |
tree | 679678fe05814b95e8aaf8e3ca441c3410f1c8c5 | |
parent | a292822196d0911a68fc56597ed52a8c84a41cbe (diff) | |
download | DotNetOpenAuth-9a3885e6992462122057f532b7cbcda3695ca6bd.zip DotNetOpenAuth-9a3885e6992462122057f532b7cbcda3695ca6bd.tar.gz DotNetOpenAuth-9a3885e6992462122057f532b7cbcda3695ca6bd.tar.bz2 |
Replaced API requirements for HttpRequestInfo with HttpRequestBase (new in .NET 3.5 SP1).
This makes us more friendly to MVC as well as mock-based unit testing.
54 files changed, 447 insertions, 626 deletions
diff --git a/projecttemplates/MvcRelyingParty/Code/OpenIdRelyingPartyService.cs b/projecttemplates/MvcRelyingParty/Code/OpenIdRelyingPartyService.cs index 7931200..30f3fae 100644 --- a/projecttemplates/MvcRelyingParty/Code/OpenIdRelyingPartyService.cs +++ b/projecttemplates/MvcRelyingParty/Code/OpenIdRelyingPartyService.cs @@ -24,7 +24,7 @@ IAuthenticationResponse GetResponse(); - IAuthenticationResponse GetResponse(HttpRequestInfo request); + IAuthenticationResponse GetResponse(HttpRequestBase request); } /// <summary> @@ -101,7 +101,7 @@ return relyingParty.GetResponse(); } - public IAuthenticationResponse GetResponse(HttpRequestInfo request) { + public IAuthenticationResponse GetResponse(HttpRequestBase request) { return relyingParty.GetResponse(request); } diff --git a/projecttemplates/MvcRelyingParty/Controllers/AuthController.cs b/projecttemplates/MvcRelyingParty/Controllers/AuthController.cs index 9cc6e15..446c6ac 100644 --- a/projecttemplates/MvcRelyingParty/Controllers/AuthController.cs +++ b/projecttemplates/MvcRelyingParty/Controllers/AuthController.cs @@ -121,14 +121,9 @@ namespace MvcRelyingParty.Controllers { public ActionResult LogOnPostAssertion(string openid_openidAuthData) { IAuthenticationResponse response; if (!string.IsNullOrEmpty(openid_openidAuthData)) { - var auth = new Uri(openid_openidAuthData); - var headers = new WebHeaderCollection(); - foreach (string header in Request.Headers) { - headers[header] = Request.Headers[header]; - } - // Always say it's a GET since the payload is all in the URL, even the large ones. - HttpRequestInfo clientResponseInfo = new HttpRequestInfo("GET", auth, auth.PathAndQuery, headers, null); + var auth = new Uri(openid_openidAuthData); + HttpRequestBase clientResponseInfo = new HttpRequestInfo("GET", auth, headers: Request.Headers); response = this.RelyingParty.GetResponse(clientResponseInfo); } else { response = this.RelyingParty.GetResponse(); @@ -170,7 +165,7 @@ namespace MvcRelyingParty.Controllers { } // Always say it's a GET since the payload is all in the URL, even the large ones. - HttpRequestInfo clientResponseInfo = new HttpRequestInfo("GET", auth, auth.PathAndQuery, headers, null); + HttpRequestBase clientResponseInfo = new HttpRequestInfo("GET", auth, headers: headers); response = this.RelyingParty.GetResponse(clientResponseInfo); } else { response = this.RelyingParty.GetResponse(); diff --git a/projecttemplates/MvcRelyingParty/OAuthTokenEndpoint.ashx.cs b/projecttemplates/MvcRelyingParty/OAuthTokenEndpoint.ashx.cs index 4c741ae..2c655e0 100644 --- a/projecttemplates/MvcRelyingParty/OAuthTokenEndpoint.ashx.cs +++ b/projecttemplates/MvcRelyingParty/OAuthTokenEndpoint.ashx.cs @@ -41,7 +41,7 @@ namespace MvcRelyingParty { public void ProcessRequest(HttpContext context) { var serviceProvider = OAuthServiceProvider.AuthorizationServer; IDirectResponseProtocolMessage response; - if (serviceProvider.TryPrepareAccessTokenResponse(new HttpRequestInfo(context.Request), out response)) { + if (serviceProvider.TryPrepareAccessTokenResponse(new HttpRequestWrapper(context.Request), out response)) { serviceProvider.Channel.Respond(response); } else { throw new InvalidOperationException(); diff --git a/projecttemplates/RelyingPartyLogic/OAuthAuthenticationModule.cs b/projecttemplates/RelyingPartyLogic/OAuthAuthenticationModule.cs index 581b575..13e725d 100644 --- a/projecttemplates/RelyingPartyLogic/OAuthAuthenticationModule.cs +++ b/projecttemplates/RelyingPartyLogic/OAuthAuthenticationModule.cs @@ -54,7 +54,7 @@ namespace RelyingPartyLogic { var resourceServer = new ResourceServer(tokenAnalyzer); IPrincipal principal; - var errorMessage = resourceServer.VerifyAccess(new HttpRequestInfo(this.application.Context.Request), out principal); + var errorMessage = resourceServer.VerifyAccess(new HttpRequestWrapper(this.application.Context.Request), out principal); if (errorMessage == null) { this.application.Context.User = principal; } diff --git a/projecttemplates/RelyingPartyLogic/RelyingPartyLogic.csproj b/projecttemplates/RelyingPartyLogic/RelyingPartyLogic.csproj index 2880176..58e684e 100644 --- a/projecttemplates/RelyingPartyLogic/RelyingPartyLogic.csproj +++ b/projecttemplates/RelyingPartyLogic/RelyingPartyLogic.csproj @@ -85,6 +85,7 @@ <Reference Include="System.ServiceModel"> <RequiredTargetFramework>3.0</RequiredTargetFramework> </Reference> + <Reference Include="System.Web.Abstractions" /> <Reference Include="System.Web.Entity"> <RequiredTargetFramework>3.5</RequiredTargetFramework> </Reference> diff --git a/projecttemplates/WebFormsRelyingParty/OAuthTokenEndpoint.ashx.cs b/projecttemplates/WebFormsRelyingParty/OAuthTokenEndpoint.ashx.cs index 3402bbe..fd68462 100644 --- a/projecttemplates/WebFormsRelyingParty/OAuthTokenEndpoint.ashx.cs +++ b/projecttemplates/WebFormsRelyingParty/OAuthTokenEndpoint.ashx.cs @@ -41,7 +41,7 @@ namespace WebFormsRelyingParty { public void ProcessRequest(HttpContext context) { var serviceProvider = OAuthServiceProvider.AuthorizationServer; IDirectResponseProtocolMessage response; - if (serviceProvider.TryPrepareAccessTokenResponse(new HttpRequestInfo(context.Request), out response)) { + if (serviceProvider.TryPrepareAccessTokenResponse(new HttpRequestWrapper(context.Request), out response)) { serviceProvider.Channel.Respond(response); } else { throw new InvalidOperationException(); diff --git a/projecttemplates/WebFormsRelyingParty/WebFormsRelyingParty.csproj b/projecttemplates/WebFormsRelyingParty/WebFormsRelyingParty.csproj index 81b2360..1f17837 100644 --- a/projecttemplates/WebFormsRelyingParty/WebFormsRelyingParty.csproj +++ b/projecttemplates/WebFormsRelyingParty/WebFormsRelyingParty.csproj @@ -63,6 +63,7 @@ <Reference Include="System.ServiceModel"> <RequiredTargetFramework>3.0</RequiredTargetFramework> </Reference> + <Reference Include="System.Web.Abstractions" /> <Reference Include="System.Web.DynamicData" /> <Reference Include="System.Web.Entity"> <RequiredTargetFramework>3.5</RequiredTargetFramework> diff --git a/samples/OAuthClient/OAuthClient.csproj b/samples/OAuthClient/OAuthClient.csproj index 2de5915..9aeb0d1 100644 --- a/samples/OAuthClient/OAuthClient.csproj +++ b/samples/OAuthClient/OAuthClient.csproj @@ -46,6 +46,7 @@ <Reference Include="System.Data.DataSetExtensions" /> <Reference Include="System.Runtime.Serialization" /> <Reference Include="System.ServiceModel" /> + <Reference Include="System.Web.Abstractions" /> <Reference Include="System.Web.Extensions" /> <Reference Include="System.Drawing" /> <Reference Include="System.Web" /> diff --git a/samples/OAuthResourceServer/OAuthResourceServer.csproj b/samples/OAuthResourceServer/OAuthResourceServer.csproj index 1d81d85..599727f 100644 --- a/samples/OAuthResourceServer/OAuthResourceServer.csproj +++ b/samples/OAuthResourceServer/OAuthResourceServer.csproj @@ -45,6 +45,7 @@ <Reference Include="System.IdentityModel" /> <Reference Include="System.ServiceModel" /> <Reference Include="System.ServiceModel.Web" /> + <Reference Include="System.Web.Abstractions" /> <Reference Include="System.Web.Extensions" /> <Reference Include="System.Xml.Linq" /> <Reference Include="System.Drawing" /> diff --git a/samples/OpenIdOfflineProvider/HostedProvider.cs b/samples/OpenIdOfflineProvider/HostedProvider.cs index 5e8ef0a..788817d 100644 --- a/samples/OpenIdOfflineProvider/HostedProvider.cs +++ b/samples/OpenIdOfflineProvider/HostedProvider.cs @@ -11,6 +11,8 @@ namespace DotNetOpenAuth.OpenIdOfflineProvider { using System.IO; using System.Linq; using System.Net; + using System.Web; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.Provider; using log4net; @@ -70,7 +72,7 @@ namespace DotNetOpenAuth.OpenIdOfflineProvider { /// <summary> /// Gets or sets the delegate that handles authentication requests. /// </summary> - internal Action<HttpRequestInfo, HttpListenerResponse> ProcessRequest { get; set; } + internal Action<HttpRequestBase, HttpListenerResponse> ProcessRequest { get; set; } /// <summary> /// Gets the provider endpoint. @@ -225,7 +227,7 @@ namespace DotNetOpenAuth.OpenIdOfflineProvider { Uri providerEndpoint = providerEndpointBuilder.Uri; if (context.Request.Url.AbsolutePath == ProviderPath) { - HttpRequestInfo requestInfo = new HttpRequestInfo(context.Request); + HttpRequestBase requestInfo = new HttpRequestInfo(context.Request); this.ProcessRequest(requestInfo, context.Response); } else if (context.Request.Url.AbsolutePath.StartsWith(UserIdentifierPath, StringComparison.Ordinal)) { using (StreamWriter sw = new StreamWriter(outputStream)) { diff --git a/samples/OpenIdOfflineProvider/MainWindow.xaml.cs b/samples/OpenIdOfflineProvider/MainWindow.xaml.cs index 5136c24..6bf7f6a 100644 --- a/samples/OpenIdOfflineProvider/MainWindow.xaml.cs +++ b/samples/OpenIdOfflineProvider/MainWindow.xaml.cs @@ -15,6 +15,7 @@ namespace DotNetOpenAuth.OpenIdOfflineProvider { using System.Net; using System.Runtime.InteropServices; using System.Text; + using System.Web; using System.Windows; using System.Windows.Controls; using System.Windows.Data; @@ -99,7 +100,7 @@ namespace DotNetOpenAuth.OpenIdOfflineProvider { /// </summary> /// <param name="requestInfo">The request info.</param> /// <param name="response">The response.</param> - private void ProcessRequest(HttpRequestInfo requestInfo, HttpListenerResponse response) { + private void ProcessRequest(HttpRequestBase requestInfo, HttpListenerResponse response) { IRequest request = this.hostedProvider.Provider.GetRequest(requestInfo); if (request == null) { App.Logger.Error("A request came in that did not carry an OpenID message."); diff --git a/src/DotNetOpenAuth.AspNet/OpenAuthSecurityManager.cs b/src/DotNetOpenAuth.AspNet/OpenAuthSecurityManager.cs index 6851c6d..01d8c90 100644 --- a/src/DotNetOpenAuth.AspNet/OpenAuthSecurityManager.cs +++ b/src/DotNetOpenAuth.AspNet/OpenAuthSecurityManager.cs @@ -140,7 +140,7 @@ namespace DotNetOpenAuth.AspNet { if (!string.IsNullOrEmpty(returnUrl)) { uri = UriHelper.ConvertToAbsoluteUri(returnUrl, this._requestContext); } else { - uri = HttpRequestInfo.GetPublicFacingUrl(this._requestContext.Request, this._requestContext.Request.ServerVariables); + uri = this._requestContext.Request.GetPublicFacingUrl(); } // attach the provider parameter so that we know which provider initiated diff --git a/src/DotNetOpenAuth.AspNet/UriHelper.cs b/src/DotNetOpenAuth.AspNet/UriHelper.cs index 53a5c7f..2c6e5a9 100644 --- a/src/DotNetOpenAuth.AspNet/UriHelper.cs +++ b/src/DotNetOpenAuth.AspNet/UriHelper.cs @@ -76,7 +76,7 @@ namespace DotNetOpenAuth.AspNet { returnUrl = VirtualPathUtility.ToAbsolute(returnUrl); } - Uri publicUrl = HttpRequestInfo.GetPublicFacingUrl(context.Request, context.Request.ServerVariables); + Uri publicUrl = context.Request.GetPublicFacingUrl(); return new Uri(publicUrl, returnUrl); } diff --git a/src/DotNetOpenAuth.Core/DotNetOpenAuth.Core.csproj b/src/DotNetOpenAuth.Core/DotNetOpenAuth.Core.csproj index f669731..ad17119 100644 --- a/src/DotNetOpenAuth.Core/DotNetOpenAuth.Core.csproj +++ b/src/DotNetOpenAuth.Core/DotNetOpenAuth.Core.csproj @@ -28,6 +28,7 @@ <Compile Include="Messaging\CachedDirectWebResponse.cs" /> <Compile Include="Messaging\ChannelContract.cs" /> <Compile Include="Messaging\DataBagFormatterBase.cs" /> + <Compile Include="Messaging\HttpRequestHeaders.cs" /> <Compile Include="Messaging\IHttpIndirectResponse.cs" /> <Compile Include="Messaging\IMessageOriginalPayload.cs" /> <Compile Include="Messaging\DirectWebRequestOptions.cs" /> diff --git a/src/DotNetOpenAuth.Core/Messaging/Channel.cs b/src/DotNetOpenAuth.Core/Messaging/Channel.cs index 26a8179..0feb999 100644 --- a/src/DotNetOpenAuth.Core/Messaging/Channel.cs +++ b/src/DotNetOpenAuth.Core/Messaging/Channel.cs @@ -409,7 +409,7 @@ namespace DotNetOpenAuth.Messaging { /// <returns>True if the expected message was recognized and deserialized. False otherwise.</returns> /// <exception cref="InvalidOperationException">Thrown when <see cref="HttpContext.Current"/> is null.</exception> /// <exception cref="ProtocolException">Thrown when a request message of an unexpected type is received.</exception> - public bool TryReadFromRequest<TRequest>(HttpRequestInfo httpRequest, out TRequest request) + public bool TryReadFromRequest<TRequest>(HttpRequestBase httpRequest, out TRequest request) where TRequest : class, IProtocolMessage { Requires.NotNull(httpRequest, "httpRequest"); Contract.Ensures(Contract.Result<bool>() == (Contract.ValueAtReturn<TRequest>(out request) != null)); @@ -450,7 +450,7 @@ namespace DotNetOpenAuth.Messaging { /// <returns>The deserialized message. Never null.</returns> /// <exception cref="ProtocolException">Thrown if the expected message was not recognized in the response.</exception> [SuppressMessage("Microsoft.Design", "CA1004:GenericMethodsShouldProvideTypeParameter", Justification = "This returns and verifies the appropriate message type.")] - public TRequest ReadFromRequest<TRequest>(HttpRequestInfo httpRequest) + public TRequest ReadFromRequest<TRequest>(HttpRequestBase httpRequest) where TRequest : class, IProtocolMessage { Requires.NotNull(httpRequest, "httpRequest"); TRequest request; @@ -466,11 +466,11 @@ namespace DotNetOpenAuth.Messaging { /// </summary> /// <param name="httpRequest">The request to search for an embedded message.</param> /// <returns>The deserialized message, if one is found. Null otherwise.</returns> - public IDirectedProtocolMessage ReadFromRequest(HttpRequestInfo httpRequest) { + public IDirectedProtocolMessage ReadFromRequest(HttpRequestBase httpRequest) { Requires.NotNull(httpRequest, "httpRequest"); - if (Logger.Channel.IsInfoEnabled && httpRequest.UrlBeforeRewriting != null) { - Logger.Channel.InfoFormat("Scanning incoming request for messages: {0}", httpRequest.UrlBeforeRewriting.AbsoluteUri); + if (Logger.Channel.IsInfoEnabled && httpRequest.GetPublicFacingUrl() != null) { + Logger.Channel.InfoFormat("Scanning incoming request for messages: {0}", httpRequest.GetPublicFacingUrl().AbsoluteUri); } IDirectedProtocolMessage requestMessage = this.ReadFromRequestCore(httpRequest); if (requestMessage != null) { @@ -607,16 +607,13 @@ namespace DotNetOpenAuth.Messaging { /// </remarks> /// <exception cref="InvalidOperationException">Thrown if <see cref="HttpContext.Current">HttpContext.Current</see> == <c>null</c>.</exception> [SuppressMessage("Microsoft.Design", "CA1024:UsePropertiesWhereAppropriate", Justification = "Costly call should not be a property.")] - protected internal virtual HttpRequestInfo GetRequestFromContext() { + protected internal virtual HttpRequestBase GetRequestFromContext() { Requires.ValidState(HttpContext.Current != null && HttpContext.Current.Request != null, MessagingStrings.HttpContextRequired); - Contract.Ensures(Contract.Result<HttpRequestInfo>() != null); - Contract.Ensures(Contract.Result<HttpRequestInfo>().Url != null); - Contract.Ensures(Contract.Result<HttpRequestInfo>().RawUrl != null); - Contract.Ensures(Contract.Result<HttpRequestInfo>().UrlBeforeRewriting != null); + Contract.Ensures(Contract.Result<HttpRequestBase>() != null); Contract.Assume(HttpContext.Current.Request.Url != null); Contract.Assume(HttpContext.Current.Request.RawUrl != null); - return new HttpRequestInfo(HttpContext.Current.Request); + return new HttpRequestWrapper(HttpContext.Current.Request); } /// <summary> @@ -731,16 +728,16 @@ namespace DotNetOpenAuth.Messaging { /// </summary> /// <param name="request">The request to search for an embedded message.</param> /// <returns>The deserialized message, if one is found. Null otherwise.</returns> - protected virtual IDirectedProtocolMessage ReadFromRequestCore(HttpRequestInfo request) { + protected virtual IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request) { Requires.NotNull(request, "request"); - Logger.Channel.DebugFormat("Incoming HTTP request: {0} {1}", request.HttpMethod, request.UrlBeforeRewriting.AbsoluteUri); + Logger.Channel.DebugFormat("Incoming HTTP request: {0} {1}", request.HttpMethod, request.GetPublicFacingUrl().AbsoluteUri); // Search Form data first, and if nothing is there search the QueryString - Contract.Assume(request.Form != null && request.QueryStringBeforeRewriting != null); + Contract.Assume(request.Form != null && request.GetQueryStringBeforeRewriting() != null); var fields = request.Form.ToDictionary(); if (fields.Count == 0 && request.HttpMethod != "POST") { // OpenID 2.0 section 4.1.2 - fields = request.QueryStringBeforeRewriting.ToDictionary(); + fields = request.GetQueryStringBeforeRewriting().ToDictionary(); } MessageReceivingEndpoint recipient; diff --git a/src/DotNetOpenAuth.Core/Messaging/HttpRequestHeaders.cs b/src/DotNetOpenAuth.Core/Messaging/HttpRequestHeaders.cs new file mode 100644 index 0000000..8da8013 --- /dev/null +++ b/src/DotNetOpenAuth.Core/Messaging/HttpRequestHeaders.cs @@ -0,0 +1,27 @@ +// ----------------------------------------------------------------------- +// <copyright file="HttpRequestHeaders.cs" company=""> +// TODO: Update copyright text. +// </copyright> +// ----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging { + using System; + using System.Collections.Generic; + using System.Linq; + using System.Text; + + /// <summary> + /// TODO: Update summary. + /// </summary> + internal static class HttpRequestHeaders { + /// <summary> + /// The Authorization header, which specifies the credentials that the client presents in order to authenticate itself to the server. + /// </summary> + internal const string Authorization = "Authorization"; + + /// <summary> + /// The Content-Type header, which specifies the MIME type of the accompanying body data. + /// </summary> + internal const string ContentType = "Content-Type"; + } +} diff --git a/src/DotNetOpenAuth.Core/Messaging/HttpRequestInfo.cs b/src/DotNetOpenAuth.Core/Messaging/HttpRequestInfo.cs index 579225b..24ca616 100644 --- a/src/DotNetOpenAuth.Core/Messaging/HttpRequestInfo.cs +++ b/src/DotNetOpenAuth.Core/Messaging/HttpRequestInfo.cs @@ -25,82 +25,42 @@ namespace DotNetOpenAuth.Messaging { /// ASP.NET does not let us fully initialize that class, so we have to write one /// of our one. /// </remarks> - public class HttpRequestInfo { - /// <summary> - /// The key/value pairs found in the entity of a POST request. - /// </summary> - private NameValueCollection form; + public class HttpRequestInfo : HttpRequestBase { + private readonly string httpMethod; - /// <summary> - /// The key/value pairs found in the querystring of the incoming request. - /// </summary> - private NameValueCollection queryString; + private readonly Uri requestUri; - /// <summary> - /// Backing field for the <see cref="QueryStringBeforeRewriting"/> property. - /// </summary> - private NameValueCollection queryStringBeforeRewriting; + private readonly NameValueCollection queryString; - /// <summary> - /// Backing field for the <see cref="Message"/> property. - /// </summary> - private IDirectedProtocolMessage message; + private readonly NameValueCollection headers; - /// <summary> - /// Initializes a new instance of the <see cref="HttpRequestInfo"/> class. - /// </summary> - /// <param name="request">The ASP.NET structure to copy from.</param> - public HttpRequestInfo(HttpRequest request) { - Requires.NotNull(request, "request"); - Contract.Ensures(this.HttpMethod == request.HttpMethod); - Contract.Ensures(this.Url == request.Url); - Contract.Ensures(this.RawUrl == request.RawUrl); - Contract.Ensures(this.UrlBeforeRewriting != null); - Contract.Ensures(this.Headers != null); - Contract.Ensures(this.InputStream == request.InputStream); - Contract.Ensures(this.form == request.Form); - Contract.Ensures(this.queryString == request.QueryString); + private readonly NameValueCollection form; - this.HttpMethod = request.HttpMethod; - this.Url = request.Url; - this.UrlBeforeRewriting = GetPublicFacingUrl(request); - this.RawUrl = request.RawUrl; - this.Headers = GetHeaderCollection(request.Headers); - this.InputStream = request.InputStream; + private readonly NameValueCollection serverVariables; + + public HttpRequestInfo(HttpRequestMessageProperty request, Uri requestUri) { + Requires.NotNull(request, "request"); + Requires.NotNull(requestUri, "requestUri"); - // These values would normally be calculated, but we'll reuse them from - // HttpRequest since they're already calculated, and there's a chance (<g>) - // that ASP.NET does a better job of being comprehensive about gathering - // these as well. - this.form = request.Form; - this.queryString = request.QueryString; + this.httpMethod = request.Method; + this.headers = request.Headers; + this.requestUri = requestUri; + this.form = new NameValueCollection(); + this.serverVariables = new NameValueCollection(); Reporting.RecordRequestStatistics(this); } - /// <summary> - /// Initializes a new instance of the <see cref="HttpRequestInfo"/> class. - /// </summary> - /// <param name="httpMethod">The HTTP method (i.e. GET or POST) of the incoming request.</param> - /// <param name="requestUrl">The URL being requested.</param> - /// <param name="rawUrl">The raw URL that appears immediately following the HTTP verb in the request, - /// before any URL rewriting takes place.</param> - /// <param name="headers">Headers in the HTTP request.</param> - /// <param name="inputStream">The entity stream, if any. (POST requests typically have these). Use <c>null</c> for GET requests.</param> - public HttpRequestInfo(string httpMethod, Uri requestUrl, string rawUrl, WebHeaderCollection headers, Stream inputStream) { + public HttpRequestInfo(string httpMethod, Uri requestUri, NameValueCollection form = null, NameValueCollection headers = null) { Requires.NotNullOrEmpty(httpMethod, "httpMethod"); - Requires.NotNull(requestUrl, "requestUrl"); - Requires.NotNull(rawUrl, "rawUrl"); - Requires.NotNull(headers, "headers"); - - this.HttpMethod = httpMethod; - this.Url = requestUrl; - this.UrlBeforeRewriting = requestUrl; - this.RawUrl = rawUrl; - this.Headers = headers; - this.InputStream = inputStream; + Requires.NotNull(requestUri, "requestUri"); - Reporting.RecordRequestStatistics(this); + this.httpMethod = httpMethod; + this.requestUri = requestUri; + this.form = form ?? new NameValueCollection(); + this.queryString = HttpUtility.ParseQueryString(requestUri.Query); + this.headers = headers ?? new NameValueCollection(); + this.serverVariables = new NameValueCollection(); } /// <summary> @@ -110,337 +70,78 @@ namespace DotNetOpenAuth.Messaging { public HttpRequestInfo(HttpListenerRequest listenerRequest) { Requires.NotNull(listenerRequest, "listenerRequest"); - this.HttpMethod = listenerRequest.HttpMethod; - this.Url = listenerRequest.Url; - this.UrlBeforeRewriting = listenerRequest.Url; - this.RawUrl = listenerRequest.RawUrl; - this.Headers = new WebHeaderCollection(); - foreach (string key in listenerRequest.Headers) { - this.Headers[key] = listenerRequest.Headers[key]; - } - - this.InputStream = listenerRequest.InputStream; + this.httpMethod = listenerRequest.HttpMethod; + this.requestUri = listenerRequest.Url; + this.queryString = listenerRequest.QueryString; + this.headers = listenerRequest.Headers; + this.form = ParseFormData(listenerRequest.HttpMethod, listenerRequest.Headers, listenerRequest.InputStream); + this.serverVariables = new NameValueCollection(); Reporting.RecordRequestStatistics(this); } - /// <summary> - /// Initializes a new instance of the <see cref="HttpRequestInfo"/> class. - /// </summary> - /// <param name="request">The WCF incoming request structure to get the HTTP information from.</param> - /// <param name="requestUri">The URI of the service endpoint.</param> - public HttpRequestInfo(HttpRequestMessageProperty request, Uri requestUri) { - Requires.NotNull(request, "request"); + public HttpRequestInfo(string httpMethod, Uri requestUri, NameValueCollection headers, Stream inputStream) { + Requires.NotNullOrEmpty(httpMethod, "httpMethod"); Requires.NotNull(requestUri, "requestUri"); - this.HttpMethod = request.Method; - this.Headers = request.Headers; - this.Url = requestUri; - this.UrlBeforeRewriting = requestUri; - this.RawUrl = MakeUpRawUrlFromUrl(requestUri); + this.httpMethod = httpMethod; + this.requestUri = requestUri; + this.headers = headers; + this.queryString = HttpUtility.ParseQueryString(requestUri.Query); + this.form = ParseFormData(httpMethod, headers, inputStream); + this.serverVariables = new NameValueCollection(); Reporting.RecordRequestStatistics(this); } - /// <summary> - /// Initializes a new instance of the <see cref="HttpRequestInfo"/> class. - /// </summary> - internal HttpRequestInfo() { - Contract.Ensures(this.HttpMethod == "GET"); - Contract.Ensures(this.Headers != null); - - this.HttpMethod = "GET"; - this.Headers = new WebHeaderCollection(); - } - - /// <summary> - /// Initializes a new instance of the <see cref="HttpRequestInfo"/> class. - /// </summary> - /// <param name="request">The HttpWebRequest (that was never used) to copy from.</param> - internal HttpRequestInfo(WebRequest request) { - Requires.NotNull(request, "request"); - - this.HttpMethod = request.Method; - this.Url = request.RequestUri; - this.UrlBeforeRewriting = request.RequestUri; - this.RawUrl = MakeUpRawUrlFromUrl(request.RequestUri); - this.Headers = GetHeaderCollection(request.Headers); - this.InputStream = null; - - Reporting.RecordRequestStatistics(this); - } - - /// <summary> - /// Initializes a new instance of the <see cref="HttpRequestInfo"/> class. - /// </summary> - /// <param name="message">The message being passed in through a mock transport. May be null.</param> - /// <param name="httpMethod">The HTTP method that the incoming request came in on, whether or not <paramref name="message"/> is null.</param> - internal HttpRequestInfo(IDirectedProtocolMessage message, HttpDeliveryMethods httpMethod) { - this.message = message; - this.HttpMethod = MessagingUtilities.GetHttpVerb(httpMethod); - } - - /// <summary> - /// Gets or sets the message that is being sent over a mock transport (for testing). - /// </summary> - internal virtual IDirectedProtocolMessage Message { - get { return this.message; } - set { this.message = value; } - } - - /// <summary> - /// Gets or sets the verb in the request (i.e. GET, POST, etc.) - /// </summary> - internal string HttpMethod { get; set; } - - /// <summary> - /// Gets or sets the entire URL of the request, after any URL rewriting. - /// </summary> - internal Uri Url { get; set; } - - /// <summary> - /// Gets or sets the raw URL that appears immediately following the HTTP verb in the request, - /// before any URL rewriting takes place. - /// </summary> - internal string RawUrl { get; set; } - - /// <summary> - /// Gets or sets the full public URL used by the remote client to initiate this request, - /// before any URL rewriting and before any changes made by web farm load distributors. - /// </summary> - internal Uri UrlBeforeRewriting { get; set; } - - /// <summary> - /// Gets the query part of the URL (The ? and everything after it), after URL rewriting. - /// </summary> - internal string Query { - get { return this.Url != null ? this.Url.Query : null; } - } - - /// <summary> - /// Gets or sets the collection of headers that came in with the request. - /// </summary> - internal WebHeaderCollection Headers { get; set; } - - /// <summary> - /// Gets or sets the entity, or body of the request, if any. - /// </summary> - internal Stream InputStream { get; set; } - - /// <summary> - /// Gets the key/value pairs found in the entity of a POST request. - /// </summary> - internal NameValueCollection Form { - get { - Contract.Ensures(Contract.Result<NameValueCollection>() != null); - if (this.form == null) { - ContentType contentType = string.IsNullOrEmpty(this.Headers[HttpRequestHeader.ContentType]) ? null : new ContentType(this.Headers[HttpRequestHeader.ContentType]); - if (this.HttpMethod == "POST" && contentType != null && string.Equals(contentType.MediaType, Channel.HttpFormUrlEncoded, StringComparison.Ordinal)) { - StreamReader reader = new StreamReader(this.InputStream); - long originalPosition = 0; - if (this.InputStream.CanSeek) { - originalPosition = this.InputStream.Position; - } - this.form = HttpUtility.ParseQueryString(reader.ReadToEnd()); - if (this.InputStream.CanSeek) { - this.InputStream.Seek(originalPosition, SeekOrigin.Begin); - } - } - else { - this.form = new NameValueCollection(); - } - } - - return this.form; - } - } - - /// <summary> - /// Gets the key/value pairs found in the querystring of the incoming request. - /// </summary> - internal NameValueCollection QueryString { - get { - if (this.queryString == null) { - this.queryString = this.Query != null ? HttpUtility.ParseQueryString(this.Query) : new NameValueCollection(); - } - - return this.queryString; - } - } - - /// <summary> - /// Gets the query data from the original request (before any URL rewriting has occurred.) - /// </summary> - /// <returns>A <see cref="NameValueCollection"/> containing all the parameters in the query string.</returns> - internal NameValueCollection QueryStringBeforeRewriting { - get { - if (this.queryStringBeforeRewriting == null) { - // This request URL may have been rewritten by the host site. - // For openid protocol purposes, we really need to look at - // the original query parameters before any rewriting took place. - if (!this.IsUrlRewritten) { - // No rewriting has taken place. - this.queryStringBeforeRewriting = this.QueryString; - } - else { - // Rewriting detected! Recover the original request URI. - ErrorUtilities.VerifyInternal(this.UrlBeforeRewriting != null, "UrlBeforeRewriting is null, so the query string cannot be determined."); - this.queryStringBeforeRewriting = HttpUtility.ParseQueryString(this.UrlBeforeRewriting.Query); - } - } - - return this.queryStringBeforeRewriting; - } + public override string HttpMethod { + get { return this.httpMethod; } } - /// <summary> - /// Gets a value indicating whether the request's URL was rewritten by ASP.NET - /// or some other module. - /// </summary> - /// <value> - /// <c>true</c> if this request's URL was rewritten; otherwise, <c>false</c>. - /// </value> - internal bool IsUrlRewritten { - get { return this.Url != this.UrlBeforeRewriting; } + public override NameValueCollection Headers { + get { return this.headers; } } - /// <summary> - /// Gets the public facing URL for the given incoming HTTP request. - /// </summary> - /// <param name="request">The request.</param> - /// <param name="serverVariables">The server variables to consider part of the request.</param> - /// <returns> - /// The URI that the outside world used to create this request. - /// </returns> - /// <remarks> - /// Although the <paramref name="serverVariables"/> value can be obtained from - /// <see cref="HttpRequest.ServerVariables"/>, it's useful to be able to pass them - /// in so we can simulate injected values from our unit tests since the actual property - /// is a read-only kind of <see cref="NameValueCollection"/>. - /// </remarks> - internal static Uri GetPublicFacingUrl(HttpRequest request, NameValueCollection serverVariables) { - return GetPublicFacingUrl(new HttpRequestWrapper(request), serverVariables); + public override Uri Url { + get { return this.requestUri; } } - /// <summary> - /// Gets the public facing URL for the given incoming HTTP request. - /// </summary> - /// <param name="request">The request.</param> - /// <param name="serverVariables">The server variables to consider part of the request.</param> - /// <returns> - /// The URI that the outside world used to create this request. - /// </returns> - /// <remarks> - /// Although the <paramref name="serverVariables"/> value can be obtained from - /// <see cref="HttpRequest.ServerVariables"/>, it's useful to be able to pass them - /// in so we can simulate injected values from our unit tests since the actual property - /// is a read-only kind of <see cref="NameValueCollection"/>. - /// </remarks> - internal static Uri GetPublicFacingUrl(HttpRequestBase request, NameValueCollection serverVariables) { - Requires.NotNull(request, "request"); - Requires.NotNull(serverVariables, "serverVariables"); - - // Due to URL rewriting, cloud computing (i.e. Azure) - // and web farms, etc., we have to be VERY careful about what - // we consider the incoming URL. We want to see the URL as it would - // appear on the public-facing side of the hosting web site. - // HttpRequest.Url gives us the internal URL in a cloud environment, - // So we use a variable that (at least from what I can tell) gives us - // the public URL: - if (serverVariables["HTTP_HOST"] != null) { - ErrorUtilities.VerifySupported(request.Url.Scheme == Uri.UriSchemeHttps || request.Url.Scheme == Uri.UriSchemeHttp, "Only HTTP and HTTPS are supported protocols."); - string scheme = serverVariables["HTTP_X_FORWARDED_PROTO"] ?? request.Url.Scheme; - Uri hostAndPort = new Uri(scheme + Uri.SchemeDelimiter + serverVariables["HTTP_HOST"]); - UriBuilder publicRequestUri = new UriBuilder(request.Url); - publicRequestUri.Scheme = scheme; - publicRequestUri.Host = hostAndPort.Host; - publicRequestUri.Port = hostAndPort.Port; // CC missing Uri.Port contract that's on UriBuilder.Port - return publicRequestUri.Uri; - } - else { - // Failover to the method that works for non-web farm enviroments. - // We use Request.Url for the full path to the server, and modify it - // with Request.RawUrl to capture both the cookieless session "directory" if it exists - // and the original path in case URL rewriting is going on. We don't want to be - // fooled by URL rewriting because we're comparing the actual URL with what's in - // the return_to parameter in some cases. - // Response.ApplyAppPathModifier(builder.Path) would have worked for the cookieless - // session, but not the URL rewriting problem. - return new Uri(request.Url, request.RawUrl); - } + public override string RawUrl { + get { return this.requestUri.AbsolutePath + this.requestUri.Query; } } - /// <summary> - /// Gets the query or form data from the original request (before any URL rewriting has occurred.) - /// </summary> - /// <returns>A set of name=value pairs.</returns> - [SuppressMessage("Microsoft.Design", "CA1024:UsePropertiesWhereAppropriate", Justification = "Expensive call")] - internal NameValueCollection GetQueryOrFormFromContext() { - NameValueCollection query; - if (this.HttpMethod == "GET") { - query = this.QueryStringBeforeRewriting; - } - else { - query = this.Form; - } - return query; + public override NameValueCollection Form { + get { return this.form; } } - /// <summary> - /// Gets the public facing URL for the given incoming HTTP request. - /// </summary> - /// <param name="request">The request.</param> - /// <returns>The URI that the outside world used to create this request.</returns> - private static Uri GetPublicFacingUrl(HttpRequest request) { - Requires.NotNull(request, "request"); - return GetPublicFacingUrl(request, request.ServerVariables); + public override NameValueCollection QueryString { + get { return this.queryString; } } - /// <summary> - /// Makes up a reasonable guess at the raw URL from the possibly rewritten URL. - /// </summary> - /// <param name="url">A full URL.</param> - /// <returns>A raw URL that might have come in on the HTTP verb.</returns> - private static string MakeUpRawUrlFromUrl(Uri url) { - Requires.NotNull(url, "url"); - return url.AbsolutePath + url.Query + url.Fragment; + public override NameValueCollection ServerVariables { + get { return this.serverVariables; } } - /// <summary> - /// Converts a NameValueCollection to a WebHeaderCollection. - /// </summary> - /// <param name="pairs">The collection a HTTP headers.</param> - /// <returns>A new collection of the given headers.</returns> - private static WebHeaderCollection GetHeaderCollection(NameValueCollection pairs) { - Requires.NotNull(pairs, "pairs"); + private static NameValueCollection ParseFormData(string httpMethod, NameValueCollection headers, Stream inputStream) { + Requires.NotNullOrEmpty(httpMethod, "httpMethod"); + Requires.NotNull(headers, "headers"); - WebHeaderCollection headers = new WebHeaderCollection(); - foreach (string key in pairs) { - try { - headers.Add(key, pairs[key]); + ContentType contentType = string.IsNullOrEmpty(headers[HttpRequestHeaders.ContentType]) ? null : new ContentType(headers[HttpRequestHeaders.ContentType]); + if (inputStream != null && httpMethod == "POST" && contentType != null && string.Equals(contentType.MediaType, Channel.HttpFormUrlEncoded, StringComparison.Ordinal)) { + var reader = new StreamReader(inputStream); + long originalPosition = 0; + if (inputStream.CanSeek) { + originalPosition = inputStream.Position; } - catch (ArgumentException ex) { - Logger.Messaging.WarnFormat( - "{0} thrown when trying to add web header \"{1}: {2}\". {3}", - ex.GetType().Name, - key, - pairs[key], - ex.Message); + string postEntity = reader.ReadToEnd(); + if (inputStream.CanSeek) { + inputStream.Seek(originalPosition, SeekOrigin.Begin); } - } - return headers; - } + return HttpUtility.ParseQueryString(postEntity); + } -#if CONTRACTS_FULL - /// <summary> - /// Verifies conditions that should be true for any valid state of this object. - /// </summary> - [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification = "Called by code contracts.")] - [SuppressMessage("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode", Justification = "Called by code contracts.")] - [ContractInvariantMethod] - private void ObjectInvariant() { + return new NameValueCollection(); } -#endif } } diff --git a/src/DotNetOpenAuth.Core/Messaging/MessagingUtilities.cs b/src/DotNetOpenAuth.Core/Messaging/MessagingUtilities.cs index fbf6b4f..bff016b 100644 --- a/src/DotNetOpenAuth.Core/Messaging/MessagingUtilities.cs +++ b/src/DotNetOpenAuth.Core/Messaging/MessagingUtilities.cs @@ -153,9 +153,7 @@ namespace DotNetOpenAuth.Messaging { [SuppressMessage("Microsoft.Design", "CA1024:UsePropertiesWhereAppropriate", Justification = "Expensive call should not be a property.")] public static Uri GetRequestUrlFromContext() { Requires.ValidState(HttpContext.Current != null && HttpContext.Current.Request != null, MessagingStrings.HttpContextRequired); - HttpContext context = HttpContext.Current; - - return HttpRequestInfo.GetPublicFacingUrl(context.Request, context.Request.ServerVariables); + return new HttpRequestWrapper(HttpContext.Current.Request).GetPublicFacingUrl(); } /// <summary> @@ -1352,8 +1350,8 @@ namespace DotNetOpenAuth.Messaging { /// <param name="request">The request to get recipient information from.</param> /// <returns>The recipient.</returns> /// <exception cref="ArgumentException">Thrown if the HTTP request is something we can't handle.</exception> - internal static MessageReceivingEndpoint GetRecipient(this HttpRequestInfo request) { - return new MessageReceivingEndpoint(request.UrlBeforeRewriting, GetHttpDeliveryMethod(request.HttpMethod)); + internal static MessageReceivingEndpoint GetRecipient(this HttpRequestBase request) { + return new MessageReceivingEndpoint(request.GetPublicFacingUrl(), GetHttpDeliveryMethod(request.HttpMethod)); } /// <summary> @@ -1483,6 +1481,15 @@ namespace DotNetOpenAuth.Messaging { return dictionary; } + internal static NameValueCollection ToNameValueCollection(this IDictionary<string,string> data) { + var nvc = new NameValueCollection(); + foreach (var entry in data) { + nvc.Add(entry.Key, entry.Value); + } + + return nvc; + } + /// <summary> /// Sorts the elements of a sequence in ascending order by using a specified comparer. /// </summary> @@ -1663,6 +1670,103 @@ namespace DotNetOpenAuth.Messaging { } /// <summary> + /// Gets the query data from the original request (before any URL rewriting has occurred.) + /// </summary> + /// <returns>A <see cref="NameValueCollection"/> containing all the parameters in the query string.</returns> + internal static NameValueCollection GetQueryStringBeforeRewriting(this HttpRequestBase request) { + // This request URL may have been rewritten by the host site. + // For openid protocol purposes, we really need to look at + // the original query parameters before any rewriting took place. + Uri beforeRewriting = GetPublicFacingUrl(request); + if (beforeRewriting == request.Url) { + // No rewriting has taken place. + return request.QueryString; + } else { + // Rewriting detected! Recover the original request URI. + ErrorUtilities.VerifyInternal(beforeRewriting != null, "UrlBeforeRewriting is null, so the query string cannot be determined."); + return HttpUtility.ParseQueryString(beforeRewriting.Query); + } + } + + /// <summary> + /// Gets a value indicating whether the request's URL was rewritten by ASP.NET + /// or some other module. + /// </summary> + /// <value> + /// <c>true</c> if this request's URL was rewritten; otherwise, <c>false</c>. + /// </value> + internal static bool GetIsUrlRewritten(this HttpRequestBase request) { + return request.Url != GetPublicFacingUrl(request); + } + + /// <summary> + /// Gets the public facing URL for the given incoming HTTP request. + /// </summary> + /// <param name="request">The request.</param> + /// <param name="serverVariables">The server variables to consider part of the request.</param> + /// <returns> + /// The URI that the outside world used to create this request. + /// </returns> + /// <remarks> + /// Although the <paramref name="serverVariables"/> value can be obtained from + /// <see cref="HttpRequest.ServerVariables"/>, it's useful to be able to pass them + /// in so we can simulate injected values from our unit tests since the actual property + /// is a read-only kind of <see cref="NameValueCollection"/>. + /// </remarks> + internal static Uri GetPublicFacingUrl(this HttpRequestBase request, NameValueCollection serverVariables) { + Requires.NotNull(request, "request"); + Requires.NotNull(serverVariables, "serverVariables"); + + // Due to URL rewriting, cloud computing (i.e. Azure) + // and web farms, etc., we have to be VERY careful about what + // we consider the incoming URL. We want to see the URL as it would + // appear on the public-facing side of the hosting web site. + // HttpRequest.Url gives us the internal URL in a cloud environment, + // So we use a variable that (at least from what I can tell) gives us + // the public URL: + if (serverVariables["HTTP_HOST"] != null) { + ErrorUtilities.VerifySupported(request.Url.Scheme == Uri.UriSchemeHttps || request.Url.Scheme == Uri.UriSchemeHttp, "Only HTTP and HTTPS are supported protocols."); + string scheme = serverVariables["HTTP_X_FORWARDED_PROTO"] ?? request.Url.Scheme; + Uri hostAndPort = new Uri(scheme + Uri.SchemeDelimiter + serverVariables["HTTP_HOST"]); + UriBuilder publicRequestUri = new UriBuilder(request.Url); + publicRequestUri.Scheme = scheme; + publicRequestUri.Host = hostAndPort.Host; + publicRequestUri.Port = hostAndPort.Port; // CC missing Uri.Port contract that's on UriBuilder.Port + return publicRequestUri.Uri; + } else { + // Failover to the method that works for non-web farm enviroments. + // We use Request.Url for the full path to the server, and modify it + // with Request.RawUrl to capture both the cookieless session "directory" if it exists + // and the original path in case URL rewriting is going on. We don't want to be + // fooled by URL rewriting because we're comparing the actual URL with what's in + // the return_to parameter in some cases. + // Response.ApplyAppPathModifier(builder.Path) would have worked for the cookieless + // session, but not the URL rewriting problem. + return new Uri(request.Url, request.RawUrl); + } + } + + /// <summary> + /// Gets the public facing URL for the given incoming HTTP request. + /// </summary> + /// <param name="request">The request.</param> + /// <returns>The URI that the outside world used to create this request.</returns> + internal static Uri GetPublicFacingUrl(this HttpRequestBase request) { + Requires.NotNull(request, "request"); + return GetPublicFacingUrl(request, request.ServerVariables); + } + + /// <summary> + /// Gets the query or form data from the original request (before any URL rewriting has occurred.) + /// </summary> + /// <returns>A set of name=value pairs.</returns> + [SuppressMessage("Microsoft.Design", "CA1024:UsePropertiesWhereAppropriate", Justification = "Expensive call")] + internal static NameValueCollection GetQueryOrForm(this HttpRequestBase request) { + Requires.NotNull(request, "request"); + return request.HttpMethod == "GET" ? GetQueryStringBeforeRewriting(request) : request.Form; + } + + /// <summary> /// Creates a symmetric algorithm for use in encryption/decryption. /// </summary> /// <param name="key">The symmetric key to use for encryption/decryption.</param> diff --git a/src/DotNetOpenAuth.Core/Reporting.cs b/src/DotNetOpenAuth.Core/Reporting.cs index a7940b6..310d1ba 100644 --- a/src/DotNetOpenAuth.Core/Reporting.cs +++ b/src/DotNetOpenAuth.Core/Reporting.cs @@ -297,7 +297,7 @@ namespace DotNetOpenAuth { /// Records statistics collected from incoming requests. /// </summary> /// <param name="request">The request.</param> - internal static void RecordRequestStatistics(HttpRequestInfo request) { + internal static void RecordRequestStatistics(HttpRequestBase request) { Contract.Requires(request != null); // In release builds, just quietly return. @@ -311,7 +311,7 @@ namespace DotNetOpenAuth { } if (Configuration.IncludeLocalRequestUris && !observedRequests.IsFull) { - var requestBuilder = new UriBuilder(request.UrlBeforeRewriting); + var requestBuilder = new UriBuilder(request.GetPublicFacingUrl()); requestBuilder.Query = null; requestBuilder.Fragment = null; observedRequests.Add(requestBuilder.Uri.AbsoluteUri); diff --git a/src/DotNetOpenAuth.Core/Requires.cs b/src/DotNetOpenAuth.Core/Requires.cs index 8aa15dd..41720c2 100644 --- a/src/DotNetOpenAuth.Core/Requires.cs +++ b/src/DotNetOpenAuth.Core/Requires.cs @@ -28,12 +28,13 @@ namespace DotNetOpenAuth { [ContractArgumentValidator] #endif [Pure, DebuggerStepThrough] - internal static void NotNull<T>(T value, string parameterName) where T : class { + internal static T NotNull<T>(T value, string parameterName) where T : class { if (value == null) { throw new ArgumentNullException(parameterName); } Contract.EndContractBlock(); + return value; } /// <summary> diff --git a/src/DotNetOpenAuth.InfoCard.UI/InfoCard/InfoCardSelector.cs b/src/DotNetOpenAuth.InfoCard.UI/InfoCard/InfoCardSelector.cs index 756b9a7..c4563f2 100644 --- a/src/DotNetOpenAuth.InfoCard.UI/InfoCard/InfoCardSelector.cs +++ b/src/DotNetOpenAuth.InfoCard.UI/InfoCard/InfoCardSelector.cs @@ -279,7 +279,7 @@ namespace DotNetOpenAuth.InfoCard { if (!string.IsNullOrEmpty(value)) { if (this.Page != null && !this.DesignMode) { // Validate new value by trying to construct a Uri based on it. - new Uri(new HttpRequestInfo(HttpContext.Current.Request).UrlBeforeRewriting, this.Page.ResolveUrl(value)); // throws an exception on failure. + new Uri(new HttpRequestWrapper(HttpContext.Current.Request).GetPublicFacingUrl(), this.Page.ResolveUrl(value)); // throws an exception on failure. } else { // We can't fully test it, but it should start with either ~/ or a protocol. if (Regex.IsMatch(value, @"^https?://")) { diff --git a/src/DotNetOpenAuth.OAuth.Consumer/OAuth/WebConsumer.cs b/src/DotNetOpenAuth.OAuth.Consumer/OAuth/WebConsumer.cs index d599598..086ff7a 100644 --- a/src/DotNetOpenAuth.OAuth.Consumer/OAuth/WebConsumer.cs +++ b/src/DotNetOpenAuth.OAuth.Consumer/OAuth/WebConsumer.cs @@ -40,7 +40,7 @@ namespace DotNetOpenAuth.OAuth { /// Requires HttpContext.Current. /// </remarks> public UserAuthorizationRequest PrepareRequestUserAuthorization() { - Uri callback = this.Channel.GetRequestFromContext().UrlBeforeRewriting.StripQueryArgumentsWithPrefix(Protocol.ParameterPrefix); + Uri callback = this.Channel.GetRequestFromContext().GetPublicFacingUrl().StripQueryArgumentsWithPrefix(Protocol.ParameterPrefix); return this.PrepareRequestUserAuthorization(callback, null, null); } @@ -76,7 +76,7 @@ namespace DotNetOpenAuth.OAuth { /// </summary> /// <param name="request">The incoming HTTP request.</param> /// <returns>The access token, or null if no incoming authorization message was recognized.</returns> - public AuthorizedTokenResponse ProcessUserAuthorization(HttpRequestInfo request) { + public AuthorizedTokenResponse ProcessUserAuthorization(HttpRequestBase request) { Requires.NotNull(request, "request"); UserAuthorizationResponse authorizationMessage; diff --git a/src/DotNetOpenAuth.OAuth.ServiceProvider/OAuth/ServiceProvider.cs b/src/DotNetOpenAuth.OAuth.ServiceProvider/OAuth/ServiceProvider.cs index 9d93e4f..ecfd191 100644 --- a/src/DotNetOpenAuth.OAuth.ServiceProvider/OAuth/ServiceProvider.cs +++ b/src/DotNetOpenAuth.OAuth.ServiceProvider/OAuth/ServiceProvider.cs @@ -216,7 +216,7 @@ namespace DotNetOpenAuth.OAuth { /// </summary> /// <param name="request">The HTTP request to read the message from.</param> /// <returns>The deserialized message.</returns> - public IDirectedProtocolMessage ReadRequest(HttpRequestInfo request) { + public IDirectedProtocolMessage ReadRequest(HttpRequestBase request) { return this.Channel.ReadFromRequest(request); } @@ -238,7 +238,7 @@ namespace DotNetOpenAuth.OAuth { /// <param name="request">The HTTP request to read from.</param> /// <returns>The incoming request, or null if no OAuth message was attached.</returns> /// <exception cref="ProtocolException">Thrown if an unexpected OAuth message is attached to the incoming request.</exception> - public UnauthorizedTokenRequest ReadTokenRequest(HttpRequestInfo request) { + public UnauthorizedTokenRequest ReadTokenRequest(HttpRequestBase request) { UnauthorizedTokenRequest message; if (this.Channel.TryReadFromRequest(request, out message)) { ErrorUtilities.VerifyProtocol(message.Version >= Protocol.Lookup(this.SecuritySettings.MinimumRequiredOAuthVersion).Version, OAuthStrings.MinimumConsumerVersionRequirementNotMet, this.SecuritySettings.MinimumRequiredOAuthVersion, message.Version); @@ -282,7 +282,7 @@ namespace DotNetOpenAuth.OAuth { /// <param name="request">The HTTP request to read from.</param> /// <returns>The incoming request, or null if no OAuth message was attached.</returns> /// <exception cref="ProtocolException">Thrown if an unexpected OAuth message is attached to the incoming request.</exception> - public UserAuthorizationRequest ReadAuthorizationRequest(HttpRequestInfo request) { + public UserAuthorizationRequest ReadAuthorizationRequest(HttpRequestBase request) { UserAuthorizationRequest message; this.Channel.TryReadFromRequest(request, out message); return message; @@ -368,7 +368,7 @@ namespace DotNetOpenAuth.OAuth { /// <param name="request">The HTTP request to read from.</param> /// <returns>The incoming request, or null if no OAuth message was attached.</returns> /// <exception cref="ProtocolException">Thrown if an unexpected OAuth message is attached to the incoming request.</exception> - public AuthorizedTokenRequest ReadAccessTokenRequest(HttpRequestInfo request) { + public AuthorizedTokenRequest ReadAccessTokenRequest(HttpRequestBase request) { AuthorizedTokenRequest message; this.Channel.TryReadFromRequest(request, out message); return message; @@ -436,7 +436,7 @@ namespace DotNetOpenAuth.OAuth { /// to access the resources being requested. /// </remarks> /// <exception cref="ProtocolException">Thrown if an unexpected message is attached to the request.</exception> - public AccessProtectedResourceRequest ReadProtectedResourceAuthorization(HttpRequestInfo request) { + public AccessProtectedResourceRequest ReadProtectedResourceAuthorization(HttpRequestBase request) { Requires.NotNull(request, "request"); AccessProtectedResourceRequest accessMessage; diff --git a/src/DotNetOpenAuth.OAuth/OAuth/ChannelElements/OAuthChannel.cs b/src/DotNetOpenAuth.OAuth/OAuth/ChannelElements/OAuthChannel.cs index 2cbc16b..ace3777 100644 --- a/src/DotNetOpenAuth.OAuth/OAuth/ChannelElements/OAuthChannel.cs +++ b/src/DotNetOpenAuth.OAuth/OAuth/ChannelElements/OAuthChannel.cs @@ -109,15 +109,15 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// </summary> /// <param name="request">The HTTP request to search.</param> /// <returns>The deserialized message, if one is found. Null otherwise.</returns> - protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestInfo request) { + protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request) { // First search the Authorization header. - string authorization = request.Headers[HttpRequestHeader.Authorization]; + string authorization = request.Headers[HttpRequestHeaders.Authorization]; var fields = MessagingUtilities.ParseAuthorizationHeader(Protocol.AuthorizationHeaderScheme, authorization).ToDictionary(); fields.Remove("realm"); // ignore the realm parameter, since we don't use it, and it must be omitted from signature base string. // Scrape the entity - if (!string.IsNullOrEmpty(request.Headers[HttpRequestHeader.ContentType])) { - var contentType = new ContentType(request.Headers[HttpRequestHeader.ContentType]); + if (!string.IsNullOrEmpty(request.Headers[HttpRequestHeaders.ContentType])) { + var contentType = new ContentType(request.Headers[HttpRequestHeaders.ContentType]); if (string.Equals(contentType.MediaType, HttpFormUrlEncoded, StringComparison.Ordinal)) { foreach (string key in request.Form) { if (key != null) { @@ -130,11 +130,12 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { } // Scrape the query string - foreach (string key in request.QueryStringBeforeRewriting) { + var qs = request.GetQueryStringBeforeRewriting(); + foreach (string key in qs) { if (key != null) { - fields.Add(key, request.QueryStringBeforeRewriting[key]); + fields.Add(key, qs[key]); } else { - Logger.OAuth.WarnFormat("Ignoring query string parameter '{0}' since it isn't a standard name=value parameter.", request.QueryStringBeforeRewriting[key]); + Logger.OAuth.WarnFormat("Ignoring query string parameter '{0}' since it isn't a standard name=value parameter.", qs[key]); } } @@ -152,7 +153,7 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { // Add receiving HTTP transport information required for signature generation. var signedMessage = message as ITamperResistantOAuthMessage; if (signedMessage != null) { - signedMessage.Recipient = request.UrlBeforeRewriting; + signedMessage.Recipient = request.GetPublicFacingUrl(); signedMessage.HttpMethod = request.HttpMethod; } diff --git a/src/DotNetOpenAuth.OAuth2.AuthorizationServer/OAuth2/AuthorizationServer.cs b/src/DotNetOpenAuth.OAuth2.AuthorizationServer/OAuth2/AuthorizationServer.cs index 9840218..7fbc56c 100644 --- a/src/DotNetOpenAuth.OAuth2.AuthorizationServer/OAuth2/AuthorizationServer.cs +++ b/src/DotNetOpenAuth.OAuth2.AuthorizationServer/OAuth2/AuthorizationServer.cs @@ -12,6 +12,8 @@ namespace DotNetOpenAuth.OAuth2 { using System.Linq; using System.Security.Cryptography; using System.Text; + using System.Web; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OAuth2.ChannelElements; using DotNetOpenAuth.OAuth2.Messages; @@ -51,7 +53,7 @@ namespace DotNetOpenAuth.OAuth2 { /// <returns>The incoming request, or null if no OAuth message was attached.</returns> /// <exception cref="ProtocolException">Thrown if an unexpected OAuth message is attached to the incoming request.</exception> [SuppressMessage("Microsoft.Naming", "CA2204:Literals should be spelled correctly", MessageId = "unauthorizedclient", Justification = "Protocol required.")] - public EndUserAuthorizationRequest ReadAuthorizationRequest(HttpRequestInfo request = null) { + public EndUserAuthorizationRequest ReadAuthorizationRequest(HttpRequestBase request = null) { if (request == null) { request = this.Channel.GetRequestFromContext(); } @@ -117,7 +119,7 @@ namespace DotNetOpenAuth.OAuth2 { /// This method assumes that the authorization server and the resource server are the same and that they share a single /// asymmetric key for signing and encrypting the access token. If this is not true, use the <see cref="ReadAccessTokenRequest"/> method instead. /// </remarks> - public bool TryPrepareAccessTokenResponse(HttpRequestInfo httpRequestInfo, out IDirectResponseProtocolMessage response) { + public bool TryPrepareAccessTokenResponse(HttpRequestBase httpRequestInfo, out IDirectResponseProtocolMessage response) { Requires.NotNull(httpRequestInfo, "httpRequestInfo"); Contract.Ensures(Contract.Result<bool>() == (Contract.ValueAtReturn<IDirectResponseProtocolMessage>(out response) != null)); @@ -136,7 +138,7 @@ namespace DotNetOpenAuth.OAuth2 { /// </summary> /// <param name="requestInfo">The request info.</param> /// <returns>The Client's request for an access token; or <c>null</c> if no such message was found in the request.</returns> - public AccessTokenRequestBase ReadAccessTokenRequest(HttpRequestInfo requestInfo = null) { + public AccessTokenRequestBase ReadAccessTokenRequest(HttpRequestBase requestInfo = null) { if (requestInfo == null) { requestInfo = this.Channel.GetRequestFromContext(); } diff --git a/src/DotNetOpenAuth.OAuth2.Client/OAuth2/UserAgentClient.cs b/src/DotNetOpenAuth.OAuth2.Client/OAuth2/UserAgentClient.cs index 5131b10..cfbc886 100644 --- a/src/DotNetOpenAuth.OAuth2.Client/OAuth2/UserAgentClient.cs +++ b/src/DotNetOpenAuth.OAuth2.Client/OAuth2/UserAgentClient.cs @@ -10,6 +10,8 @@ namespace DotNetOpenAuth.OAuth2 { using System.Diagnostics.Contracts; using System.Linq; using System.Text; + using System.Web; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OAuth2.Messages; @@ -93,7 +95,7 @@ namespace DotNetOpenAuth.OAuth2 { authorizationState = new AuthorizationState(); } - var carrier = new HttpRequestInfo("GET", actualRedirectUrl, actualRedirectUrl.PathAndQuery, new System.Net.WebHeaderCollection(), null); + var carrier = new HttpRequestInfo("GET", actualRedirectUrl); IDirectedProtocolMessage response = this.Channel.ReadFromRequest(carrier); if (response == null) { return null; diff --git a/src/DotNetOpenAuth.OAuth2.Client/OAuth2/WebServerClient.cs b/src/DotNetOpenAuth.OAuth2.Client/OAuth2/WebServerClient.cs index ffcc1ee..fe37dc3 100644 --- a/src/DotNetOpenAuth.OAuth2.Client/OAuth2/WebServerClient.cs +++ b/src/DotNetOpenAuth.OAuth2.Client/OAuth2/WebServerClient.cs @@ -75,7 +75,7 @@ namespace DotNetOpenAuth.OAuth2 { Contract.Ensures(Contract.Result<OutgoingWebResponse>() != null); if (authorization.Callback == null) { - authorization.Callback = this.Channel.GetRequestFromContext().UrlBeforeRewriting + authorization.Callback = this.Channel.GetRequestFromContext().GetPublicFacingUrl() .StripMessagePartsFromQueryString(this.Channel.MessageDescriptions.Get(typeof(EndUserAuthorizationSuccessResponseBase), Protocol.Default.Version)) .StripMessagePartsFromQueryString(this.Channel.MessageDescriptions.Get(typeof(EndUserAuthorizationFailedResponse), Protocol.Default.Version)); authorization.SaveChanges(); @@ -96,7 +96,7 @@ namespace DotNetOpenAuth.OAuth2 { /// </summary> /// <param name="request">The incoming HTTP request that may carry an authorization response.</param> /// <returns>The authorization state that contains the details of the authorization.</returns> - public IAuthorizationState ProcessUserAuthorization(HttpRequestInfo request = null) { + public IAuthorizationState ProcessUserAuthorization(HttpRequestBase request = null) { Requires.ValidState(!string.IsNullOrEmpty(this.ClientIdentifier), OAuth2Strings.RequiredPropertyNotYetPreset, "ClientIdentifier"); Requires.ValidState(!string.IsNullOrEmpty(this.ClientSecret), OAuth2Strings.RequiredPropertyNotYetPreset, "ClientSecret"); @@ -106,7 +106,7 @@ namespace DotNetOpenAuth.OAuth2 { IMessageWithClientState response; if (this.Channel.TryReadFromRequest<IMessageWithClientState>(request, out response)) { - Uri callback = MessagingUtilities.StripMessagePartsFromQueryString(request.UrlBeforeRewriting, this.Channel.MessageDescriptions.Get(response)); + Uri callback = MessagingUtilities.StripMessagePartsFromQueryString(request.GetPublicFacingUrl(), this.Channel.MessageDescriptions.Get(response)); IAuthorizationState authorizationState; if (this.AuthorizationTracker != null) { authorizationState = this.AuthorizationTracker.GetAuthorizationState(callback, response.ClientState); diff --git a/src/DotNetOpenAuth.OAuth2.ResourceServer/OAuth2/ResourceServer.cs b/src/DotNetOpenAuth.OAuth2.ResourceServer/OAuth2/ResourceServer.cs index a614219..79cbbd7 100644 --- a/src/DotNetOpenAuth.OAuth2.ResourceServer/OAuth2/ResourceServer.cs +++ b/src/DotNetOpenAuth.OAuth2.ResourceServer/OAuth2/ResourceServer.cs @@ -71,7 +71,7 @@ namespace DotNetOpenAuth.OAuth2 { /// </returns> [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "1#", Justification = "Try pattern")] [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "2#", Justification = "Try pattern")] - public virtual OutgoingWebResponse VerifyAccess(HttpRequestInfo httpRequestInfo, out string userName, out HashSet<string> scope) { + public virtual OutgoingWebResponse VerifyAccess(HttpRequestBase httpRequestInfo, out string userName, out HashSet<string> scope) { Requires.NotNull(httpRequestInfo, "httpRequestInfo"); AccessProtectedResourceRequest request = null; @@ -108,7 +108,7 @@ namespace DotNetOpenAuth.OAuth2 { /// An error to return to the client if access is not authorized; <c>null</c> if access is granted. /// </returns> [SuppressMessage("Microsoft.Design", "CA1021:AvoidOutParameters", MessageId = "1#", Justification = "Try pattern")] - public virtual OutgoingWebResponse VerifyAccess(HttpRequestInfo httpRequestInfo, out IPrincipal principal) { + public virtual OutgoingWebResponse VerifyAccess(HttpRequestBase httpRequestInfo, out IPrincipal principal) { string username; HashSet<string> scope; var result = this.VerifyAccess(httpRequestInfo, out username, out scope); diff --git a/src/DotNetOpenAuth.OAuth2/OAuth2/ChannelElements/OAuth2AuthorizationServerChannel.cs b/src/DotNetOpenAuth.OAuth2/OAuth2/ChannelElements/OAuth2AuthorizationServerChannel.cs index 3375328..0e6aa47 100644 --- a/src/DotNetOpenAuth.OAuth2/OAuth2/ChannelElements/OAuth2AuthorizationServerChannel.cs +++ b/src/DotNetOpenAuth.OAuth2/OAuth2/ChannelElements/OAuth2AuthorizationServerChannel.cs @@ -69,7 +69,7 @@ namespace DotNetOpenAuth.OAuth2.ChannelElements { /// <returns> /// The deserialized message, if one is found. Null otherwise. /// </returns> - protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestInfo request) { + protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request) { if (!string.IsNullOrEmpty(request.Url.Fragment)) { var fields = HttpUtility.ParseQueryString(request.Url.Fragment.Substring(1)).ToDictionary(); diff --git a/src/DotNetOpenAuth.OAuth2/OAuth2/ChannelElements/OAuth2ClientChannel.cs b/src/DotNetOpenAuth.OAuth2/OAuth2/ChannelElements/OAuth2ClientChannel.cs index 3a8a7c0..c9981d3 100644 --- a/src/DotNetOpenAuth.OAuth2/OAuth2/ChannelElements/OAuth2ClientChannel.cs +++ b/src/DotNetOpenAuth.OAuth2/OAuth2/ChannelElements/OAuth2ClientChannel.cs @@ -76,16 +76,16 @@ namespace DotNetOpenAuth.OAuth2.ChannelElements { /// <returns> /// The deserialized message, if one is found. Null otherwise. /// </returns> - protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestInfo request) { - Logger.Channel.DebugFormat("Incoming HTTP request: {0} {1}", request.HttpMethod, request.UrlBeforeRewriting.AbsoluteUri); + protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request) { + Logger.Channel.DebugFormat("Incoming HTTP request: {0} {1}", request.HttpMethod, request.GetPublicFacingUrl().AbsoluteUri); - var fields = request.QueryStringBeforeRewriting.ToDictionary(); + var fields = request.GetQueryStringBeforeRewriting().ToDictionary(); // Also read parameters from the fragment, if it's available. // Typically the fragment is not available because the browser doesn't send it to a web server // but this request may have been fabricated by an installed desktop app, in which case // the fragment is available. - string fragment = request.UrlBeforeRewriting.Fragment; + string fragment = request.GetPublicFacingUrl().Fragment; if (!string.IsNullOrEmpty(fragment)) { foreach (var pair in HttpUtility.ParseQueryString(fragment.Substring(1)).ToDictionary()) { fields.Add(pair.Key, pair.Value); diff --git a/src/DotNetOpenAuth.OAuth2/OAuth2/ChannelElements/OAuth2ResourceServerChannel.cs b/src/DotNetOpenAuth.OAuth2/OAuth2/ChannelElements/OAuth2ResourceServerChannel.cs index 1c2a080..73d68e4 100644 --- a/src/DotNetOpenAuth.OAuth2/OAuth2/ChannelElements/OAuth2ResourceServerChannel.cs +++ b/src/DotNetOpenAuth.OAuth2/OAuth2/ChannelElements/OAuth2ResourceServerChannel.cs @@ -48,7 +48,7 @@ namespace DotNetOpenAuth.OAuth2.ChannelElements { /// <returns> /// The deserialized message, if one is found. Null otherwise. /// </returns> - protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestInfo request) { + protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request) { var fields = new Dictionary<string, string>(); string accessToken; if ((accessToken = SearchForBearerAccessTokenInRequest(request)) != null) { @@ -122,18 +122,18 @@ namespace DotNetOpenAuth.OAuth2.ChannelElements { /// </summary> /// <param name="request">The request.</param> /// <returns>The bearer access token, if one exists. Otherwise <c>null</c>.</returns> - private static string SearchForBearerAccessTokenInRequest(HttpRequestInfo request) { + private static string SearchForBearerAccessTokenInRequest(HttpRequestBase request) { Requires.NotNull(request, "request"); // First search the authorization header. - string authorizationHeader = request.Headers[HttpRequestHeader.Authorization]; + string authorizationHeader = request.Headers[HttpRequestHeaders.Authorization]; if (!string.IsNullOrEmpty(authorizationHeader) && authorizationHeader.StartsWith(Protocol.BearerHttpAuthorizationSchemeWithTrailingSpace, StringComparison.OrdinalIgnoreCase)) { return authorizationHeader.Substring(Protocol.BearerHttpAuthorizationSchemeWithTrailingSpace.Length); } // Failing that, scan the entity - if (!string.IsNullOrEmpty(request.Headers[HttpRequestHeader.ContentType])) { - var contentType = new ContentType(request.Headers[HttpRequestHeader.ContentType]); + if (!string.IsNullOrEmpty(request.Headers[HttpRequestHeaders.ContentType])) { + var contentType = new ContentType(request.Headers[HttpRequestHeaders.ContentType]); if (string.Equals(contentType.MediaType, HttpFormUrlEncoded, StringComparison.Ordinal)) { if (request.Form[Protocol.BearerTokenEncodedUrlParameterName] != null) { return request.Form[Protocol.BearerTokenEncodedUrlParameterName]; @@ -142,8 +142,9 @@ namespace DotNetOpenAuth.OAuth2.ChannelElements { } // Finally, check the least desirable location: the query string - if (!String.IsNullOrEmpty(request.QueryStringBeforeRewriting[Protocol.BearerTokenEncodedUrlParameterName])) { - return request.QueryStringBeforeRewriting[Protocol.BearerTokenEncodedUrlParameterName]; + var unrewrittenQuery = request.GetQueryStringBeforeRewriting(); + if (!String.IsNullOrEmpty(unrewrittenQuery[Protocol.BearerTokenEncodedUrlParameterName])) { + return unrewrittenQuery[Protocol.BearerTokenEncodedUrlParameterName]; } return null; diff --git a/src/DotNetOpenAuth.OpenId.Provider/OpenId/Provider/OpenIdProvider.cs b/src/DotNetOpenAuth.OpenId.Provider/OpenId/Provider/OpenIdProvider.cs index f7e49f2..72fdc80 100644 --- a/src/DotNetOpenAuth.OpenId.Provider/OpenId/Provider/OpenIdProvider.cs +++ b/src/DotNetOpenAuth.OpenId.Provider/OpenId/Provider/OpenIdProvider.cs @@ -256,7 +256,7 @@ namespace DotNetOpenAuth.OpenId.Provider { /// </remarks> /// <exception cref="ProtocolException">Thrown if the incoming message is recognized /// but deviates from the protocol specification irrecoverably.</exception> - public IRequest GetRequest(HttpRequestInfo httpRequestInfo) { + public IRequest GetRequest(HttpRequestBase httpRequestInfo) { Requires.NotNull(httpRequestInfo, "httpRequestInfo"); IDirectedProtocolMessage incomingMessage = null; @@ -266,7 +266,7 @@ namespace DotNetOpenAuth.OpenId.Provider { // If the incoming request does not resemble an OpenID message at all, // it's probably a user who just navigated to this URL, and we should // just return null so the host can display a message to the user. - if (httpRequestInfo.HttpMethod == "GET" && !httpRequestInfo.UrlBeforeRewriting.QueryStringContainPrefixedParameters(Protocol.Default.openid.Prefix)) { + if (httpRequestInfo.HttpMethod == "GET" && !httpRequestInfo.GetPublicFacingUrl().QueryStringContainPrefixedParameters(Protocol.Default.openid.Prefix)) { return null; } @@ -533,7 +533,7 @@ namespace DotNetOpenAuth.OpenId.Provider { /// <returns> /// Either the <see cref="IRequest"/> to return to the host site or null to indicate no response could be reasonably created and that the caller should rethrow the exception. /// </returns> - private IRequest GetErrorResponse(ProtocolException ex, HttpRequestInfo httpRequestInfo, IDirectedProtocolMessage incomingMessage) { + private IRequest GetErrorResponse(ProtocolException ex, HttpRequestBase httpRequestInfo, IDirectedProtocolMessage incomingMessage) { Requires.NotNull(ex, "ex"); Requires.NotNull(httpRequestInfo, "httpRequestInfo"); diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdMobileTextBox.cs b/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdMobileTextBox.cs index baf8b44..b3d208a 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdMobileTextBox.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdMobileTextBox.cs @@ -325,7 +325,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { set { if (Page != null && !DesignMode) { // Validate new value by trying to construct a Uri based on it. - new Uri(this.RelyingParty.Channel.GetRequestFromContext().UrlBeforeRewriting, this.Page.ResolveUrl(value)); // throws an exception on failure. + new Uri(this.RelyingParty.Channel.GetRequestFromContext().GetPublicFacingUrl(), this.Page.ResolveUrl(value)); // throws an exception on failure. } else { // We can't fully test it, but it should start with either ~/ or a protocol. if (Regex.IsMatch(value, @"^https?://")) { @@ -603,7 +603,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { if (string.IsNullOrEmpty(this.ReturnToUrl)) { this.Request = this.RelyingParty.CreateRequest(userSuppliedIdentifier, typedRealm); } else { - Uri returnTo = new Uri(this.RelyingParty.Channel.GetRequestFromContext().UrlBeforeRewriting, this.ReturnToUrl); + Uri returnTo = new Uri(this.RelyingParty.Channel.GetRequestFromContext().GetPublicFacingUrl(), this.ReturnToUrl); this.Request = this.RelyingParty.CreateRequest(userSuppliedIdentifier, typedRealm, returnTo); } this.Request.Mode = this.ImmediateMode ? AuthenticationRequestMode.Immediate : AuthenticationRequestMode.Setup; @@ -747,7 +747,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { Language = this.RequestLanguage, TimeZone = this.RequestTimeZone, PolicyUrl = string.IsNullOrEmpty(this.PolicyUrl) ? - null : new Uri(this.RelyingParty.Channel.GetRequestFromContext().UrlBeforeRewriting, this.Page.ResolveUrl(this.PolicyUrl)), + null : new Uri(this.RelyingParty.Channel.GetRequestFromContext().GetPublicFacingUrl(), this.Page.ResolveUrl(this.PolicyUrl)), }); } diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdRelyingPartyAjaxControlBase.cs b/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdRelyingPartyAjaxControlBase.cs index acd8c50..34c4df4 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdRelyingPartyAjaxControlBase.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdRelyingPartyAjaxControlBase.cs @@ -183,11 +183,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { if (!string.IsNullOrEmpty(formAuthData) && !string.Equals(viewstateAuthData, formAuthData, StringComparison.Ordinal)) { this.ViewState[AuthDataViewStateKey] = formAuthData; - Uri authUri = new Uri(formAuthData); - HttpRequestInfo clientResponseInfo = new HttpRequestInfo { - UrlBeforeRewriting = authUri, - }; - + HttpRequestBase clientResponseInfo = new HttpRequestInfo("GET", new Uri(formAuthData)); this.authenticationResponse = this.RelyingParty.GetResponse(clientResponseInfo); Logger.Controls.DebugFormat( "The {0} control checked for an authentication response and found: {1}", diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdRelyingPartyControlBase.cs b/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdRelyingPartyControlBase.cs index dfac2be..c730dea 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdRelyingPartyControlBase.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdRelyingPartyControlBase.cs @@ -395,7 +395,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { set { if (this.Page != null && !this.DesignMode) { // Validate new value by trying to construct a Uri based on it. - new Uri(this.RelyingParty.Channel.GetRequestFromContext().UrlBeforeRewriting, this.Page.ResolveUrl(value)); // throws an exception on failure. + new Uri(this.RelyingParty.Channel.GetRequestFromContext().GetPublicFacingUrl(), this.Page.ResolveUrl(value)); // throws an exception on failure. } else { // We can't fully test it, but it should start with either ~/ or a protocol. if (Regex.IsMatch(value, @"^https?://")) { @@ -919,7 +919,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { Uri returnToApproximation; if (this.ReturnToUrl != null) { string returnToResolvedPath = this.ResolveUrl(this.ReturnToUrl); - returnToApproximation = new Uri(this.RelyingParty.Channel.GetRequestFromContext().UrlBeforeRewriting, returnToResolvedPath); + returnToApproximation = new Uri(this.RelyingParty.Channel.GetRequestFromContext().GetPublicFacingUrl(), returnToResolvedPath); } else { returnToApproximation = this.Page.Request.Url; } diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdTextBox.cs b/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdTextBox.cs index a8af6e0..8ba689f 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdTextBox.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty.UI/OpenId/RelyingParty/OpenIdTextBox.cs @@ -696,7 +696,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { Language = this.RequestLanguage, TimeZone = this.RequestTimeZone, PolicyUrl = string.IsNullOrEmpty(this.PolicyUrl) ? - null : new Uri(this.RelyingParty.Channel.GetRequestFromContext().UrlBeforeRewriting, this.Page.ResolveUrl(this.PolicyUrl)), + null : new Uri(this.RelyingParty.Channel.GetRequestFromContext().GetPublicFacingUrl(), this.Page.ResolveUrl(this.PolicyUrl)), }; // Only actually add the extension request if fields are actually being requested. diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/Interop/OpenIdRelyingPartyShim.cs b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/Interop/OpenIdRelyingPartyShim.cs index 97b3780..7fcac91 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/Interop/OpenIdRelyingPartyShim.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/Interop/OpenIdRelyingPartyShim.cs @@ -6,6 +6,7 @@ namespace DotNetOpenAuth.OpenId.Interop { using System; + using System.Collections.Specialized; using System.Diagnostics.CodeAnalysis; using System.IO; using System.Runtime.InteropServices; @@ -173,12 +174,14 @@ namespace DotNetOpenAuth.OpenId.Interop { /// <param name="form">The form data that may have been included in the case of a POST request.</param> /// <returns>The Provider's response to a previous authentication request, or null if no response is present.</returns> public AuthenticationResponseShim ProcessAuthentication(string url, string form) { - HttpRequestInfo requestInfo = new HttpRequestInfo { UrlBeforeRewriting = new Uri(url) }; + string method = "GET"; + NameValueCollection formMap = null; if (!string.IsNullOrEmpty(form)) { - requestInfo.HttpMethod = "POST"; - requestInfo.InputStream = new MemoryStream(Encoding.Unicode.GetBytes(form)); + method = "POST"; + formMap = HttpUtility.ParseQueryString(form); } + HttpRequestBase requestInfo = new HttpRequestInfo(method, new Uri(url), form: formMap); var response = relyingParty.GetResponse(requestInfo); if (response != null) { return new AuthenticationResponseShim(response); diff --git a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/OpenIdRelyingParty.cs b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/OpenIdRelyingParty.cs index aa53277..6e991d2 100644 --- a/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/OpenIdRelyingParty.cs +++ b/src/DotNetOpenAuth.OpenId.RelyingParty/OpenId/RelyingParty/OpenIdRelyingParty.cs @@ -502,12 +502,12 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { ////Contract.Ensures(Contract.ForAll(Contract.Result<IEnumerable<IAuthenticationRequest>>(), el => el != null)); // Build the return_to URL - UriBuilder returnTo = new UriBuilder(this.Channel.GetRequestFromContext().UrlBeforeRewriting); + UriBuilder returnTo = new UriBuilder(this.Channel.GetRequestFromContext().GetPublicFacingUrl()); // Trim off any parameters with an "openid." prefix, and a few known others // to avoid carrying state from a prior login attempt. returnTo.Query = string.Empty; - NameValueCollection queryParams = this.Channel.GetRequestFromContext().QueryStringBeforeRewriting; + NameValueCollection queryParams = this.Channel.GetRequestFromContext().GetQueryStringBeforeRewriting(); var returnToParams = new Dictionary<string, string>(queryParams.Count); foreach (string key in queryParams) { if (!IsOpenIdSupportingParameter(key) && key != null) { @@ -564,7 +564,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// </summary> /// <param name="httpRequestInfo">The HTTP request that may be carrying an authentication response from the Provider.</param> /// <returns>The processed authentication response if there is any; <c>null</c> otherwise.</returns> - public IAuthenticationResponse GetResponse(HttpRequestInfo httpRequestInfo) { + public IAuthenticationResponse GetResponse(HttpRequestBase httpRequestInfo) { Requires.NotNull(httpRequestInfo, "httpRequestInfo"); try { var message = this.Channel.ReadFromRequest(httpRequestInfo); @@ -619,7 +619,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// </summary> /// <param name="request">The incoming HTTP request that is expected to carry an OpenID authentication response.</param> /// <returns>The HTTP response to send to this HTTP request.</returns> - public OutgoingWebResponse ProcessResponseFromPopup(HttpRequestInfo request) { + public OutgoingWebResponse ProcessResponseFromPopup(HttpRequestBase request) { Requires.NotNull(request, "request"); Contract.Ensures(Contract.Result<OutgoingWebResponse>() != null); @@ -706,7 +706,7 @@ namespace DotNetOpenAuth.OpenId.RelyingParty { /// The HTTP response to send to this HTTP request. /// </returns> [SuppressMessage("Microsoft.Naming", "CA2204:Literals should be spelled correctly", MessageId = "OpenID", Justification = "real word"), SuppressMessage("Microsoft.Naming", "CA2204:Literals should be spelled correctly", MessageId = "iframe", Justification = "Code contracts")] - internal OutgoingWebResponse ProcessResponseFromPopup(HttpRequestInfo request, Action<AuthenticationStatus> callback) { + internal OutgoingWebResponse ProcessResponseFromPopup(HttpRequestBase request, Action<AuthenticationStatus> callback) { Requires.NotNull(request, "request"); Contract.Ensures(Contract.Result<OutgoingWebResponse>() != null); diff --git a/src/DotNetOpenAuth.OpenId/OpenId/OpenIdUtilities.cs b/src/DotNetOpenAuth.OpenId/OpenId/OpenIdUtilities.cs index 1b848f9..d136289 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/OpenIdUtilities.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/OpenIdUtilities.cs @@ -13,6 +13,7 @@ namespace DotNetOpenAuth.OpenId { using System.Linq; using System.Text; using System.Text.RegularExpressions; + using System.Web; using System.Web.UI; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId.ChannelElements; @@ -115,7 +116,7 @@ namespace DotNetOpenAuth.OpenId { /// <param name="requestContext">The request context.</param> /// <returns>The fully-qualified realm.</returns> [SuppressMessage("Microsoft.Usage", "CA1806:DoNotIgnoreMethodResults", MessageId = "DotNetOpenAuth.OpenId.Realm", Justification = "Using ctor for validation.")] - internal static UriBuilder GetResolvedRealm(Page page, string realm, HttpRequestInfo requestContext) { + internal static UriBuilder GetResolvedRealm(Page page, string realm, HttpRequestBase requestContext) { Requires.NotNull(page, "page"); Requires.NotNull(requestContext, "requestContext"); @@ -134,7 +135,7 @@ namespace DotNetOpenAuth.OpenId { string realmNoWildcard = Regex.Replace(realm, @"^(\w+://)\*\.", matchDelegate); UriBuilder fullyQualifiedRealm = new UriBuilder( - new Uri(requestContext.UrlBeforeRewriting, page.ResolveUrl(realmNoWildcard))); + new Uri(requestContext.GetPublicFacingUrl(), page.ResolveUrl(realmNoWildcard))); if (foundWildcard) { fullyQualifiedRealm.Host = "*." + fullyQualifiedRealm.Host; diff --git a/src/DotNetOpenAuth.OpenId/OpenId/Realm.cs b/src/DotNetOpenAuth.OpenId/OpenId/Realm.cs index 5c2ff8b..d682542 100644 --- a/src/DotNetOpenAuth.OpenId/OpenId/Realm.cs +++ b/src/DotNetOpenAuth.OpenId/OpenId/Realm.cs @@ -126,8 +126,8 @@ namespace DotNetOpenAuth.OpenId { Requires.ValidState(HttpContext.Current != null && HttpContext.Current.Request != null, MessagingStrings.HttpContextRequired); Contract.Ensures(Contract.Result<Realm>() != null); - HttpRequestInfo requestInfo = new HttpRequestInfo(HttpContext.Current.Request); - UriBuilder realmUrl = new UriBuilder(requestInfo.UrlBeforeRewriting); + HttpRequestBase requestInfo = new HttpRequestWrapper(HttpContext.Current.Request); + UriBuilder realmUrl = new UriBuilder(requestInfo.GetPublicFacingUrl()); realmUrl.Path = HttpContext.Current.Request.ApplicationPath; realmUrl.Query = null; realmUrl.Fragment = null; diff --git a/src/DotNetOpenAuth.Test/Messaging/HttpRequestInfoTests.cs b/src/DotNetOpenAuth.Test/Messaging/HttpRequestInfoTests.cs index b2f2b14..fbe1d6b 100644 --- a/src/DotNetOpenAuth.Test/Messaging/HttpRequestInfoTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/HttpRequestInfoTests.cs @@ -13,25 +13,6 @@ namespace DotNetOpenAuth.Test.Messaging { [TestFixture] public class HttpRequestInfoTests : TestBase { - [Test] - public void CtorDefault() { - HttpRequestInfo info = new HttpRequestInfo(); - Assert.AreEqual("GET", info.HttpMethod); - } - - [Test] - public void CtorRequest() { - HttpRequest request = new HttpRequest("file", "http://someserver?a=b", "a=b"); - ////request.Headers["headername"] = "headervalue"; // PlatformNotSupportedException prevents us mocking this up - HttpRequestInfo info = new HttpRequestInfo(request); - Assert.AreEqual(request.Headers["headername"], info.Headers["headername"]); - Assert.AreEqual(request.Url.Query, info.Query); - Assert.AreEqual(request.QueryString["a"], info.QueryString["a"]); - Assert.AreEqual(request.Url, info.Url); - Assert.AreEqual(request.Url, info.UrlBeforeRewriting); - Assert.AreEqual(request.HttpMethod, info.HttpMethod); - } - // All these tests are ineffective because ServerVariables[] cannot be set. ////[Test] ////public void CtorRequestWithDifferentPublicHttpHost() { @@ -77,21 +58,11 @@ namespace DotNetOpenAuth.Test.Messaging { ////} /// <summary> - /// Checks that a property dependent on another null property - /// doesn't generate a NullReferenceException. - /// </summary> - [Test] - public void QueryBeforeSettingUrl() { - HttpRequestInfo info = new HttpRequestInfo(); - Assert.IsNull(info.Query); - } - - /// <summary> /// Verifies that looking up a querystring variable is gracefully handled without a query in the URL. /// </summary> [Test] public void QueryStringLookupWithoutQuery() { - HttpRequestInfo info = new HttpRequestInfo(); + var info = new HttpRequestInfo("GET", new Uri("http://somehost/somepath")); Assert.IsNull(info.QueryString["hi"]); } @@ -104,7 +75,7 @@ namespace DotNetOpenAuth.Test.Messaging { var serverVariables = new NameValueCollection(); serverVariables["HTTP_X_FORWARDED_PROTO"] = "https"; serverVariables["HTTP_HOST"] = "somehost"; - Uri actual = HttpRequestInfo.GetPublicFacingUrl(req, serverVariables); + Uri actual = new HttpRequestWrapper(req).GetPublicFacingUrl(serverVariables); Uri expected = new Uri("https://somehost/a.aspx?a=b"); Assert.AreEqual(expected, actual); } @@ -118,7 +89,7 @@ namespace DotNetOpenAuth.Test.Messaging { var serverVariables = new NameValueCollection(); serverVariables["HTTP_X_FORWARDED_PROTO"] = "https"; serverVariables["HTTP_HOST"] = "somehost:999"; - Uri actual = HttpRequestInfo.GetPublicFacingUrl(req, serverVariables); + Uri actual = new HttpRequestWrapper(req).GetPublicFacingUrl(serverVariables); Uri expected = new Uri("https://somehost:999/a.aspx?a=b"); Assert.AreEqual(expected, actual); } @@ -131,7 +102,7 @@ namespace DotNetOpenAuth.Test.Messaging { HttpRequest req = new HttpRequest("a.aspx", "http://someinternalhost/a.aspx?a=b", "a=b"); var serverVariables = new NameValueCollection(); serverVariables["HTTP_HOST"] = "somehost"; - Uri actual = HttpRequestInfo.GetPublicFacingUrl(req, serverVariables); + Uri actual = new HttpRequestWrapper(req).GetPublicFacingUrl(serverVariables); Uri expected = new Uri("http://somehost/a.aspx?a=b"); Assert.AreEqual(expected, actual); } @@ -144,7 +115,7 @@ namespace DotNetOpenAuth.Test.Messaging { HttpRequest req = new HttpRequest("a.aspx", "http://someinternalhost/a.aspx?a=b", "a=b"); var serverVariables = new NameValueCollection(); serverVariables["HTTP_HOST"] = "somehost:79"; - Uri actual = HttpRequestInfo.GetPublicFacingUrl(req, serverVariables); + Uri actual = new HttpRequestWrapper(req).GetPublicFacingUrl(serverVariables); Uri expected = new Uri("http://somehost:79/a.aspx?a=b"); Assert.AreEqual(expected, actual); } diff --git a/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs b/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs index e3700b8..b7c0980 100644 --- a/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs +++ b/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs @@ -7,6 +7,7 @@ namespace DotNetOpenAuth.Test { using System; using System.Collections.Generic; + using System.Collections.Specialized; using System.IO; using System.Net; using System.Xml; @@ -19,6 +20,8 @@ namespace DotNetOpenAuth.Test { /// The base class that all messaging test classes inherit from. /// </summary> public class MessagingTestBase : TestBase { + protected internal const string DefaultUrlForHttpRequestInfo = "http://localhost/path"; + internal enum FieldFill { /// <summary> /// An empty dictionary is returned. @@ -53,29 +56,19 @@ namespace DotNetOpenAuth.Test { } internal static HttpRequestInfo CreateHttpRequestInfo(string method, IDictionary<string, string> fields) { - string query = MessagingUtilities.CreateQueryString(fields); - UriBuilder requestUri = new UriBuilder("http://localhost/path"); - WebHeaderCollection headers = new WebHeaderCollection(); - MemoryStream ms = new MemoryStream(); + var requestUri = new UriBuilder(DefaultUrlForHttpRequestInfo); + var headers = new NameValueCollection(); + NameValueCollection form = null; if (method == "POST") { - headers.Add(HttpRequestHeader.ContentType, "application/x-www-form-urlencoded"); - StreamWriter sw = new StreamWriter(ms); - sw.Write(query); - sw.Flush(); - ms.Position = 0; + form = fields.ToNameValueCollection(); + headers.Add(HttpRequestHeaders.ContentType, Channel.HttpFormUrlEncoded); } else if (method == "GET") { - requestUri.Query = query; + requestUri.Query = MessagingUtilities.CreateQueryString(fields); } else { throw new ArgumentOutOfRangeException("method", method, "Expected POST or GET"); } - HttpRequestInfo request = new HttpRequestInfo { - HttpMethod = method, - UrlBeforeRewriting = requestUri.Uri, - Headers = headers, - InputStream = ms, - }; - return request; + return new HttpRequestInfo(method, requestUri.Uri, form: form, headers: headers); } internal static Channel CreateChannel(MessageProtections capabilityAndRecognition) { diff --git a/src/DotNetOpenAuth.Test/Mocks/CoordinatingChannel.cs b/src/DotNetOpenAuth.Test/Mocks/CoordinatingChannel.cs index 8d5295b..10bd59a 100644 --- a/src/DotNetOpenAuth.Test/Mocks/CoordinatingChannel.cs +++ b/src/DotNetOpenAuth.Test/Mocks/CoordinatingChannel.cs @@ -11,6 +11,8 @@ namespace DotNetOpenAuth.Test.Mocks { using System.Linq; using System.Text; using System.Threading; + using System.Web; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Reflection; using DotNetOpenAuth.Test.OpenId; @@ -146,7 +148,7 @@ namespace DotNetOpenAuth.Test.Mocks { this.incomingMessageSignal.Set(); } - protected internal override HttpRequestInfo GetRequestFromContext() { + protected internal override HttpRequestBase GetRequestFromContext() { MessageReceivingEndpoint recipient; var messageData = this.AwaitIncomingMessage(out recipient); if (messageData != null) { @@ -191,12 +193,13 @@ namespace DotNetOpenAuth.Test.Mocks { return this.PrepareDirectResponse(message); } - protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestInfo request) { - if (request.Message != null) { - this.ProcessMessageFilter(request.Message, false); + protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request) { + var mockRequest = (CoordinatingHttpRequestInfo)request; + if (mockRequest.Message != null) { + this.ProcessMessageFilter(mockRequest.Message, false); } - return request.Message; + return mockRequest.Message; } protected override IDictionary<string, string> ReadFromResponseCore(IncomingWebResponse response) { diff --git a/src/DotNetOpenAuth.Test/Mocks/CoordinatingHttpRequestInfo.cs b/src/DotNetOpenAuth.Test/Mocks/CoordinatingHttpRequestInfo.cs index bfb9017..1917ce6 100644 --- a/src/DotNetOpenAuth.Test/Mocks/CoordinatingHttpRequestInfo.cs +++ b/src/DotNetOpenAuth.Test/Mocks/CoordinatingHttpRequestInfo.cs @@ -5,15 +5,21 @@ //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.Mocks { + using System; using System.Collections.Generic; using System.Diagnostics.Contracts; using DotNetOpenAuth.Messaging; internal class CoordinatingHttpRequestInfo : HttpRequestInfo { - private IDictionary<string, string> messageData; - private IMessageFactory messageFactory; - private MessageReceivingEndpoint recipient; - private Channel channel; + private readonly Channel channel; + + private readonly IDictionary<string, string> messageData; + + private readonly IMessageFactory messageFactory; + + private readonly MessageReceivingEndpoint recipient; + + private IDirectedProtocolMessage message; /// <summary> /// Initializes a new instance of the <see cref="CoordinatingHttpRequestInfo"/> class @@ -23,14 +29,18 @@ namespace DotNetOpenAuth.Test.Mocks { /// <param name="messageFactory">The message factory.</param> /// <param name="messageData">The message data.</param> /// <param name="recipient">The recipient.</param> - internal CoordinatingHttpRequestInfo(Channel channel, IMessageFactory messageFactory, IDictionary<string, string> messageData, MessageReceivingEndpoint recipient) + internal CoordinatingHttpRequestInfo( + Channel channel, + IMessageFactory messageFactory, + IDictionary<string, string> messageData, + MessageReceivingEndpoint recipient) : this(recipient) { Contract.Requires(channel != null); Contract.Requires(messageFactory != null); Contract.Requires(messageData != null); this.channel = channel; - this.messageFactory = messageFactory; this.messageData = messageData; + this.messageFactory = messageFactory; } /// <summary> @@ -38,35 +48,56 @@ namespace DotNetOpenAuth.Test.Mocks { /// that will not generate any message. /// </summary> /// <param name="recipient">The recipient.</param> - internal CoordinatingHttpRequestInfo(MessageReceivingEndpoint recipient) { + internal CoordinatingHttpRequestInfo(MessageReceivingEndpoint recipient) + : base(GetHttpVerb(recipient), recipient != null ? recipient.Location : new Uri("http://host/path")) { this.recipient = recipient; - if (recipient != null) { - this.UrlBeforeRewriting = recipient.Location; - } + } - if (recipient == null || (recipient.AllowedMethods & HttpDeliveryMethods.GetRequest) != 0) { - this.HttpMethod = "GET"; - } else if ((recipient.AllowedMethods & HttpDeliveryMethods.PostRequest) != 0) { - this.HttpMethod = "POST"; - } + /// <summary> + /// Initializes a new instance of the <see cref="CoordinatingHttpRequestInfo"/> class. + /// </summary> + /// <param name="message">The message being passed in through a mock transport. May be null.</param> + /// <param name="httpMethod">The HTTP method that the incoming request came in on, whether or not <paramref name="message"/> is null.</param> + internal CoordinatingHttpRequestInfo(IDirectedProtocolMessage message, HttpDeliveryMethods httpMethod) + : base(GetHttpVerb(httpMethod), message.Recipient) { + this.message = message; } - internal override IDirectedProtocolMessage Message { + /// <summary> + /// Gets the message deserialized from the remote channel. + /// </summary> + internal IDirectedProtocolMessage Message { get { - if (base.Message == null && this.messageData != null) { - IDirectedProtocolMessage message = this.messageFactory.GetNewRequestMessage(this.recipient, this.messageData); + if (this.message == null && this.messageData != null) { + var message = messageFactory.GetNewRequestMessage(recipient, this.messageData); if (message != null) { this.channel.MessageDescriptions.GetAccessor(message).Deserialize(this.messageData); + this.message = message; } - base.Message = message; } - return base.Message; + return this.message; + } + } + + private static string GetHttpVerb(MessageReceivingEndpoint recipient) { + if (recipient == null) { + return "GET"; } - set { - base.Message = value; + return GetHttpVerb(recipient.AllowedMethods); + } + + private static string GetHttpVerb(HttpDeliveryMethods httpMethod) { + if ((httpMethod & HttpDeliveryMethods.GetRequest) != 0) { + return "GET"; } + + if ((httpMethod & HttpDeliveryMethods.PostRequest) != 0) { + return "POST"; + } + + throw new ArgumentOutOfRangeException(); } } } diff --git a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthConsumerChannel.cs b/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthConsumerChannel.cs index 6cc5819..e145952 100644 --- a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthConsumerChannel.cs +++ b/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthConsumerChannel.cs @@ -8,6 +8,8 @@ namespace DotNetOpenAuth.Test.Mocks { using System; using System.Diagnostics.Contracts; using System.Threading; + using System.Web; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OAuth.ChannelElements; @@ -49,7 +51,7 @@ namespace DotNetOpenAuth.Test.Mocks { internal OutgoingWebResponse RequestProtectedResource(AccessProtectedResourceRequest request) { ((ITamperResistantOAuthMessage)request).HttpMethod = this.GetHttpMethod(((ITamperResistantOAuthMessage)request).HttpMethods); this.ProcessOutgoingMessage(request); - HttpRequestInfo requestInfo = this.SpoofHttpMethod(request); + var requestInfo = this.SpoofHttpMethod(request); TestBase.TestLogger.InfoFormat("Sending protected resource request: {0}", requestInfo.Message); // Drop the outgoing message in the other channel's in-slot and let them know it's there. this.RemoteChannel.IncomingMessage = requestInfo.Message; @@ -57,13 +59,13 @@ namespace DotNetOpenAuth.Test.Mocks { return this.AwaitIncomingRawResponse(); } - protected internal override HttpRequestInfo GetRequestFromContext() { + protected internal override HttpRequestBase GetRequestFromContext() { var directedMessage = (IDirectedProtocolMessage)this.AwaitIncomingMessage(); - return new HttpRequestInfo(directedMessage, directedMessage.HttpMethods); + return new CoordinatingHttpRequestInfo(directedMessage, directedMessage.HttpMethods); } protected override IProtocolMessage RequestCore(IDirectedProtocolMessage request) { - HttpRequestInfo requestInfo = this.SpoofHttpMethod(request); + var requestInfo = this.SpoofHttpMethod(request); // Drop the outgoing message in the other channel's in-slot and let them know it's there. this.RemoteChannel.IncomingMessage = requestInfo.Message; this.RemoteChannel.IncomingMessageSignal.Set(); @@ -72,7 +74,7 @@ namespace DotNetOpenAuth.Test.Mocks { } protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) { - this.RemoteChannel.IncomingMessage = CloneSerializedParts(response, null); + this.RemoteChannel.IncomingMessage = this.CloneSerializedParts(response); this.RemoteChannel.IncomingMessageSignal.Set(); return new OutgoingWebResponse(); // not used, but returning null is not allowed } @@ -82,8 +84,9 @@ namespace DotNetOpenAuth.Test.Mocks { return this.PrepareDirectResponse(message); } - protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestInfo request) { - return request.Message; + protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request) { + var mockRequest = (CoordinatingHttpRequestInfo)request; + return mockRequest.Message; } /// <summary> @@ -91,19 +94,14 @@ namespace DotNetOpenAuth.Test.Mocks { /// </summary> /// <param name="message">The message to add a pretend HTTP method to.</param> /// <returns>A spoofed HttpRequestInfo that wraps the new message.</returns> - private HttpRequestInfo SpoofHttpMethod(IDirectedProtocolMessage message) { - HttpRequestInfo requestInfo = new HttpRequestInfo(message, message.HttpMethods); - + private CoordinatingHttpRequestInfo SpoofHttpMethod(IDirectedProtocolMessage message) { var signedMessage = message as ITamperResistantOAuthMessage; if (signedMessage != null) { string httpMethod = this.GetHttpMethod(signedMessage.HttpMethods); - requestInfo.HttpMethod = httpMethod; - requestInfo.UrlBeforeRewriting = message.Recipient; signedMessage.HttpMethod = httpMethod; } - requestInfo.Message = this.CloneSerializedParts(message, requestInfo); - + var requestInfo = new CoordinatingHttpRequestInfo(this.CloneSerializedParts(message), message.HttpMethods); return requestInfo; } @@ -121,7 +119,7 @@ namespace DotNetOpenAuth.Test.Mocks { return response; } - private T CloneSerializedParts<T>(T message, HttpRequestInfo requestInfo) where T : class, IProtocolMessage { + private T CloneSerializedParts<T>(T message) where T : class, IProtocolMessage { Requires.NotNull(message, "message"); IProtocolMessage clonedMessage; diff --git a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthServiceProviderChannel.cs b/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthServiceProviderChannel.cs index ad5c695..012173c 100644 --- a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthServiceProviderChannel.cs +++ b/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthServiceProviderChannel.cs @@ -8,10 +8,13 @@ namespace DotNetOpenAuth.Test.Mocks { using System; using System.Diagnostics.Contracts; using System.Threading; + using System.Web; + using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OAuth.ChannelElements; using DotNetOpenAuth.OAuth.Messages; + using NUnit.Framework; /// <summary> /// A special channel used in test simulations to pass messages directly between two parties. @@ -48,9 +51,9 @@ namespace DotNetOpenAuth.Test.Mocks { internal CoordinatingOAuthConsumerChannel RemoteChannel { get; set; } internal OutgoingWebResponse RequestProtectedResource(AccessProtectedResourceRequest request) { - ((ITamperResistantOAuthMessage)request).HttpMethod = this.GetHttpMethod(((ITamperResistantOAuthMessage)request).HttpMethods); + ((ITamperResistantOAuthMessage)request).HttpMethod = GetHttpMethod(((ITamperResistantOAuthMessage)request).HttpMethods); this.ProcessOutgoingMessage(request); - HttpRequestInfo requestInfo = this.SpoofHttpMethod(request); + var requestInfo = this.SpoofHttpMethod(request); TestBase.TestLogger.InfoFormat("Sending protected resource request: {0}", requestInfo.Message); // Drop the outgoing message in the other channel's in-slot and let them know it's there. this.RemoteChannel.IncomingMessage = requestInfo.Message; @@ -63,13 +66,13 @@ namespace DotNetOpenAuth.Test.Mocks { this.RemoteChannel.IncomingMessageSignal.Set(); } - protected internal override HttpRequestInfo GetRequestFromContext() { + protected internal override HttpRequestBase GetRequestFromContext() { var directedMessage = (IDirectedProtocolMessage)this.AwaitIncomingMessage(); - return new HttpRequestInfo(directedMessage, directedMessage.HttpMethods); + return new CoordinatingHttpRequestInfo(directedMessage, directedMessage.HttpMethods); } protected override IProtocolMessage RequestCore(IDirectedProtocolMessage request) { - HttpRequestInfo requestInfo = this.SpoofHttpMethod(request); + var requestInfo = this.SpoofHttpMethod(request); // Drop the outgoing message in the other channel's in-slot and let them know it's there. this.RemoteChannel.IncomingMessage = requestInfo.Message; this.RemoteChannel.IncomingMessageSignal.Set(); @@ -78,7 +81,7 @@ namespace DotNetOpenAuth.Test.Mocks { } protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) { - this.RemoteChannel.IncomingMessage = CloneSerializedParts(response, null); + this.RemoteChannel.IncomingMessage = this.CloneSerializedParts(response); this.RemoteChannel.IncomingMessageSignal.Set(); return new OutgoingWebResponse(); // not used, but returning null is not allowed } @@ -88,8 +91,13 @@ namespace DotNetOpenAuth.Test.Mocks { return this.PrepareDirectResponse(message); } - protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestInfo request) { - return request.Message; + protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestBase request) { + var mockRequest = (CoordinatingHttpRequestInfo)request; + return mockRequest.Message; + } + + private static string GetHttpMethod(HttpDeliveryMethods methods) { + return (methods & HttpDeliveryMethods.PostRequest) != 0 ? "POST" : "GET"; } /// <summary> @@ -97,24 +105,20 @@ namespace DotNetOpenAuth.Test.Mocks { /// </summary> /// <param name="message">The message to add a pretend HTTP method to.</param> /// <returns>A spoofed HttpRequestInfo that wraps the new message.</returns> - private HttpRequestInfo SpoofHttpMethod(IDirectedProtocolMessage message) { - HttpRequestInfo requestInfo = new HttpRequestInfo(message, message.HttpMethods); - + private CoordinatingHttpRequestInfo SpoofHttpMethod(IDirectedProtocolMessage message) { var signedMessage = message as ITamperResistantOAuthMessage; if (signedMessage != null) { - string httpMethod = this.GetHttpMethod(signedMessage.HttpMethods); - requestInfo.HttpMethod = httpMethod; - requestInfo.UrlBeforeRewriting = message.Recipient; + string httpMethod = GetHttpMethod(signedMessage.HttpMethods); signedMessage.HttpMethod = httpMethod; } - requestInfo.Message = this.CloneSerializedParts(message, requestInfo); - + var requestInfo = new CoordinatingHttpRequestInfo(this.CloneSerializedParts(message), message.HttpMethods); return requestInfo; } private IProtocolMessage AwaitIncomingMessage() { this.IncomingMessageSignal.WaitOne(); + Assert.That(this.IncomingMessage, Is.Not.Null, "Incoming message signaled, but none supplied."); IProtocolMessage response = this.IncomingMessage; this.IncomingMessage = null; return response; @@ -127,7 +131,7 @@ namespace DotNetOpenAuth.Test.Mocks { return response; } - private T CloneSerializedParts<T>(T message, HttpRequestInfo requestInfo) where T : class, IProtocolMessage { + private T CloneSerializedParts<T>(T message) where T : class, IProtocolMessage { Requires.NotNull(message, "message"); IProtocolMessage clonedMessage; @@ -155,9 +159,5 @@ namespace DotNetOpenAuth.Test.Mocks { return (T)clonedMessage; } - - private string GetHttpMethod(HttpDeliveryMethods methods) { - return (methods & HttpDeliveryMethods.PostRequest) != 0 ? "POST" : "GET"; - } } } diff --git a/src/DotNetOpenAuth.Test/Mocks/TestBadChannel.cs b/src/DotNetOpenAuth.Test/Mocks/TestBadChannel.cs index 5344304..263f0fd 100644 --- a/src/DotNetOpenAuth.Test/Mocks/TestBadChannel.cs +++ b/src/DotNetOpenAuth.Test/Mocks/TestBadChannel.cs @@ -7,6 +7,7 @@ namespace DotNetOpenAuth.Test.Mocks { using System; using System.Collections.Generic; + using System.Web; using DotNetOpenAuth.Messaging; /// <summary> @@ -33,7 +34,7 @@ namespace DotNetOpenAuth.Test.Mocks { return base.Receive(fields, recipient); } - internal new IProtocolMessage ReadFromRequest(HttpRequestInfo request) { + internal new IProtocolMessage ReadFromRequest(HttpRequestBase request) { return base.ReadFromRequest(request); } diff --git a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs index dda5452..7999a44 100644 --- a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs @@ -78,23 +78,24 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { [Test] public void ReadFromRequestAuthorizationScattered() { // Start by creating a standard POST HTTP request. - var fields = new Dictionary<string, string> { + var postedFields = new Dictionary<string, string> { { "age", "15" }, }; - HttpRequestInfo requestInfo = CreateHttpRequestInfo(HttpDeliveryMethods.PostRequest, fields); // Now add another field to the request URL - UriBuilder builder = new UriBuilder(requestInfo.UrlBeforeRewriting); + var builder = new UriBuilder(MessagingTestBase.DefaultUrlForHttpRequestInfo); builder.Query = "Name=Andrew"; - requestInfo.UrlBeforeRewriting = builder.Uri; - requestInfo.RawUrl = builder.Path + builder.Query + builder.Fragment; // Finally, add an Authorization header - fields = new Dictionary<string, string> { + var authHeaderFields = new Dictionary<string, string> { { "Location", "http://hostb/pathB" }, { "Timestamp", XmlConvert.ToString(DateTime.UtcNow, XmlDateTimeSerializationMode.Utc) }, }; - requestInfo.Headers.Add(HttpRequestHeader.Authorization, CreateAuthorizationHeader(fields)); + var headers = new NameValueCollection(); + headers.Add(HttpRequestHeaders.Authorization, CreateAuthorizationHeader(authHeaderFields)); + headers.Add(HttpRequestHeaders.ContentType, Channel.HttpFormUrlEncoded); + + var requestInfo = new HttpRequestInfo("POST", builder.Uri, form: postedFields.ToNameValueCollection(), headers: headers); IDirectedProtocolMessage requestMessage = this.channel.ReadFromRequest(requestInfo); @@ -266,51 +267,33 @@ namespace DotNetOpenAuth.Test.OAuth.ChannelElements { } private static HttpRequestInfo CreateHttpRequestInfo(HttpDeliveryMethods scheme, IDictionary<string, string> fields) { - string query = MessagingUtilities.CreateQueryString(fields); - UriBuilder requestUri = new UriBuilder("http://localhost/path"); - WebHeaderCollection headers = new WebHeaderCollection(); - MemoryStream ms = new MemoryStream(); + var requestUri = new UriBuilder(MessagingTestBase.DefaultUrlForHttpRequestInfo); + var headers = new NameValueCollection(); + NameValueCollection form = null; string method; switch (scheme) { case HttpDeliveryMethods.PostRequest: method = "POST"; - headers.Add(HttpRequestHeader.ContentType, "application/x-www-form-urlencoded"); - StreamWriter sw = new StreamWriter(ms); - sw.Write(query); - sw.Flush(); - ms.Position = 0; + form = fields.ToNameValueCollection(); + headers.Add(HttpRequestHeaders.ContentType, Channel.HttpFormUrlEncoded); break; case HttpDeliveryMethods.GetRequest: method = "GET"; - requestUri.Query = query; + requestUri.Query = MessagingUtilities.CreateQueryString(fields); break; case HttpDeliveryMethods.AuthorizationHeaderRequest: method = "GET"; - headers.Add(HttpRequestHeader.Authorization, CreateAuthorizationHeader(fields)); + headers.Add(HttpRequestHeaders.Authorization, CreateAuthorizationHeader(fields)); break; default: throw new ArgumentOutOfRangeException("scheme", scheme, "Unexpected value"); } - HttpRequestInfo request = new HttpRequestInfo { - HttpMethod = method, - UrlBeforeRewriting = requestUri.Uri, - RawUrl = requestUri.Path + requestUri.Query + requestUri.Fragment, - Headers = headers, - InputStream = ms, - }; - return request; + return new HttpRequestInfo(method, requestUri.Uri, form: form, headers: headers); } private static HttpRequestInfo ConvertToRequestInfo(HttpWebRequest request, Stream postEntity) { - HttpRequestInfo info = new HttpRequestInfo { - HttpMethod = request.Method, - UrlBeforeRewriting = request.RequestUri, - RawUrl = request.RequestUri.AbsolutePath + request.RequestUri.Query + request.RequestUri.Fragment, - Headers = request.Headers, - InputStream = postEntity, - }; - return info; + return new HttpRequestInfo(request.Method, request.RequestUri, request.Headers, postEntity); } private void ParameterizedRequestTest(HttpDeliveryMethods scheme) { diff --git a/src/DotNetOpenAuth.Test/OAuth2/UserAgentClientAuthorizeTests.cs b/src/DotNetOpenAuth.Test/OAuth2/UserAgentClientAuthorizeTests.cs index 1f56b32..b00cd8e 100644 --- a/src/DotNetOpenAuth.Test/OAuth2/UserAgentClientAuthorizeTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth2/UserAgentClientAuthorizeTests.cs @@ -39,6 +39,7 @@ namespace DotNetOpenAuth.Test.OAuth2 { }, server => { var request = server.ReadAuthorizationRequest(); + Assert.That(request, Is.Not.Null); server.ApproveAuthorizationRequest(request, ResourceOwnerUsername); var tokenRequest = server.ReadAccessTokenRequest(); IAccessTokenRequest accessTokenRequest = tokenRequest; @@ -70,6 +71,7 @@ namespace DotNetOpenAuth.Test.OAuth2 { }, server => { var request = server.ReadAuthorizationRequest(); + Assert.That(request, Is.Not.Null); IAccessTokenRequest accessTokenRequest = (EndUserAuthorizationImplicitRequest)request; Assert.That(accessTokenRequest.ClientAuthenticated, Is.False); server.ApproveAuthorizationRequest(request, ResourceOwnerUsername); diff --git a/src/DotNetOpenAuth.Test/OAuth2/WebServerClientAuthorizeTests.cs b/src/DotNetOpenAuth.Test/OAuth2/WebServerClientAuthorizeTests.cs index d7439d9..0bb4378 100644 --- a/src/DotNetOpenAuth.Test/OAuth2/WebServerClientAuthorizeTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth2/WebServerClientAuthorizeTests.cs @@ -35,6 +35,7 @@ namespace DotNetOpenAuth.Test.OAuth2 { }, server => { var request = server.ReadAuthorizationRequest(); + Assert.That(request, Is.Not.Null); server.ApproveAuthorizationRequest(request, ResourceOwnerUsername); var tokenRequest = server.ReadAccessTokenRequest(); IAccessTokenRequest accessTokenRequest = tokenRequest; diff --git a/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs b/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs index e8c955e..029447d 100644 --- a/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs @@ -82,6 +82,7 @@ namespace DotNetOpenAuth.Test.OpenId { // Receive initial request for an HMAC-SHA256 association. AutoResponsiveRequest req = (AutoResponsiveRequest)op.GetRequest(); AssociateRequest associateRequest = (AssociateRequest)req.RequestMessage; + Assert.That(associateRequest, Is.Not.Null); Assert.AreEqual(protocol.Args.SignatureAlgorithm.HMAC_SHA256, associateRequest.AssociationType); // Ensure that the response is a suggestion that the RP try again with HMAC-SHA1 diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/AuthenticationRequestTest.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/AuthenticationRequestTest.cs index 2819e40..8cc7116 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/AuthenticationRequestTest.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/AuthenticationRequestTest.cs @@ -34,7 +34,7 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { Assert.IsNotNull(userSetupUrl); // Now construct a new request as if it had just come in. - HttpRequestInfo httpRequest = new HttpRequestInfo { UrlBeforeRewriting = userSetupUrl }; + HttpRequestInfo httpRequest = new HttpRequestInfo("GET", userSetupUrl); var setupRequest = (AuthenticationRequest)provider.GetRequest(httpRequest); var setupRequestMessage = (CheckIdRequest)setupRequest.RequestMessage; diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs index d981e71..598aeb7 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/OpenIdProviderTests.cs @@ -92,10 +92,9 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { /// </summary> [Test] public void GetRequest() { - HttpRequestInfo httpInfo = new HttpRequestInfo(); - httpInfo.UrlBeforeRewriting = new Uri("http://someUri"); + var httpInfo = new HttpRequestInfo("GET", new Uri("http://someUri")); Assert.IsNull(this.provider.GetRequest(httpInfo), "An irrelevant request should return null."); - var providerDescription = new ProviderEndpointDescription(OpenIdTestBase.OPUri, Protocol.Default.Version); + var providerDescription = new ProviderEndpointDescription(OPUri, Protocol.Default.Version); // Test some non-empty request scenario. OpenIdCoordinator coordinator = new OpenIdCoordinator( diff --git a/src/DotNetOpenAuth.Test/OpenId/Provider/PerformanceTests.cs b/src/DotNetOpenAuth.Test/OpenId/Provider/PerformanceTests.cs index 27e65cc..e2c719d 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Provider/PerformanceTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Provider/PerformanceTests.cs @@ -102,7 +102,7 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { ms.Position = 0; var headers = new WebHeaderCollection(); headers.Add(HttpRequestHeader.ContentType, Channel.HttpFormUrlEncoded); - var httpRequest = new HttpRequestInfo("POST", opEndpoint, opEndpoint.PathAndQuery, headers, ms); + var httpRequest = new HttpRequestInfo("POST", opEndpoint, headers, ms); return httpRequest; } @@ -122,8 +122,7 @@ namespace DotNetOpenAuth.Test.OpenId.Provider { Channel rpChannel = rp.Channel; UriBuilder receiver = new UriBuilder(OPUri); receiver.Query = MessagingUtilities.CreateQueryString(rpChannel.MessageDescriptions.GetAccessor(checkidMessage)); - var headers = new WebHeaderCollection(); - var httpRequest = new HttpRequestInfo("GET", receiver.Uri, receiver.Uri.PathAndQuery, headers, null); + var httpRequest = new HttpRequestInfo("GET", receiver.Uri); return httpRequest; } } |