summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/main.test/Tests/DnsValidationTests/When_resolving_name_servers.cs9
-rw-r--r--src/main/Clients/DNS/LookupClientProvider.cs50
-rw-r--r--src/main/Clients/DNS/LookupClientWrapper.cs40
3 files changed, 70 insertions, 29 deletions
diff --git a/src/main.test/Tests/DnsValidationTests/When_resolving_name_servers.cs b/src/main.test/Tests/DnsValidationTests/When_resolving_name_servers.cs
index e7ccc07..892ac13 100644
--- a/src/main.test/Tests/DnsValidationTests/When_resolving_name_servers.cs
+++ b/src/main.test/Tests/DnsValidationTests/When_resolving_name_servers.cs
@@ -21,12 +21,15 @@ namespace PKISharp.WACS.UnitTests.Tests.DnsValidationTests
[TestMethod]
[DataRow("_acme-challenge.logs.hourstrackercloud.com", "Tx1e8X4LF-c615tnacJeuKmzkRmScZzsU-MJHxdDMhU")]
- [DataRow("_acme-challenge.www2.candell.org", "IpualE-HBtD8bxr60LoyuLw8FxMPOIUgg2XQTR6mSvw")]
- [DataRow("_acme-challenge.www2.candell.org", "I2F57jex1qSMXprwPy0crWFSUe2n5AowLitxU0q_WKM")]
+ //[DataRow("_acme-challenge.www2.candell.org", "IpualE-HBtD8bxr60LoyuLw8FxMPOIUgg2XQTR6mSvw")]
+ //[DataRow("_acme-challenge.www2.candell.org", "I2F57jex1qSMXprwPy0crWFSUe2n5AowLitxU0q_WKM")]
[DataRow("_acme-challenge.wouter.tinus.online", "DHrsG3LudqI9S0jvitp25tDofK1Jf58J08s3c5rIY3k")]
+ //[DataRow("_acme-challenge.www7.candell.org", "xxx")]
public void Should_recursively_follow_cnames(string challengeUri, string expectedToken)
{
- var tokens = _dnsClient.DefaultClient.GetTextRecordValues(challengeUri);
+ //var client = _dnsClient.DefaultClient();
+ var client = _dnsClient.GetClient(challengeUri);
+ var tokens = client.GetTextRecordValues(challengeUri);
Assert.IsTrue(tokens.Contains(expectedToken));
}
}
diff --git a/src/main/Clients/DNS/LookupClientProvider.cs b/src/main/Clients/DNS/LookupClientProvider.cs
index 2d11e9a..e9560a0 100644
--- a/src/main/Clients/DNS/LookupClientProvider.cs
+++ b/src/main/Clients/DNS/LookupClientProvider.cs
@@ -2,6 +2,7 @@
using Nager.PublicSuffix;
using PKISharp.WACS.Extensions;
using PKISharp.WACS.Services;
+using Serilog.Context;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
@@ -71,8 +72,53 @@ namespace PKISharp.WACS.Clients.DNS
/// <returns>Returns an <see cref="ILookupClient"/> using a name server associated with the specified domain name.</returns>
public LookupClientWrapper GetClient(string domainName)
{
- IPAddress[] ipAddresses = DefaultClient.GetAuthoritativeNameServers(domainName, out string authoratitiveZone).ToArray();
- return _lookupClients.GetOrAdd(authoratitiveZone, new LookupClientWrapper(DomainParser, _log, new LookupClient(ipAddresses), this));
+ // _acme-challenge.sub.example.co.uk
+ domainName = domainName.TrimEnd('.');
+
+ // First domain we should try to ask
+ var rootDomain = DomainParser.GetRegisterableDomain(domainName);
+ var testZone = rootDomain;
+ var authoritativeZone = testZone;
+ var client = DefaultClient;
+
+ // Other sub domains we should try asking:
+ // 1. sub
+ // 2. _acme-challenge
+ IEnumerable<string> remainingParts = domainName.Substring(0, domainName.LastIndexOf(rootDomain)).Trim('.').Split('.');
+ remainingParts = remainingParts.Reverse();
+
+ var digDeeper = true;
+ IEnumerable<IPAddress> ipSet = null;
+ do
+ {
+ using (LogContext.PushProperty("Domain", testZone))
+ {
+ _log.Debug("Querying name servers for {part}", testZone);
+ var tempResult = client.GetAuthoritativeNameServers(testZone);
+ if (tempResult != null)
+ {
+ ipSet = tempResult;
+ authoritativeZone = testZone;
+ client = GetClient(ipSet.First());
+ }
+ }
+ if (remainingParts.Any())
+ {
+ testZone = $"{remainingParts.First()}.{testZone}";
+ remainingParts = remainingParts.Skip(1).ToArray();
+ }
+ else
+ {
+ digDeeper = false;
+ }
+ } while (digDeeper);
+
+ if (ipSet == null)
+ {
+ throw new Exception($"Unable to determine name servers for domain {domainName}");
+ }
+
+ return _lookupClients.GetOrAdd(authoritativeZone, new LookupClientWrapper(DomainParser, _log, new LookupClient(ipSet.First()), this));
}
}
diff --git a/src/main/Clients/DNS/LookupClientWrapper.cs b/src/main/Clients/DNS/LookupClientWrapper.cs
index 31d343f..bcea03c 100644
--- a/src/main/Clients/DNS/LookupClientWrapper.cs
+++ b/src/main/Clients/DNS/LookupClientWrapper.cs
@@ -27,42 +27,34 @@ namespace PKISharp.WACS.Clients.DNS
public string GetRootDomain(string domainName)
{
- if (domainName.EndsWith("."))
- {
- domainName = domainName.TrimEnd('.');
- }
- return _domainParser.GetRegisterableDomain(domainName);
+ return _domainParser.GetRegisterableDomain(domainName.TrimEnd('.'));
}
- public IEnumerable<IPAddress> GetAuthoritativeNameServers(string domainName, out string authoritativeZone)
+ public IEnumerable<IPAddress> GetAuthoritativeNameServers(string domainName)
{
- var rootDomain = GetRootDomain(domainName);
- authoritativeZone = domainName.TrimEnd('.');
- do
+ domainName = domainName.TrimEnd('.');
+ _log.Debug("Querying name servers for {part}", domainName);
+ var nsResponse = LookupClient.Query(domainName, QueryType.NS);
+ var nsRecords = nsResponse.Answers.NsRecords();
+ if (!nsRecords.Any())
{
- using (LogContext.PushProperty("Domain", authoritativeZone))
- {
- _log.Debug("Querying name servers for {part}", authoritativeZone);
- var nsResponse = LookupClient.Query(authoritativeZone, QueryType.NS);
- if (nsResponse.Answers.NsRecords().Any())
- {
- return GetNameServerIpAddresses(nsResponse.Answers.NsRecords());
- }
- }
- authoritativeZone = authoritativeZone.Substring(authoritativeZone.IndexOf('.') + 1);
+ nsRecords = nsResponse.Authorities.OfType<NsRecord>();
+ }
+ if (nsRecords.Any())
+ {
+ return GetNameServerIpAddresses(nsRecords.Select(n => n.NSDName.Value));
}
- while (authoritativeZone.Length >= rootDomain.Length);
- throw new Exception($"Unable to determine name servers for domain {domainName}");
+ return null;
}
- private IEnumerable<IPAddress> GetNameServerIpAddresses(IEnumerable<NsRecord> nsRecords)
+ private IEnumerable<IPAddress> GetNameServerIpAddresses(IEnumerable<string> nsRecords)
{
foreach (var nsRecord in nsRecords)
{
- using (LogContext.PushProperty("NameServer", nsRecord.NSDName))
+ using (LogContext.PushProperty("NameServer", nsRecord))
{
_log.Debug("Querying IP for name server");
- var aResponse = _provider.DefaultClient.LookupClient.Query(nsRecord.NSDName, QueryType.A);
+ var aResponse = _provider.DefaultClient.LookupClient.Query(nsRecord, QueryType.A);
var nameServerIp = aResponse.Answers.ARecords().FirstOrDefault()?.Address;
_log.Debug("Name server IP {NameServerIpAddress} identified", nameServerIp);
yield return nameServerIp;