//-----------------------------------------------------------------------
//
// Copyright (c) Outercurve Foundation. All rights reserved.
//
//-----------------------------------------------------------------------
namespace DotNetOpenAuth.Test.OAuth.ChannelElements {
using System;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Web;
using System.Xml;
using DotNetOpenAuth.Messaging;
using DotNetOpenAuth.Messaging.Bindings;
using DotNetOpenAuth.Messaging.Reflection;
using DotNetOpenAuth.OAuth.ChannelElements;
using DotNetOpenAuth.Test.Mocks;
using NUnit.Framework;
using Validation;
[TestFixture]
public class OAuthChannelTests : TestBase {
private OAuthChannel channel;
private SigningBindingElementBase signingElement;
private INonceStore nonceStore;
private DotNetOpenAuth.OAuth.ServiceProviderSecuritySettings serviceProviderSecuritySettings = DotNetOpenAuth.Configuration.OAuthElement.Configuration.ServiceProvider.SecuritySettings.CreateSecuritySettings();
[SetUp]
public override void SetUp() {
base.SetUp();
this.signingElement = new RsaSha1ServiceProviderSigningBindingElement(new InMemoryTokenManager());
this.nonceStore = new MemoryNonceStore(StandardExpirationBindingElement.MaximumMessageAge);
this.channel = new OAuthServiceProviderChannel(this.signingElement, this.nonceStore, new InMemoryTokenManager(), this.serviceProviderSecuritySettings, new TestMessageFactory(), this.HostFactories);
}
[Test, ExpectedException(typeof(ArgumentException))]
public void CtorNullSigner() {
new OAuthServiceProviderChannel(null, this.nonceStore, new InMemoryTokenManager(), this.serviceProviderSecuritySettings, new TestMessageFactory());
}
[Test, ExpectedException(typeof(ArgumentNullException))]
public void CtorNullStore() {
new OAuthServiceProviderChannel(new RsaSha1ServiceProviderSigningBindingElement(new InMemoryTokenManager()), null, new InMemoryTokenManager(), this.serviceProviderSecuritySettings, new TestMessageFactory());
}
[Test, ExpectedException(typeof(ArgumentNullException))]
public void CtorNullTokenManager() {
new OAuthServiceProviderChannel(new RsaSha1ServiceProviderSigningBindingElement(new InMemoryTokenManager()), this.nonceStore, null, this.serviceProviderSecuritySettings, new TestMessageFactory());
}
[Test]
public void CtorSimpleServiceProvider() {
new OAuthServiceProviderChannel(new RsaSha1ServiceProviderSigningBindingElement(new InMemoryTokenManager()), this.nonceStore, (IServiceProviderTokenManager)new InMemoryTokenManager(), this.serviceProviderSecuritySettings);
}
[Test]
public async Task ReadFromRequestAuthorization() {
await this.ParameterizedReceiveTestAsync(HttpDeliveryMethods.AuthorizationHeaderRequest);
}
///
/// Verifies that the OAuth ReadFromRequest method gathers parameters
/// from the Authorization header, the query string and the entity form data.
///
[Test]
public async Task ReadFromRequestAuthorizationScattered() {
// Start by creating a standard POST HTTP request.
var postedFields = new Dictionary {
{ "age", "15" },
};
// Now add another field to the request URL
var builder = new UriBuilder(MessagingTestBase.DefaultUrlForHttpRequestInfo);
builder.Query = "Name=Andrew";
// Finally, add an Authorization header
var authHeaderFields = new Dictionary {
{ "Location", "http://hostb/pathB" },
{ "Timestamp", XmlConvert.ToString(DateTime.UtcNow, XmlDateTimeSerializationMode.Utc) },
};
var headers = new NameValueCollection();
headers.Add(HttpRequestHeaders.Authorization, CreateAuthorizationHeader(authHeaderFields));
headers.Add(HttpRequestHeaders.ContentType, Channel.HttpFormUrlEncoded);
var requestInfo = new HttpRequestInfo("POST", builder.Uri, form: postedFields.ToNameValueCollection(), headers: headers);
IDirectedProtocolMessage requestMessage = await this.channel.ReadFromRequestAsync(requestInfo.AsHttpRequestMessage(), CancellationToken.None);
Assert.IsNotNull(requestMessage);
Assert.IsInstanceOf(requestMessage);
TestMessage testMessage = (TestMessage)requestMessage;
Assert.AreEqual(15, testMessage.Age);
Assert.AreEqual("Andrew", testMessage.Name);
Assert.AreEqual("http://hostb/pathB", testMessage.Location.AbsoluteUri);
}
[Test]
public async Task ReadFromRequestForm() {
await this.ParameterizedReceiveTestAsync(HttpDeliveryMethods.PostRequest);
}
[Test]
public async Task ReadFromRequestQueryString() {
await this.ParameterizedReceiveTestAsync(HttpDeliveryMethods.GetRequest);
}
[Test]
public async Task SendDirectMessageResponse() {
IProtocolMessage message = new TestDirectedMessage {
Age = 15,
Name = "Andrew",
Location = new Uri("http://hostb/pathB"),
};
var response = await this.channel.PrepareResponseAsync(message);
Assert.AreEqual(HttpStatusCode.OK, response.StatusCode);
Assert.AreEqual(Channel.HttpFormUrlEncodedContentType.MediaType, response.Content.Headers.ContentType.MediaType);
NameValueCollection body = HttpUtility.ParseQueryString(await response.Content.ReadAsStringAsync());
Assert.AreEqual("15", body["age"]);
Assert.AreEqual("Andrew", body["Name"]);
Assert.AreEqual("http://hostb/pathB", body["Location"]);
}
[Test]
public async Task ReadFromResponse() {
var fields = new Dictionary {
{ "age", "15" },
{ "Name", "Andrew" },
{ "Location", "http://hostb/pathB" },
{ "Timestamp", XmlConvert.ToString(DateTime.UtcNow, XmlDateTimeSerializationMode.Utc) },
};
MemoryStream ms = new MemoryStream();
StreamWriter writer = new StreamWriter(ms);
writer.Write(MessagingUtilities.CreateQueryString(fields));
writer.Flush();
ms.Seek(0, SeekOrigin.Begin);
IDictionary deserializedFields = await this.channel.ReadFromResponseCoreAsyncTestHook(
new HttpResponseMessage { Content = new StreamContent(ms) },
CancellationToken.None);
Assert.AreEqual(fields.Count, deserializedFields.Count);
foreach (string key in fields.Keys) {
Assert.AreEqual(fields[key], deserializedFields[key]);
}
}
[Test, ExpectedException(typeof(ArgumentNullException))]
public async Task RequestNull() {
await this.channel.RequestAsync(null, CancellationToken.None);
}
[Test, ExpectedException(typeof(ArgumentException))]
public async Task RequestNullRecipient() {
IDirectedProtocolMessage message = new TestDirectedMessage(MessageTransport.Direct);
await this.channel.RequestAsync(message, CancellationToken.None);
}
[Test, ExpectedException(typeof(NotSupportedException))]
public async Task RequestBadPreferredScheme() {
TestDirectedMessage message = new TestDirectedMessage(MessageTransport.Direct);
message.Recipient = new Uri("http://localtest");
message.HttpMethods = HttpDeliveryMethods.None;
await this.channel.RequestAsync(message, CancellationToken.None);
}
[Test]
public async Task RequestUsingAuthorizationHeader() {
await this.ParameterizedRequestTestAsync(HttpDeliveryMethods.AuthorizationHeaderRequest);
}
///
/// Verifies that message parts can be distributed to the query, form, and Authorization header.
///
[Test]
public async Task RequestUsingAuthorizationHeaderScattered() {
TestDirectedMessage request = new TestDirectedMessage(MessageTransport.Direct) {
Age = 15,
Name = "Andrew",
Location = new Uri("http://hostb/pathB"),
Recipient = new Uri("http://localtest"),
Timestamp = DateTime.UtcNow,
HttpMethods = HttpDeliveryMethods.AuthorizationHeaderRequest,
};
// ExtraData should appear in the form since this is a POST request,
// and only official message parts get a place in the Authorization header.
((IProtocolMessage)request).ExtraData["appearinform"] = "formish";
request.Recipient = new Uri("http://localhost/?appearinquery=queryish");
request.HttpMethods = HttpDeliveryMethods.AuthorizationHeaderRequest | HttpDeliveryMethods.PostRequest;
var webRequest = await this.channel.InitializeRequestAsync(request, CancellationToken.None);
Assert.IsNotNull(webRequest);
Assert.AreEqual(HttpMethod.Post, webRequest.Method);
Assert.AreEqual(request.Recipient, webRequest.RequestUri);
var declaredParts = new Dictionary {
{ "age", request.Age.ToString() },
{ "Name", request.Name },
{ "Location", request.Location.AbsoluteUri },
{ "Timestamp", XmlConvert.ToString(request.Timestamp, XmlDateTimeSerializationMode.Utc) },
};
Assert.AreEqual(CreateAuthorizationHeader(declaredParts), webRequest.Headers.Authorization.ToString());
Assert.AreEqual("appearinform=formish", await webRequest.Content.ReadAsStringAsync());
}
[Test]
public async Task RequestUsingGet() {
await this.ParameterizedRequestTestAsync(HttpDeliveryMethods.GetRequest);
}
[Test]
public async Task RequestUsingPost() {
await this.ParameterizedRequestTestAsync(HttpDeliveryMethods.PostRequest);
}
[Test]
public async Task RequestUsingHead() {
await this.ParameterizedRequestTestAsync(HttpDeliveryMethods.HeadRequest);
}
///
/// Verifies that messages asking for special HTTP status codes get them.
///
[Test]
public void SendDirectMessageResponseHonorsHttpStatusCodes() {
IProtocolMessage message = MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired);
var directResponse = this.channel.PrepareDirectResponseTestHook(message);
Assert.AreEqual(HttpStatusCode.OK, directResponse.StatusCode);
var httpMessage = new TestDirectResponseMessageWithHttpStatus();
MessagingTestBase.GetStandardTestMessage(MessagingTestBase.FieldFill.AllRequired, httpMessage);
httpMessage.HttpStatusCode = HttpStatusCode.NotAcceptable;
directResponse = this.channel.PrepareDirectResponseTestHook(httpMessage);
Assert.AreEqual(HttpStatusCode.NotAcceptable, directResponse.StatusCode);
}
private static string CreateAuthorizationHeader(IDictionary fields) {
Requires.NotNull(fields, "fields");
StringBuilder authorization = new StringBuilder();
authorization.Append("OAuth ");
foreach (var pair in fields) {
string key = Uri.EscapeDataString(pair.Key);
string value = Uri.EscapeDataString(pair.Value);
authorization.Append(key);
authorization.Append("=\"");
authorization.Append(value);
authorization.Append("\",");
}
authorization.Length--; // remove trailing comma
return authorization.ToString();
}
private static HttpRequestInfo CreateHttpRequestInfo(HttpDeliveryMethods scheme, IDictionary fields) {
var requestUri = new UriBuilder(MessagingTestBase.DefaultUrlForHttpRequestInfo);
var headers = new NameValueCollection();
NameValueCollection form = null;
string method;
switch (scheme) {
case HttpDeliveryMethods.PostRequest:
method = "POST";
form = fields.ToNameValueCollection();
headers.Add(HttpRequestHeaders.ContentType, Channel.HttpFormUrlEncoded);
break;
case HttpDeliveryMethods.GetRequest:
method = "GET";
requestUri.Query = MessagingUtilities.CreateQueryString(fields);
break;
case HttpDeliveryMethods.AuthorizationHeaderRequest:
method = "GET";
headers.Add(HttpRequestHeaders.Authorization, CreateAuthorizationHeader(fields));
break;
default:
throw new ArgumentOutOfRangeException("scheme", scheme, "Unexpected value");
}
return new HttpRequestInfo(method, requestUri.Uri, form: form, headers: headers);
}
private static HttpRequestInfo ConvertToRequestInfo(HttpWebRequest request, Stream postEntity) {
return new HttpRequestInfo(request.Method, request.RequestUri, request.Headers, postEntity);
}
private async Task ParameterizedRequestTestAsync(HttpDeliveryMethods scheme) {
var request = new TestDirectedMessage(MessageTransport.Direct) {
Age = 15,
Name = "Andrew",
Location = new Uri("http://hostb/pathB"),
Recipient = new Uri("http://localtest"),
Timestamp = DateTime.UtcNow,
HttpMethods = scheme,
};
Handle(request.Recipient).By(
async (req, ct) => {
Assert.IsNotNull(req);
Assert.AreEqual(MessagingUtilities.GetHttpVerb(scheme), req.Method);
var incomingMessage = (await this.channel.ReadFromRequestAsync(req, CancellationToken.None)) as TestMessage;
Assert.IsNotNull(incomingMessage);
Assert.AreEqual(request.Age, incomingMessage.Age);
Assert.AreEqual(request.Name, incomingMessage.Name);
Assert.AreEqual(request.Location, incomingMessage.Location);
Assert.AreEqual(request.Timestamp, incomingMessage.Timestamp);
var responseFields = new Dictionary {
{ "age", request.Age.ToString() },
{ "Name", request.Name },
{ "Location", request.Location.AbsoluteUri },
{ "Timestamp", XmlConvert.ToString(request.Timestamp, XmlDateTimeSerializationMode.Utc) },
};
var rawResponse = new HttpResponseMessage();
rawResponse.Content = new StringContent(MessagingUtilities.CreateQueryString(responseFields));
return rawResponse;
});
IProtocolMessage response = await this.channel.RequestAsync(request, CancellationToken.None);
Assert.IsNotNull(response);
Assert.IsInstanceOf(response);
TestMessage responseMessage = (TestMessage)response;
Assert.AreEqual(request.Age, responseMessage.Age);
Assert.AreEqual(request.Name, responseMessage.Name);
Assert.AreEqual(request.Location, responseMessage.Location);
}
private async Task ParameterizedReceiveTestAsync(HttpDeliveryMethods scheme) {
var fields = new Dictionary {
{ "age", "15" },
{ "Name", "Andrew" },
{ "Location", "http://hostb/pathB" },
{ "Timestamp", XmlConvert.ToString(DateTime.UtcNow, XmlDateTimeSerializationMode.Utc) },
{ "realm", "someValue" },
};
IProtocolMessage requestMessage = await this.channel.ReadFromRequestAsync(CreateHttpRequestInfo(scheme, fields).AsHttpRequestMessage(), CancellationToken.None);
Assert.IsNotNull(requestMessage);
Assert.IsInstanceOf(requestMessage);
TestMessage testMessage = (TestMessage)requestMessage;
Assert.AreEqual(15, testMessage.Age);
Assert.AreEqual("Andrew", testMessage.Name);
Assert.AreEqual("http://hostb/pathB", testMessage.Location.AbsoluteUri);
if (scheme == HttpDeliveryMethods.AuthorizationHeaderRequest) {
// The realm value should be ignored in the authorization header
Assert.IsFalse(((IMessage)testMessage).ExtraData.ContainsKey("realm"));
} else {
Assert.AreEqual("someValue", ((IMessage)testMessage).ExtraData["realm"]);
}
}
}
}