diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main.test/Tests/DnsValidationTests/When_resolving_name_servers.cs | 9 | ||||
-rw-r--r-- | src/main/Clients/DNS/LookupClientProvider.cs | 50 | ||||
-rw-r--r-- | src/main/Clients/DNS/LookupClientWrapper.cs | 40 |
3 files changed, 70 insertions, 29 deletions
diff --git a/src/main.test/Tests/DnsValidationTests/When_resolving_name_servers.cs b/src/main.test/Tests/DnsValidationTests/When_resolving_name_servers.cs index e7ccc07..892ac13 100644 --- a/src/main.test/Tests/DnsValidationTests/When_resolving_name_servers.cs +++ b/src/main.test/Tests/DnsValidationTests/When_resolving_name_servers.cs @@ -21,12 +21,15 @@ namespace PKISharp.WACS.UnitTests.Tests.DnsValidationTests [TestMethod] [DataRow("_acme-challenge.logs.hourstrackercloud.com", "Tx1e8X4LF-c615tnacJeuKmzkRmScZzsU-MJHxdDMhU")] - [DataRow("_acme-challenge.www2.candell.org", "IpualE-HBtD8bxr60LoyuLw8FxMPOIUgg2XQTR6mSvw")] - [DataRow("_acme-challenge.www2.candell.org", "I2F57jex1qSMXprwPy0crWFSUe2n5AowLitxU0q_WKM")] + //[DataRow("_acme-challenge.www2.candell.org", "IpualE-HBtD8bxr60LoyuLw8FxMPOIUgg2XQTR6mSvw")] + //[DataRow("_acme-challenge.www2.candell.org", "I2F57jex1qSMXprwPy0crWFSUe2n5AowLitxU0q_WKM")] [DataRow("_acme-challenge.wouter.tinus.online", "DHrsG3LudqI9S0jvitp25tDofK1Jf58J08s3c5rIY3k")] + //[DataRow("_acme-challenge.www7.candell.org", "xxx")] public void Should_recursively_follow_cnames(string challengeUri, string expectedToken) { - var tokens = _dnsClient.DefaultClient.GetTextRecordValues(challengeUri); + //var client = _dnsClient.DefaultClient(); + var client = _dnsClient.GetClient(challengeUri); + var tokens = client.GetTextRecordValues(challengeUri); Assert.IsTrue(tokens.Contains(expectedToken)); } } diff --git a/src/main/Clients/DNS/LookupClientProvider.cs b/src/main/Clients/DNS/LookupClientProvider.cs index 2d11e9a..e9560a0 100644 --- a/src/main/Clients/DNS/LookupClientProvider.cs +++ b/src/main/Clients/DNS/LookupClientProvider.cs @@ -2,6 +2,7 @@ using Nager.PublicSuffix; using PKISharp.WACS.Extensions; using PKISharp.WACS.Services; +using Serilog.Context; using System; using System.Collections.Concurrent; using System.Collections.Generic; @@ -71,8 +72,53 @@ namespace PKISharp.WACS.Clients.DNS /// <returns>Returns an <see cref="ILookupClient"/> using a name server associated with the specified domain name.</returns> public LookupClientWrapper GetClient(string domainName) { - IPAddress[] ipAddresses = DefaultClient.GetAuthoritativeNameServers(domainName, out string authoratitiveZone).ToArray(); - return _lookupClients.GetOrAdd(authoratitiveZone, new LookupClientWrapper(DomainParser, _log, new LookupClient(ipAddresses), this)); + // _acme-challenge.sub.example.co.uk + domainName = domainName.TrimEnd('.'); + + // First domain we should try to ask + var rootDomain = DomainParser.GetRegisterableDomain(domainName); + var testZone = rootDomain; + var authoritativeZone = testZone; + var client = DefaultClient; + + // Other sub domains we should try asking: + // 1. sub + // 2. _acme-challenge + IEnumerable<string> remainingParts = domainName.Substring(0, domainName.LastIndexOf(rootDomain)).Trim('.').Split('.'); + remainingParts = remainingParts.Reverse(); + + var digDeeper = true; + IEnumerable<IPAddress> ipSet = null; + do + { + using (LogContext.PushProperty("Domain", testZone)) + { + _log.Debug("Querying name servers for {part}", testZone); + var tempResult = client.GetAuthoritativeNameServers(testZone); + if (tempResult != null) + { + ipSet = tempResult; + authoritativeZone = testZone; + client = GetClient(ipSet.First()); + } + } + if (remainingParts.Any()) + { + testZone = $"{remainingParts.First()}.{testZone}"; + remainingParts = remainingParts.Skip(1).ToArray(); + } + else + { + digDeeper = false; + } + } while (digDeeper); + + if (ipSet == null) + { + throw new Exception($"Unable to determine name servers for domain {domainName}"); + } + + return _lookupClients.GetOrAdd(authoritativeZone, new LookupClientWrapper(DomainParser, _log, new LookupClient(ipSet.First()), this)); } } diff --git a/src/main/Clients/DNS/LookupClientWrapper.cs b/src/main/Clients/DNS/LookupClientWrapper.cs index 31d343f..bcea03c 100644 --- a/src/main/Clients/DNS/LookupClientWrapper.cs +++ b/src/main/Clients/DNS/LookupClientWrapper.cs @@ -27,42 +27,34 @@ namespace PKISharp.WACS.Clients.DNS public string GetRootDomain(string domainName) { - if (domainName.EndsWith(".")) - { - domainName = domainName.TrimEnd('.'); - } - return _domainParser.GetRegisterableDomain(domainName); + return _domainParser.GetRegisterableDomain(domainName.TrimEnd('.')); } - public IEnumerable<IPAddress> GetAuthoritativeNameServers(string domainName, out string authoritativeZone) + public IEnumerable<IPAddress> GetAuthoritativeNameServers(string domainName) { - var rootDomain = GetRootDomain(domainName); - authoritativeZone = domainName.TrimEnd('.'); - do + domainName = domainName.TrimEnd('.'); + _log.Debug("Querying name servers for {part}", domainName); + var nsResponse = LookupClient.Query(domainName, QueryType.NS); + var nsRecords = nsResponse.Answers.NsRecords(); + if (!nsRecords.Any()) { - using (LogContext.PushProperty("Domain", authoritativeZone)) - { - _log.Debug("Querying name servers for {part}", authoritativeZone); - var nsResponse = LookupClient.Query(authoritativeZone, QueryType.NS); - if (nsResponse.Answers.NsRecords().Any()) - { - return GetNameServerIpAddresses(nsResponse.Answers.NsRecords()); - } - } - authoritativeZone = authoritativeZone.Substring(authoritativeZone.IndexOf('.') + 1); + nsRecords = nsResponse.Authorities.OfType<NsRecord>(); + } + if (nsRecords.Any()) + { + return GetNameServerIpAddresses(nsRecords.Select(n => n.NSDName.Value)); } - while (authoritativeZone.Length >= rootDomain.Length); - throw new Exception($"Unable to determine name servers for domain {domainName}"); + return null; } - private IEnumerable<IPAddress> GetNameServerIpAddresses(IEnumerable<NsRecord> nsRecords) + private IEnumerable<IPAddress> GetNameServerIpAddresses(IEnumerable<string> nsRecords) { foreach (var nsRecord in nsRecords) { - using (LogContext.PushProperty("NameServer", nsRecord.NSDName)) + using (LogContext.PushProperty("NameServer", nsRecord)) { _log.Debug("Querying IP for name server"); - var aResponse = _provider.DefaultClient.LookupClient.Query(nsRecord.NSDName, QueryType.A); + var aResponse = _provider.DefaultClient.LookupClient.Query(nsRecord, QueryType.A); var nameServerIp = aResponse.Answers.ARecords().FirstOrDefault()?.Address; _log.Debug("Name server IP {NameServerIpAddress} identified", nameServerIp); yield return nameServerIp; |