diff options
Diffstat (limited to 'src')
11 files changed, 166 insertions, 34 deletions
diff --git a/src/DotNetOpenAuth.Test/Messaging/Reflection/ValueMappingTests.cs b/src/DotNetOpenAuth.Test/Messaging/Reflection/ValueMappingTests.cs index d556b11..60c8bc3 100644 --- a/src/DotNetOpenAuth.Test/Messaging/Reflection/ValueMappingTests.cs +++ b/src/DotNetOpenAuth.Test/Messaging/Reflection/ValueMappingTests.cs @@ -13,12 +13,12 @@ namespace DotNetOpenAuth.Test.Messaging.Reflection { public class ValueMappingTests { [TestCase, ExpectedException(typeof(ArgumentNullException))] public void CtorNullToString() { - new ValueMapping(null, str => new object()); + new ValueMapping(null, null, str => new object()); } [TestCase, ExpectedException(typeof(ArgumentNullException))] public void CtorNullToObject() { - new ValueMapping(obj => obj.ToString(), null); + new ValueMapping(obj => obj.ToString(), null, null); } } } diff --git a/src/DotNetOpenAuth.Test/OpenId/Messages/CheckAuthenticationRequestTests.cs b/src/DotNetOpenAuth.Test/OpenId/Messages/CheckAuthenticationRequestTests.cs index cf6b814..d2d2cc4 100644 --- a/src/DotNetOpenAuth.Test/OpenId/Messages/CheckAuthenticationRequestTests.cs +++ b/src/DotNetOpenAuth.Test/OpenId/Messages/CheckAuthenticationRequestTests.cs @@ -6,12 +6,37 @@ namespace DotNetOpenAuth.Test.OpenId.Messages { using System; - using System.Collections.Generic; - using System.Linq; - using System.Text; + using DotNetOpenAuth.OpenId; + using DotNetOpenAuth.OpenId.Messages; using NUnit.Framework; [TestFixture] public class CheckAuthenticationRequestTests : OpenIdTestBase { + /// <summary> + /// Verifies that the check_auth request is sent preserving EXACTLY (non-normalized) + /// what is in the positive assertion. + /// </summary> + /// <remarks> + /// This is very important because any normalization + /// (like changing https://host:443/ to https://host/) in the message will invalidate the signature + /// and cause the authentication to inappropriately fail. + /// Designed to verify fix to Trac #198. + /// </remarks> + [TestCase] + public void ExactPositiveAssertionPreservation() { + var rp = CreateRelyingParty(true); + + // Initialize the positive assertion response with some data that is NOT in normalized form. + var positiveAssertion = new PositiveAssertionResponse(Protocol.Default.Version, RPUri) + { + ClaimedIdentifier = "https://HOST:443/a", + ProviderEndpoint = new Uri("https://anotherHOST:443/b"), + }; + + var checkAuth = new CheckAuthenticationRequest(positiveAssertion, rp.Channel); + var actual = rp.Channel.MessageDescriptions.GetAccessor(checkAuth); + Assert.AreEqual("https://HOST:443/a", actual["openid.claimed_id"]); + Assert.AreEqual("https://anotherHOST:443/b", actual["openid.op_endpoint"]); + } } } diff --git a/src/DotNetOpenAuth/DotNetOpenAuth.csproj b/src/DotNetOpenAuth/DotNetOpenAuth.csproj index 70e0c8d..919d1f2 100644 --- a/src/DotNetOpenAuth/DotNetOpenAuth.csproj +++ b/src/DotNetOpenAuth/DotNetOpenAuth.csproj @@ -298,6 +298,7 @@ http://opensource.org/licenses/ms-pl.html <Compile Include="Messaging\OutgoingWebResponseActionResult.cs" /> <Compile Include="Messaging\Reflection\IMessagePartEncoder.cs" /> <Compile Include="Messaging\Reflection\IMessagePartNullEncoder.cs" /> + <Compile Include="Messaging\Reflection\IMessagePartOriginalEncoder.cs" /> <Compile Include="Messaging\Reflection\MessageDescriptionCollection.cs" /> <Compile Include="Mvc\OpenIdHelper.cs" /> <Compile Include="Mvc\OpenIdAjaxOptions.cs" /> diff --git a/src/DotNetOpenAuth/Messaging/Reflection/IMessagePartOriginalEncoder.cs b/src/DotNetOpenAuth/Messaging/Reflection/IMessagePartOriginalEncoder.cs new file mode 100644 index 0000000..9ad55c9 --- /dev/null +++ b/src/DotNetOpenAuth/Messaging/Reflection/IMessagePartOriginalEncoder.cs @@ -0,0 +1,22 @@ +//----------------------------------------------------------------------- +// <copyright file="IMessagePartOriginalEncoder.cs" company="Andrew Arnott"> +// Copyright (c) Andrew Arnott. All rights reserved. +// </copyright> +//----------------------------------------------------------------------- + +namespace DotNetOpenAuth.Messaging.Reflection { + /// <summary> + /// An interface describing how various objects can be serialized and deserialized between their object and string forms. + /// </summary> + /// <remarks> + /// Implementations of this interface must include a default constructor and must be thread-safe. + /// </remarks> + public interface IMessagePartOriginalEncoder : IMessagePartEncoder { + /// <summary> + /// Encodes the specified value as the original value that was formerly decoded. + /// </summary> + /// <param name="value">The value. Guaranteed to never be null.</param> + /// <returns>The <paramref name="value"/> in string form, ready for message transport.</returns> + string EncodeAsOriginalString(object value); + } +} diff --git a/src/DotNetOpenAuth/Messaging/Reflection/MessageDescription.cs b/src/DotNetOpenAuth/Messaging/Reflection/MessageDescription.cs index 5493ba6..3b41b35 100644 --- a/src/DotNetOpenAuth/Messaging/Reflection/MessageDescription.cs +++ b/src/DotNetOpenAuth/Messaging/Reflection/MessageDescription.cs @@ -66,7 +66,20 @@ namespace DotNetOpenAuth.Messaging.Reflection { internal MessageDictionary GetDictionary(IMessage message) { Contract.Requires<ArgumentNullException>(message != null); Contract.Ensures(Contract.Result<MessageDictionary>() != null); - return new MessageDictionary(message, this); + return this.GetDictionary(message, false); + } + + /// <summary> + /// Gets a dictionary that provides read/write access to a message. + /// </summary> + /// <param name="message">The message the dictionary should provide access to.</param> + /// <param name="getOriginalValues">A value indicating whether this message dictionary will retrieve original values instead of normalized ones.</param> + /// <returns>The dictionary accessor to the message</returns> + [Pure] + internal MessageDictionary GetDictionary(IMessage message, bool getOriginalValues) { + Contract.Requires<ArgumentNullException>(message != null); + Contract.Ensures(Contract.Result<MessageDictionary>() != null); + return new MessageDictionary(message, this, getOriginalValues); } /// <summary> diff --git a/src/DotNetOpenAuth/Messaging/Reflection/MessageDescriptionCollection.cs b/src/DotNetOpenAuth/Messaging/Reflection/MessageDescriptionCollection.cs index ff8b74b..125742c 100644 --- a/src/DotNetOpenAuth/Messaging/Reflection/MessageDescriptionCollection.cs +++ b/src/DotNetOpenAuth/Messaging/Reflection/MessageDescriptionCollection.cs @@ -78,7 +78,19 @@ namespace DotNetOpenAuth.Messaging.Reflection { [Pure] internal MessageDictionary GetAccessor(IMessage message) { Contract.Requires<ArgumentNullException>(message != null); - return this.Get(message).GetDictionary(message); + return this.GetAccessor(message, false); + } + + /// <summary> + /// Gets the dictionary that provides read/write access to a message. + /// </summary> + /// <param name="message">The message.</param> + /// <param name="getOriginalValues">A value indicating whether this message dictionary will retrieve original values instead of normalized ones.</param> + /// <returns>The dictionary.</returns> + [Pure] + internal MessageDictionary GetAccessor(IMessage message, bool getOriginalValues) { + Contract.Requires<ArgumentNullException>(message != null); + return this.Get(message).GetDictionary(message, getOriginalValues); } /// <summary> diff --git a/src/DotNetOpenAuth/Messaging/Reflection/MessageDictionary.cs b/src/DotNetOpenAuth/Messaging/Reflection/MessageDictionary.cs index f6eed24..2b60a9c 100644 --- a/src/DotNetOpenAuth/Messaging/Reflection/MessageDictionary.cs +++ b/src/DotNetOpenAuth/Messaging/Reflection/MessageDictionary.cs @@ -30,17 +30,24 @@ namespace DotNetOpenAuth.Messaging.Reflection { private readonly MessageDescription description; /// <summary> + /// Whether original string values should be retrieved instead of normalized ones. + /// </summary> + private readonly bool getOriginalValues; + + /// <summary> /// Initializes a new instance of the <see cref="MessageDictionary"/> class. /// </summary> /// <param name="message">The message instance whose values will be manipulated by this dictionary.</param> /// <param name="description">The message description.</param> + /// <param name="getOriginalValues">A value indicating whether this message dictionary will retrieve original values instead of normalized ones.</param> [Pure] - internal MessageDictionary(IMessage message, MessageDescription description) { + internal MessageDictionary(IMessage message, MessageDescription description, bool getOriginalValues) { Contract.Requires<ArgumentNullException>(message != null); Contract.Requires<ArgumentNullException>(description != null); this.message = message; this.description = description; + this.getOriginalValues = getOriginalValues; } /// <summary> @@ -103,7 +110,7 @@ namespace DotNetOpenAuth.Messaging.Reflection { List<string> keys = new List<string>(this.description.Mapping.Count); foreach (var pair in this.description.Mapping) { // Don't include keys with null values, but default values for structs is ok - if (pair.Value.GetValue(this.message) != null) { + if (pair.Value.GetValue(this.message, this.getOriginalValues) != null) { keys.Add(pair.Key); } } @@ -126,8 +133,8 @@ namespace DotNetOpenAuth.Messaging.Reflection { get { List<string> values = new List<string>(this.message.ExtraData.Count + this.description.Mapping.Count); foreach (MessagePart part in this.description.Mapping.Values) { - if (part.GetValue(this.message) != null) { - values.Add(part.GetValue(this.message)); + if (part.GetValue(this.message, this.getOriginalValues) != null) { + values.Add(part.GetValue(this.message, this.getOriginalValues)); } } @@ -167,7 +174,7 @@ namespace DotNetOpenAuth.Messaging.Reflection { MessagePart part; if (this.description.Mapping.TryGetValue(key, out part)) { // Never throw KeyNotFoundException for declared properties. - return part.GetValue(this.message); + return part.GetValue(this.message, this.getOriginalValues); } else { return this.message.ExtraData[key]; } @@ -223,7 +230,7 @@ namespace DotNetOpenAuth.Messaging.Reflection { /// <returns>True if the parameter by the given name has a set value. False otherwise.</returns> public bool ContainsKey(string key) { return this.message.ExtraData.ContainsKey(key) || - (this.description.Mapping.ContainsKey(key) && this.description.Mapping[key].GetValue(this.message) != null); + (this.description.Mapping.ContainsKey(key) && this.description.Mapping[key].GetValue(this.message, this.getOriginalValues) != null); } /// <summary> @@ -237,7 +244,7 @@ namespace DotNetOpenAuth.Messaging.Reflection { } else { MessagePart part; if (this.description.Mapping.TryGetValue(key, out part)) { - if (part.GetValue(this.message) != null) { + if (part.GetValue(this.message, this.getOriginalValues) != null) { part.SetValue(this.message, null); return true; } @@ -255,7 +262,7 @@ namespace DotNetOpenAuth.Messaging.Reflection { public bool TryGetValue(string key, out string value) { MessagePart part; if (this.description.Mapping.TryGetValue(key, out part)) { - value = part.GetValue(this.message); + value = part.GetValue(this.message, this.getOriginalValues); return value != null; } return this.message.ExtraData.TryGetValue(key, out value); @@ -306,7 +313,7 @@ namespace DotNetOpenAuth.Messaging.Reflection { public bool Contains(KeyValuePair<string, string> item) { MessagePart part; if (this.description.Mapping.TryGetValue(item.Key, out part)) { - return string.Equals(part.GetValue(this.message), item.Value, StringComparison.Ordinal); + return string.Equals(part.GetValue(this.message, this.getOriginalValues), item.Value, StringComparison.Ordinal); } else { return this.message.ExtraData.Contains(item); } diff --git a/src/DotNetOpenAuth/Messaging/Reflection/MessagePart.cs b/src/DotNetOpenAuth/Messaging/Reflection/MessagePart.cs index 4590c44..a791e69 100644 --- a/src/DotNetOpenAuth/Messaging/Reflection/MessagePart.cs +++ b/src/DotNetOpenAuth/Messaging/Reflection/MessagePart.cs @@ -90,14 +90,14 @@ namespace DotNetOpenAuth.Messaging.Reflection { Contract.Assume(str != null); return new Realm(str); }; - Map<Uri>(uri => uri.AbsoluteUri, safeUri); - Map<DateTime>(dt => XmlConvert.ToString(dt, XmlDateTimeSerializationMode.Utc), str => XmlConvert.ToDateTime(str, XmlDateTimeSerializationMode.Utc)); - Map<byte[]>(safeFromByteArray, safeToByteArray); - Map<Realm>(realm => realm.ToString(), safeRealm); - Map<Identifier>(id => id.SerializedString, safeIdentifier); - Map<bool>(value => value.ToString().ToLowerInvariant(), safeBool); - Map<CultureInfo>(c => c.Name, str => new CultureInfo(str)); - Map<CultureInfo[]>(cs => string.Join(",", cs.Select(c => c.Name).ToArray()), str => str.Split(',').Select(s => new CultureInfo(s)).ToArray()); + Map<Uri>(uri => uri.AbsoluteUri, uri => uri.OriginalString, safeUri); + Map<DateTime>(dt => XmlConvert.ToString(dt, XmlDateTimeSerializationMode.Utc), null, str => XmlConvert.ToDateTime(str, XmlDateTimeSerializationMode.Utc)); + Map<byte[]>(safeFromByteArray, null, safeToByteArray); + Map<Realm>(realm => realm.ToString(), realm => realm.OriginalString, safeRealm); + Map<Identifier>(id => id.SerializedString, id => id.OriginalString, safeIdentifier); + Map<bool>(value => value.ToString().ToLowerInvariant(), null, safeBool); + Map<CultureInfo>(c => c.Name, null, str => new CultureInfo(str)); + Map<CultureInfo[]>(cs => string.Join(",", cs.Select(c => c.Name).ToArray()), null, str => str.Split(',').Select(s => new CultureInfo(s)).ToArray()); } /// <summary> @@ -137,15 +137,18 @@ namespace DotNetOpenAuth.Messaging.Reflection { if (converters.TryGetValue(underlyingType, out underlyingMapping)) { this.converter = new ValueMapping( underlyingMapping.ValueToString, + null, str => str != null ? underlyingMapping.StringToValue(str) : null); } else { this.converter = new ValueMapping( obj => obj != null ? obj.ToString() : null, + null, str => str != null ? Convert.ChangeType(str, underlyingType, CultureInfo.InvariantCulture) : null); } } else { this.converter = new ValueMapping( obj => obj != null ? obj.ToString() : null, + null, str => str != null ? Convert.ChangeType(str, this.memberDeclaredType, CultureInfo.InvariantCulture) : null); } } @@ -227,7 +230,7 @@ namespace DotNetOpenAuth.Messaging.Reflection { } /// <summary> - /// Gets the value of a member of a given message. + /// Gets the normalized form of a value of a member of a given message. /// Used in serialization. /// </summary> /// <param name="message">The message instance to read the value from.</param> @@ -235,7 +238,23 @@ namespace DotNetOpenAuth.Messaging.Reflection { internal string GetValue(IMessage message) { try { object value = this.GetValueAsObject(message); - return this.ToString(value); + return this.ToString(value, false); + } catch (FormatException ex) { + throw ErrorUtilities.Wrap(ex, MessagingStrings.MessagePartWriteFailure, message.GetType(), this.Name); + } + } + + /// <summary> + /// Gets the value of a member of a given message. + /// Used in serialization. + /// </summary> + /// <param name="message">The message instance to read the value from.</param> + /// <param name="originalValue">A value indicating whether the original value should be retrieved (as opposed to a normalized form of it).</param> + /// <returns>The string representation of the member's value.</returns> + internal string GetValue(IMessage message, bool originalValue) { + try { + object value = this.GetValueAsObject(message); + return this.ToString(value, originalValue); } catch (FormatException ex) { throw ErrorUtilities.Wrap(ex, MessagingStrings.MessagePartWriteFailure, message.GetType(), this.Name); } @@ -272,11 +291,20 @@ namespace DotNetOpenAuth.Messaging.Reflection { /// </summary> /// <typeparam name="T">The custom type to convert to and from strings.</typeparam> /// <param name="toString">The function to convert the custom type to a string.</param> + /// <param name="toOriginalString">The mapping function that converts some custom value to its original (non-normalized) string. May be null if the same as the <paramref name="toString"/> function.</param> /// <param name="toValue">The function to convert a string to the custom type.</param> - private static void Map<T>(Func<T, string> toString, Func<string, T> toValue) { + private static void Map<T>(Func<T, string> toString, Func<T, string> toOriginalString, Func<string, T> toValue) { + Contract.Requires<ArgumentNullException>(toString != null, "toString"); + Contract.Requires<ArgumentNullException>(toValue != null, "toValue"); + + if (toOriginalString == null) { + toOriginalString = toString; + } + Func<object, string> safeToString = obj => obj != null ? toString((T)obj) : null; + Func<object, string> safeToOriginalString = obj => obj != null ? toOriginalString((T)obj) : null; Func<string, object> safeToT = str => str != null ? toValue(str) : default(T); - converters.Add(typeof(T), new ValueMapping(safeToString, safeToT)); + converters.Add(typeof(T), new ValueMapping(safeToString, safeToOriginalString, safeToT)); } /// <summary> @@ -328,11 +356,12 @@ namespace DotNetOpenAuth.Messaging.Reflection { /// Converts the member's value to its string representation. /// </summary> /// <param name="value">The value of the member.</param> + /// <param name="originalString">A value indicating whether a string matching the originally decoded string should be returned (as opposed to a normalized string).</param> /// <returns> /// The string representation of the member's value. /// </returns> - private string ToString(object value) { - return this.converter.ValueToString(value); + private string ToString(object value, bool originalString) { + return originalString ? this.converter.ValueToOriginalString(value) : this.converter.ValueToString(value); } /// <summary> diff --git a/src/DotNetOpenAuth/Messaging/Reflection/ValueMapping.cs b/src/DotNetOpenAuth/Messaging/Reflection/ValueMapping.cs index 1c7631e..b0b8b47 100644 --- a/src/DotNetOpenAuth/Messaging/Reflection/ValueMapping.cs +++ b/src/DotNetOpenAuth/Messaging/Reflection/ValueMapping.cs @@ -19,6 +19,12 @@ namespace DotNetOpenAuth.Messaging.Reflection { internal readonly Func<object, string> ValueToString; /// <summary> + /// The mapping function that converts some custom type to the original string + /// (possibly non-normalized) that represents it. + /// </summary> + internal readonly Func<object, string> ValueToOriginalString; + + /// <summary> /// The mapping function that converts a string to some custom type. /// </summary> internal readonly Func<string, object> StringToValue; @@ -26,13 +32,15 @@ namespace DotNetOpenAuth.Messaging.Reflection { /// <summary> /// Initializes a new instance of the <see cref="ValueMapping"/> struct. /// </summary> - /// <param name="toString">The mapping function that converts some custom type to a string.</param> - /// <param name="toValue">The mapping function that converts a string to some custom type.</param> - internal ValueMapping(Func<object, string> toString, Func<string, object> toValue) { + /// <param name="toString">The mapping function that converts some custom value to a string.</param> + /// <param name="toOriginalString">The mapping function that converts some custom value to its original (non-normalized) string. May be null if the same as the <paramref name="toString"/> function.</param> + /// <param name="toValue">The mapping function that converts a string to some custom value.</param> + internal ValueMapping(Func<object, string> toString, Func<object, string> toOriginalString, Func<string, object> toValue) { Contract.Requires<ArgumentNullException>(toString != null); Contract.Requires<ArgumentNullException>(toValue != null); this.ValueToString = toString; + this.ValueToOriginalString = toOriginalString ?? toString; this.StringToValue = toValue; } @@ -45,8 +53,15 @@ namespace DotNetOpenAuth.Messaging.Reflection { var nullEncoder = encoder as IMessagePartNullEncoder; string nullString = nullEncoder != null ? nullEncoder.EncodedNullValue : null; + var originalStringEncoder = encoder as IMessagePartOriginalEncoder; + Func<object, string> originalStringEncode = encoder.Encode; + if (originalStringEncoder != null) { + originalStringEncode = originalStringEncoder.EncodeAsOriginalString; + } + this.ValueToString = obj => (obj != null) ? encoder.Encode(obj) : nullString; this.StringToValue = str => (str != null) ? encoder.Decode(str) : null; + this.ValueToOriginalString = obj => (obj != null) ? originalStringEncode(obj) : nullString; } } } diff --git a/src/DotNetOpenAuth/OpenId/Messages/CheckAuthenticationRequest.cs b/src/DotNetOpenAuth/OpenId/Messages/CheckAuthenticationRequest.cs index 5306c54..db69d3d 100644 --- a/src/DotNetOpenAuth/OpenId/Messages/CheckAuthenticationRequest.cs +++ b/src/DotNetOpenAuth/OpenId/Messages/CheckAuthenticationRequest.cs @@ -45,7 +45,7 @@ namespace DotNetOpenAuth.OpenId.Messages { // Copy all message parts from the id_res message into this one, // except for the openid.mode parameter. - MessageDictionary checkPayload = channel.MessageDescriptions.GetAccessor(message); + MessageDictionary checkPayload = channel.MessageDescriptions.GetAccessor(message, true); MessageDictionary thisPayload = channel.MessageDescriptions.GetAccessor(this); foreach (var pair in checkPayload) { if (!string.Equals(pair.Key, this.Protocol.openid.mode)) { diff --git a/src/DotNetOpenAuth/OpenId/Realm.cs b/src/DotNetOpenAuth/OpenId/Realm.cs index 4137c72..98e3598 100644 --- a/src/DotNetOpenAuth/OpenId/Realm.cs +++ b/src/DotNetOpenAuth/OpenId/Realm.cs @@ -180,6 +180,14 @@ namespace DotNetOpenAuth.OpenId { } /// <summary> + /// Gets the original string. + /// </summary> + /// <value>The original string.</value> + internal string OriginalString { + get { return this.uri.OriginalString; } + } + + /// <summary> /// Gets the realm URL. If the realm includes a wildcard, it is not included here. /// </summary> internal Uri NoWildcardUri { |