summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAndrew Arnott <andrewarnott@gmail.com>2010-02-24 17:10:21 -0800
committerAndrew Arnott <andrewarnott@gmail.com>2010-02-24 17:10:21 -0800
commit2ed543b9b5058d80c255600b5fe37f79d8eb501d (patch)
treeaede6f851c118382bfa674912f28d2f82da84484
parentc728427c61f32b4e017f834e5acc34204b600c50 (diff)
downloadDotNetOpenAuth-2ed543b9b5058d80c255600b5fe37f79d8eb501d.zip
DotNetOpenAuth-2ed543b9b5058d80c255600b5fe37f79d8eb501d.tar.gz
DotNetOpenAuth-2ed543b9b5058d80c255600b5fe37f79d8eb501d.tar.bz2
Improved precision of calculating which message type to instantiate.
-rw-r--r--src/DotNetOpenAuth.Test/OAuthWrap/OAuthWrapChannelTests.cs15
-rw-r--r--src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs41
2 files changed, 43 insertions, 13 deletions
diff --git a/src/DotNetOpenAuth.Test/OAuthWrap/OAuthWrapChannelTests.cs b/src/DotNetOpenAuth.Test/OAuthWrap/OAuthWrapChannelTests.cs
index 1f76e8f..433035c 100644
--- a/src/DotNetOpenAuth.Test/OAuthWrap/OAuthWrapChannelTests.cs
+++ b/src/DotNetOpenAuth.Test/OAuthWrap/OAuthWrapChannelTests.cs
@@ -10,17 +10,22 @@ namespace DotNetOpenAuth.Test.OAuthWrap {
using System.Linq;
using System.Text;
using DotNetOpenAuth.Messaging;
+ using DotNetOpenAuth.OAuthWrap;
using DotNetOpenAuth.OAuthWrap.ChannelElements;
+ using DotNetOpenAuth.OAuthWrap.Messages;
using NUnit.Framework;
[TestFixture]
public class OAuthWrapChannelTests : OAuthWrapTestBase {
private OAuthWrapChannel channel;
+ private IMessageFactory messageFactory;
+ private MessageReceivingEndpoint recipient = new MessageReceivingEndpoint("http://who", HttpDeliveryMethods.PostRequest);
public override void SetUp() {
base.SetUp();
this.channel = new OAuthWrapChannel();
+ this.messageFactory = OAuthWrapChannel_Accessor.AttachShadow(this.channel).MessageFactory;
}
/// <summary>
@@ -28,7 +33,15 @@ namespace DotNetOpenAuth.Test.OAuthWrap {
/// </summary>
[TestCase]
public void MessageFactory() {
- // TODO: code here
+ var fields = new Dictionary<string, string> {
+ { Protocol.wrap_refresh_token, "abc" },
+ };
+ IDirectedProtocolMessage request = messageFactory.GetNewRequestMessage(recipient, fields);
+ Assert.IsInstanceOf(typeof(RefreshAccessTokenRequest), request);
+
+ fields.Clear();
+ fields[Protocol.wrap_access_token] = "abc";
+ Assert.IsInstanceOf(typeof(RefreshAccessTokenSuccessResponse), messageFactory.GetNewResponseMessage(request, fields));
}
}
}
diff --git a/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs b/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs
index 670d750..68432e8 100644
--- a/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs
+++ b/src/DotNetOpenAuth/Messaging/StandardMessageFactory.cs
@@ -142,13 +142,16 @@ namespace DotNetOpenAuth.Messaging {
Contract.Requires<ArgumentNullException>(recipient != null);
Contract.Requires<ArgumentNullException>(fields != null);
- var basicMatches = this.requestMessageTypes.Keys.Where(message => message.CheckMessagePartsPassBasicValidation(fields));
- var match = basicMatches.FirstOrDefault();
+ var matches = this.requestMessageTypes.Keys
+ .Where(message => message.CheckMessagePartsPassBasicValidation(fields))
+ .OrderByDescending(message => message.Mapping.Count)
+ .CacheGeneratedResults();
+ var match = matches.FirstOrDefault();
if (match != null) {
- if (Logger.Messaging.IsDebugEnabled && basicMatches.Count() > 1) {
- Logger.Messaging.DebugFormat(
+ if (Logger.Messaging.IsWarnEnabled && matches.Count() > 1) {
+ Logger.Messaging.WarnFormat(
"Multiple message types seemed to fit the incoming data: {0}",
- basicMatches.ToStringDeferred());
+ matches.ToStringDeferred());
}
return match;
@@ -168,13 +171,20 @@ namespace DotNetOpenAuth.Messaging {
/// </returns>
/// <exception cref="ProtocolException">May be thrown if the incoming data is ambiguous.</exception>
protected virtual MessageDescription GetMessageDescription(IDirectedProtocolMessage request, IDictionary<string, string> fields) {
- var basicMatches = this.responseMessageTypes.Keys.Where(message => message.CheckMessagePartsPassBasicValidation(fields)).CacheGeneratedResults();
- var match = basicMatches.FirstOrDefault();
+ Contract.Requires<ArgumentNullException>(request != null);
+ Contract.Requires<ArgumentNullException>(fields != null);
+
+ var matches = this.responseMessageTypes.Keys
+ .Where(message => message.CheckMessagePartsPassBasicValidation(fields))
+ .Where(message => FindMatchingResponseConstructors(message, request.GetType()).Any())
+ .OrderByDescending(message => message.Mapping.Count)
+ .CacheGeneratedResults();
+ var match = matches.FirstOrDefault();
if (match != null) {
- if (Logger.Messaging.IsDebugEnabled && basicMatches.Count() > 1) {
- Logger.Messaging.DebugFormat(
+ if (Logger.Messaging.IsWarnEnabled && matches.Count() > 1) {
+ Logger.Messaging.WarnFormat(
"Multiple message types seemed to fit the incoming data: {0}",
- basicMatches.ToStringDeferred());
+ matches.ToStringDeferred());
}
return match;
@@ -211,10 +221,10 @@ namespace DotNetOpenAuth.Messaging {
Contract.Ensures(Contract.Result<IDirectResponseProtocolMessage>() != null);
Type requestType = request.GetType();
- var ctors = this.responseMessageTypes[messageDescription].Where(pair => pair.Key.IsAssignableFrom(requestType));
+ var ctors = this.FindMatchingResponseConstructors(messageDescription, requestType);
ConstructorInfo ctor = null;
try {
- ctor = ctors.Single().Value;
+ ctor = ctors.Single();
} catch (InvalidOperationException) {
if (ctors.Any()) {
ErrorUtilities.ThrowInternal("More than one matching constructor for request type " + requestType.Name + " and response type " + messageDescription.MessageType.Name);
@@ -224,5 +234,12 @@ namespace DotNetOpenAuth.Messaging {
}
return (IDirectResponseProtocolMessage)ctor.Invoke(new object[] { request });
}
+
+ private IEnumerable<ConstructorInfo> FindMatchingResponseConstructors(MessageDescription messageDescription, Type requestType) {
+ Contract.Requires<ArgumentNullException>(messageDescription != null);
+ Contract.Requires<ArgumentNullException>(requestType != null);
+
+ return this.responseMessageTypes[messageDescription].Where(pair => pair.Key.IsAssignableFrom(requestType)).Select(pair => pair.Value);
+ }
}
}