summaryrefslogtreecommitdiffstats
path: root/src/DotNetOpenAuth.Core/Messaging/StandardMessageFactoryChannel.cs
blob: acfc0047f77cbbe680e440c8bf9395fae98ec772 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
//-----------------------------------------------------------------------
// <copyright file="StandardMessageFactoryChannel.cs" company="Andrew Arnott">
//     Copyright (c) Andrew Arnott. All rights reserved.
// </copyright>
//-----------------------------------------------------------------------

namespace DotNetOpenAuth.Messaging {
	using System;
	using System.Collections.Generic;
	using System.Diagnostics.Contracts;
	using System.Linq;
	using System.Text;
	using Reflection;

	/// <summary>
	/// A channel that uses the standard message factory.
	/// </summary>
	public abstract class StandardMessageFactoryChannel : Channel {
		/// <summary>
		/// The message types receivable by this channel.
		/// </summary>
		private readonly ICollection<Type> messageTypes;

		/// <summary>
		/// The protocol versions supported by this channel.
		/// </summary>
		private readonly ICollection<Version> versions;

		/// <summary>
		/// Initializes a new instance of the <see cref="StandardMessageFactoryChannel"/> class.
		/// </summary>
		/// <param name="messageTypes">The message types that might be encountered.</param>
		/// <param name="versions">All the possible message versions that might be encountered.</param>
		/// <param name="bindingElements">The binding elements to apply to the channel.</param>
		protected StandardMessageFactoryChannel(ICollection<Type> messageTypes, ICollection<Version> versions, params IChannelBindingElement[] bindingElements)
			: base(new StandardMessageFactory(), bindingElements) {
			Requires.NotNull(messageTypes, "messageTypes");
			Requires.NotNull(versions, "versions");

			this.messageTypes = messageTypes;
			this.versions = versions;
			this.StandardMessageFactory.AddMessageTypes(GetMessageDescriptions(this.messageTypes, this.versions, this.MessageDescriptions));
		}

		/// <summary>
		/// Gets or sets a tool that can figure out what kind of message is being received
		/// so it can be deserialized.
		/// </summary>
		internal StandardMessageFactory StandardMessageFactory {
			get { return (Messaging.StandardMessageFactory)this.MessageFactory; }
			set { this.MessageFactory = value; }
		}

		/// <summary>
		/// Gets or sets the message descriptions.
		/// </summary>
		internal sealed override MessageDescriptionCollection MessageDescriptions {
			get {
				return base.MessageDescriptions;
			}

			set {
				base.MessageDescriptions = value;

				// We must reinitialize the message factory so it can use the new message descriptions.
				var factory = new StandardMessageFactory();
				factory.AddMessageTypes(GetMessageDescriptions(this.messageTypes, this.versions, value));
				this.MessageFactory = factory;
			}
		}

		/// <summary>
		/// Gets or sets a tool that can figure out what kind of message is being received
		/// so it can be deserialized.
		/// </summary>
		protected sealed override IMessageFactory MessageFactory {
			get {
				return (StandardMessageFactory)base.MessageFactory;
			}

			set {
				StandardMessageFactory newValue = (StandardMessageFactory)value;
				base.MessageFactory = newValue;
			}
		}

		/// <summary>
		/// Generates all the message descriptions for a given set of message types and versions.
		/// </summary>
		/// <param name="messageTypes">The message types.</param>
		/// <param name="versions">The message versions.</param>
		/// <param name="descriptionsCache">The cache to use when obtaining the message descriptions.</param>
		/// <returns>The generated/retrieved message descriptions.</returns>
		private static IEnumerable<MessageDescription> GetMessageDescriptions(ICollection<Type> messageTypes, ICollection<Version> versions, MessageDescriptionCollection descriptionsCache)
		{
			Requires.NotNull(messageTypes, "messageTypes");
			Requires.NotNull(descriptionsCache, "descriptionsCache");
			Contract.Ensures(Contract.Result<IEnumerable<MessageDescription>>() != null);

			// Get all the MessageDescription objects through the standard cache,
			// so that perhaps it will be a quick lookup, or at least it will be
			// stored there for a quick lookup later.
			var messageDescriptions = new List<MessageDescription>(messageTypes.Count * versions.Count);
			messageDescriptions.AddRange(from version in versions
			                             from messageType in messageTypes
			                             select descriptionsCache.Get(messageType, version));

			return messageDescriptions;
		}
	}
}