diff options
7 files changed, 417 insertions, 1 deletions
diff --git a/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj b/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj index 6540714..a0c849a 100644 --- a/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj +++ b/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj @@ -196,6 +196,7 @@ <Compile Include="Messaging\Bindings\StandardExpirationBindingElementTests.cs" /> <Compile Include="Messaging\Reflection\MessagePartTests.cs" /> <Compile Include="Messaging\Reflection\ValueMappingTests.cs" /> + <Compile Include="Messaging\StandardMessageFactoryTests.cs" /> <Compile Include="Mocks\AssociateUnencryptedRequestNoSslCheck.cs" /> <Compile Include="Mocks\CoordinatingChannel.cs" /> <Compile Include="Mocks\CoordinatingHttpRequestInfo.cs" /> diff --git a/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs b/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs index e57df65..92f39cc 100644 --- a/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDescriptionTests.cs @@ -73,6 +73,16 @@ namespace DotNetOpenAuth.Test.Messaging.Reflection { Assert.IsTrue(v30.Mapping["OptionalIn10RequiredIn25AndLater"].IsRequired); } + /// <summary> + /// Verifies that the constructors cache is properly initialized. + /// </summary> + [TestCase] + public void CtorsCache() { + var message = new MessageDescription(typeof(MultiVersionMessage), new Version(1, 0)); + Assert.IsNotNull(message.Constructors); + Assert.AreEqual(1, message.Constructors.Length); + } + private class MultiVersionMessage : Mocks.TestBaseMessage { #pragma warning disable 0649 // these fields are never written to, but part of the test [MessagePart] diff --git a/src/DotNetOpenAuth.Test/Messaging/StandardMessageFactoryTests.cs b/src/DotNetOpenAuth.Test/Messaging/StandardMessageFactoryTests.cs new file mode 100644 index 0000000..2768350 --- /dev/null +++ b/src/DotNetOpenAuth.Test/Messaging/StandardMessageFactoryTests.cs @@ -0,0 +1,170 @@ +//----------------------------------------------------------------------- +// <copyright file="StandardMessageFactoryTests.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Test.Messaging { + using System; + using System.Collections.Generic; + using System.Linq; + using System.Text; + using DotNetOpenAuth.Messaging; + using DotNetOpenAuth.Messaging.Reflection; + using DotNetOpenAuth.Test.Mocks; + using NUnit.Framework; + + [TestFixture] + public class StandardMessageFactoryTests : MessagingTestBase { + private static readonly Version V1 = new Version(1, 0); + private static readonly MessageReceivingEndpoint receiver = new MessageReceivingEndpoint("http://receiver", HttpDeliveryMethods.PostRequest); + + /// <summary> + /// Verifies the constructor throws the appropriate exception on null input. + /// </summary> + [TestCase, ExpectedException(typeof(ArgumentNullException))] + public void CtorNull() { + new StandardMessageFactory(null); + } + + /// <summary> + /// Verifies the constructor throws the appropriate exception on null input. + /// </summary> + [TestCase, ExpectedException(typeof(ArgumentException))] + public void CtorNullMessageDescription() { + new StandardMessageFactory(new MessageDescription[] { null }); + } + + /// <summary> + /// Verifies very simple recognition of a single message type + /// </summary> + [TestCase] + public void SingleRequestMessageType() { + var factory = new StandardMessageFactory(new MessageDescription[] { MessageDescriptions.Get(typeof(RequestMessageMock), V1) }); + var fields = new Dictionary<string, string> { + { "random", "bits" }, + }; + Assert.IsNull(factory.GetNewRequestMessage(receiver, fields)); + fields["Age"] = "18"; + Assert.IsInstanceOf(typeof(RequestMessageMock), factory.GetNewRequestMessage(receiver, fields)); + } + + /// <summary> + /// Verifies very simple recognition of a single message type + /// </summary> + [TestCase] + public void SingleResponseMessageType() { + var factory = new StandardMessageFactory(new MessageDescription[] { MessageDescriptions.Get(typeof(DirectResponseMessageMock), V1) }); + var fields = new Dictionary<string, string> { + { "random", "bits" }, + }; + IDirectedProtocolMessage request = new RequestMessageMock(receiver.Location, V1); + Assert.IsNull(factory.GetNewResponseMessage(request, fields)); + fields["Age"] = "18"; + IDirectResponseProtocolMessage response = factory.GetNewResponseMessage(request, fields); + Assert.IsInstanceOf<DirectResponseMessageMock>(response); + Assert.AreSame(request, response.OriginatingRequest); + + // Verify that we can instantiate a response with a derived-type of an expected request message. + request = new TestSignedDirectedMessage(); + response = factory.GetNewResponseMessage(request, fields); + Assert.IsInstanceOf<DirectResponseMessageMock>(response); + Assert.AreSame(request, response.OriginatingRequest); + } + + private class DirectResponseMessageMock : IDirectResponseProtocolMessage { + internal DirectResponseMessageMock(RequestMessageMock request) { + this.OriginatingRequest = request; + } + + internal DirectResponseMessageMock(TestDirectedMessage request) { + this.OriginatingRequest = request; + } + + [MessagePart(IsRequired = true)] + public int Age { get; set; } + + #region IDirectResponseProtocolMessage Members + + public IDirectedProtocolMessage OriginatingRequest { get; private set; } + + #endregion + + #region IProtocolMessage Members + + public MessageProtections RequiredProtection { + get { throw new NotImplementedException(); } + } + + public MessageTransport Transport { + get { throw new NotImplementedException(); } + } + + #endregion + + #region IMessage Members + + public Version Version { + get { throw new NotImplementedException(); } + } + + public System.Collections.Generic.IDictionary<string, string> ExtraData { + get { throw new NotImplementedException(); } + } + + public void EnsureValidMessage() { + throw new NotImplementedException(); + } + + #endregion + } + + private class RequestMessageMock : IDirectedProtocolMessage { + internal RequestMessageMock(Uri recipient, Version version) { + } + + [MessagePart(IsRequired = true)] + public int Age { get; set; } + + #region IDirectedProtocolMessage Members + + public HttpDeliveryMethods HttpMethods { + get { throw new NotImplementedException(); } + } + + public Uri Recipient { + get { throw new NotImplementedException(); } + } + + #endregion + + #region IProtocolMessage Members + + public MessageProtections RequiredProtection { + get { throw new NotImplementedException(); } + } + + public MessageTransport Transport { + get { throw new NotImplementedException(); } + } + + #endregion + + #region IMessage Members + + public Version Version { + get { throw new NotImplementedException(); } + } + + public System.Collections.Generic.IDictionary<string, string> ExtraData { + get { throw new NotImplementedException(); } + } + + public void EnsureValidMessage() { + throw new NotImplementedException(); + } + + #endregion + } + } +} diff --git a/src/DotNetOpenAuth/DotNetOpenAuth.csproj b/src/DotNetOpenAuth/DotNetOpenAuth.csproj index 80487fc..d622a4a 100644 --- a/src/DotNetOpenAuth/DotNetOpenAuth.csproj +++ b/src/DotNetOpenAuth/DotNetOpenAuth.csproj @@ -302,6 +302,7 @@ http://opensource.org/licenses/ms-pl.html <Compile Include="Messaging\Reflection\IMessagePartEncoder.cs" /> <Compile Include="Messaging\Reflection\IMessagePartNullEncoder.cs" /> <Compile Include="Messaging\Reflection\MessageDescriptionCollection.cs" /> + <Compile Include="Messaging\StandardMessageFactory.cs" /> <Compile Include="OAuth\ChannelElements\ICombinedOpenIdProviderTokenManager.cs" /> <Compile Include="OAuth\ChannelElements\IConsumerDescription.cs" /> <Compile Include="OAuth\ChannelElements\IConsumerTokenManager.cs" /> diff --git a/src/DotNetOpenAuth/Messaging/MessagingStrings.Designer.cs b/src/DotNetOpenAuth/Messaging/MessagingStrings.Designer.cs index 0bbac42..ea3bf6b 100644 --- a/src/DotNetOpenAuth/Messaging/MessagingStrings.Designer.cs +++ b/src/DotNetOpenAuth/Messaging/MessagingStrings.Designer.cs @@ -1,7 +1,7 @@ //------------------------------------------------------------------------------ // <auto-generated> // This code was generated by a tool. -// Runtime Version:4.0.30104.0 +// Runtime Version:4.0.30128.0 // // Changes to this file may cause incorrect behavior and will be lost if // the code is regenerated. @@ -412,6 +412,15 @@ namespace DotNetOpenAuth.Messaging { } /// <summary> + /// Looks up a localized string similar to This message factory does not support message type(s): {0}. + /// </summary> + internal static string StandardMessageFactoryUnsupportedMessageType { + get { + return ResourceManager.GetString("StandardMessageFactoryUnsupportedMessageType", resourceCulture); + } + } + + /// <summary> /// Looks up a localized string similar to The stream must have a known length.. /// </summary> internal static string StreamMustHaveKnownLength { diff --git a/src/DotNetOpenAuth/Messaging/MessagingStrings.resx b/src/DotNetOpenAuth/Messaging/MessagingStrings.resx index 34385d4..8b50767 100644 --- a/src/DotNetOpenAuth/Messaging/MessagingStrings.resx +++ b/src/DotNetOpenAuth/Messaging/MessagingStrings.resx @@ -300,4 +300,7 @@ <data name="BinaryDataRequiresMultipart" xml:space="preserve"> <value>Unable to send all message data because some of it requires multi-part POST, but IMessageWithBinaryData.SendAsMultipart was false.</value> </data> + <data name="StandardMessageFactoryUnsupportedMessageType" xml:space="preserve"> + <value>This message factory does not support message type(s): {0}</value> + </data> </root>
\ No newline at end of file diff --git a/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs b/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs new file mode 100644 index 0000000..511325d --- /dev/null +++ b/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs @@ -0,0 +1,222 @@ +//----------------------------------------------------------------------- +// <copyright file="StandardMessageFactory.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging { + using System; + using System.Collections.Generic; + using System.Diagnostics.Contracts; + using System.Linq; + using System.Reflection; + using System.Text; + using DotNetOpenAuth.Messaging.Reflection; + + /// <summary> + /// A message factory that automatically selects the message type based on the incoming data. + /// </summary> + internal class StandardMessageFactory : IMessageFactory { + /// <summary> + /// The request message types and their constructors to use for instantiating the messages. + /// </summary> + private readonly Dictionary<MessageDescription, ConstructorInfo> requestMessageTypes = new Dictionary<MessageDescription, ConstructorInfo>(); + + /// <summary> + /// The response message types and their constructors to use for instantiating the messages. + /// </summary> + /// <value> + /// The value is a dictionary, whose key is the type of the constructor's lone parameter. + /// </value> + private readonly Dictionary<MessageDescription, Dictionary<Type, ConstructorInfo>> responseMessageTypes = new Dictionary<MessageDescription, Dictionary<Type, ConstructorInfo>>(); + + /// <summary> + /// Initializes a new instance of the <see cref="StandardMessageFactory"/> class. + /// </summary> + /// <param name="messageTypes">The message types that this factory may instantiate.</param> + internal StandardMessageFactory(IEnumerable<MessageDescription> messageTypes) { + Contract.Requires<ArgumentNullException>(messageTypes != null); + Contract.Requires<ArgumentException>(messageTypes.All(msg => msg != null)); + + var unsupportedMessageTypes = new List<MessageDescription>(0); + foreach (MessageDescription messageDescription in messageTypes) { + bool supportedMessageType = false; + + // First see whether this message fits the recognized pattern for request messages. + if (typeof(IDirectedProtocolMessage).IsAssignableFrom(messageDescription.MessageType)) { + foreach (ConstructorInfo ctor in messageDescription.Constructors) { + ParameterInfo[] parameters = ctor.GetParameters(); + if (parameters.Length == 2 && parameters[0].ParameterType == typeof(Uri) && parameters[1].ParameterType == typeof(Version)) { + supportedMessageType = true; + this.requestMessageTypes.Add(messageDescription, ctor); + break; + } + } + } + + // Also see if this message fits the recognized pattern for response messages. + if (typeof(IDirectResponseProtocolMessage).IsAssignableFrom(messageDescription.MessageType)) { + var responseCtors = new Dictionary<Type, ConstructorInfo>(messageDescription.Constructors.Length); + foreach (ConstructorInfo ctor in messageDescription.Constructors) { + ParameterInfo[] parameters = ctor.GetParameters(); + if (parameters.Length == 1 && typeof(IDirectedProtocolMessage).IsAssignableFrom(parameters[0].ParameterType)) { + responseCtors.Add(parameters[0].ParameterType, ctor); + } + } + + if (responseCtors.Count > 0) { + supportedMessageType = true; + this.responseMessageTypes.Add(messageDescription, responseCtors); + } + } + + if (!supportedMessageType) { + unsupportedMessageTypes.Add(messageDescription); + } + } + + ErrorUtilities.VerifySupported( + !unsupportedMessageTypes.Any(), + MessagingStrings.StandardMessageFactoryUnsupportedMessageType, + unsupportedMessageTypes.ToStringDeferred()); + } + + #region IMessageFactory Members + + /// <summary> + /// Analyzes an incoming request message payload to discover what kind of + /// message is embedded in it and returns the type, or null if no match is found. + /// </summary> + /// <param name="recipient">The intended or actual recipient of the request message.</param> + /// <param name="fields">The name/value pairs that make up the message payload.</param> + /// <returns> + /// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can + /// deserialize to. Null if the request isn't recognized as a valid protocol message. + /// </returns> + public virtual IDirectedProtocolMessage GetNewRequestMessage(MessageReceivingEndpoint recipient, IDictionary<string, string> fields) { + MessageDescription matchingType = this.GetMessageDescription(recipient, fields); + if (matchingType != null) { + return this.InstantiateAsRequest(matchingType, recipient); + } else { + return null; + } + } + + /// <summary> + /// Analyzes an incoming request message payload to discover what kind of + /// message is embedded in it and returns the type, or null if no match is found. + /// </summary> + /// <param name="request">The message that was sent as a request that resulted in the response.</param> + /// <param name="fields">The name/value pairs that make up the message payload.</param> + /// <returns> + /// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can + /// deserialize to. Null if the request isn't recognized as a valid protocol message. + /// </returns> + public virtual IDirectResponseProtocolMessage GetNewResponseMessage(IDirectedProtocolMessage request, IDictionary<string, string> fields) { + MessageDescription matchingType = this.GetMessageDescription(request, fields); + if (matchingType != null) { + return this.InstantiateAsResponse(matchingType, request); + } else { + return null; + } + } + + #endregion + + /// <summary> + /// Gets the message type that best fits the given incoming request data. + /// </summary> + /// <param name="recipient">The recipient of the incoming data. Typically not used, but included just in case.</param> + /// <param name="fields">The data of the incoming message.</param> + /// <returns> + /// The message type that matches the incoming data; or <c>null</c> if no match. + /// </returns> + /// <exception cref="ProtocolException">May be thrown if the incoming data is ambiguous.</exception> + protected virtual MessageDescription GetMessageDescription(MessageReceivingEndpoint recipient, IDictionary<string, string> fields) { + Contract.Requires<ArgumentNullException>(recipient != null); + Contract.Requires<ArgumentNullException>(fields != null); + + var basicMatches = this.requestMessageTypes.Keys.Where(message => message.CheckMessagePartsPassBasicValidation(fields)); + var match = basicMatches.FirstOrDefault(); + if (match != null) { + if (Logger.Messaging.IsDebugEnabled && basicMatches.Count() > 1) { + Logger.Messaging.DebugFormat( + "Multiple message types seemed to fit the incoming data: {0}", + basicMatches.ToStringDeferred()); + } + + return match; + } else { + // No message type matches the incoming data. + return null; + } + } + + /// <summary> + /// Gets the message type that best fits the given incoming direct response data. + /// </summary> + /// <param name="request">The request message that prompted the response data.</param> + /// <param name="fields">The data of the incoming message.</param> + /// <returns> + /// The message type that matches the incoming data; or <c>null</c> if no match. + /// </returns> + /// <exception cref="ProtocolException">May be thrown if the incoming data is ambiguous.</exception> + protected virtual MessageDescription GetMessageDescription(IDirectedProtocolMessage request, IDictionary<string, string> fields) { + var basicMatches = this.responseMessageTypes.Keys.Where(message => message.CheckMessagePartsPassBasicValidation(fields)).CacheGeneratedResults(); + var match = basicMatches.FirstOrDefault(); + if (match != null) { + if (Logger.Messaging.IsDebugEnabled && basicMatches.Count() > 1) { + Logger.Messaging.DebugFormat( + "Multiple message types seemed to fit the incoming data: {0}", + basicMatches.ToStringDeferred()); + } + + return match; + } else { + // No message type matches the incoming data. + return null; + } + } + + /// <summary> + /// Instantiates the given request message type. + /// </summary> + /// <param name="messageDescription">The message description.</param> + /// <param name="recipient">The recipient.</param> + /// <returns>The instantiated message. Never null.</returns> + protected virtual IDirectedProtocolMessage InstantiateAsRequest(MessageDescription messageDescription, MessageReceivingEndpoint recipient) { + Contract.Requires<ArgumentNullException>(messageDescription != null); + Contract.Requires<ArgumentNullException>(recipient != null); + Contract.Ensures(Contract.Result<IDirectedProtocolMessage>() != null); + + ConstructorInfo ctor = this.requestMessageTypes[messageDescription]; + return (IDirectedProtocolMessage)ctor.Invoke(new object[] { recipient.Location, messageDescription.MessageVersion }); + } + + /// <summary> + /// Instantiates the given request message type. + /// </summary> + /// <param name="messageDescription">The message description.</param> + /// <param name="request">The request that resulted in this response.</param> + /// <returns>The instantiated message. Never null.</returns> + protected virtual IDirectResponseProtocolMessage InstantiateAsResponse(MessageDescription messageDescription, IDirectedProtocolMessage request) { + Contract.Requires<ArgumentNullException>(messageDescription != null); + Contract.Requires<ArgumentNullException>(request != null); + Contract.Ensures(Contract.Result<IDirectResponseProtocolMessage>() != null); + + Type requestType = request.GetType(); + var ctors = this.responseMessageTypes[messageDescription].Where(pair => pair.Key.IsAssignableFrom(requestType)); + ConstructorInfo ctor = null; + try { + ctor = ctors.Single().Value; + } catch (InvalidOperationException) { + if (ctors.Any()) { + ErrorUtilities.ThrowInternal("More than one matching constructor for request type " + requestType.Name + " and response type " + messageDescription.MessageType.Name); + } else { + ErrorUtilities.ThrowInternal("Unexpected request message type " + requestType.FullName + " for response type " + messageDescription.MessageType.Name); + } + } + return (IDirectResponseProtocolMessage)ctor.Invoke(new object[] { request }); + } + } +} |