//----------------------------------------------------------------------- // // Copyright (c) Andrew Arnott. All rights reserved. // //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test { using System; using System.Collections.Generic; using System.IO; using System.Net; using System.Xml; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.Test.Mocks; using Microsoft.VisualStudio.TestTools.UnitTesting; /// /// The base class that all messaging test classes inherit from. /// public class MessagingTestBase : TestBase { internal enum FieldFill { /// /// An empty dictionary is returned. /// None, /// /// Only enough fields for the /// to identify the message are included. /// IdentifiableButNotAllRequired, /// /// All fields marked as required are included. /// AllRequired, /// /// All user-fillable fields in the message, leaving out those whose /// values are to be set by channel binding elements. /// CompleteBeforeBindings, } internal Channel Channel { get; set; } [TestInitialize] public override void SetUp() { base.SetUp(); this.Channel = new TestChannel(); } internal static HttpRequestInfo CreateHttpRequestInfo(string method, IDictionary fields) { string query = MessagingUtilities.CreateQueryString(fields); UriBuilder requestUri = new UriBuilder("http://localhost/path"); WebHeaderCollection headers = new WebHeaderCollection(); MemoryStream ms = new MemoryStream(); if (method == "POST") { headers.Add(HttpRequestHeader.ContentType, "application/x-www-form-urlencoded"); StreamWriter sw = new StreamWriter(ms); sw.Write(query); sw.Flush(); ms.Position = 0; } else if (method == "GET") { requestUri.Query = query; } else { throw new ArgumentOutOfRangeException("method", method, "Expected POST or GET"); } HttpRequestInfo request = new HttpRequestInfo { HttpMethod = method, Url = requestUri.Uri, Headers = headers, InputStream = ms, }; return request; } internal static Channel CreateChannel(MessageProtections capabilityAndRecognition) { return CreateChannel(capabilityAndRecognition, capabilityAndRecognition); } internal static Channel CreateChannel(MessageProtections capability, MessageProtections recognition) { var bindingElements = new List(); if (capability >= MessageProtections.TamperProtection) { bindingElements.Add(new MockSigningBindingElement()); } if (capability >= MessageProtections.Expiration) { bindingElements.Add(new StandardExpirationBindingElement()); } if (capability >= MessageProtections.ReplayProtection) { bindingElements.Add(new MockReplayProtectionBindingElement()); } bool signing = false, expiration = false, replay = false; if (recognition >= MessageProtections.TamperProtection) { signing = true; } if (recognition >= MessageProtections.Expiration) { expiration = true; } if (recognition >= MessageProtections.ReplayProtection) { replay = true; } var typeProvider = new TestMessageFactory(signing, expiration, replay); return new TestChannel(typeProvider, bindingElements.ToArray()); } internal static IDictionary GetStandardTestFields(FieldFill fill) { TestMessage expectedMessage = GetStandardTestMessage(fill); var fields = new Dictionary(); if (fill >= FieldFill.IdentifiableButNotAllRequired) { fields.Add("age", expectedMessage.Age.ToString()); } if (fill >= FieldFill.AllRequired) { fields.Add("Timestamp", XmlConvert.ToString(expectedMessage.Timestamp, XmlDateTimeSerializationMode.Utc)); } if (fill >= FieldFill.CompleteBeforeBindings) { fields.Add("Name", expectedMessage.Name); fields.Add("Location", expectedMessage.Location.AbsoluteUri); } return fields; } internal static TestMessage GetStandardTestMessage(FieldFill fill) { TestMessage message = new TestDirectedMessage(); GetStandardTestMessage(fill, message); return message; } internal static void GetStandardTestMessage(FieldFill fill, TestMessage message) { if (message == null) { throw new ArgumentNullException("message"); } if (fill >= FieldFill.IdentifiableButNotAllRequired) { message.Age = 15; } if (fill >= FieldFill.AllRequired) { message.Timestamp = DateTime.SpecifyKind(DateTime.Parse("9/19/2008 8 AM"), DateTimeKind.Utc); } if (fill >= FieldFill.CompleteBeforeBindings) { message.Name = "Andrew"; message.Location = new Uri("http://localtest/path"); } } internal void ParameterizedReceiveTest(string method) { var fields = GetStandardTestFields(FieldFill.CompleteBeforeBindings); TestMessage expectedMessage = GetStandardTestMessage(FieldFill.CompleteBeforeBindings); IDirectedProtocolMessage requestMessage = this.Channel.ReadFromRequest(CreateHttpRequestInfo(method, fields)); Assert.IsNotNull(requestMessage); Assert.IsInstanceOfType(requestMessage, typeof(TestMessage)); TestMessage actualMessage = (TestMessage)requestMessage; Assert.AreEqual(expectedMessage.Age, actualMessage.Age); Assert.AreEqual(expectedMessage.Name, actualMessage.Name); Assert.AreEqual(expectedMessage.Location, actualMessage.Location); } internal void ParameterizedReceiveProtectedTest(DateTime? utcCreatedDate, bool invalidSignature) { TestMessage expectedMessage = GetStandardTestMessage(FieldFill.CompleteBeforeBindings); var fields = GetStandardTestFields(FieldFill.CompleteBeforeBindings); fields.Add("Signature", invalidSignature ? "badsig" : MockSigningBindingElement.MessageSignature); fields.Add("Nonce", "someNonce"); if (utcCreatedDate.HasValue) { utcCreatedDate = DateTime.Parse(utcCreatedDate.Value.ToUniversalTime().ToString()); // round off the milliseconds so comparisons work later fields.Add("created_on", XmlConvert.ToString(utcCreatedDate.Value, XmlDateTimeSerializationMode.Utc)); } IProtocolMessage requestMessage = this.Channel.ReadFromRequest(CreateHttpRequestInfo("GET", fields)); Assert.IsNotNull(requestMessage); Assert.IsInstanceOfType(requestMessage, typeof(TestSignedDirectedMessage)); TestSignedDirectedMessage actualMessage = (TestSignedDirectedMessage)requestMessage; Assert.AreEqual(expectedMessage.Age, actualMessage.Age); Assert.AreEqual(expectedMessage.Name, actualMessage.Name); Assert.AreEqual(expectedMessage.Location, actualMessage.Location); if (utcCreatedDate.HasValue) { IExpiringProtocolMessage expiringMessage = (IExpiringProtocolMessage)requestMessage; Assert.AreEqual(utcCreatedDate.Value, expiringMessage.UtcCreationDate); } } } }