summaryrefslogtreecommitdiffstats
path: root/src/DotNetOpenAuth.Test/Mocks/CoordinatingOAuthServiceProviderChannel.cs
blob: 9bdbc0408be7c3ae3f629da03a4ad1db1b55ff41 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
//-----------------------------------------------------------------------
// <copyright file="CoordinatingOAuthChannel.cs" company="Andrew Arnott">
//     Copyright (c) Andrew Arnott. All rights reserved.
// </copyright>
//-----------------------------------------------------------------------

namespace DotNetOpenAuth.Test.Mocks {
	using System;
	using System.Diagnostics.Contracts;
	using System.Threading;
	using DotNetOpenAuth.Messaging;
	using DotNetOpenAuth.Messaging.Bindings;
	using DotNetOpenAuth.OAuth.ChannelElements;
	using DotNetOpenAuth.OAuth.Messages;

	/// <summary>
	/// A special channel used in test simulations to pass messages directly between two parties.
	/// </summary>
	internal class CoordinatingOAuthServiceProviderChannel : OAuthServiceProviderChannel {
		internal EventWaitHandle incomingMessageSignal = new AutoResetEvent(false);
		internal IProtocolMessage incomingMessage;
		internal OutgoingWebResponse incomingRawResponse;

		/// <summary>
		/// Initializes a new instance of the <see cref="CoordinatingOAuthChannel"/> class for Service Providers.
		/// </summary>
		/// <param name="signingBindingElement">The signing element for the Consumer to use.  Null for the Service Provider.</param>
		/// <param name="tokenManager">The token manager to use.</param>
		/// <param name="securitySettings">The security settings.</param>
		internal CoordinatingOAuthServiceProviderChannel(ITamperProtectionChannelBindingElement signingBindingElement, IServiceProviderTokenManager tokenManager, DotNetOpenAuth.OAuth.ServiceProviderSecuritySettings securitySettings)
			: base(
			signingBindingElement,
			new NonceMemoryStore(StandardExpirationBindingElement.MaximumMessageAge),
			tokenManager,
			securitySettings,
			new OAuthServiceProviderMessageFactory(tokenManager)) {
		}

		/// <summary>
		/// Gets or sets the coordinating channel used by the other party.
		/// </summary>
		internal CoordinatingOAuthConsumerChannel RemoteChannel { get; set; }

		internal OutgoingWebResponse RequestProtectedResource(AccessProtectedResourceRequest request) {
			((ITamperResistantOAuthMessage)request).HttpMethod = this.GetHttpMethod(((ITamperResistantOAuthMessage)request).HttpMethods);
			this.ProcessOutgoingMessage(request);
			HttpRequestInfo requestInfo = this.SpoofHttpMethod(request);
			TestBase.TestLogger.InfoFormat("Sending protected resource request: {0}", requestInfo.Message);
			// Drop the outgoing message in the other channel's in-slot and let them know it's there.
			this.RemoteChannel.incomingMessage = requestInfo.Message;
			this.RemoteChannel.incomingMessageSignal.Set();
			return this.AwaitIncomingRawResponse();
		}

		internal void SendDirectRawResponse(OutgoingWebResponse response) {
			this.RemoteChannel.incomingRawResponse = response;
			this.RemoteChannel.incomingMessageSignal.Set();
		}

		protected internal override HttpRequestInfo GetRequestFromContext() {
			var directedMessage = (IDirectedProtocolMessage)this.AwaitIncomingMessage();
			return new HttpRequestInfo(directedMessage, directedMessage.HttpMethods);
		}

		protected override IProtocolMessage RequestCore(IDirectedProtocolMessage request) {
			HttpRequestInfo requestInfo = this.SpoofHttpMethod(request);
			// Drop the outgoing message in the other channel's in-slot and let them know it's there.
			this.RemoteChannel.incomingMessage = requestInfo.Message;
			this.RemoteChannel.incomingMessageSignal.Set();
			// Now wait for a response...
			return this.AwaitIncomingMessage();
		}

		protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) {
			this.RemoteChannel.incomingMessage = CloneSerializedParts(response, null);
			this.RemoteChannel.incomingMessageSignal.Set();
			return new OutgoingWebResponse(); // not used, but returning null is not allowed
		}

		protected override OutgoingWebResponse PrepareIndirectResponse(IDirectedProtocolMessage message) {
			// In this mock transport, direct and indirect messages are the same.
			return this.PrepareDirectResponse(message);
		}

		protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestInfo request) {
			return request.Message;
		}

		/// <summary>
		/// Spoof HTTP request information for signing/verification purposes.
		/// </summary>
		/// <param name="message">The message to add a pretend HTTP method to.</param>
		/// <returns>A spoofed HttpRequestInfo that wraps the new message.</returns>
		private HttpRequestInfo SpoofHttpMethod(IDirectedProtocolMessage message) {
			HttpRequestInfo requestInfo = new HttpRequestInfo(message, message.HttpMethods);

			var signedMessage = message as ITamperResistantOAuthMessage;
			if (signedMessage != null) {
				string httpMethod = this.GetHttpMethod(signedMessage.HttpMethods);
				requestInfo.HttpMethod = httpMethod;
				requestInfo.UrlBeforeRewriting = message.Recipient;
				signedMessage.HttpMethod = httpMethod;
			}

			requestInfo.Message = this.CloneSerializedParts(message, requestInfo);

			return requestInfo;
		}

		private IProtocolMessage AwaitIncomingMessage() {
			this.incomingMessageSignal.WaitOne();
			IProtocolMessage response = this.incomingMessage;
			this.incomingMessage = null;
			return response;
		}

		private OutgoingWebResponse AwaitIncomingRawResponse() {
			this.incomingMessageSignal.WaitOne();
			OutgoingWebResponse response = this.incomingRawResponse;
			this.incomingRawResponse = null;
			return response;
		}

		private T CloneSerializedParts<T>(T message, HttpRequestInfo requestInfo) where T : class, IProtocolMessage {
			Contract.Requires<ArgumentNullException>(message != null);

			IProtocolMessage clonedMessage;
			var messageAccessor = this.MessageDescriptions.GetAccessor(message);
			var fields = messageAccessor.Serialize();

			MessageReceivingEndpoint recipient = null;
			var directedMessage = message as IDirectedProtocolMessage;
			var directResponse = message as IDirectResponseProtocolMessage;
			if (directedMessage != null && directedMessage.IsRequest()) {
				if (directedMessage.Recipient != null) {
					recipient = new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods);
				}

				clonedMessage = this.RemoteChannel.MessageFactoryTestHook.GetNewRequestMessage(recipient, fields);
			} else if (directResponse != null && directResponse.IsDirectResponse()) {
				clonedMessage = this.RemoteChannel.MessageFactoryTestHook.GetNewResponseMessage(directResponse.OriginatingRequest, fields);
			} else {
				throw new InvalidOperationException("Totally expected a message to implement one of the two derived interface types.");
			}

			// Fill the cloned message with data.
			var clonedMessageAccessor = this.MessageDescriptions.GetAccessor(clonedMessage);
			clonedMessageAccessor.Deserialize(fields);

			return (T)clonedMessage;
		}

		private string GetHttpMethod(HttpDeliveryMethods methods) {
			return (methods & HttpDeliveryMethods.PostRequest) != 0 ? "POST" : "GET";
		}
	}
}