//----------------------------------------------------------------------- // // Copyright (c) Outercurve Foundation. All rights reserved. // //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test { using System; using System.Collections.Generic; using System.Collections.Specialized; using System.Globalization; using System.IO; using System.Net; using System.Net.Http; using System.Threading; using System.Threading.Tasks; using System.Xml; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.Test.Mocks; using NUnit.Framework; /// /// The base class that all messaging test classes inherit from. /// public class MessagingTestBase : TestBase { protected internal const string DefaultUrlForHttpRequestInfo = "http://localhost/path"; internal enum FieldFill { /// /// An empty dictionary is returned. /// None, /// /// Only enough fields for the /// to identify the message are included. /// IdentifiableButNotAllRequired, /// /// All fields marked as required are included. /// AllRequired, /// /// All user-fillable fields in the message, leaving out those whose /// values are to be set by channel binding elements. /// CompleteBeforeBindings, } internal Channel Channel { get; set; } [SetUp] public override void SetUp() { base.SetUp(); this.Channel = new TestChannel(this.HostFactories); } internal static HttpRequestMessage CreateHttpRequestInfo(HttpMethod method, IDictionary fields) { var result = new HttpRequestMessage() { Method = method }; var requestUri = new UriBuilder(DefaultUrlForHttpRequestInfo); if (method == HttpMethod.Post) { result.Content = new FormUrlEncodedContent(fields); } else if (method == HttpMethod.Get) { requestUri.AppendQueryArgs(fields); } else { throw new ArgumentOutOfRangeException("method", method, "Expected POST or GET"); } result.RequestUri = requestUri.Uri; return result; } internal static IDictionary GetStandardTestFields(FieldFill fill) { TestMessage expectedMessage = GetStandardTestMessage(fill); var fields = new Dictionary(); if (fill >= FieldFill.IdentifiableButNotAllRequired) { fields.Add("age", expectedMessage.Age.ToString()); } if (fill >= FieldFill.AllRequired) { fields.Add("Timestamp", XmlConvert.ToString(expectedMessage.Timestamp, XmlDateTimeSerializationMode.Utc)); } if (fill >= FieldFill.CompleteBeforeBindings) { fields.Add("Name", expectedMessage.Name); fields.Add("Location", expectedMessage.Location.AbsoluteUri); } return fields; } internal static TestMessage GetStandardTestMessage(FieldFill fill) { TestMessage message = new TestDirectedMessage(); GetStandardTestMessage(fill, message); return message; } internal static void GetStandardTestMessage(FieldFill fill, TestMessage message) { if (message == null) { throw new ArgumentNullException("message"); } if (fill >= FieldFill.IdentifiableButNotAllRequired) { message.Age = 15; } if (fill >= FieldFill.AllRequired) { message.Timestamp = DateTime.ParseExact("09/09/2008 08:00", "dd/MM/yyyy hh:mm", CultureInfo.InvariantCulture); } if (fill >= FieldFill.CompleteBeforeBindings) { message.Name = "Andrew"; message.Location = new Uri("http://localtest/path"); } } internal Channel CreateChannel(MessageProtections capabilityAndRecognition) { return this.CreateChannel(capabilityAndRecognition, capabilityAndRecognition); } internal Channel CreateChannel(MessageProtections capability, MessageProtections recognition) { var bindingElements = new List(); if (capability >= MessageProtections.TamperProtection) { bindingElements.Add(new MockSigningBindingElement()); } if (capability >= MessageProtections.Expiration) { bindingElements.Add(new StandardExpirationBindingElement()); } if (capability >= MessageProtections.ReplayProtection) { bindingElements.Add(new MockReplayProtectionBindingElement()); } bool signing = false, expiration = false, replay = false; if (recognition >= MessageProtections.TamperProtection) { signing = true; } if (recognition >= MessageProtections.Expiration) { expiration = true; } if (recognition >= MessageProtections.ReplayProtection) { replay = true; } var typeProvider = new TestMessageFactory(signing, expiration, replay); return new TestChannel(typeProvider, bindingElements.ToArray(), this.HostFactories); } internal async Task ParameterizedReceiveTestAsync(HttpMethod method) { var fields = GetStandardTestFields(FieldFill.CompleteBeforeBindings); TestMessage expectedMessage = GetStandardTestMessage(FieldFill.CompleteBeforeBindings); IDirectedProtocolMessage requestMessage = await this.Channel.ReadFromRequestAsync(CreateHttpRequestInfo(method, fields), CancellationToken.None); Assert.IsNotNull(requestMessage); Assert.IsInstanceOf(requestMessage); TestMessage actualMessage = (TestMessage)requestMessage; Assert.AreEqual(expectedMessage.Age, actualMessage.Age); Assert.AreEqual(expectedMessage.Name, actualMessage.Name); Assert.AreEqual(expectedMessage.Location, actualMessage.Location); } internal async Task ParameterizedReceiveProtectedTestAsync(DateTime? utcCreatedDate, bool invalidSignature) { TestMessage expectedMessage = GetStandardTestMessage(FieldFill.CompleteBeforeBindings); var fields = GetStandardTestFields(FieldFill.CompleteBeforeBindings); fields.Add("Signature", invalidSignature ? "badsig" : MockSigningBindingElement.MessageSignature); fields.Add("Nonce", "someNonce"); if (utcCreatedDate.HasValue) { utcCreatedDate = DateTime.Parse(utcCreatedDate.Value.ToUniversalTime().ToString()); // round off the milliseconds so comparisons work later fields.Add("created_on", XmlConvert.ToString(utcCreatedDate.Value, XmlDateTimeSerializationMode.Utc)); } IProtocolMessage requestMessage = await this.Channel.ReadFromRequestAsync(CreateHttpRequestInfo(HttpMethod.Get, fields), CancellationToken.None); Assert.IsNotNull(requestMessage); Assert.IsInstanceOf(requestMessage); TestSignedDirectedMessage actualMessage = (TestSignedDirectedMessage)requestMessage; Assert.AreEqual(expectedMessage.Age, actualMessage.Age); Assert.AreEqual(expectedMessage.Name, actualMessage.Name); Assert.AreEqual(expectedMessage.Location, actualMessage.Location); if (utcCreatedDate.HasValue) { IExpiringProtocolMessage expiringMessage = (IExpiringProtocolMessage)requestMessage; Assert.AreEqual(utcCreatedDate.Value, expiringMessage.UtcCreationDate); } } } }