//----------------------------------------------------------------------- // // Copyright (c) Andrew Arnott. All rights reserved. // //----------------------------------------------------------------------- namespace DotNetOpenAuth.Test.Mocks { using System; using System.Diagnostics.Contracts; using System.Threading; using DotNetOpenAuth.Messaging; using DotNetOpenAuth.Messaging.Bindings; using DotNetOpenAuth.OAuth.ChannelElements; using DotNetOpenAuth.OAuth.Messages; /// /// A special channel used in test simulations to pass messages directly between two parties. /// internal class CoordinatingOAuthServiceProviderChannel : OAuthServiceProviderChannel { internal EventWaitHandle incomingMessageSignal = new AutoResetEvent(false); internal IProtocolMessage incomingMessage; internal OutgoingWebResponse incomingRawResponse; /// /// Initializes a new instance of the class for Service Providers. /// /// The signing element for the Consumer to use. Null for the Service Provider. /// The token manager to use. /// The security settings. internal CoordinatingOAuthServiceProviderChannel(ITamperProtectionChannelBindingElement signingBindingElement, IServiceProviderTokenManager tokenManager, DotNetOpenAuth.OAuth.ServiceProviderSecuritySettings securitySettings) : base( signingBindingElement, new NonceMemoryStore(StandardExpirationBindingElement.MaximumMessageAge), tokenManager, securitySettings, new OAuthServiceProviderMessageFactory(tokenManager)) { } /// /// Gets or sets the coordinating channel used by the other party. /// internal CoordinatingOAuthConsumerChannel RemoteChannel { get; set; } internal OutgoingWebResponse RequestProtectedResource(AccessProtectedResourceRequest request) { ((ITamperResistantOAuthMessage)request).HttpMethod = this.GetHttpMethod(((ITamperResistantOAuthMessage)request).HttpMethods); this.ProcessOutgoingMessage(request); HttpRequestInfo requestInfo = this.SpoofHttpMethod(request); TestBase.TestLogger.InfoFormat("Sending protected resource request: {0}", requestInfo.Message); // Drop the outgoing message in the other channel's in-slot and let them know it's there. this.RemoteChannel.incomingMessage = requestInfo.Message; this.RemoteChannel.incomingMessageSignal.Set(); return this.AwaitIncomingRawResponse(); } internal void SendDirectRawResponse(OutgoingWebResponse response) { this.RemoteChannel.incomingRawResponse = response; this.RemoteChannel.incomingMessageSignal.Set(); } protected internal override HttpRequestInfo GetRequestFromContext() { var directedMessage = (IDirectedProtocolMessage)this.AwaitIncomingMessage(); return new HttpRequestInfo(directedMessage, directedMessage.HttpMethods); } protected override IProtocolMessage RequestCore(IDirectedProtocolMessage request) { HttpRequestInfo requestInfo = this.SpoofHttpMethod(request); // Drop the outgoing message in the other channel's in-slot and let them know it's there. this.RemoteChannel.incomingMessage = requestInfo.Message; this.RemoteChannel.incomingMessageSignal.Set(); // Now wait for a response... return this.AwaitIncomingMessage(); } protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) { this.RemoteChannel.incomingMessage = CloneSerializedParts(response, null); this.RemoteChannel.incomingMessageSignal.Set(); return new OutgoingWebResponse(); // not used, but returning null is not allowed } protected override OutgoingWebResponse PrepareIndirectResponse(IDirectedProtocolMessage message) { // In this mock transport, direct and indirect messages are the same. return this.PrepareDirectResponse(message); } protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestInfo request) { return request.Message; } /// /// Spoof HTTP request information for signing/verification purposes. /// /// The message to add a pretend HTTP method to. /// A spoofed HttpRequestInfo that wraps the new message. private HttpRequestInfo SpoofHttpMethod(IDirectedProtocolMessage message) { HttpRequestInfo requestInfo = new HttpRequestInfo(message, message.HttpMethods); var signedMessage = message as ITamperResistantOAuthMessage; if (signedMessage != null) { string httpMethod = this.GetHttpMethod(signedMessage.HttpMethods); requestInfo.HttpMethod = httpMethod; requestInfo.UrlBeforeRewriting = message.Recipient; signedMessage.HttpMethod = httpMethod; } requestInfo.Message = this.CloneSerializedParts(message, requestInfo); return requestInfo; } private IProtocolMessage AwaitIncomingMessage() { this.incomingMessageSignal.WaitOne(); IProtocolMessage response = this.incomingMessage; this.incomingMessage = null; return response; } private OutgoingWebResponse AwaitIncomingRawResponse() { this.incomingMessageSignal.WaitOne(); OutgoingWebResponse response = this.incomingRawResponse; this.incomingRawResponse = null; return response; } private T CloneSerializedParts(T message, HttpRequestInfo requestInfo) where T : class, IProtocolMessage { Contract.Requires(message != null); IProtocolMessage clonedMessage; var messageAccessor = this.MessageDescriptions.GetAccessor(message); var fields = messageAccessor.Serialize(); MessageReceivingEndpoint recipient = null; var directedMessage = message as IDirectedProtocolMessage; var directResponse = message as IDirectResponseProtocolMessage; if (directedMessage != null && directedMessage.IsRequest()) { if (directedMessage.Recipient != null) { recipient = new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods); } clonedMessage = this.RemoteChannel.MessageFactoryTestHook.GetNewRequestMessage(recipient, fields); } else if (directResponse != null && directResponse.IsDirectResponse()) { clonedMessage = this.RemoteChannel.MessageFactoryTestHook.GetNewResponseMessage(directResponse.OriginatingRequest, fields); } else { throw new InvalidOperationException("Totally expected a message to implement one of the two derived interface types."); } // Fill the cloned message with data. var clonedMessageAccessor = this.MessageDescriptions.GetAccessor(clonedMessage); clonedMessageAccessor.Deserialize(fields); return (T)clonedMessage; } private string GetHttpMethod(HttpDeliveryMethods methods) { return (methods & HttpDeliveryMethods.PostRequest) != 0 ? "POST" : "GET"; } } }