diff options
author | Andrew Arnott <andrewarnott@gmail.com> | 2010-02-24 17:10:21 -0800 |
---|---|---|
committer | Andrew Arnott <andrewarnott@gmail.com> | 2010-02-24 17:10:21 -0800 |
commit | 2ed543b9b5058d80c255600b5fe37f79d8eb501d (patch) | |
tree | aede6f851c118382bfa674912f28d2f82da84484 | |
parent | c728427c61f32b4e017f834e5acc34204b600c50 (diff) | |
download | DotNetOpenAuth-2ed543b9b5058d80c255600b5fe37f79d8eb501d.zip DotNetOpenAuth-2ed543b9b5058d80c255600b5fe37f79d8eb501d.tar.gz DotNetOpenAuth-2ed543b9b5058d80c255600b5fe37f79d8eb501d.tar.bz2 |
Improved precision of calculating which message type to instantiate.
-rw-r--r-- | src/DotNetOpenAuth.Test/OAuthWrap/OAuthWrapChannelTests.cs | 15 | ||||
-rw-r--r-- | src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs | 41 |
2 files changed, 43 insertions, 13 deletions
diff --git a/src/DotNetOpenAuth.Test/OAuthWrap/OAuthWrapChannelTests.cs b/src/DotNetOpenAuth.Test/OAuthWrap/OAuthWrapChannelTests.cs index 1f76e8f..433035c 100644 --- a/src/DotNetOpenAuth.Test/OAuthWrap/OAuthWrapChannelTests.cs +++ b/src/DotNetOpenAuth.Test/OAuthWrap/OAuthWrapChannelTests.cs @@ -10,17 +10,22 @@ namespace DotNetOpenAuth.Test.OAuthWrap { using System.Linq; using System.Text; using DotNetOpenAuth.Messaging; + using DotNetOpenAuth.OAuthWrap; using DotNetOpenAuth.OAuthWrap.ChannelElements; + using DotNetOpenAuth.OAuthWrap.Messages; using NUnit.Framework; [TestFixture] public class OAuthWrapChannelTests : OAuthWrapTestBase { private OAuthWrapChannel channel; + private IMessageFactory messageFactory; + private MessageReceivingEndpoint recipient = new MessageReceivingEndpoint("http://who", HttpDeliveryMethods.PostRequest); public override void SetUp() { base.SetUp(); this.channel = new OAuthWrapChannel(); + this.messageFactory = OAuthWrapChannel_Accessor.AttachShadow(this.channel).MessageFactory; } /// <summary> @@ -28,7 +33,15 @@ namespace DotNetOpenAuth.Test.OAuthWrap { /// </summary> [TestCase] public void MessageFactory() { - // TODO: code here + var fields = new Dictionary<string, string> { + { Protocol.wrap_refresh_token, "abc" }, + }; + IDirectedProtocolMessage request = messageFactory.GetNewRequestMessage(recipient, fields); + Assert.IsInstanceOf(typeof(RefreshAccessTokenRequest), request); + + fields.Clear(); + fields[Protocol.wrap_access_token] = "abc"; + Assert.IsInstanceOf(typeof(RefreshAccessTokenSuccessResponse), messageFactory.GetNewResponseMessage(request, fields)); } } } diff --git a/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs b/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs index 670d750..68432e8 100644 --- a/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs +++ b/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs @@ -142,13 +142,16 @@ namespace DotNetOpenAuth.Messaging { Contract.Requires<ArgumentNullException>(recipient != null); Contract.Requires<ArgumentNullException>(fields != null); - var basicMatches = this.requestMessageTypes.Keys.Where(message => message.CheckMessagePartsPassBasicValidation(fields)); - var match = basicMatches.FirstOrDefault(); + var matches = this.requestMessageTypes.Keys + .Where(message => message.CheckMessagePartsPassBasicValidation(fields)) + .OrderByDescending(message => message.Mapping.Count) + .CacheGeneratedResults(); + var match = matches.FirstOrDefault(); if (match != null) { - if (Logger.Messaging.IsDebugEnabled && basicMatches.Count() > 1) { - Logger.Messaging.DebugFormat( + if (Logger.Messaging.IsWarnEnabled && matches.Count() > 1) { + Logger.Messaging.WarnFormat( "Multiple message types seemed to fit the incoming data: {0}", - basicMatches.ToStringDeferred()); + matches.ToStringDeferred()); } return match; @@ -168,13 +171,20 @@ namespace DotNetOpenAuth.Messaging { /// </returns> /// <exception cref="ProtocolException">May be thrown if the incoming data is ambiguous.</exception> protected virtual MessageDescription GetMessageDescription(IDirectedProtocolMessage request, IDictionary<string, string> fields) { - var basicMatches = this.responseMessageTypes.Keys.Where(message => message.CheckMessagePartsPassBasicValidation(fields)).CacheGeneratedResults(); - var match = basicMatches.FirstOrDefault(); + Contract.Requires<ArgumentNullException>(request != null); + Contract.Requires<ArgumentNullException>(fields != null); + + var matches = this.responseMessageTypes.Keys + .Where(message => message.CheckMessagePartsPassBasicValidation(fields)) + .Where(message => FindMatchingResponseConstructors(message, request.GetType()).Any()) + .OrderByDescending(message => message.Mapping.Count) + .CacheGeneratedResults(); + var match = matches.FirstOrDefault(); if (match != null) { - if (Logger.Messaging.IsDebugEnabled && basicMatches.Count() > 1) { - Logger.Messaging.DebugFormat( + if (Logger.Messaging.IsWarnEnabled && matches.Count() > 1) { + Logger.Messaging.WarnFormat( "Multiple message types seemed to fit the incoming data: {0}", - basicMatches.ToStringDeferred()); + matches.ToStringDeferred()); } return match; @@ -211,10 +221,10 @@ namespace DotNetOpenAuth.Messaging { Contract.Ensures(Contract.Result<IDirectResponseProtocolMessage>() != null); Type requestType = request.GetType(); - var ctors = this.responseMessageTypes[messageDescription].Where(pair => pair.Key.IsAssignableFrom(requestType)); + var ctors = this.FindMatchingResponseConstructors(messageDescription, requestType); ConstructorInfo ctor = null; try { - ctor = ctors.Single().Value; + ctor = ctors.Single(); } catch (InvalidOperationException) { if (ctors.Any()) { ErrorUtilities.ThrowInternal("More than one matching constructor for request type " + requestType.Name + " and response type " + messageDescription.MessageType.Name); @@ -224,5 +234,12 @@ namespace DotNetOpenAuth.Messaging { } return (IDirectResponseProtocolMessage)ctor.Invoke(new object[] { request }); } + + private IEnumerable<ConstructorInfo> FindMatchingResponseConstructors(MessageDescription messageDescription, Type requestType) { + Contract.Requires<ArgumentNullException>(messageDescription != null); + Contract.Requires<ArgumentNullException>(requestType != null); + + return this.responseMessageTypes[messageDescription].Where(pair => pair.Key.IsAssignableFrom(requestType)).Select(pair => pair.Value); + } } } |