//-----------------------------------------------------------------------
//
// Copyright (c) Outercurve Foundation. All rights reserved.
//
//-----------------------------------------------------------------------
namespace DotNetOpenAuth.Test.Messaging {
using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using System.Web;
using DotNetOpenAuth.Messaging;
using DotNetOpenAuth.Messaging.Bindings;
using DotNetOpenAuth.OpenId;
using DotNetOpenAuth.Test.Mocks;
using NUnit.Framework;
[TestFixture]
public class ChannelTests : MessagingTestBase {
[Test, ExpectedException(typeof(ArgumentNullException))]
public void CtorNullFirstParameter() {
new TestBadChannel(null, new IChannelBindingElement[0], new DefaultOpenIdHostFactories());
}
[Test, ExpectedException(typeof(ArgumentNullException))]
public void CtorNullSecondParameter() {
new TestBadChannel(new TestMessageFactory(), null, new DefaultOpenIdHostFactories());
}
[Test, ExpectedException(typeof(ArgumentNullException))]
public void CtorNullThirdParameter() {
new TestBadChannel(new TestMessageFactory(), new IChannelBindingElement[0], null);
}
[Test]
public async Task ReadFromRequestQueryString() {
await this.ParameterizedReceiveTestAsync(HttpMethod.Get);
}
[Test]
public async Task ReadFromRequestForm() {
await this.ParameterizedReceiveTestAsync(HttpMethod.Post);
}
///
/// Verifies compliance to OpenID 2.0 section 5.1.1 by verifying the channel
/// will reject messages that come with an unexpected HTTP verb.
///
[Test, ExpectedException(typeof(ProtocolException))]
public async Task ReadFromRequestDisallowedHttpMethod() {
var fields = GetStandardTestFields(FieldFill.CompleteBeforeBindings);
fields["GetOnly"] = "true";
await this.Channel.ReadFromRequestAsync(CreateHttpRequestInfo(HttpMethod.Post, fields), CancellationToken.None);
}
[Test, ExpectedException(typeof(ArgumentNullException))]
public async Task SendNull() {
await this.Channel.PrepareResponseAsync(null);
}
[Test, ExpectedException(typeof(ArgumentException))]
public async Task SendIndirectedUndirectedMessage() {
IProtocolMessage message = new TestDirectedMessage(MessageTransport.Indirect);
await this.Channel.PrepareResponseAsync(message);
}
[Test, ExpectedException(typeof(ArgumentException))]
public async Task SendDirectedNoRecipientMessage() {
IProtocolMessage message = new TestDirectedMessage(MessageTransport.Indirect);
await this.Channel.PrepareResponseAsync(message);
}
[Test, ExpectedException(typeof(ArgumentException))]
public async Task SendInvalidMessageTransport() {
IProtocolMessage message = new TestDirectedMessage((MessageTransport)100);
await this.Channel.PrepareResponseAsync(message);
}
[Test]
public async Task SendIndirectMessage301Get() {
TestDirectedMessage message = new TestDirectedMessage(MessageTransport.Indirect);
GetStandardTestMessage(FieldFill.CompleteBeforeBindings, message);
message.Recipient = new Uri("http://provider/path");
var expected = GetStandardTestFields(FieldFill.CompleteBeforeBindings);
var response = await this.Channel.PrepareResponseAsync(message);
Assert.AreEqual(HttpStatusCode.Redirect, response.StatusCode);
Assert.AreEqual("text/html; charset=utf-8", response.Content.Headers.ContentType.ToString());
Assert.IsTrue(response.Content != null && response.Content.Headers.ContentLength > 0); // a non-empty body helps get passed filters like WebSense
StringAssert.StartsWith("http://provider/path", response.Headers.Location.AbsoluteUri);
foreach (var pair in expected) {
string key = MessagingUtilities.EscapeUriDataStringRfc3986(pair.Key);
string value = MessagingUtilities.EscapeUriDataStringRfc3986(pair.Value);
string substring = string.Format("{0}={1}", key, value);
StringAssert.Contains(substring, response.Headers.Location.AbsoluteUri);
}
}
[Test, ExpectedException(typeof(ArgumentNullException))]
public void SendIndirectMessage301GetNullMessage() {
TestBadChannel badChannel = new TestBadChannel();
badChannel.Create301RedirectResponse(null, new Dictionary());
}
[Test, ExpectedException(typeof(ArgumentException))]
public void SendIndirectMessage301GetEmptyRecipient() {
TestBadChannel badChannel = new TestBadChannel();
var message = new TestDirectedMessage(MessageTransport.Indirect);
badChannel.Create301RedirectResponse(message, new Dictionary());
}
[Test, ExpectedException(typeof(ArgumentNullException))]
public void SendIndirectMessage301GetNullFields() {
TestBadChannel badChannel = new TestBadChannel();
var message = new TestDirectedMessage(MessageTransport.Indirect);
message.Recipient = new Uri("http://someserver");
badChannel.Create301RedirectResponse(message, null);
}
[Test]
public async Task SendIndirectMessageFormPost() {
// We craft a very large message to force fallback to form POST.
// We'll also stick some HTML reserved characters in the string value
// to test proper character escaping.
var message = new TestDirectedMessage(MessageTransport.Indirect) {
Age = 15,
Name = "c", body);
StringAssert.Contains("", body);
StringAssert.Contains("", body);
StringAssert.Contains(".submit()", body, "There should be some javascript to automate form submission.");
}
[Test, ExpectedException(typeof(ArgumentNullException))]
public void SendIndirectMessageFormPostNullMessage() {
TestBadChannel badChannel = new TestBadChannel();
badChannel.CreateFormPostResponse(null, new Dictionary());
}
[Test, ExpectedException(typeof(ArgumentException))]
public void SendIndirectMessageFormPostEmptyRecipient() {
TestBadChannel badChannel = new TestBadChannel();
var message = new TestDirectedMessage(MessageTransport.Indirect);
badChannel.CreateFormPostResponse(message, new Dictionary());
}
[Test, ExpectedException(typeof(ArgumentNullException))]
public void SendIndirectMessageFormPostNullFields() {
TestBadChannel badChannel = new TestBadChannel();
var message = new TestDirectedMessage(MessageTransport.Indirect);
message.Recipient = new Uri("http://someserver");
badChannel.CreateFormPostResponse(message, null);
}
///
/// Tests that a direct message is sent when the appropriate message type is provided.
///
///
/// Since this is a mock channel that doesn't actually formulate a direct message response,
/// we just check that the right method was called.
///
[Test, ExpectedException(typeof(NotImplementedException))]
public async Task SendDirectMessageResponse() {
IProtocolMessage message = new TestDirectedMessage {
Age = 15,
Name = "Andrew",
Location = new Uri("http://host/path"),
};
await this.Channel.PrepareResponseAsync(message);
}
[Test, ExpectedException(typeof(ArgumentNullException))]
public void SendIndirectMessageNull() {
TestBadChannel badChannel = new TestBadChannel();
badChannel.PrepareIndirectResponse(null);
}
[Test, ExpectedException(typeof(ArgumentNullException))]
public void ReceiveNull() {
TestBadChannel badChannel = new TestBadChannel();
badChannel.Receive(null, null);
}
[Test]
public void ReceiveUnrecognizedMessage() {
TestBadChannel badChannel = new TestBadChannel();
Assert.IsNull(badChannel.Receive(new Dictionary(), null));
}
[Test]
public async Task ReadFromRequestWithContext() {
var fields = GetStandardTestFields(FieldFill.AllRequired);
TestMessage expectedMessage = GetStandardTestMessage(FieldFill.AllRequired);
HttpRequest request = new HttpRequest("somefile", "http://someurl", MessagingUtilities.CreateQueryString(fields));
HttpContext.Current = new HttpContext(request, new HttpResponse(new StringWriter()));
var requestBase = this.Channel.GetRequestFromContext();
IProtocolMessage message = await this.Channel.ReadFromRequestAsync(requestBase.AsHttpRequestMessage(), CancellationToken.None);
Assert.IsNotNull(message);
Assert.IsInstanceOf(message);
Assert.AreEqual(expectedMessage.Age, ((TestMessage)message).Age);
}
[Test, ExpectedException(typeof(InvalidOperationException))]
public void GetRequestFromContextNoContext() {
HttpContext.Current = null;
TestBadChannel badChannel = new TestBadChannel();
badChannel.GetRequestFromContext();
}
[Test, ExpectedException(typeof(ArgumentNullException))]
public async Task ReadFromRequestNull() {
TestBadChannel badChannel = new TestBadChannel();
await badChannel.ReadFromRequestAsync(null, CancellationToken.None);
}
[Test]
public async Task SendReplayProtectedMessageSetsNonce() {
TestReplayProtectedMessage message = new TestReplayProtectedMessage(MessageTransport.Indirect);
message.Recipient = new Uri("http://localtest");
this.Channel = CreateChannel(MessageProtections.ReplayProtection);
await this.Channel.PrepareResponseAsync(message);
Assert.IsNotNull(((IReplayProtectedProtocolMessage)message).Nonce);
}
[Test, ExpectedException(typeof(InvalidSignatureException))]
public async Task ReceivedInvalidSignature() {
this.Channel = CreateChannel(MessageProtections.TamperProtection);
await this.ParameterizedReceiveProtectedTestAsync(DateTime.UtcNow, true);
}
[Test]
public async Task ReceivedReplayProtectedMessageJustOnce() {
this.Channel = CreateChannel(MessageProtections.ReplayProtection);
await this.ParameterizedReceiveProtectedTestAsync(DateTime.UtcNow, false);
}
[Test, ExpectedException(typeof(ReplayedMessageException))]
public async Task ReceivedReplayProtectedMessageTwice() {
this.Channel = CreateChannel(MessageProtections.ReplayProtection);
await this.ParameterizedReceiveProtectedTestAsync(DateTime.UtcNow, false);
await this.ParameterizedReceiveProtectedTestAsync(DateTime.UtcNow, false);
}
[Test, ExpectedException(typeof(ProtocolException))]
public void MessageExpirationWithoutTamperResistance() {
new TestChannel(
new TestMessageFactory(),
new IChannelBindingElement[] { new StandardExpirationBindingElement() },
new DefaultOpenIdHostFactories());
}
[Test, ExpectedException(typeof(ProtocolException))]
public async Task TooManyBindingElementsProvidingSameProtection() {
Channel channel = new TestChannel(
new TestMessageFactory(),
new IChannelBindingElement[] { new MockSigningBindingElement(), new MockSigningBindingElement() },
new DefaultOpenIdHostFactories());
await channel.ProcessOutgoingMessageTestHookAsync(new TestSignedDirectedMessage());
}
[Test]
public void BindingElementsOrdering() {
IChannelBindingElement transformA = new MockTransformationBindingElement("a");
IChannelBindingElement transformB = new MockTransformationBindingElement("b");
IChannelBindingElement sign = new MockSigningBindingElement();
IChannelBindingElement replay = new MockReplayProtectionBindingElement();
IChannelBindingElement expire = new StandardExpirationBindingElement();
Channel channel = new TestChannel(
new TestMessageFactory(),
new[] { sign, replay, expire, transformB, transformA },
new DefaultOpenIdHostFactories());
Assert.AreEqual(5, channel.BindingElements.Count);
Assert.AreSame(transformB, channel.BindingElements[0]);
Assert.AreSame(transformA, channel.BindingElements[1]);
Assert.AreSame(replay, channel.BindingElements[2]);
Assert.AreSame(expire, channel.BindingElements[3]);
Assert.AreSame(sign, channel.BindingElements[4]);
}
[Test, ExpectedException(typeof(UnprotectedMessageException))]
public async Task InsufficientlyProtectedMessageSent() {
var message = new TestSignedDirectedMessage(MessageTransport.Direct);
message.Recipient = new Uri("http://localtest");
await this.Channel.PrepareResponseAsync(message);
}
[Test, ExpectedException(typeof(UnprotectedMessageException))]
public async Task InsufficientlyProtectedMessageReceived() {
this.Channel = CreateChannel(MessageProtections.None, MessageProtections.TamperProtection);
await this.ParameterizedReceiveProtectedTestAsync(DateTime.Now, false);
}
[Test, ExpectedException(typeof(ProtocolException))]
public async Task IncomingMessageMissingRequiredParameters() {
var fields = GetStandardTestFields(FieldFill.IdentifiableButNotAllRequired);
await this.Channel.ReadFromRequestAsync(CreateHttpRequestInfo(HttpMethod.Get, fields), CancellationToken.None);
}
}
}