diff options
Diffstat (limited to 'src')
53 files changed, 650 insertions, 335 deletions
diff --git a/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj b/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj index d16b6c4..4469c17 100644 --- a/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj +++ b/src/DotNetOpenAuth.Test/DotNetOpenAuth.Test.csproj @@ -88,7 +88,7 @@ <Compile Include="Mocks\TestWebRequestHandler.cs" /> <Compile Include="Mocks\TestChannel.cs" /> <Compile Include="Mocks\TestMessage.cs" /> - <Compile Include="Mocks\TestMessageTypeProvider.cs" /> + <Compile Include="Mocks\TestMessageFactory.cs" /> <Compile Include="OAuth\ChannelElements\HmacSha1SigningBindingElementTests.cs" /> <Compile Include="OAuth\ChannelElements\OAuthChannelTests.cs" /> <Compile Include="OAuth\ChannelElements\PlaintextSigningBindingElementTest.cs" /> diff --git a/src/DotNetOpenAuth.Test/Messaging/ChannelTests.cs b/src/DotNetOpenAuth.Test/Messaging/ChannelTests.cs index 4127b5a..0176164 100644 --- a/src/DotNetOpenAuth.Test/Messaging/ChannelTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/ChannelTests.cs @@ -41,7 +41,7 @@ namespace DotNetOpenAuth.Test.Messaging { [TestMethod, ExpectedException(typeof(ArgumentException))] public void SendIndirectedUndirectedMessage() { - IProtocolMessage message = new TestMessage(MessageTransport.Indirect); + IProtocolMessage message = new TestDirectedMessage(MessageTransport.Indirect); this.Channel.Send(message); } @@ -150,7 +150,7 @@ namespace DotNetOpenAuth.Test.Messaging { /// </remarks> [TestMethod, ExpectedException(typeof(NotImplementedException), "SendDirectMessageResponse")] public void SendDirectMessageResponse() { - IProtocolMessage message = new TestMessage { + IProtocolMessage message = new TestDirectedMessage { Age = 15, Name = "Andrew", Location = new Uri("http://host/path"), @@ -232,14 +232,14 @@ namespace DotNetOpenAuth.Test.Messaging { [TestMethod, ExpectedException(typeof(ProtocolException))] public void MessageExpirationWithoutTamperResistance() { new TestChannel( - new TestMessageTypeProvider(), + new TestMessageFactory(), new StandardExpirationBindingElement()); } [TestMethod, ExpectedException(typeof(ProtocolException))] public void TooManyBindingElementsProvidingSameProtection() { new TestChannel( - new TestMessageTypeProvider(), + new TestMessageFactory(), new MockSigningBindingElement(), new MockSigningBindingElement()); } @@ -253,7 +253,7 @@ namespace DotNetOpenAuth.Test.Messaging { IChannelBindingElement expire = new StandardExpirationBindingElement(); Channel channel = new TestChannel( - new TestMessageTypeProvider(), + new TestMessageFactory(), sign, replay, expire, diff --git a/src/DotNetOpenAuth.Test/Messaging/MessageSerializerTests.cs b/src/DotNetOpenAuth.Test/Messaging/MessageSerializerTests.cs index 62b6393..3bcca10 100644 --- a/src/DotNetOpenAuth.Test/Messaging/MessageSerializerTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/MessageSerializerTests.cs @@ -65,7 +65,8 @@ namespace DotNetOpenAuth.Test.Messaging { fields["Name"] = "Andrew"; fields["age"] = "15"; fields["Timestamp"] = "1990-01-01T00:00:00"; - var actual = (Mocks.TestMessage)serializer.Deserialize(fields, null); + var actual = new Mocks.TestDirectedMessage(); + serializer.Deserialize(fields, actual); Assert.AreEqual(15, actual.Age); Assert.AreEqual("Andrew", actual.Name); Assert.AreEqual(DateTime.Parse("1/1/1990"), actual.Timestamp); @@ -94,7 +95,8 @@ namespace DotNetOpenAuth.Test.Messaging { fields["SecondDerivedElement"] = "second"; fields["explicit"] = "explicitValue"; fields["private"] = "privateValue"; - var actual = (Mocks.TestDerivedMessage)serializer.Deserialize(fields, null); + var actual = new Mocks.TestDerivedMessage(); + serializer.Deserialize(fields, actual); Assert.AreEqual(15, actual.Age); Assert.AreEqual("Andrew", actual.Name); Assert.AreEqual("first", actual.TheFirstDerivedElement); @@ -113,7 +115,8 @@ namespace DotNetOpenAuth.Test.Messaging { // Add some field that is not recognized by the class. This simulates a querystring with // more parameters than are actually interesting to the protocol message. fields["someExtraField"] = "asdf"; - var actual = (Mocks.TestMessage)serializer.Deserialize(fields, null); + var actual = new Mocks.TestDirectedMessage(); + serializer.Deserialize(fields, actual); Assert.AreEqual(15, actual.Age); Assert.AreEqual("Andrew", actual.Name); Assert.IsNull(actual.EmptyMember); @@ -121,10 +124,11 @@ namespace DotNetOpenAuth.Test.Messaging { [TestMethod, ExpectedException(typeof(ProtocolException))] public void DeserializeInvalidMessage() { - var serializer = MessageSerializer.Get(typeof(Mocks.TestMessage)); + IProtocolMessage message = new Mocks.TestDirectedMessage(); + var serializer = MessageSerializer.Get(message.GetType()); var fields = GetStandardTestFields(FieldFill.AllRequired); fields["age"] = "-1"; // Set an disallowed value. - serializer.Deserialize(fields, null); + serializer.Deserialize(fields, message); } } } diff --git a/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs b/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs index acaf4a0..0a11a75 100644 --- a/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs +++ b/src/DotNetOpenAuth.Test/Messaging/MessagingTestBase.cs @@ -26,7 +26,7 @@ namespace DotNetOpenAuth.Test { None, /// <summary> - /// Only enough fields for the <see cref="TestMessageTypeProvider"/> + /// Only enough fields for the <see cref="TestMessageFactory"/> /// to identify the message are included. /// </summary> IdentifiableButNotAllRequired, @@ -105,7 +105,7 @@ namespace DotNetOpenAuth.Test { replay = true; } - var typeProvider = new TestMessageTypeProvider(signing, expiration, replay); + var typeProvider = new TestMessageFactory(signing, expiration, replay); return new TestChannel(typeProvider, bindingElements.ToArray()); } @@ -128,7 +128,7 @@ namespace DotNetOpenAuth.Test { } internal static TestMessage GetStandardTestMessage(FieldFill fill) { - TestMessage message = new TestMessage(); + TestMessage message = new TestDirectedMessage(); GetStandardTestMessage(fill, message); return message; } diff --git a/src/DotNetOpenAuth.Test/Messaging/ProtocolExceptionTests.cs b/src/DotNetOpenAuth.Test/Messaging/ProtocolExceptionTests.cs index a2e3eaa..02aea64 100644 --- a/src/DotNetOpenAuth.Test/Messaging/ProtocolExceptionTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/ProtocolExceptionTests.cs @@ -32,7 +32,7 @@ namespace DotNetOpenAuth.Test.Messaging { [TestMethod] public void CtorWithProtocolMessage() { - IProtocolMessage request = new Mocks.TestMessage(); + IProtocolMessage request = new Mocks.TestDirectedMessage(); Uri receiver = new Uri("http://receiver"); ProtocolException ex = new ProtocolException("some error occurred", request, receiver); IDirectedProtocolMessage msg = (IDirectedProtocolMessage)ex; diff --git a/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDictionaryTests.cs b/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDictionaryTests.cs index 039743e..0175173 100644 --- a/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDictionaryTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/Reflection/MessageDictionaryTests.cs @@ -22,7 +22,7 @@ namespace DotNetOpenAuth.Test.Messaging.Reflection { public override void SetUp() { base.SetUp(); - this.message = new Mocks.TestMessage(); + this.message = new Mocks.TestDirectedMessage(); } [TestMethod, ExpectedException(typeof(ArgumentNullException))] diff --git a/src/DotNetOpenAuth.Test/Mocks/CoordinatingChannel.cs b/src/DotNetOpenAuth.Test/Mocks/CoordinatingChannel.cs index bb094af..2f82c06 100644 --- a/src/DotNetOpenAuth.Test/Mocks/CoordinatingChannel.cs +++ b/src/DotNetOpenAuth.Test/Mocks/CoordinatingChannel.cs @@ -21,7 +21,7 @@ namespace DotNetOpenAuth.Test.Mocks { private Action<IProtocolMessage> outgoingMessageFilter; internal CoordinatingChannel(Channel wrappedChannel, Action<IProtocolMessage> incomingMessageFilter, Action<IProtocolMessage> outgoingMessageFilter) - : base(GetMessageTypeProvider(wrappedChannel), wrappedChannel.BindingElements.ToArray()) { + : base(GetMessageFactory(wrappedChannel), wrappedChannel.BindingElements.ToArray()) { ErrorUtilities.VerifyArgumentNotNull(wrappedChannel, "wrappedChannel"); this.wrappedChannel = wrappedChannel; @@ -87,25 +87,38 @@ namespace DotNetOpenAuth.Test.Mocks { } protected virtual T CloneSerializedParts<T>(T message, HttpRequestInfo requestInfo) where T : class, IProtocolMessage { - if (message == null) { - throw new ArgumentNullException("message"); - } + ErrorUtilities.VerifyArgumentNotNull(message, "message"); + + IProtocolMessage clonedMessage; + MessageSerializer serializer = MessageSerializer.Get(message.GetType()); + var fields = serializer.Serialize(message); MessageReceivingEndpoint recipient = null; - IDirectedProtocolMessage directedMessage = message as IDirectedProtocolMessage; - if (directedMessage != null && directedMessage.Recipient != null) { - recipient = new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods); + var directedMessage = message as IDirectedProtocolMessage; + var directResponse = message as IDirectResponseProtocolMessage; + if (directedMessage != null && directedMessage.IsRequest()) { + if (directedMessage.Recipient != null) { + recipient = new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods); + } + + clonedMessage = this.RemoteChannel.MessageFactory.GetNewRequestMessage(recipient, fields); + } else if (directResponse != null && directResponse.IsDirectResponse()) { + clonedMessage = this.RemoteChannel.MessageFactory.GetNewResponseMessage(directResponse.OriginatingRequest, fields); + } else { + throw new InvalidOperationException("Totally expected a message to implement one of the two derived interface types."); } - MessageSerializer serializer = MessageSerializer.Get(message.GetType()); - return (T)serializer.Deserialize(serializer.Serialize(message), recipient); + // Fill the cloned message with data. + serializer.Deserialize(fields, clonedMessage); + + return (T)clonedMessage; } - private static IMessageTypeProvider GetMessageTypeProvider(Channel channel) { + private static IMessageFactory GetMessageFactory(Channel channel) { ErrorUtilities.VerifyArgumentNotNull(channel, "channel"); Channel_Accessor accessor = Channel_Accessor.AttachShadow(channel); - return accessor.MessageTypeProvider; + return accessor.MessageFactory; } private IProtocolMessage AwaitIncomingMessage() { diff --git a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthChannel.cs b/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthChannel.cs index 148d2da..10a8d7e 100644 --- a/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthChannel.cs +++ b/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthChannel.cs @@ -33,7 +33,7 @@ namespace DotNetOpenAuth.Test.Mocks { signingBindingElement, new NonceMemoryStore(StandardExpirationBindingElement.DefaultMaximumMessageAge), tokenManager, - isConsumer ? (IMessageTypeProvider)new OAuthConsumerMessageTypeProvider() : new OAuthServiceProviderMessageTypeProvider(tokenManager)) { + isConsumer ? (IMessageFactory)new OAuthConsumerMessageFactory() : new OAuthServiceProviderMessageFactory(tokenManager)) { } /// <summary> @@ -121,18 +121,31 @@ namespace DotNetOpenAuth.Test.Mocks { } private T CloneSerializedParts<T>(T message, HttpRequestInfo requestInfo) where T : class, IProtocolMessage { - if (message == null) { - throw new ArgumentNullException("message"); - } + ErrorUtilities.VerifyArgumentNotNull(message, "message"); + + IProtocolMessage clonedMessage; + MessageSerializer serializer = MessageSerializer.Get(message.GetType()); + var fields = serializer.Serialize(message); MessageReceivingEndpoint recipient = null; - IDirectedProtocolMessage directedMessage = message as IDirectedProtocolMessage; - if (directedMessage != null && directedMessage.Recipient != null) { - recipient = new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods); + var directedMessage = message as IDirectedProtocolMessage; + var directResponse = message as IDirectResponseProtocolMessage; + if (directedMessage != null && directedMessage.IsRequest()) { + if (directedMessage.Recipient != null) { + recipient = new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods); + } + + clonedMessage = this.RemoteChannel.MessageFactory.GetNewRequestMessage(recipient, fields); + } else if (directResponse != null && directResponse.IsDirectResponse()) { + clonedMessage = this.RemoteChannel.MessageFactory.GetNewResponseMessage(directResponse.OriginatingRequest, fields); + } else { + throw new InvalidOperationException("Totally expected a message to implement one of the two derived interface types."); } - MessageSerializer serializer = MessageSerializer.Get(message.GetType()); - return (T)serializer.Deserialize(serializer.Serialize(message), recipient); + // Fill the cloned message with data. + serializer.Deserialize(fields, clonedMessage); + + return (T)clonedMessage; } private string GetHttpMethod(HttpDeliveryMethods methods) { diff --git a/src/DotNetOpenAuth.Test/Mocks/TestBadChannel.cs b/src/DotNetOpenAuth.Test/Mocks/TestBadChannel.cs index 9d4712c..515766e 100644 --- a/src/DotNetOpenAuth.Test/Mocks/TestBadChannel.cs +++ b/src/DotNetOpenAuth.Test/Mocks/TestBadChannel.cs @@ -14,7 +14,7 @@ namespace DotNetOpenAuth.Test.Mocks { /// </summary> internal class TestBadChannel : Channel { internal TestBadChannel(bool badConstructorParam) - : base(badConstructorParam ? null : new TestMessageTypeProvider()) { + : base(badConstructorParam ? null : new TestMessageFactory()) { } internal new void Create301RedirectResponse(IDirectedProtocolMessage message, IDictionary<string, string> fields) { diff --git a/src/DotNetOpenAuth.Test/Mocks/TestBaseMessage.cs b/src/DotNetOpenAuth.Test/Mocks/TestBaseMessage.cs index bd17dd6..212907f 100644 --- a/src/DotNetOpenAuth.Test/Mocks/TestBaseMessage.cs +++ b/src/DotNetOpenAuth.Test/Mocks/TestBaseMessage.cs @@ -17,6 +17,7 @@ namespace DotNetOpenAuth.Test.Mocks { internal class TestBaseMessage : IProtocolMessage, IBaseMessageExplicitMembers { private Dictionary<string, string> extraData = new Dictionary<string, string>(); + private bool incoming; [MessagePart("age", IsRequired = true)] public int Age { get; set; } @@ -43,7 +44,9 @@ namespace DotNetOpenAuth.Test.Mocks { get { return this.extraData; } } - bool IProtocolMessage.Incoming { get; set; } + bool IProtocolMessage.Incoming { + get { return this.incoming; } + } internal string PrivatePropertyAccessor { get { return this.PrivateProperty; } @@ -54,5 +57,9 @@ namespace DotNetOpenAuth.Test.Mocks { private string PrivateProperty { get; set; } void IProtocolMessage.EnsureValidMessage() { } + + internal void SetAsIncoming() { + this.incoming = true; + } } } diff --git a/src/DotNetOpenAuth.Test/Mocks/TestChannel.cs b/src/DotNetOpenAuth.Test/Mocks/TestChannel.cs index 9c11589..0f7f4b8 100644 --- a/src/DotNetOpenAuth.Test/Mocks/TestChannel.cs +++ b/src/DotNetOpenAuth.Test/Mocks/TestChannel.cs @@ -12,10 +12,10 @@ namespace DotNetOpenAuth.Test.Mocks { internal class TestChannel : Channel { internal TestChannel() - : this(new TestMessageTypeProvider()) { + : this(new TestMessageFactory()) { } - internal TestChannel(IMessageTypeProvider messageTypeProvider, params IChannelBindingElement[] bindingElements) + internal TestChannel(IMessageFactory messageTypeProvider, params IChannelBindingElement[] bindingElements) : base(messageTypeProvider, bindingElements) { } diff --git a/src/DotNetOpenAuth.Test/Mocks/TestMessage.cs b/src/DotNetOpenAuth.Test/Mocks/TestMessage.cs index 55e44ba..4a168a1 100644 --- a/src/DotNetOpenAuth.Test/Mocks/TestMessage.cs +++ b/src/DotNetOpenAuth.Test/Mocks/TestMessage.cs @@ -11,15 +11,16 @@ namespace DotNetOpenAuth.Test.Mocks { using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Reflection; - internal class TestMessage : IProtocolMessage { + internal abstract class TestMessage : IDirectResponseProtocolMessage { private MessageTransport transport; private Dictionary<string, string> extraData = new Dictionary<string, string>(); + private bool incoming; - internal TestMessage() + protected TestMessage() : this(MessageTransport.Direct) { } - internal TestMessage(MessageTransport transport) { + protected TestMessage(MessageTransport transport) { this.transport = transport; } @@ -34,7 +35,7 @@ namespace DotNetOpenAuth.Test.Mocks { [MessagePart(IsRequired = true)] public DateTime Timestamp { get; set; } - #region IProtocolMessage Members + #region IProtocolMessage Properties Version IProtocolMessage.ProtocolVersion { get { return new Version(1, 0); } @@ -52,7 +53,19 @@ namespace DotNetOpenAuth.Test.Mocks { get { return this.extraData; } } - bool IProtocolMessage.Incoming { get; set; } + bool IProtocolMessage.Incoming { + get { return this.incoming; } + } + + #endregion + + #region IDirectResponseProtocolMessage Members + + public IDirectedProtocolMessage OriginatingRequest { get; set; } + + #endregion + + #region IProtocolMessage Methods void IProtocolMessage.EnsureValidMessage() { if (this.EmptyMember != null || this.Age < 0) { @@ -61,5 +74,9 @@ namespace DotNetOpenAuth.Test.Mocks { } #endregion + + internal void SetAsIncoming() { + this.incoming = true; + } } } diff --git a/src/DotNetOpenAuth.Test/Mocks/TestMessageTypeProvider.cs b/src/DotNetOpenAuth.Test/Mocks/TestMessageFactory.cs index 8f075f7..7c88898 100644 --- a/src/DotNetOpenAuth.Test/Mocks/TestMessageTypeProvider.cs +++ b/src/DotNetOpenAuth.Test/Mocks/TestMessageFactory.cs @@ -1,5 +1,5 @@ //----------------------------------------------------------------------- -// <copyright file="TestMessageTypeProvider.cs" company="Andrew Arnott"> +// <copyright file="TestMessageFactory.cs" company="Andrew Arnott"> // Copyright (c) Andrew Arnott. All rights reserved. // </copyright> //----------------------------------------------------------------------- @@ -11,16 +11,16 @@ namespace DotNetOpenAuth.Test.Mocks { using System.Text; using DotNetOpenAuth.Messaging; - internal class TestMessageTypeProvider : IMessageTypeProvider { + internal class TestMessageFactory : IMessageFactory { private bool signedMessages; private bool expiringMessages; private bool replayMessages; - internal TestMessageTypeProvider() + internal TestMessageFactory() : this(false, false, false) { } - internal TestMessageTypeProvider(bool signed, bool expiring, bool replay) { + internal TestMessageFactory(bool signed, bool expiring, bool replay) { if ((!signed && expiring) || (!expiring && replay)) { throw new ArgumentException("Invalid combination of protection."); } @@ -29,26 +29,30 @@ namespace DotNetOpenAuth.Test.Mocks { this.replayMessages = replay; } - #region IMessageTypeProvider Members + #region IMessageFactory Members + + public IDirectedProtocolMessage GetNewRequestMessage(MessageReceivingEndpoint recipient, IDictionary<string, string> fields) { + ErrorUtilities.VerifyArgumentNotNull(fields, "fields"); - public Type GetRequestMessageType(IDictionary<string, string> fields) { if (fields.ContainsKey("age")) { if (this.signedMessages) { if (this.expiringMessages) { if (this.replayMessages) { - return typeof(TestReplayProtectedMessage); + return new TestReplayProtectedMessage(); } - return typeof(TestExpiringMessage); + return new TestExpiringMessage(); } - return typeof(TestSignedDirectedMessage); + return new TestSignedDirectedMessage(); } - return typeof(TestDirectedMessage); + return new TestDirectedMessage(); } return null; } - public Type GetResponseMessageType(IProtocolMessage request, IDictionary<string, string> fields) { - return this.GetRequestMessageType(fields); + public IDirectResponseProtocolMessage GetNewResponseMessage(IDirectedProtocolMessage request, IDictionary<string, string> fields) { + TestMessage message = (TestMessage)this.GetNewRequestMessage(null, fields); + message.OriginatingRequest = request; + return message; } #endregion diff --git a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs index fbbd6a6..b1fe7c4 100644 --- a/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs +++ b/src/DotNetOpenAuth.Test/OAuth/ChannelElements/OAuthChannelTests.cs @@ -33,23 +33,23 @@ namespace DotNetOpenAuth.Test.ChannelElements { this.webRequestHandler = new TestWebRequestHandler(); this.signingElement = new RsaSha1SigningBindingElement(); this.nonceStore = new NonceMemoryStore(StandardExpirationBindingElement.DefaultMaximumMessageAge); - this.channel = new OAuthChannel(this.signingElement, this.nonceStore, new InMemoryTokenManager(), new TestMessageTypeProvider()); + this.channel = new OAuthChannel(this.signingElement, this.nonceStore, new InMemoryTokenManager(), new TestMessageFactory()); this.channel.WebRequestHandler = this.webRequestHandler; } [TestMethod, ExpectedException(typeof(ArgumentException))] public void CtorNullSigner() { - new OAuthChannel(null, this.nonceStore, new InMemoryTokenManager(), new TestMessageTypeProvider()); + new OAuthChannel(null, this.nonceStore, new InMemoryTokenManager(), new TestMessageFactory()); } [TestMethod, ExpectedException(typeof(ArgumentNullException))] public void CtorNullStore() { - new OAuthChannel(new RsaSha1SigningBindingElement(), null, new InMemoryTokenManager(), new TestMessageTypeProvider()); + new OAuthChannel(new RsaSha1SigningBindingElement(), null, new InMemoryTokenManager(), new TestMessageFactory()); } [TestMethod, ExpectedException(typeof(ArgumentNullException))] public void CtorNullTokenManager() { - new OAuthChannel(new RsaSha1SigningBindingElement(), this.nonceStore, null, new TestMessageTypeProvider()); + new OAuthChannel(new RsaSha1SigningBindingElement(), this.nonceStore, null, new TestMessageFactory()); } [TestMethod] @@ -79,7 +79,7 @@ namespace DotNetOpenAuth.Test.ChannelElements { [TestMethod] public void SendDirectMessageResponse() { - IProtocolMessage message = new TestMessage { + IProtocolMessage message = new TestDirectedMessage { Age = 15, Name = "Andrew", Location = new Uri("http://hostb/pathB"), diff --git a/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs b/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs index eb0bccc..6209fac 100644 --- a/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/AssociationHandshakeTests.cs @@ -20,34 +20,38 @@ namespace DotNetOpenAuth.Test.OpenId { [TestMethod] public void DHv2() { - var opDescription = new ProviderEndpointDescription(new Uri("http://host"), Protocol.V20); + Protocol protocol = Protocol.V20; + var opDescription = new ProviderEndpointDescription(new Uri("http://host"), protocol.Version); this.ParameterizedAssociationTest( opDescription, - Protocol.V20.Args.SignatureAlgorithm.HMAC_SHA256); + protocol.Args.SignatureAlgorithm.HMAC_SHA256); } [TestMethod] public void DHv1() { - var opDescription = new ProviderEndpointDescription(new Uri("http://host"), Protocol.V10); + Protocol protocol = Protocol.V11; + var opDescription = new ProviderEndpointDescription(new Uri("http://host"), protocol.Version); this.ParameterizedAssociationTest( opDescription, - Protocol.V20.Args.SignatureAlgorithm.HMAC_SHA1); + protocol.Args.SignatureAlgorithm.HMAC_SHA1); } [TestMethod] public void PTv2() { - var opDescription = new ProviderEndpointDescription(new Uri("https://host"), Protocol.V20); + Protocol protocol = Protocol.V20; + var opDescription = new ProviderEndpointDescription(new Uri("https://host"), protocol.Version); this.ParameterizedAssociationTest( opDescription, - Protocol.V20.Args.SignatureAlgorithm.HMAC_SHA256); + protocol.Args.SignatureAlgorithm.HMAC_SHA256); } [TestMethod] public void PTv1() { - var opDescription = new ProviderEndpointDescription(new Uri("https://host"), Protocol.V11); + Protocol protocol = Protocol.V11; + var opDescription = new ProviderEndpointDescription(new Uri("https://host"), protocol.Version); this.ParameterizedAssociationTest( opDescription, - Protocol.V20.Args.SignatureAlgorithm.HMAC_SHA1); + protocol.Args.SignatureAlgorithm.HMAC_SHA1); } /// <summary> @@ -64,6 +68,7 @@ namespace DotNetOpenAuth.Test.OpenId { private void ParameterizedAssociationTest( ProviderEndpointDescription opDescription, string expectedAssociationType) { + Protocol protocol = Protocol.Lookup(opDescription.ProtocolVersion); bool expectSuccess = expectedAssociationType != null; bool expectDiffieHellman = !opDescription.Endpoint.IsTransportSecure(); Association rpAssociation = null, opAssociation; @@ -79,7 +84,7 @@ namespace DotNetOpenAuth.Test.OpenId { op.AutoRespond(); }); coordinator.IncomingMessageFilter = message => { - Assert.AreSame(opDescription.Protocol.Version, message.ProtocolVersion, "The message was for version {0} but was expected to be for {1}.", message.ProtocolVersion, opDescription.Protocol.Version); + Assert.AreSame(opDescription.ProtocolVersion, message.ProtocolVersion, "The message was recognized as version {0} but was expected to be {1}.", message.ProtocolVersion, opDescription.ProtocolVersion); var associateSuccess = message as AssociateSuccessfulResponse; var associateFailed = message as AssociateUnsuccessfulResponse; if (associateSuccess != null) { @@ -90,7 +95,7 @@ namespace DotNetOpenAuth.Test.OpenId { } }; coordinator.OutgoingMessageFilter = message => { - Assert.AreSame(opDescription.Protocol.Version, message.ProtocolVersion, "The message was for version {0} but was expected to be for {1}.", message.ProtocolVersion, opDescription.Protocol.Version); + Assert.AreSame(opDescription.ProtocolVersion, message.ProtocolVersion, "The message was for version {0} but was expected to be for {1}.", message.ProtocolVersion, opDescription.ProtocolVersion); }; coordinator.Run(); @@ -101,8 +106,8 @@ namespace DotNetOpenAuth.Test.OpenId { Assert.IsNotNull(opAssociation, "The Provider should have stored the association."); Assert.AreEqual(opAssociation.Handle, rpAssociation.Handle); - Assert.AreEqual(expectedAssociationType, rpAssociation.GetAssociationType(opDescription.Protocol)); - Assert.AreEqual(expectedAssociationType, opAssociation.GetAssociationType(opDescription.Protocol)); + Assert.AreEqual(expectedAssociationType, rpAssociation.GetAssociationType(protocol)); + Assert.AreEqual(expectedAssociationType, opAssociation.GetAssociationType(protocol)); Assert.IsTrue(Math.Abs(opAssociation.SecondsTillExpiration - rpAssociation.SecondsTillExpiration) < 60); Assert.IsTrue(MessagingUtilities.AreEquivalent(opAssociation.SecretKey, rpAssociation.SecretKey)); diff --git a/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateDiffieHellmanRequestTests.cs b/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateDiffieHellmanRequestTests.cs index 81014f2..a8648ac 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateDiffieHellmanRequestTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateDiffieHellmanRequestTests.cs @@ -6,6 +6,7 @@ namespace DotNetOpenAuth.Test.OpenId.Messages { using System; + using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -16,7 +17,7 @@ namespace DotNetOpenAuth.Test.OpenId.Messages { [TestInitialize] public void Setup() { - this.request = new AssociateDiffieHellmanRequest(Recipient); + this.request = new AssociateDiffieHellmanRequest(Protocol.V20.Version, Recipient); } [TestMethod] diff --git a/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateRequestTests.cs b/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateRequestTests.cs index ca9b6d6..db73a8c 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateRequestTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateRequestTests.cs @@ -21,7 +21,7 @@ namespace DotNetOpenAuth.Test.OpenId.Messages { [TestInitialize] public void Setup() { - this.request = new AssociateUnencryptedRequest(this.secureRecipient); + this.request = new AssociateUnencryptedRequest(this.protocol.Version, this.secureRecipient); } [TestMethod] @@ -52,14 +52,14 @@ namespace DotNetOpenAuth.Test.OpenId.Messages { [TestMethod] public void ValidMessageTest() { - this.request = new AssociateUnencryptedRequest(this.secureRecipient); + this.request = new AssociateUnencryptedRequest(Protocol.V20.Version, this.secureRecipient); this.request.AssociationType = this.protocol.Args.SignatureAlgorithm.HMAC_SHA1; this.request.EnsureValidMessage(); } [TestMethod, ExpectedException(typeof(ProtocolException))] public void InvalidMessageTest() { - this.request = new AssociateUnencryptedRequest(this.insecureRecipient); + this.request = new AssociateUnencryptedRequest(Protocol.V20.Version, this.insecureRecipient); this.request.AssociationType = this.protocol.Args.SignatureAlgorithm.HMAC_SHA1; this.request.EnsureValidMessage(); // no-encryption only allowed for secure channels. } diff --git a/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateUnencryptedResponseTests.cs b/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateUnencryptedResponseTests.cs index 7455cce..16f76cf 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateUnencryptedResponseTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateUnencryptedResponseTests.cs @@ -5,6 +5,7 @@ //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.OpenId.Messages { + using System; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; @@ -16,7 +17,8 @@ namespace DotNetOpenAuth.Test.OpenId.Messages { [TestInitialize] public void Setup() { - this.response = new AssociateUnencryptedResponse(); + var request = new AssociateUnencryptedRequest(Protocol.V20.Version, new Uri("http://host")); + this.response = new AssociateUnencryptedResponse(request); } [TestMethod] diff --git a/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateUnsuccessfulResponseTests.cs b/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateUnsuccessfulResponseTests.cs index 588ea76..b6f6914 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateUnsuccessfulResponseTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Messages/AssociateUnsuccessfulResponseTests.cs @@ -5,6 +5,7 @@ //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.OpenId.Messages { + using System; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; @@ -16,7 +17,8 @@ namespace DotNetOpenAuth.Test.OpenId.Messages { [TestInitialize] public void Setup() { - this.response = new AssociateUnsuccessfulResponse(); + var request = new AssociateUnencryptedRequest(Protocol.V20.Version, new Uri("http://host")); + this.response = new AssociateUnsuccessfulResponse(request); } [TestMethod] diff --git a/src/DotNetOpenAuth.Test/OpenId/Messages/DirectErrorResponseTests.cs b/src/DotNetOpenAuth.Test/OpenId/Messages/DirectErrorResponseTests.cs index 0cdc3aa..6fd5602 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Messages/DirectErrorResponseTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Messages/DirectErrorResponseTests.cs @@ -5,18 +5,22 @@ //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.OpenId.Messages { + using System; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.OpenId; using DotNetOpenAuth.OpenId.Messages; using Microsoft.VisualStudio.TestTools.UnitTesting; [TestClass] - public class DirectErrorResponseTests { + public class DirectErrorResponseTests : OpenIdTestBase { private DirectErrorResponse response; [TestInitialize] - public void Setup() { - this.response = new DirectErrorResponse(); + public override void SetUp() { + base.SetUp(); + + var request = new AssociateUnencryptedRequest(Protocol.V20.Version, new Uri("http://host")); + this.response = new DirectErrorResponse(request); } [TestMethod] diff --git a/src/DotNetOpenAuth.Test/OpenId/Messages/IndirectErrorResponseTests.cs b/src/DotNetOpenAuth.Test/OpenId/Messages/IndirectErrorResponseTests.cs index a899d89..7794970 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Messages/IndirectErrorResponseTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Messages/IndirectErrorResponseTests.cs @@ -18,7 +18,7 @@ namespace DotNetOpenAuth.Test.OpenId.Messages { [TestInitialize] public void Setup() { - this.response = new IndirectErrorResponse(this.recipient); + this.response = new IndirectErrorResponse(Protocol.V20.Version, this.recipient); } [TestMethod] diff --git a/src/DotNetOpenAuth/DotNetOpenAuth.csproj b/src/DotNetOpenAuth/DotNetOpenAuth.csproj index 441b6a4..0da6de6 100644 --- a/src/DotNetOpenAuth/DotNetOpenAuth.csproj +++ b/src/DotNetOpenAuth/DotNetOpenAuth.csproj @@ -72,13 +72,14 @@ <Compile Include="Configuration\UntrustedWebRequestSection.cs" /> <Compile Include="Configuration\HostNameOrRegexCollection.cs" /> <Compile Include="Configuration\HostNameElement.cs" /> + <Compile Include="Messaging\IDirectResponseProtocolMessage.cs" /> <Compile Include="Messaging\EmptyDictionary.cs" /> <Compile Include="Messaging\EmptyEnumerator.cs" /> <Compile Include="Messaging\EmptyList.cs" /> <Compile Include="Messaging\ErrorUtilities.cs" /> <Compile Include="Messaging\InternalErrorException.cs" /> <Compile Include="Messaging\Reflection\IMessagePartEncoder.cs" /> - <Compile Include="OAuth\ChannelElements\OAuthConsumerMessageTypeProvider.cs" /> + <Compile Include="OAuth\ChannelElements\OAuthConsumerMessageFactory.cs" /> <Compile Include="OAuth\ChannelElements\ITokenGenerator.cs" /> <Compile Include="OAuth\ChannelElements\ITokenManager.cs" /> <Compile Include="OAuth\ChannelElements\OAuthHttpMethodBindingElement.cs" /> @@ -126,7 +127,7 @@ <Compile Include="Messaging\Channel.cs" /> <Compile Include="Messaging\HttpRequestInfo.cs" /> <Compile Include="Messaging\IDirectedProtocolMessage.cs" /> - <Compile Include="Messaging\IMessageTypeProvider.cs" /> + <Compile Include="Messaging\IMessageFactory.cs" /> <Compile Include="Messaging\ITamperResistantProtocolMessage.cs" /> <Compile Include="Messaging\MessageSerializer.cs" /> <Compile Include="Messaging\MessagingStrings.Designer.cs"> @@ -151,7 +152,7 @@ <Compile Include="Loggers\TraceLogger.cs" /> <Compile Include="Messaging\HttpDeliveryMethods.cs" /> <Compile Include="Messaging\MessageTransport.cs" /> - <Compile Include="OAuth\ChannelElements\OAuthServiceProviderMessageTypeProvider.cs" /> + <Compile Include="OAuth\ChannelElements\OAuthServiceProviderMessageFactory.cs" /> <Compile Include="Messaging\ProtocolException.cs" /> <Compile Include="OpenId\Association.cs" /> <Compile Include="OpenId\AssociationMemoryStore.cs" /> @@ -160,7 +161,7 @@ <Compile Include="OpenId\ChannelElements\SigningBindingElement.cs" /> <Compile Include="OpenId\ChannelElements\KeyValueFormEncoding.cs" /> <Compile Include="OpenId\ChannelElements\OpenIdChannel.cs" /> - <Compile Include="OpenId\ChannelElements\OpenIdMessageTypeProvider.cs" /> + <Compile Include="OpenId\ChannelElements\OpenIdMessageFactory.cs" /> <Compile Include="OpenId\Configuration.cs" /> <Compile Include="OpenId\RelyingPartyDescription.cs" /> <Compile Include="OpenId\DiffieHellmanUtilities.cs" /> diff --git a/src/DotNetOpenAuth/Messaging/Channel.cs b/src/DotNetOpenAuth/Messaging/Channel.cs index 190fd36..0f020cc 100644 --- a/src/DotNetOpenAuth/Messaging/Channel.cs +++ b/src/DotNetOpenAuth/Messaging/Channel.cs @@ -53,7 +53,7 @@ namespace DotNetOpenAuth.Messaging { /// A tool that can figure out what kind of message is being received /// so it can be deserialized. /// </summary> - private IMessageTypeProvider messageTypeProvider; + private IMessageFactory messageTypeProvider; /// <summary> /// A list of binding elements in the order they must be applied to outgoing messages. @@ -72,7 +72,7 @@ namespace DotNetOpenAuth.Messaging { /// message types can deserialize from it. /// </param> /// <param name="bindingElements">The binding elements to use in sending and receiving messages.</param> - protected Channel(IMessageTypeProvider messageTypeProvider, params IChannelBindingElement[] bindingElements) { + protected Channel(IMessageFactory messageTypeProvider, params IChannelBindingElement[] bindingElements) { if (messageTypeProvider == null) { throw new ArgumentNullException("messageTypeProvider"); } @@ -113,7 +113,7 @@ namespace DotNetOpenAuth.Messaging { /// Gets a tool that can figure out what kind of message is being received /// so it can be deserialized. /// </summary> - protected IMessageTypeProvider MessageTypeProvider { + protected IMessageFactory MessageFactory { get { return this.messageTypeProvider; } } @@ -365,13 +365,13 @@ namespace DotNetOpenAuth.Messaging { } var responseFields = this.ReadFromResponseInternal(response); - Type messageType = this.MessageTypeProvider.GetResponseMessageType(request, responseFields); - if (messageType == null) { + IDirectResponseProtocolMessage responseMessage = this.MessageFactory.GetNewResponseMessage(request, responseFields); + if (responseMessage == null) { return null; } - var responseSerialize = MessageSerializer.Get(messageType); - var responseMessage = responseSerialize.Deserialize(responseFields, null); + var responseSerialize = MessageSerializer.Get(responseMessage.GetType()); + responseSerialize.Deserialize(responseFields, responseMessage); return responseMessage; } @@ -406,16 +406,16 @@ namespace DotNetOpenAuth.Messaging { throw new ArgumentNullException("fields"); } - Type messageType = this.MessageTypeProvider.GetRequestMessageType(fields); + IProtocolMessage message = this.MessageFactory.GetNewRequestMessage(recipient, fields); // If there was no data, or we couldn't recognize it as a message, abort. - if (messageType == null) { + if (message == null) { return null; } // We have a message! Assemble it. - var serializer = MessageSerializer.Get(messageType); - IProtocolMessage message = serializer.Deserialize(fields, recipient); + var serializer = MessageSerializer.Get(message.GetType()); + serializer.Deserialize(fields, message); return message; } diff --git a/src/DotNetOpenAuth/Messaging/ErrorUtilities.cs b/src/DotNetOpenAuth/Messaging/ErrorUtilities.cs index 5e89a77..9068601 100644 --- a/src/DotNetOpenAuth/Messaging/ErrorUtilities.cs +++ b/src/DotNetOpenAuth/Messaging/ErrorUtilities.cs @@ -63,6 +63,15 @@ namespace DotNetOpenAuth.Messaging { } /// <summary> + /// Throws a <see cref="ProtocolException"/>. + /// </summary> + /// <param name="message">The message to set in the exception.</param> + /// <param name="args">The formatting arguments of the message.</param> + internal static void ThrowProtocol(string message, params object[] args) { + VerifyProtocol(false, message, args); + } + + /// <summary> /// Verifies something about the argument supplied to a method. /// </summary> /// <param name="condition">The condition that must evaluate to true to avoid an exception.</param> diff --git a/src/DotNetOpenAuth/Messaging/IDirectResponseProtocolMessage.cs b/src/DotNetOpenAuth/Messaging/IDirectResponseProtocolMessage.cs new file mode 100644 index 0000000..3b4da6c --- /dev/null +++ b/src/DotNetOpenAuth/Messaging/IDirectResponseProtocolMessage.cs @@ -0,0 +1,17 @@ +//----------------------------------------------------------------------- +// <copyright file="IDirectResponseProtocolMessage.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging { + /// <summary> + /// Undirected messages that serve as direct responses to direct requests. + /// </summary> + public interface IDirectResponseProtocolMessage : IProtocolMessage { + /// <summary> + /// Gets the originating request message that caused this response to be formed. + /// </summary> + IDirectedProtocolMessage OriginatingRequest { get; } + } +} diff --git a/src/DotNetOpenAuth/Messaging/IMessageTypeProvider.cs b/src/DotNetOpenAuth/Messaging/IMessageFactory.cs index ea7fcc7..9ce5f89 100644 --- a/src/DotNetOpenAuth/Messaging/IMessageTypeProvider.cs +++ b/src/DotNetOpenAuth/Messaging/IMessageFactory.cs @@ -1,5 +1,5 @@ //----------------------------------------------------------------------- -// <copyright file="IMessageTypeProvider.cs" company="Andrew Arnott"> +// <copyright file="IMessageFactory.cs" company="Andrew Arnott"> // Copyright (c) Andrew Arnott. All rights reserved. // </copyright> //----------------------------------------------------------------------- @@ -10,19 +10,20 @@ namespace DotNetOpenAuth.Messaging { /// <summary> /// A tool to analyze an incoming message to figure out what concrete class - /// is designed to deserialize it. + /// is designed to deserialize it and instantiates that class. /// </summary> - public interface IMessageTypeProvider { + public interface IMessageFactory { /// <summary> /// Analyzes an incoming request message payload to discover what kind of /// message is embedded in it and returns the type, or null if no match is found. /// </summary> + /// <param name="recipient">The intended or actual recipient of the request message.</param> /// <param name="fields">The name/value pairs that make up the message payload.</param> /// <returns> - /// The <see cref="IProtocolMessage"/>-derived concrete class that this message can + /// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can /// deserialize to. Null if the request isn't recognized as a valid protocol message. /// </returns> - Type GetRequestMessageType(IDictionary<string, string> fields); + IDirectedProtocolMessage GetNewRequestMessage(MessageReceivingEndpoint recipient, IDictionary<string, string> fields); /// <summary> /// Analyzes an incoming request message payload to discover what kind of @@ -33,9 +34,9 @@ namespace DotNetOpenAuth.Messaging { /// </param> /// <param name="fields">The name/value pairs that make up the message payload.</param> /// <returns> - /// The <see cref="IProtocolMessage"/>-derived concrete class that this message can + /// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can /// deserialize to. Null if the request isn't recognized as a valid protocol message. /// </returns> - Type GetResponseMessageType(IProtocolMessage request, IDictionary<string, string> fields); + IDirectResponseProtocolMessage GetNewResponseMessage(IDirectedProtocolMessage request, IDictionary<string, string> fields); } } diff --git a/src/DotNetOpenAuth/Messaging/IProtocolMessage.cs b/src/DotNetOpenAuth/Messaging/IProtocolMessage.cs index 9e3be80..9060a1e 100644 --- a/src/DotNetOpenAuth/Messaging/IProtocolMessage.cs +++ b/src/DotNetOpenAuth/Messaging/IProtocolMessage.cs @@ -38,13 +38,13 @@ namespace DotNetOpenAuth.Messaging { IDictionary<string, string> ExtraData { get; } /// <summary> - /// Gets or sets a value indicating whether this message was deserialized as an incoming message. + /// Gets a value indicating whether this message was deserialized as an incoming message. /// </summary> /// <remarks> /// In message type implementations, this property should default to false and will be set /// to true by the messaging system when the message is deserialized as an incoming message. /// </remarks> - bool Incoming { get; set; } + bool Incoming { get; } /// <summary> /// Checks the message state for conformity to the protocol specification diff --git a/src/DotNetOpenAuth/Messaging/MessageSerializer.cs b/src/DotNetOpenAuth/Messaging/MessageSerializer.cs index 5e86949..1db0c57 100644 --- a/src/DotNetOpenAuth/Messaging/MessageSerializer.cs +++ b/src/DotNetOpenAuth/Messaging/MessageSerializer.cs @@ -77,66 +77,24 @@ namespace DotNetOpenAuth.Messaging { /// Reads name=value pairs into an OAuth message. /// </summary> /// <param name="fields">The name=value pairs that were read in from the transport.</param> - /// <param name="recipient">The recipient of the message.</param> - /// <returns>The instantiated and initialized <see cref="IProtocolMessage"/> instance.</returns> + /// <param name="message">The message to deserialize into.</param> /// <exception cref="ProtocolException">Thrown when protocol rules are broken by the incoming message.</exception> - internal IProtocolMessage Deserialize(IDictionary<string, string> fields, MessageReceivingEndpoint recipient) { - if (fields == null) { - throw new ArgumentNullException("fields"); - } + internal void Deserialize(IDictionary<string, string> fields, IProtocolMessage message) { + ErrorUtilities.VerifyArgumentNotNull(fields, "fields"); + ErrorUtilities.VerifyArgumentNotNull(message, "message"); // Before we deserialize the message, make sure all the required parts are present. MessageDescription.Get(this.messageType).EnsureMessagePartsPassBasicValidation(fields); - IProtocolMessage result = this.CreateMessage(recipient); try { foreach (var pair in fields) { - IDictionary<string, string> dictionary = new MessageDictionary(result); + IDictionary<string, string> dictionary = new MessageDictionary(message); dictionary[pair.Key] = pair.Value; } } catch (ArgumentException ex) { throw ErrorUtilities.Wrap(ex, MessagingStrings.ErrorDeserializingMessage, this.messageType.Name); } - result.EnsureValidMessage(); - return result; - } - - /// <summary> - /// Instantiates a new message to deserialize data into. - /// </summary> - /// <param name="recipient">The recipient this message is directed to, if any.</param> - /// <returns>The newly created message object.</returns> - private IProtocolMessage CreateMessage(MessageReceivingEndpoint recipient) { - IProtocolMessage result; - BindingFlags bindingFlags = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance; - if (typeof(IDirectedProtocolMessage).IsAssignableFrom(this.messageType)) { - // Some directed messages take just the recipient, while others take the whole endpoint - ConstructorInfo ctor; - if ((ctor = this.messageType.GetConstructor(bindingFlags, null, new Type[] { typeof(Uri) }, null)) != null) { - if (recipient == null) { - // We need a recipient to deserialize directed messages. - throw new ArgumentNullException("recipient"); - } - - result = (IProtocolMessage)ctor.Invoke(new object[] { recipient.Location }); - } else if ((ctor = this.messageType.GetConstructor(bindingFlags, null, new Type[] { typeof(MessageReceivingEndpoint) }, null)) != null) { - if (recipient == null) { - // We need a recipient to deserialize directed messages. - throw new ArgumentNullException("recipient"); - } - - result = (IProtocolMessage)ctor.Invoke(new object[] { recipient }); - } else if ((ctor = this.messageType.GetConstructor(bindingFlags, null, new Type[0], null)) != null) { - result = (IProtocolMessage)ctor.Invoke(new object[0]); - } else { - throw new InvalidOperationException("Unrecognized constructor signature on type " + this.messageType); - } - } else { - result = (IProtocolMessage)Activator.CreateInstance(this.messageType, true); - } - - result.Incoming = true; - return result; + message.EnsureValidMessage(); } } } diff --git a/src/DotNetOpenAuth/Messaging/MessagingUtilities.cs b/src/DotNetOpenAuth/Messaging/MessagingUtilities.cs index 042d531..f23aea8 100644 --- a/src/DotNetOpenAuth/Messaging/MessagingUtilities.cs +++ b/src/DotNetOpenAuth/Messaging/MessagingUtilities.cs @@ -290,6 +290,41 @@ namespace DotNetOpenAuth.Messaging { } /// <summary> + /// Determines whether the specified message is a request (indirect message or direct request). + /// </summary> + /// <param name="message">The message in question.</param> + /// <returns> + /// <c>true</c> if the specified message is a request; otherwise, <c>false</c>. + /// </returns> + /// <remarks> + /// Although an <see cref="IProtocolMessage"/> may implement the <see cref="IDirectedProtocolMessage"/> + /// interface, it may only be doing that for its derived classes. These objects are only requests + /// if their <see cref="IDirectedProtocolMessage.Recipient"/> property is non-null. + /// </remarks> + internal static bool IsRequest(this IDirectedProtocolMessage message) { + ErrorUtilities.VerifyArgumentNotNull(message, "message"); + return message.Recipient != null; + } + + /// <summary> + /// Determines whether the specified message is a direct response. + /// </summary> + /// <param name="message">The message in question.</param> + /// <returns> + /// <c>true</c> if the specified message is a direct response; otherwise, <c>false</c>. + /// </returns> + /// <remarks> + /// Although an <see cref="IProtocolMessage"/> may implement the + /// <see cref="IDirectResponseProtocolMessage"/> interface, it may only be doing + /// that for its derived classes. These objects are only requests if their + /// <see cref="IDirectResponseProtocolMessage.OriginatingRequest"/> property is non-null. + /// </remarks> + internal static bool IsDirectResponse(this IDirectResponseProtocolMessage message) { + ErrorUtilities.VerifyArgumentNotNull(message, "message"); + return message.OriginatingRequest != null; + } + + /// <summary> /// A class to convert a <see cref="Comparison<T>"/> into an <see cref="IComparer<T>"/>. /// </summary> /// <typeparam name="T">The type of objects being compared.</typeparam> diff --git a/src/DotNetOpenAuth/Messaging/ProtocolException.cs b/src/DotNetOpenAuth/Messaging/ProtocolException.cs index db79b00..365a689 100644 --- a/src/DotNetOpenAuth/Messaging/ProtocolException.cs +++ b/src/DotNetOpenAuth/Messaging/ProtocolException.cs @@ -130,9 +130,14 @@ namespace DotNetOpenAuth.Messaging { } /// <summary> - /// Gets or sets a value indicating whether this message was deserialized as an incoming message. + /// Gets a value indicating whether this message was deserialized as an incoming message. /// </summary> - bool IProtocolMessage.Incoming { get; set; } + /// <remarks> + /// Always false because exceptions are not a valid message to deserialize. + /// </remarks> + bool IProtocolMessage.Incoming { + get { return false; } + } #endregion @@ -237,11 +242,10 @@ namespace DotNetOpenAuth.Messaging { } /// <summary> - /// Gets or sets a value indicating whether this message was deserialized as an incoming message. + /// Gets a value indicating whether this message was deserialized as an incoming message. /// </summary> protected bool Incoming { - get { return ((IProtocolMessage)this).Incoming; } - set { ((IProtocolMessage)this).Incoming = value; } + get { return false; } } /// <summary> diff --git a/src/DotNetOpenAuth/OAuth/ChannelElements/OAuthChannel.cs b/src/DotNetOpenAuth/OAuth/ChannelElements/OAuthChannel.cs index fafefde..d827ed3 100644 --- a/src/DotNetOpenAuth/OAuth/ChannelElements/OAuthChannel.cs +++ b/src/DotNetOpenAuth/OAuth/ChannelElements/OAuthChannel.cs @@ -33,7 +33,7 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { signingBindingElement, store, tokenManager, - isConsumer ? (IMessageTypeProvider)new OAuthConsumerMessageTypeProvider() : new OAuthServiceProviderMessageTypeProvider(tokenManager)) { + isConsumer ? (IMessageFactory)new OAuthConsumerMessageFactory() : new OAuthServiceProviderMessageFactory(tokenManager)) { } /// <summary> @@ -45,12 +45,12 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// <param name="messageTypeProvider"> /// An injected message type provider instance. /// Except for mock testing, this should always be one of - /// <see cref="OAuthConsumerMessageTypeProvider"/> or <see cref="OAuthServiceProviderMessageTypeProvider"/>. + /// <see cref="OAuthConsumerMessageFactory"/> or <see cref="OAuthServiceProviderMessageFactory"/>. /// </param> /// <remarks> /// This overload for testing purposes only. /// </remarks> - internal OAuthChannel(ITamperProtectionChannelBindingElement signingBindingElement, INonceStore store, ITokenManager tokenManager, IMessageTypeProvider messageTypeProvider) + internal OAuthChannel(ITamperProtectionChannelBindingElement signingBindingElement, INonceStore store, ITokenManager tokenManager, IMessageFactory messageTypeProvider) : base(messageTypeProvider, new OAuthHttpMethodBindingElement(), signingBindingElement, new StandardExpirationBindingElement(), new StandardReplayProtectionBindingElement(store)) { if (tokenManager == null) { throw new ArgumentNullException("tokenManager"); diff --git a/src/DotNetOpenAuth/OAuth/ChannelElements/OAuthConsumerMessageTypeProvider.cs b/src/DotNetOpenAuth/OAuth/ChannelElements/OAuthConsumerMessageFactory.cs index 6305326..5bcac58 100644 --- a/src/DotNetOpenAuth/OAuth/ChannelElements/OAuthConsumerMessageTypeProvider.cs +++ b/src/DotNetOpenAuth/OAuth/ChannelElements/OAuthConsumerMessageFactory.cs @@ -1,5 +1,5 @@ //----------------------------------------------------------------------- -// <copyright file="OAuthConsumerMessageTypeProvider.cs" company="Andrew Arnott"> +// <copyright file="OAuthConsumerMessageFactory.cs" company="Andrew Arnott"> // Copyright (c) Andrew Arnott. All rights reserved. // </copyright> //----------------------------------------------------------------------- @@ -11,41 +11,47 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { using DotNetOpenAuth.OAuth.Messages; /// <summary> - /// An OAuth-protocol specific implementation of the <see cref="IMessageTypeProvider"/> + /// An OAuth-protocol specific implementation of the <see cref="IMessageFactory"/> /// interface. /// </summary> - public class OAuthConsumerMessageTypeProvider : IMessageTypeProvider { + public class OAuthConsumerMessageFactory : IMessageFactory { /// <summary> - /// Initializes a new instance of the <see cref="OAuthConsumerMessageTypeProvider"/> class. + /// Initializes a new instance of the <see cref="OAuthConsumerMessageFactory"/> class. /// </summary> - protected internal OAuthConsumerMessageTypeProvider() { + protected internal OAuthConsumerMessageFactory() { } - #region IMessageTypeProvider Members + #region IMessageFactory Members /// <summary> - /// Analyzes an incoming request message payload to discover what kind of + /// Analyzes an incoming request message payload to discover what kind of /// message is embedded in it and returns the type, or null if no match is found. /// </summary> + /// <param name="recipient">The intended or actual recipient of the request message.</param> /// <param name="fields">The name/value pairs that make up the message payload.</param> + /// <returns> + /// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can + /// deserialize to. Null if the request isn't recognized as a valid protocol message. + /// </returns> /// <remarks> /// The request messages are: /// UserAuthorizationResponse /// </remarks> - /// <returns> - /// The <see cref="IProtocolMessage"/>-derived concrete class that this message can - /// deserialize to. Null if the request isn't recognized as a valid protocol message. - /// </returns> - public virtual Type GetRequestMessageType(IDictionary<string, string> fields) { - if (fields == null) { - throw new ArgumentNullException("fields"); - } + public virtual IDirectedProtocolMessage GetNewRequestMessage(MessageReceivingEndpoint recipient, IDictionary<string, string> fields) { + ErrorUtilities.VerifyArgumentNotNull(recipient, "recipient"); + ErrorUtilities.VerifyArgumentNotNull(fields, "fields"); + + MessageBase message = null; if (fields.ContainsKey("oauth_token")) { - return typeof(UserAuthorizationResponse); + message = new UserAuthorizationResponse(recipient.Location); + } + + if (message != null) { + message.SetAsIncoming(); } - return null; + return message; } /// <summary> @@ -58,7 +64,7 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// </param> /// <param name="fields">The name/value pairs that make up the message payload.</param> /// <returns> - /// The <see cref="IProtocolMessage"/>-derived concrete class that this message can + /// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can /// deserialize to. Null if the request isn't recognized as a valid protocol message. /// </returns> /// <remarks> @@ -66,10 +72,11 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// UnauthorizedTokenResponse /// AuthorizedTokenResponse /// </remarks> - public virtual Type GetResponseMessageType(IProtocolMessage request, IDictionary<string, string> fields) { - if (fields == null) { - throw new ArgumentNullException("fields"); - } + public virtual IDirectResponseProtocolMessage GetNewResponseMessage(IDirectedProtocolMessage request, IDictionary<string, string> fields) { + ErrorUtilities.VerifyArgumentNotNull(request, "request"); + ErrorUtilities.VerifyArgumentNotNull(fields, "fields"); + + MessageBase message = null; // All response messages have the oauth_token field. if (!fields.ContainsKey("oauth_token")) { @@ -82,14 +89,22 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { return null; } - if (request is UnauthorizedTokenRequest) { - return typeof(UnauthorizedTokenResponse); - } else if (request is AuthorizedTokenRequest) { - return typeof(AuthorizedTokenResponse); + var unauthorizedTokenRequest = request as UnauthorizedTokenRequest; + var authorizedTokenRequest = request as AuthorizedTokenRequest; + if (unauthorizedTokenRequest != null) { + message = new UnauthorizedTokenResponse(unauthorizedTokenRequest); + } else if (authorizedTokenRequest != null) { + message = new AuthorizedTokenResponse(authorizedTokenRequest); } else { Logger.ErrorFormat("Unexpected response message given the request type {0}", request.GetType().Name); throw new ProtocolException(OAuthStrings.InvalidIncomingMessage); } + + if (message != null) { + message.SetAsIncoming(); + } + + return message; } #endregion diff --git a/src/DotNetOpenAuth/OAuth/ChannelElements/OAuthServiceProviderMessageTypeProvider.cs b/src/DotNetOpenAuth/OAuth/ChannelElements/OAuthServiceProviderMessageFactory.cs index ee509c6..cd900cf 100644 --- a/src/DotNetOpenAuth/OAuth/ChannelElements/OAuthServiceProviderMessageTypeProvider.cs +++ b/src/DotNetOpenAuth/OAuth/ChannelElements/OAuthServiceProviderMessageFactory.cs @@ -1,5 +1,5 @@ //----------------------------------------------------------------------- -// <copyright file="OAuthServiceProviderMessageTypeProvider.cs" company="Andrew Arnott"> +// <copyright file="OAuthServiceProviderMessageFactory.cs" company="Andrew Arnott"> // Copyright (c) Andrew Arnott. All rights reserved. // </copyright> //----------------------------------------------------------------------- @@ -11,34 +11,37 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { using DotNetOpenAuth.OAuth.Messages; /// <summary> - /// An OAuth-protocol specific implementation of the <see cref="IMessageTypeProvider"/> + /// An OAuth-protocol specific implementation of the <see cref="IMessageFactory"/> /// interface. /// </summary> - public class OAuthServiceProviderMessageTypeProvider : IMessageTypeProvider { + public class OAuthServiceProviderMessageFactory : IMessageFactory { /// <summary> /// The token manager to use for discerning between request and access tokens. /// </summary> private ITokenManager tokenManager; /// <summary> - /// Initializes a new instance of the <see cref="OAuthServiceProviderMessageTypeProvider"/> class. + /// Initializes a new instance of the <see cref="OAuthServiceProviderMessageFactory"/> class. /// </summary> /// <param name="tokenManager">The token manager instance to use.</param> - protected internal OAuthServiceProviderMessageTypeProvider(ITokenManager tokenManager) { - if (tokenManager == null) { - throw new ArgumentNullException("tokenManager"); - } + protected internal OAuthServiceProviderMessageFactory(ITokenManager tokenManager) { + ErrorUtilities.VerifyArgumentNotNull(tokenManager, "tokenManager"); this.tokenManager = tokenManager; } - #region IMessageTypeProvider Members + #region IMessageFactory Members /// <summary> - /// Analyzes an incoming request message payload to discover what kind of + /// Analyzes an incoming request message payload to discover what kind of /// message is embedded in it and returns the type, or null if no match is found. /// </summary> + /// <param name="recipient">The intended or actual recipient of the request message.</param> /// <param name="fields">The name/value pairs that make up the message payload.</param> + /// <returns> + /// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can + /// deserialize to. Null if the request isn't recognized as a valid protocol message. + /// </returns> /// <remarks> /// The request messages are: /// UnauthorizedTokenRequest @@ -46,33 +49,34 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// UserAuthorizationRequest /// AccessProtectedResourceRequest /// </remarks> - /// <returns> - /// The <see cref="IProtocolMessage"/>-derived concrete class that this message can - /// deserialize to. Null if the request isn't recognized as a valid protocol message. - /// </returns> - public virtual Type GetRequestMessageType(IDictionary<string, string> fields) { - if (fields == null) { - throw new ArgumentNullException("fields"); - } + public virtual IDirectedProtocolMessage GetNewRequestMessage(MessageReceivingEndpoint recipient, IDictionary<string, string> fields) { + ErrorUtilities.VerifyArgumentNotNull(recipient, "recipient"); + ErrorUtilities.VerifyArgumentNotNull(fields, "fields"); - if (fields.ContainsKey("oauth_consumer_key") && - !fields.ContainsKey("oauth_token")) { - return typeof(UnauthorizedTokenRequest); - } + MessageBase message = null; if (fields.ContainsKey("oauth_consumer_key") && + !fields.ContainsKey("oauth_token")) { + message = new UnauthorizedTokenRequest(recipient); + } else if (fields.ContainsKey("oauth_consumer_key") && fields.ContainsKey("oauth_token")) { // Discern between RequestAccessToken and AccessProtectedResources, // which have all the same parameters, by figuring out what type of token // is in the token parameter. bool tokenTypeIsAccessToken = this.tokenManager.GetTokenType(fields["oauth_token"]) == TokenType.AccessToken; - return tokenTypeIsAccessToken ? typeof(AccessProtectedResourceRequest) : - typeof(AuthorizedTokenRequest); + message = tokenTypeIsAccessToken ? (MessageBase)new AccessProtectedResourceRequest(recipient) : + new AuthorizedTokenRequest(recipient); + } else { + // fail over to the message with no required fields at all. + message = new UserAuthorizationRequest(recipient); + } + + if (message != null) { + message.SetAsIncoming(); } - // fail over to the message with no required fields at all. - return typeof(UserAuthorizationRequest); + return message; } /// <summary> @@ -92,10 +96,9 @@ namespace DotNetOpenAuth.OAuth.ChannelElements { /// The response messages are: /// None. /// </remarks> - public virtual Type GetResponseMessageType(IProtocolMessage request, IDictionary<string, string> fields) { - if (fields == null) { - throw new ArgumentNullException("fields"); - } + public virtual IDirectResponseProtocolMessage GetNewResponseMessage(IDirectedProtocolMessage request, IDictionary<string, string> fields) { + ErrorUtilities.VerifyArgumentNotNull(request, "request"); + ErrorUtilities.VerifyArgumentNotNull(fields, "fields"); Logger.Error("Service Providers are not expected to ever receive responses."); return null; diff --git a/src/DotNetOpenAuth/OAuth/ConsumerBase.cs b/src/DotNetOpenAuth/OAuth/ConsumerBase.cs index 22e9bdc..7c634b7 100644 --- a/src/DotNetOpenAuth/OAuth/ConsumerBase.cs +++ b/src/DotNetOpenAuth/OAuth/ConsumerBase.cs @@ -33,7 +33,7 @@ namespace DotNetOpenAuth.OAuth { ITamperProtectionChannelBindingElement signingElement = serviceDescription.CreateTamperProtectionElement(); INonceStore store = new NonceMemoryStore(StandardExpirationBindingElement.DefaultMaximumMessageAge); - this.OAuthChannel = new OAuthChannel(signingElement, store, tokenManager, new OAuthConsumerMessageTypeProvider()); + this.OAuthChannel = new OAuthChannel(signingElement, store, tokenManager, new OAuthConsumerMessageFactory()); this.ServiceProvider = serviceDescription; } diff --git a/src/DotNetOpenAuth/OAuth/Messages/AuthorizedTokenResponse.cs b/src/DotNetOpenAuth/OAuth/Messages/AuthorizedTokenResponse.cs index 3030c01..14413a5 100644 --- a/src/DotNetOpenAuth/OAuth/Messages/AuthorizedTokenResponse.cs +++ b/src/DotNetOpenAuth/OAuth/Messages/AuthorizedTokenResponse.cs @@ -17,8 +17,9 @@ namespace DotNetOpenAuth.OAuth.Messages { /// <summary> /// Initializes a new instance of the <see cref="AuthorizedTokenResponse"/> class. /// </summary> - protected internal AuthorizedTokenResponse() - : base(MessageProtections.None, MessageTransport.Direct) { + /// <param name="originatingRequest">The originating request.</param> + protected internal AuthorizedTokenResponse(AuthorizedTokenRequest originatingRequest) + : base(MessageProtections.None, originatingRequest) { } /// <summary> diff --git a/src/DotNetOpenAuth/OAuth/Messages/MessageBase.cs b/src/DotNetOpenAuth/OAuth/Messages/MessageBase.cs index 4cf04db..f835efd 100644 --- a/src/DotNetOpenAuth/OAuth/Messages/MessageBase.cs +++ b/src/DotNetOpenAuth/OAuth/Messages/MessageBase.cs @@ -16,7 +16,7 @@ namespace DotNetOpenAuth.OAuth.Messages { /// <summary> /// A base class for all OAuth messages. /// </summary> - public abstract class MessageBase : IDirectedProtocolMessage { + public abstract class MessageBase : IDirectedProtocolMessage, IDirectResponseProtocolMessage { /// <summary> /// A store for extra name/value data pairs that are attached to this message. /// </summary> @@ -37,6 +37,16 @@ namespace DotNetOpenAuth.OAuth.Messages { /// </summary> private MessageReceivingEndpoint recipient; + /// <summary> + /// Backing store for the <see cref="OriginatingRequest"/> properties. + /// </summary> + private IDirectedProtocolMessage originatingRequest; + + /// <summary> + /// Backing store for the <see cref="Incoming"/> properties. + /// </summary> + private bool incoming; + #if DEBUG /// <summary> /// Initializes static members of the <see cref="MessageBase"/> class. @@ -47,17 +57,20 @@ namespace DotNetOpenAuth.OAuth.Messages { #endif /// <summary> - /// Initializes a new instance of the <see cref="MessageBase"/> class. + /// Initializes a new instance of the <see cref="MessageBase"/> class for direct response messages. /// </summary> /// <param name="protectionRequired">The level of protection the message requires.</param> - /// <param name="transport">A value indicating whether this message requires a direct or indirect transport.</param> - protected MessageBase(MessageProtections protectionRequired, MessageTransport transport) { + /// <param name="originatingRequest">The request that asked for this direct response.</param> + protected MessageBase(MessageProtections protectionRequired, IDirectedProtocolMessage originatingRequest) { + ErrorUtilities.VerifyArgumentNotNull(originatingRequest, "originatingRequest"); + this.protectionRequired = protectionRequired; - this.transport = transport; + this.transport = MessageTransport.Direct; + this.originatingRequest = originatingRequest; } /// <summary> - /// Initializes a new instance of the <see cref="MessageBase"/> class. + /// Initializes a new instance of the <see cref="MessageBase"/> class for direct requests or indirect messages. /// </summary> /// <param name="protectionRequired">The level of protection the message requires.</param> /// <param name="transport">A value indicating whether this message requires a direct or indirect transport.</param> @@ -103,9 +116,11 @@ namespace DotNetOpenAuth.OAuth.Messages { } /// <summary> - /// Gets or sets a value indicating whether this message was deserialized as an incoming message. + /// Gets a value indicating whether this message was deserialized as an incoming message. /// </summary> - bool IProtocolMessage.Incoming { get; set; } + bool IProtocolMessage.Incoming { + get { return this.incoming; } + } #endregion @@ -127,6 +142,17 @@ namespace DotNetOpenAuth.OAuth.Messages { #endregion + #region IDirectResponseProtocolMessage Members + + /// <summary> + /// Gets the originating request message that caused this response to be formed. + /// </summary> + IDirectedProtocolMessage IDirectResponseProtocolMessage.OriginatingRequest { + get { return this.originatingRequest; } + } + + #endregion + /// <summary> /// Gets or sets a value indicating whether security sensitive strings are /// emitted from the ToString() method. @@ -162,11 +188,10 @@ namespace DotNetOpenAuth.OAuth.Messages { } /// <summary> - /// Gets or sets a value indicating whether this message was deserialized as an incoming message. + /// Gets a value indicating whether this message was deserialized as an incoming message. /// </summary> protected bool Incoming { - get { return ((IProtocolMessage)this).Incoming; } - set { ((IProtocolMessage)this).Incoming = value; } + get { return this.incoming; } } /// <summary> @@ -193,6 +218,13 @@ namespace DotNetOpenAuth.OAuth.Messages { } } + /// <summary> + /// Gets the originating request message that caused this response to be formed. + /// </summary> + protected IDirectedProtocolMessage OriginatingRequest { + get { return this.originatingRequest; } + } + #region IProtocolMessage Methods /// <summary> @@ -232,6 +264,13 @@ namespace DotNetOpenAuth.OAuth.Messages { } /// <summary> + /// Sets a flag indicating that this message is received (as opposed to sent). + /// </summary> + internal void SetAsIncoming() { + this.incoming = true; + } + + /// <summary> /// Checks the message state for conformity to the protocol specification /// and throws an exception if the message is invalid. /// </summary> diff --git a/src/DotNetOpenAuth/OAuth/Messages/UnauthorizedTokenResponse.cs b/src/DotNetOpenAuth/OAuth/Messages/UnauthorizedTokenResponse.cs index 39f8fd0..285dec7 100644 --- a/src/DotNetOpenAuth/OAuth/Messages/UnauthorizedTokenResponse.cs +++ b/src/DotNetOpenAuth/OAuth/Messages/UnauthorizedTokenResponse.cs @@ -24,18 +24,11 @@ namespace DotNetOpenAuth.OAuth.Messages { /// <remarks> /// This constructor is used by the Service Provider to send the message. /// </remarks> - protected internal UnauthorizedTokenResponse(UnauthorizedTokenRequest requestMessage, string requestToken, string tokenSecret) : this() { - if (requestMessage == null) { - throw new ArgumentNullException("requestMessage"); - } - if (string.IsNullOrEmpty(requestToken)) { - throw new ArgumentNullException("requestToken"); - } - if (string.IsNullOrEmpty(tokenSecret)) { - throw new ArgumentNullException("tokenSecret"); - } + protected internal UnauthorizedTokenResponse(UnauthorizedTokenRequest requestMessage, string requestToken, string tokenSecret) + : this(requestMessage) { + ErrorUtilities.VerifyArgumentNotNull(requestToken, "requestToken"); + ErrorUtilities.VerifyArgumentNotNull(tokenSecret, "tokenSecret"); - this.RequestMessage = requestMessage; this.RequestToken = requestToken; this.TokenSecret = tokenSecret; } @@ -43,9 +36,10 @@ namespace DotNetOpenAuth.OAuth.Messages { /// <summary> /// Initializes a new instance of the <see cref="UnauthorizedTokenResponse"/> class. /// </summary> + /// <param name="originatingRequest">The originating request.</param> /// <remarks>This constructor is used by the consumer to deserialize the message.</remarks> - protected internal UnauthorizedTokenResponse() - : base(MessageProtections.None, MessageTransport.Direct) { + protected internal UnauthorizedTokenResponse(UnauthorizedTokenRequest originatingRequest) + : base(MessageProtections.None, originatingRequest) { } /// <summary> @@ -81,7 +75,9 @@ namespace DotNetOpenAuth.OAuth.Messages { /// <summary> /// Gets the original request for an unauthorized token. /// </summary> - internal UnauthorizedTokenRequest RequestMessage { get; private set; } + internal UnauthorizedTokenRequest RequestMessage { + get { return (UnauthorizedTokenRequest)this.OriginatingRequest; } + } /// <summary> /// Gets or sets the Token Secret. diff --git a/src/DotNetOpenAuth/OAuth/ServiceProvider.cs b/src/DotNetOpenAuth/OAuth/ServiceProvider.cs index e35aa7c..824aa20 100644 --- a/src/DotNetOpenAuth/OAuth/ServiceProvider.cs +++ b/src/DotNetOpenAuth/OAuth/ServiceProvider.cs @@ -37,7 +37,7 @@ namespace DotNetOpenAuth.OAuth { /// <param name="serviceDescription">The endpoints and behavior on the Service Provider.</param> /// <param name="tokenManager">The host's method of storing and recalling tokens and secrets.</param> public ServiceProvider(ServiceProviderDescription serviceDescription, ITokenManager tokenManager) - : this(serviceDescription, tokenManager, new OAuthServiceProviderMessageTypeProvider(tokenManager)) { + : this(serviceDescription, tokenManager, new OAuthServiceProviderMessageFactory(tokenManager)) { } /// <summary> @@ -46,7 +46,7 @@ namespace DotNetOpenAuth.OAuth { /// <param name="serviceDescription">The endpoints and behavior on the Service Provider.</param> /// <param name="tokenManager">The host's method of storing and recalling tokens and secrets.</param> /// <param name="messageTypeProvider">An object that can figure out what type of message is being received for deserialization.</param> - public ServiceProvider(ServiceProviderDescription serviceDescription, ITokenManager tokenManager, OAuthServiceProviderMessageTypeProvider messageTypeProvider) { + public ServiceProvider(ServiceProviderDescription serviceDescription, ITokenManager tokenManager, OAuthServiceProviderMessageFactory messageTypeProvider) { if (serviceDescription == null) { throw new ArgumentNullException("serviceDescription"); } @@ -297,7 +297,7 @@ namespace DotNetOpenAuth.OAuth { string accessToken = this.TokenGenerator.GenerateAccessToken(request.ConsumerKey); string tokenSecret = this.TokenGenerator.GenerateSecret(); this.TokenManager.ExpireRequestTokenAndStoreNewAccessToken(request.ConsumerKey, request.RequestToken, accessToken, tokenSecret); - var grantAccess = new AuthorizedTokenResponse { + var grantAccess = new AuthorizedTokenResponse(request) { AccessToken = accessToken, TokenSecret = tokenSecret, }; diff --git a/src/DotNetOpenAuth/OpenId/ChannelElements/OpenIdChannel.cs b/src/DotNetOpenAuth/OpenId/ChannelElements/OpenIdChannel.cs index 7850e75..023cce6 100644 --- a/src/DotNetOpenAuth/OpenId/ChannelElements/OpenIdChannel.cs +++ b/src/DotNetOpenAuth/OpenId/ChannelElements/OpenIdChannel.cs @@ -37,14 +37,14 @@ namespace DotNetOpenAuth.OpenId.ChannelElements { /// Initializes a new instance of the <see cref="OpenIdChannel"/> class. /// </summary> internal OpenIdChannel() - : this(new OpenIdMessageTypeProvider()) { + : this(new OpenIdMessageFactory()) { } /// <summary> /// Initializes a new instance of the <see cref="OpenIdChannel"/> class. /// </summary> /// <param name="messageTypeProvider">An object that knows how to distinguish the various OpenID message types for deserialization purposes.</param> - private OpenIdChannel(IMessageTypeProvider messageTypeProvider) : + private OpenIdChannel(IMessageFactory messageTypeProvider) : base(messageTypeProvider) { } diff --git a/src/DotNetOpenAuth/OpenId/ChannelElements/OpenIdMessageFactory.cs b/src/DotNetOpenAuth/OpenId/ChannelElements/OpenIdMessageFactory.cs new file mode 100644 index 0000000..6d564b5 --- /dev/null +++ b/src/DotNetOpenAuth/OpenId/ChannelElements/OpenIdMessageFactory.cs @@ -0,0 +1,113 @@ +//----------------------------------------------------------------------- +// <copyright file="OpenIdMessageFactory.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.OpenId.ChannelElements { + using System; + using System.Collections.Generic; + using System.Linq; + using System.Text; + using DotNetOpenAuth.Messaging; + using DotNetOpenAuth.OpenId.Messages; + + /// <summary> + /// Distinguishes the various OpenID message types for deserialization purposes. + /// </summary> + internal class OpenIdMessageFactory : IMessageFactory { + #region IMessageFactory Members + + /// <summary> + /// Analyzes an incoming request message payload to discover what kind of + /// message is embedded in it and returns the type, or null if no match is found. + /// </summary> + /// <param name="recipient">The intended or actual recipient of the request message.</param> + /// <param name="fields">The name/value pairs that make up the message payload.</param> + /// <returns> + /// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can + /// deserialize to. Null if the request isn't recognized as a valid protocol message. + /// </returns> + public IDirectedProtocolMessage GetNewRequestMessage(MessageReceivingEndpoint recipient, IDictionary<string, string> fields) { + ErrorUtilities.VerifyArgumentNotNull(recipient, "recipient"); + ErrorUtilities.VerifyArgumentNotNull(fields, "fields"); + + RequestBase message = null; + + // Discern the OpenID version of the message. + Protocol protocol = Protocol.V11; + string ns; + if (fields.TryGetValue(Protocol.V20.openid.ns, out ns)) { + ErrorUtilities.VerifyProtocol(string.Equals(ns, Protocol.OpenId2Namespace, StringComparison.Ordinal), MessagingStrings.UnexpectedMessagePartValue, Protocol.V20.openid.ns, ns); + protocol = Protocol.V20; + } + + string mode; + if (fields.TryGetValue(protocol.openid.mode, out mode)) { + if (string.Equals(mode, protocol.Args.Mode.associate)) { + if (fields.ContainsKey(protocol.openid.dh_consumer_public)) { + message = new AssociateDiffieHellmanRequest(protocol.Version, recipient.Location); + } else { + message = new AssociateUnencryptedRequest(protocol.Version, recipient.Location); + } + } else { + ErrorUtilities.ThrowProtocol(MessagingStrings.UnexpectedMessagePartValue, protocol.openid.mode, mode); + } + + // TODO: handle more modes + } + + if (message != null) { + message.SetAsIncoming(); + } + + return message; + } + + /// <summary> + /// Analyzes an incoming request message payload to discover what kind of + /// message is embedded in it and returns the type, or null if no match is found. + /// </summary> + /// <param name="request">The message that was sent as a request that resulted in the response.</param> + /// <param name="fields">The name/value pairs that make up the message payload.</param> + /// <returns> + /// A newly instantiated <see cref="IProtocolMessage"/>-derived object that this message can + /// deserialize to. Null if the request isn't recognized as a valid protocol message. + /// </returns> + public IDirectResponseProtocolMessage GetNewResponseMessage(IDirectedProtocolMessage request, IDictionary<string, string> fields) { + ErrorUtilities.VerifyArgumentNotNull(request, "request"); + ErrorUtilities.VerifyArgumentNotNull(fields, "fields"); + + DirectResponseBase message = null; + + // Discern the OpenID version of the message. + Protocol protocol = Protocol.V11; + string ns; + if (fields.TryGetValue(Protocol.V20.openidnp.ns, out ns)) { + ErrorUtilities.VerifyProtocol(string.Equals(ns, Protocol.OpenId2Namespace, StringComparison.Ordinal), MessagingStrings.UnexpectedMessagePartValue, Protocol.V20.openidnp.ns, ns); + protocol = Protocol.V20; + } + + var associateDiffieHellmanRequest = request as AssociateDiffieHellmanRequest; + var associateUnencryptedRequest = request as AssociateUnencryptedRequest; + + if (associateDiffieHellmanRequest != null) { + message = new AssociateDiffieHellmanResponse(associateDiffieHellmanRequest); + } + + if (associateUnencryptedRequest != null) { + message = new AssociateUnencryptedResponse(associateUnencryptedRequest); + } + + // TODO: recognize more message types here + + if (message != null) { + message.SetAsIncoming(); + } + + return message; + } + + #endregion + } +} diff --git a/src/DotNetOpenAuth/OpenId/ChannelElements/OpenIdMessageTypeProvider.cs b/src/DotNetOpenAuth/OpenId/ChannelElements/OpenIdMessageTypeProvider.cs deleted file mode 100644 index d811d61..0000000 --- a/src/DotNetOpenAuth/OpenId/ChannelElements/OpenIdMessageTypeProvider.cs +++ /dev/null @@ -1,49 +0,0 @@ -//----------------------------------------------------------------------- -// <copyright file="OpenIdMessageTypeProvider.cs" company="Andrew Arnott"> -// Copyright (c) Andrew Arnott. All rights reserved. -// </copyright> -//----------------------------------------------------------------------- - -namespace DotNetOpenAuth.OpenId.ChannelElements { - using System; - using System.Collections.Generic; - using System.Linq; - using System.Text; - using DotNetOpenAuth.Messaging; - - /// <summary> - /// Distinguishes the various OpenID message types for deserialization purposes. - /// </summary> - internal class OpenIdMessageTypeProvider : IMessageTypeProvider { - #region IMessageTypeProvider Members - - /// <summary> - /// Analyzes an incoming request message payload to discover what kind of - /// message is embedded in it and returns the type, or null if no match is found. - /// </summary> - /// <param name="fields">The name/value pairs that make up the message payload.</param> - /// <returns> - /// The <see cref="IProtocolMessage"/>-derived concrete class that this message can - /// deserialize to. Null if the request isn't recognized as a valid protocol message. - /// </returns> - public Type GetRequestMessageType(IDictionary<string, string> fields) { - throw new NotImplementedException(); - } - - /// <summary> - /// Analyzes an incoming request message payload to discover what kind of - /// message is embedded in it and returns the type, or null if no match is found. - /// </summary> - /// <param name="request">The message that was sent as a request that resulted in the response.</param> - /// <param name="fields">The name/value pairs that make up the message payload.</param> - /// <returns> - /// The <see cref="IProtocolMessage"/>-derived concrete class that this message can - /// deserialize to. Null if the request isn't recognized as a valid protocol message. - /// </returns> - public Type GetResponseMessageType(IProtocolMessage request, IDictionary<string, string> fields) { - throw new NotImplementedException(); - } - - #endregion - } -} diff --git a/src/DotNetOpenAuth/OpenId/Messages/AssociateDiffieHellmanRequest.cs b/src/DotNetOpenAuth/OpenId/Messages/AssociateDiffieHellmanRequest.cs index 611936d..b984b30 100644 --- a/src/DotNetOpenAuth/OpenId/Messages/AssociateDiffieHellmanRequest.cs +++ b/src/DotNetOpenAuth/OpenId/Messages/AssociateDiffieHellmanRequest.cs @@ -45,9 +45,10 @@ namespace DotNetOpenAuth.OpenId.Messages { /// <summary> /// Initializes a new instance of the <see cref="AssociateDiffieHellmanRequest"/> class. /// </summary> + /// <param name="version">The OpenID version this message must comply with.</param> /// <param name="providerEndpoint">The OpenID Provider endpoint.</param> - internal AssociateDiffieHellmanRequest(Uri providerEndpoint) - : base(providerEndpoint) { + internal AssociateDiffieHellmanRequest(Version version, Uri providerEndpoint) + : base(version, providerEndpoint) { this.DiffieHellmanModulus = DefaultMod; this.DiffieHellmanGen = DefaultGen; } @@ -110,7 +111,7 @@ namespace DotNetOpenAuth.OpenId.Messages { /// Failed association response messages will derive from <see cref="AssociateUnsuccessfulResponse"/>.</para> /// </remarks> protected override IProtocolMessage CreateResponseCore() { - var response = new AssociateDiffieHellmanResponse(); + var response = new AssociateDiffieHellmanResponse(this); response.AssociationType = this.AssociationType; return response; } diff --git a/src/DotNetOpenAuth/OpenId/Messages/AssociateDiffieHellmanResponse.cs b/src/DotNetOpenAuth/OpenId/Messages/AssociateDiffieHellmanResponse.cs index 432595b..d486854 100644 --- a/src/DotNetOpenAuth/OpenId/Messages/AssociateDiffieHellmanResponse.cs +++ b/src/DotNetOpenAuth/OpenId/Messages/AssociateDiffieHellmanResponse.cs @@ -19,6 +19,14 @@ namespace DotNetOpenAuth.OpenId.Messages { /// </remarks> internal class AssociateDiffieHellmanResponse : AssociateSuccessfulResponse { /// <summary> + /// Initializes a new instance of the <see cref="AssociateDiffieHellmanResponse"/> class. + /// </summary> + /// <param name="originatingRequest">The originating request.</param> + internal AssociateDiffieHellmanResponse(AssociateDiffieHellmanRequest originatingRequest) + : base(originatingRequest) { + } + + /// <summary> /// Gets or sets the Provider's Diffie-Hellman public key. /// </summary> /// <value>btwoc(g ^ xb mod p)</value> diff --git a/src/DotNetOpenAuth/OpenId/Messages/AssociateRequest.cs b/src/DotNetOpenAuth/OpenId/Messages/AssociateRequest.cs index 656ec68..020f2dc 100644 --- a/src/DotNetOpenAuth/OpenId/Messages/AssociateRequest.cs +++ b/src/DotNetOpenAuth/OpenId/Messages/AssociateRequest.cs @@ -21,9 +21,10 @@ namespace DotNetOpenAuth.OpenId.Messages { /// <summary> /// Initializes a new instance of the <see cref="AssociateRequest"/> class. /// </summary> + /// <param name="version">The OpenID version this message must comply with.</param> /// <param name="providerEndpoint">The OpenID Provider endpoint.</param> - protected AssociateRequest(Uri providerEndpoint) - : base(providerEndpoint, "associate", MessageTransport.Direct) { + protected AssociateRequest(Version version, Uri providerEndpoint) + : base(version, providerEndpoint, "associate", MessageTransport.Direct) { } /// <summary> @@ -83,17 +84,17 @@ namespace DotNetOpenAuth.OpenId.Messages { bool unencryptedAllowed = provider.Endpoint.IsTransportSecure(); bool useDiffieHellman = !unencryptedAllowed; string associationType, sessionType; - if (!HmacShaAssociation.TryFindBestAssociation(provider.Protocol, securityRequirements, useDiffieHellman, out associationType, out sessionType)) { + if (!HmacShaAssociation.TryFindBestAssociation(Protocol.Lookup(provider.ProtocolVersion), securityRequirements, useDiffieHellman, out associationType, out sessionType)) { // There are no associations that meet all requirements. Logger.Warn("Security requirements and protocol combination knock out all possible association types. Dumb mode forced."); return null; } if (unencryptedAllowed) { - associateRequest = new AssociateUnencryptedRequest(provider.Endpoint); + associateRequest = new AssociateUnencryptedRequest(provider.ProtocolVersion, provider.Endpoint); associateRequest.AssociationType = associationType; } else { - var diffieHellmanAssociateRequest = new AssociateDiffieHellmanRequest(provider.Endpoint); + var diffieHellmanAssociateRequest = new AssociateDiffieHellmanRequest(provider.ProtocolVersion, provider.Endpoint); diffieHellmanAssociateRequest.AssociationType = associationType; diffieHellmanAssociateRequest.SessionType = sessionType; diffieHellmanAssociateRequest.InitializeRequest(); diff --git a/src/DotNetOpenAuth/OpenId/Messages/AssociateSuccessfulResponse.cs b/src/DotNetOpenAuth/OpenId/Messages/AssociateSuccessfulResponse.cs index 7d4ee04..043a98c 100644 --- a/src/DotNetOpenAuth/OpenId/Messages/AssociateSuccessfulResponse.cs +++ b/src/DotNetOpenAuth/OpenId/Messages/AssociateSuccessfulResponse.cs @@ -26,6 +26,14 @@ namespace DotNetOpenAuth.OpenId.Messages { private bool associationCreated; /// <summary> + /// Initializes a new instance of the <see cref="AssociateSuccessfulResponse"/> class. + /// </summary> + /// <param name="originatingRequest">The originating request.</param> + internal AssociateSuccessfulResponse(AssociateRequest originatingRequest) + : base(originatingRequest) { + } + + /// <summary> /// Gets or sets the association handle is used as a key to refer to this association in subsequent messages. /// </summary> /// <value>A string 255 characters or less in length. It MUST consist only of ASCII characters in the range 33-126 inclusive (printable non-whitespace characters). </value> diff --git a/src/DotNetOpenAuth/OpenId/Messages/AssociateUnencryptedRequest.cs b/src/DotNetOpenAuth/OpenId/Messages/AssociateUnencryptedRequest.cs index 1b30ccd..1c77bbe 100644 --- a/src/DotNetOpenAuth/OpenId/Messages/AssociateUnencryptedRequest.cs +++ b/src/DotNetOpenAuth/OpenId/Messages/AssociateUnencryptedRequest.cs @@ -16,9 +16,10 @@ namespace DotNetOpenAuth.OpenId.Messages { /// <summary> /// Initializes a new instance of the <see cref="AssociateUnencryptedRequest"/> class. /// </summary> + /// <param name="version">The OpenID version this message must comply with.</param> /// <param name="providerEndpoint">The OpenID Provider endpoint.</param> - internal AssociateUnencryptedRequest(Uri providerEndpoint) - : base(providerEndpoint) { + internal AssociateUnencryptedRequest(Version version, Uri providerEndpoint) + : base(version, providerEndpoint) { SessionType = Protocol.Args.SessionType.NoEncryption; } @@ -60,7 +61,7 @@ namespace DotNetOpenAuth.OpenId.Messages { /// Failed association response messages will derive from <see cref="AssociateUnsuccessfulResponse"/>.</para> /// </remarks> protected override IProtocolMessage CreateResponseCore() { - var response = new AssociateUnencryptedResponse(); + var response = new AssociateUnencryptedResponse(this); response.AssociationType = this.AssociationType; return response; } diff --git a/src/DotNetOpenAuth/OpenId/Messages/AssociateUnencryptedResponse.cs b/src/DotNetOpenAuth/OpenId/Messages/AssociateUnencryptedResponse.cs index 2cc2d0e..0c91091 100644 --- a/src/DotNetOpenAuth/OpenId/Messages/AssociateUnencryptedResponse.cs +++ b/src/DotNetOpenAuth/OpenId/Messages/AssociateUnencryptedResponse.cs @@ -19,7 +19,9 @@ namespace DotNetOpenAuth.OpenId.Messages { /// <summary> /// Initializes a new instance of the <see cref="AssociateUnencryptedResponse"/> class. /// </summary> - internal AssociateUnencryptedResponse() { + /// <param name="originatingRequest">The originating request.</param> + internal AssociateUnencryptedResponse(AssociateUnencryptedRequest originatingRequest) + : base(originatingRequest) { SessionType = Protocol.Args.SessionType.NoEncryption; } diff --git a/src/DotNetOpenAuth/OpenId/Messages/AssociateUnsuccessfulResponse.cs b/src/DotNetOpenAuth/OpenId/Messages/AssociateUnsuccessfulResponse.cs index ad79a87..44b3eb5 100644 --- a/src/DotNetOpenAuth/OpenId/Messages/AssociateUnsuccessfulResponse.cs +++ b/src/DotNetOpenAuth/OpenId/Messages/AssociateUnsuccessfulResponse.cs @@ -30,7 +30,9 @@ namespace DotNetOpenAuth.OpenId.Messages { /// <summary> /// Initializes a new instance of the <see cref="AssociateUnsuccessfulResponse"/> class. /// </summary> - internal AssociateUnsuccessfulResponse() { + /// <param name="originatingRequest">The originating request.</param> + internal AssociateUnsuccessfulResponse(AssociateRequest originatingRequest) + : base(originatingRequest) { } /// <summary> diff --git a/src/DotNetOpenAuth/OpenId/Messages/DirectErrorResponse.cs b/src/DotNetOpenAuth/OpenId/Messages/DirectErrorResponse.cs index 0d27569..a4c4d3e 100644 --- a/src/DotNetOpenAuth/OpenId/Messages/DirectErrorResponse.cs +++ b/src/DotNetOpenAuth/OpenId/Messages/DirectErrorResponse.cs @@ -16,6 +16,14 @@ namespace DotNetOpenAuth.OpenId.Messages { /// </remarks> internal class DirectErrorResponse : DirectResponseBase { /// <summary> + /// Initializes a new instance of the <see cref="DirectErrorResponse"/> class. + /// </summary> + /// <param name="originatingRequest">The originating request.</param> + internal DirectErrorResponse(IDirectedProtocolMessage originatingRequest) + : base(originatingRequest) { + } + + /// <summary> /// Gets or sets a human-readable message indicating why the request failed. /// </summary> [MessagePart("error", IsRequired = true, AllowEmpty = true)] diff --git a/src/DotNetOpenAuth/OpenId/Messages/DirectResponseBase.cs b/src/DotNetOpenAuth/OpenId/Messages/DirectResponseBase.cs index 38ec16a..ed83ff4 100644 --- a/src/DotNetOpenAuth/OpenId/Messages/DirectResponseBase.cs +++ b/src/DotNetOpenAuth/OpenId/Messages/DirectResponseBase.cs @@ -15,7 +15,7 @@ namespace DotNetOpenAuth.OpenId.Messages { /// A common base class for OpenID direct message responses. /// </summary> [DebuggerDisplay("OpenID {ProtocolVersion} response")] - internal class DirectResponseBase : IProtocolMessage { + internal class DirectResponseBase : IDirectResponseProtocolMessage { /// <summary> /// The openid.ns parameter in the message. /// </summary> @@ -33,9 +33,24 @@ namespace DotNetOpenAuth.OpenId.Messages { #pragma warning restore 0414 /// <summary> + /// Backing store for the <see cref="OriginatingRequest"/> properties. + /// </summary> + private IDirectedProtocolMessage originatingRequest; + + /// <summary> + /// Backing store for the <see cref="Incoming"/> properties. + /// </summary> + private bool incoming; + + /// <summary> /// Initializes a new instance of the <see cref="DirectResponseBase"/> class. /// </summary> - protected DirectResponseBase() { + /// <param name="originatingRequest">The originating request.</param> + protected DirectResponseBase(IDirectedProtocolMessage originatingRequest) { + ErrorUtilities.VerifyArgumentNotNull(originatingRequest, "originatingRequest"); + + this.originatingRequest = originatingRequest; + this.ProtocolVersion = originatingRequest.ProtocolVersion; } #region IProtocolMessage Properties @@ -44,9 +59,7 @@ namespace DotNetOpenAuth.OpenId.Messages { /// Gets the version of the protocol this message is prepared to implement. /// </summary> /// <value>Version 2.0</value> - public Version ProtocolVersion { - get { return new Version(2, 0); } - } + public Version ProtocolVersion { get; private set; } /// <summary> /// Gets the level of protection this message requires. @@ -73,9 +86,22 @@ namespace DotNetOpenAuth.OpenId.Messages { } /// <summary> - /// Gets or sets a value indicating whether this message was deserialized as an incoming message. + /// Gets a value indicating whether this message was deserialized as an incoming message. + /// </summary> + bool IProtocolMessage.Incoming { + get { return this.incoming; } + } + + #endregion + + #region IDirectResponseProtocolMessage Members + + /// <summary> + /// Gets the originating request message that caused this response to be formed. /// </summary> - public bool Incoming { get; set; } + IDirectedProtocolMessage IDirectResponseProtocolMessage.OriginatingRequest { + get { return this.originatingRequest; } + } #endregion @@ -86,6 +112,20 @@ namespace DotNetOpenAuth.OpenId.Messages { get { return Protocol.Lookup(this.ProtocolVersion); } } + /// <summary> + /// Gets the originating request message that caused this response to be formed. + /// </summary> + protected IDirectedProtocolMessage OriginatingRequest { + get { return this.originatingRequest; } + } + + /// <summary> + /// Gets a value indicating whether this message was deserialized as an incoming message. + /// </summary> + protected bool Incoming { + get { return this.incoming; } + } + #region IProtocolMessage methods /// <summary> @@ -104,5 +144,12 @@ namespace DotNetOpenAuth.OpenId.Messages { } #endregion + + /// <summary> + /// Sets a flag indicating that this message is received (as opposed to sent). + /// </summary> + internal void SetAsIncoming() { + this.incoming = true; + } } } diff --git a/src/DotNetOpenAuth/OpenId/Messages/IndirectErrorResponse.cs b/src/DotNetOpenAuth/OpenId/Messages/IndirectErrorResponse.cs index 22c8777..4b038b5 100644 --- a/src/DotNetOpenAuth/OpenId/Messages/IndirectErrorResponse.cs +++ b/src/DotNetOpenAuth/OpenId/Messages/IndirectErrorResponse.cs @@ -21,9 +21,10 @@ namespace DotNetOpenAuth.OpenId.Messages { /// <summary> /// Initializes a new instance of the <see cref="IndirectErrorResponse"/> class. /// </summary> + /// <param name="version">The OpenID version this message must comply with.</param> /// <param name="relyingPartyReturnTo">The value of the Relying Party's openid.return_to argument.</param> - internal IndirectErrorResponse(Uri relyingPartyReturnTo) - : base(relyingPartyReturnTo, "error", MessageTransport.Indirect) { + internal IndirectErrorResponse(Version version, Uri relyingPartyReturnTo) + : base(version, relyingPartyReturnTo, "error", MessageTransport.Indirect) { } /// <summary> diff --git a/src/DotNetOpenAuth/OpenId/Messages/RequestBase.cs b/src/DotNetOpenAuth/OpenId/Messages/RequestBase.cs index f2c3fc4..e4ae233 100644 --- a/src/DotNetOpenAuth/OpenId/Messages/RequestBase.cs +++ b/src/DotNetOpenAuth/OpenId/Messages/RequestBase.cs @@ -30,12 +30,18 @@ namespace DotNetOpenAuth.OpenId.Messages { #pragma warning restore 0414 /// <summary> + /// Backing store for the <see cref="Incoming"/> properties. + /// </summary> + private bool incoming; + + /// <summary> /// Initializes a new instance of the <see cref="RequestBase"/> class. /// </summary> + /// <param name="version">The OpenID version this message must comply with.</param> /// <param name="providerEndpoint">The OpenID Provider endpoint.</param> /// <param name="mode">The value for the openid.mode parameter.</param> /// <param name="transport">A value indicating whether the message will be transmitted directly or indirectly.</param> - protected RequestBase(Uri providerEndpoint, string mode, MessageTransport transport) { + protected RequestBase(Version version, Uri providerEndpoint, string mode, MessageTransport transport) { if (providerEndpoint == null) { throw new ArgumentNullException("providerEndpoint"); } @@ -46,6 +52,7 @@ namespace DotNetOpenAuth.OpenId.Messages { this.Recipient = providerEndpoint; this.Mode = mode; this.Transport = transport; + this.ProtocolVersion = version; } /// <summary> @@ -90,9 +97,7 @@ namespace DotNetOpenAuth.OpenId.Messages { /// Gets the version of the protocol this message is prepared to implement. /// </summary> /// <value>Version 2.0</value> - public Version ProtocolVersion { - get { return new Version(2, 0); } - } + public Version ProtocolVersion { get; private set; } /// <summary> /// Gets the level of protection this message requires. @@ -117,9 +122,11 @@ namespace DotNetOpenAuth.OpenId.Messages { } /// <summary> - /// Gets or sets a value indicating whether this message was deserialized as an incoming message. + /// Gets a value indicating whether this message was deserialized as an incoming message. /// </summary> - public bool Incoming { get; set; } + bool IProtocolMessage.Incoming { + get { return this.incoming; } + } #endregion @@ -130,6 +137,13 @@ namespace DotNetOpenAuth.OpenId.Messages { get { return Protocol.Lookup(this.ProtocolVersion); } } + /// <summary> + /// Gets a value indicating whether this message was deserialized as an incoming message. + /// </summary> + protected bool Incoming { + get { return this.incoming; } + } + #region IProtocolMessage Methods /// <summary> @@ -148,5 +162,12 @@ namespace DotNetOpenAuth.OpenId.Messages { } #endregion + + /// <summary> + /// Sets a flag indicating that this message is received (as opposed to sent). + /// </summary> + internal void SetAsIncoming() { + this.incoming = true; + } } } diff --git a/src/DotNetOpenAuth/OpenId/ProviderDescription.cs b/src/DotNetOpenAuth/OpenId/ProviderDescription.cs index 99faad1..48f8c49 100644 --- a/src/DotNetOpenAuth/OpenId/ProviderDescription.cs +++ b/src/DotNetOpenAuth/OpenId/ProviderDescription.cs @@ -22,13 +22,13 @@ namespace DotNetOpenAuth.OpenId { /// Initializes a new instance of the <see cref="ProviderEndpointDescription"/> class. /// </summary> /// <param name="providerEndpoint">The OpenID Provider endpoint URL.</param> - /// <param name="version">The OpenID version supported by this particular endpoint.</param> - internal ProviderEndpointDescription(Uri providerEndpoint, Protocol version) { + /// <param name="openIdVersion">The OpenID version supported by this particular endpoint.</param> + internal ProviderEndpointDescription(Uri providerEndpoint, Version openIdVersion) { ErrorUtilities.VerifyArgumentNotNull(providerEndpoint, "providerEndpoint"); - ErrorUtilities.VerifyArgumentNotNull(version, "version"); + ErrorUtilities.VerifyArgumentNotNull(openIdVersion, "version"); this.Endpoint = providerEndpoint; - this.Protocol = version; + this.ProtocolVersion = openIdVersion; } /// <summary> @@ -43,6 +43,6 @@ namespace DotNetOpenAuth.OpenId { /// If an endpoint supports multiple versions, each version must be represented /// by its own <see cref="ProviderEndpointDescription"/> object. /// </remarks> - internal Protocol Protocol { get; private set; } + internal Version ProtocolVersion { get; private set; } } } |