//-----------------------------------------------------------------------
//
// Copyright (c) Andrew Arnott. All rights reserved.
//
//-----------------------------------------------------------------------
namespace DotNetOpenAuth.Test.Mocks {
using System;
using System.Diagnostics.Contracts;
using System.Threading;
using DotNetOpenAuth.Messaging;
using DotNetOpenAuth.Messaging.Bindings;
using DotNetOpenAuth.OAuth.ChannelElements;
using DotNetOpenAuth.OAuth.Messages;
///
/// A special channel used in test simulations to pass messages directly between two parties.
///
internal class CoordinatingOAuthServiceProviderChannel : OAuthServiceProviderChannel {
internal EventWaitHandle incomingMessageSignal = new AutoResetEvent(false);
internal IProtocolMessage incomingMessage;
internal OutgoingWebResponse incomingRawResponse;
///
/// Initializes a new instance of the class for Service Providers.
///
/// The signing element for the Consumer to use. Null for the Service Provider.
/// The token manager to use.
/// The security settings.
internal CoordinatingOAuthServiceProviderChannel(ITamperProtectionChannelBindingElement signingBindingElement, IServiceProviderTokenManager tokenManager, DotNetOpenAuth.OAuth.ServiceProviderSecuritySettings securitySettings)
: base(
signingBindingElement,
new NonceMemoryStore(StandardExpirationBindingElement.MaximumMessageAge),
tokenManager,
securitySettings,
new OAuthServiceProviderMessageFactory(tokenManager)) {
}
///
/// Gets or sets the coordinating channel used by the other party.
///
internal CoordinatingOAuthConsumerChannel RemoteChannel { get; set; }
internal OutgoingWebResponse RequestProtectedResource(AccessProtectedResourceRequest request) {
((ITamperResistantOAuthMessage)request).HttpMethod = this.GetHttpMethod(((ITamperResistantOAuthMessage)request).HttpMethods);
this.ProcessOutgoingMessage(request);
HttpRequestInfo requestInfo = this.SpoofHttpMethod(request);
TestBase.TestLogger.InfoFormat("Sending protected resource request: {0}", requestInfo.Message);
// Drop the outgoing message in the other channel's in-slot and let them know it's there.
this.RemoteChannel.incomingMessage = requestInfo.Message;
this.RemoteChannel.incomingMessageSignal.Set();
return this.AwaitIncomingRawResponse();
}
internal void SendDirectRawResponse(OutgoingWebResponse response) {
this.RemoteChannel.incomingRawResponse = response;
this.RemoteChannel.incomingMessageSignal.Set();
}
protected internal override HttpRequestInfo GetRequestFromContext() {
var directedMessage = (IDirectedProtocolMessage)this.AwaitIncomingMessage();
return new HttpRequestInfo(directedMessage, directedMessage.HttpMethods);
}
protected override IProtocolMessage RequestCore(IDirectedProtocolMessage request) {
HttpRequestInfo requestInfo = this.SpoofHttpMethod(request);
// Drop the outgoing message in the other channel's in-slot and let them know it's there.
this.RemoteChannel.incomingMessage = requestInfo.Message;
this.RemoteChannel.incomingMessageSignal.Set();
// Now wait for a response...
return this.AwaitIncomingMessage();
}
protected override OutgoingWebResponse PrepareDirectResponse(IProtocolMessage response) {
this.RemoteChannel.incomingMessage = CloneSerializedParts(response, null);
this.RemoteChannel.incomingMessageSignal.Set();
return new OutgoingWebResponse(); // not used, but returning null is not allowed
}
protected override OutgoingWebResponse PrepareIndirectResponse(IDirectedProtocolMessage message) {
// In this mock transport, direct and indirect messages are the same.
return this.PrepareDirectResponse(message);
}
protected override IDirectedProtocolMessage ReadFromRequestCore(HttpRequestInfo request) {
return request.Message;
}
///
/// Spoof HTTP request information for signing/verification purposes.
///
/// The message to add a pretend HTTP method to.
/// A spoofed HttpRequestInfo that wraps the new message.
private HttpRequestInfo SpoofHttpMethod(IDirectedProtocolMessage message) {
HttpRequestInfo requestInfo = new HttpRequestInfo(message, message.HttpMethods);
var signedMessage = message as ITamperResistantOAuthMessage;
if (signedMessage != null) {
string httpMethod = this.GetHttpMethod(signedMessage.HttpMethods);
requestInfo.HttpMethod = httpMethod;
requestInfo.UrlBeforeRewriting = message.Recipient;
signedMessage.HttpMethod = httpMethod;
}
requestInfo.Message = this.CloneSerializedParts(message, requestInfo);
return requestInfo;
}
private IProtocolMessage AwaitIncomingMessage() {
this.incomingMessageSignal.WaitOne();
IProtocolMessage response = this.incomingMessage;
this.incomingMessage = null;
return response;
}
private OutgoingWebResponse AwaitIncomingRawResponse() {
this.incomingMessageSignal.WaitOne();
OutgoingWebResponse response = this.incomingRawResponse;
this.incomingRawResponse = null;
return response;
}
private T CloneSerializedParts(T message, HttpRequestInfo requestInfo) where T : class, IProtocolMessage {
Contract.Requires(message != null);
IProtocolMessage clonedMessage;
var messageAccessor = this.MessageDescriptions.GetAccessor(message);
var fields = messageAccessor.Serialize();
MessageReceivingEndpoint recipient = null;
var directedMessage = message as IDirectedProtocolMessage;
var directResponse = message as IDirectResponseProtocolMessage;
if (directedMessage != null && directedMessage.IsRequest()) {
if (directedMessage.Recipient != null) {
recipient = new MessageReceivingEndpoint(directedMessage.Recipient, directedMessage.HttpMethods);
}
clonedMessage = this.RemoteChannel.MessageFactoryTestHook.GetNewRequestMessage(recipient, fields);
} else if (directResponse != null && directResponse.IsDirectResponse()) {
clonedMessage = this.RemoteChannel.MessageFactoryTestHook.GetNewResponseMessage(directResponse.OriginatingRequest, fields);
} else {
throw new InvalidOperationException("Totally expected a message to implement one of the two derived interface types.");
}
// Fill the cloned message with data.
var clonedMessageAccessor = this.MessageDescriptions.GetAccessor(clonedMessage);
clonedMessageAccessor.Deserialize(fields);
return (T)clonedMessage;
}
private string GetHttpMethod(HttpDeliveryMethods methods) {
return (methods & HttpDeliveryMethods.PostRequest) != 0 ? "POST" : "GET";
}
}
}