Skip to content

Commit 08a3b6a

Browse files
authored
Merge pull request #80 from anywherelan/dns-case-insensitive
awldns: make hosts case-insensitive, add tests
2 parents 87d3996 + 5439f1e commit 08a3b6a

File tree

3 files changed

+140
-15
lines changed

3 files changed

+140
-15
lines changed

application.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,10 @@ func NewDNSService(conf *config.Config, eventbus awlevent.Bus, ctx context.Conte
301301

302302
func (a *DNSService) initDNS(interfaceName string) {
303303
var err error
304-
a.dnsResolver = awldns.NewResolver()
304+
a.dnsResolver = awldns.NewResolver(awldns.DNSAddress)
305+
a.upstreamDNS = awldns.DefaultUpstreamDNSAddress
306+
a.refreshDNSConfig()
307+
305308
awlevent.WrapSubscriptionToCallback(a.ctx, func(_ interface{}) {
306309
a.refreshDNSConfig()
307310
}, a.eventbus, new(awlevent.KnownPeerChanged))

awldns/awldns.go

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ type Resolver struct {
3535

3636
udpServerWorking bool
3737
tcpServerWorking bool
38+
39+
dnsAddress string
3840
}
3941

4042
type config struct {
@@ -43,40 +45,39 @@ type config struct {
4345
reverseMapping map[string]string
4446
}
4547

46-
func NewResolver() *Resolver {
48+
func NewResolver(dnsAddress string) *Resolver {
4749
r := &Resolver{
4850
logger: log.Logger("awl/dns"),
4951
udpClient: &dns.Client{
50-
Net: "udp",
51-
SingleInflight: true,
52+
Net: "udp",
5253
},
5354
tcpClient: &dns.Client{
54-
Net: "tcp",
55-
SingleInflight: true,
55+
Net: "tcp",
5656
},
57+
dnsAddress: dnsAddress,
5758
}
58-
r.cfg.Store(&config{upstreamDNS: "127.0.0.1:53"})
59+
r.cfg.Store(&config{})
5960

6061
mux := dns.NewServeMux()
6162
mux.HandleFunc(LocalDomain, r.dnsLocalDomainHandler)
6263
mux.HandleFunc(strings.TrimPrefix(ptrV4Suffix, "."), r.ptrv4Handler)
6364
mux.HandleFunc(".", r.dnsProxyHandler)
6465

6566
r.udpServer = &dns.Server{
66-
Addr: DNSAddress,
67+
Addr: dnsAddress,
6768
Net: "udp",
6869
Handler: mux,
6970
NotifyStartedFunc: func() {
70-
r.logger.Infof("udp server has started on %s", DNSAddress)
71+
r.logger.Infof("udp server has started on %s", dnsAddress)
7172
r.udpServerWorking = true
7273
},
7374
}
7475
r.tcpServer = &dns.Server{
75-
Addr: DNSAddress,
76+
Addr: dnsAddress,
7677
Net: "tcp",
7778
Handler: mux,
7879
NotifyStartedFunc: func() {
79-
r.logger.Infof("tcp server has started on %s", DNSAddress)
80+
r.logger.Infof("tcp server has started on %s", dnsAddress)
8081
r.tcpServerWorking = true
8182
},
8283
}
@@ -127,7 +128,7 @@ func (r *Resolver) DNSAddress() string {
127128
return ""
128129
}
129130

130-
return DNSAddress
131+
return r.dnsAddress
131132
}
132133

133134
func (r *Resolver) Close() {
@@ -149,17 +150,18 @@ func (r *Resolver) dnsLocalDomainHandler(resp dns.ResponseWriter, req *dns.Msg)
149150

150151
m := new(dns.Msg)
151152
m.SetReply(req)
152-
m.Authoritative = true
153153

154154
for _, question := range req.Question {
155155
hostname := question.Name
156156
qtype := question.Qtype
157-
mappedIP, found := cfg.directMapping[hostname]
157+
hostnameLower := strings.ToLower(hostname)
158+
mappedIP, found := cfg.directMapping[hostnameLower]
158159

159160
switch qtype {
160161
case dns.TypeA, dns.TypeAAAA, dns.TypeANY:
161162
aRec := &dns.A{
162163
Hdr: dns.RR_Header{
164+
// we should return original name from the request as some clients expect that
163165
Name: hostname,
164166
Rrtype: dns.TypeA,
165167
Class: dns.ClassINET,
@@ -176,6 +178,8 @@ func (r *Resolver) dnsLocalDomainHandler(resp dns.ResponseWriter, req *dns.Msg)
176178
}
177179
}
178180

181+
processOwnResponse(req, resp, m)
182+
179183
_ = resp.WriteMsg(m)
180184
}
181185

@@ -201,7 +205,6 @@ func (r *Resolver) ptrv4Handler(resp dns.ResponseWriter, req *dns.Msg) {
201205

202206
m := new(dns.Msg)
203207
m.SetReply(req)
204-
m.Authoritative = true
205208

206209
ptr := &dns.PTR{
207210
Hdr: dns.RR_Header{
@@ -213,6 +216,9 @@ func (r *Resolver) ptrv4Handler(resp dns.ResponseWriter, req *dns.Msg) {
213216
Ptr: mappedName,
214217
}
215218
m.Answer = append(m.Answer, ptr)
219+
220+
processOwnResponse(req, resp, m)
221+
216222
_ = resp.WriteMsg(m)
217223
}
218224

@@ -244,6 +250,24 @@ func (r *Resolver) loadConfig() config {
244250
return *cfg
245251
}
246252

253+
func processOwnResponse(req *dns.Msg, respWriter dns.ResponseWriter, resp *dns.Msg) {
254+
maxSize := dns.MinMsgSize
255+
if respWriter.LocalAddr().Network() == "tcp" {
256+
maxSize = dns.MaxMsgSize
257+
} else {
258+
if optRR := req.IsEdns0(); optRR != nil {
259+
udpsize := int(optRR.UDPSize())
260+
if udpsize > maxSize {
261+
maxSize = udpsize
262+
}
263+
}
264+
}
265+
resp.Truncate(maxSize)
266+
267+
resp.Authoritative = true
268+
resp.RecursionAvailable = true
269+
}
270+
247271
func TrimDomainName(domain string) string {
248272
domain = strings.TrimSpace(domain)
249273
domain = strings.Map(func(r rune) rune {

awldns/awldns_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package awldns
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net"
7+
"runtime"
8+
"testing"
9+
"time"
10+
11+
"github.com/miekg/dns"
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
func TestDNS(t *testing.T) {
16+
ctx := context.Background()
17+
a := require.New(t)
18+
port := FindFreePort()
19+
addr := fmt.Sprintf("127.0.0.1:%d", port)
20+
21+
resolver := NewResolver(addr)
22+
defer resolver.Close()
23+
// TODO: remove sleep. We need it because NewResolver starts servers in goroutines
24+
time.Sleep(50 * time.Millisecond)
25+
26+
name1 := "peer_id"
27+
name1Capitalized := "pEEr_Id"
28+
addr1 := "123.4.5.6"
29+
name2 := "laptop.office"
30+
name2Capitalized := "LAPTOP.office"
31+
addr2 := "10.66.0.2"
32+
33+
namesMapping := map[string]string{
34+
name1: addr1,
35+
name2: addr2,
36+
}
37+
resolver.ReceiveConfiguration("", namesMapping)
38+
39+
client := NewResolverClient(addr)
40+
41+
assertAddr := func(host, addr string) {
42+
addrs, err := client.LookupHost(ctx, host)
43+
a.NoError(err)
44+
a.Len(addrs, 1)
45+
a.Equal(addr, addrs[0])
46+
47+
hosts, err := client.LookupAddr(ctx, addr)
48+
a.NoError(err)
49+
a.Len(hosts, 1)
50+
a.Equal(dns.CanonicalName(host), hosts[0])
51+
}
52+
53+
assertAddr(name1+".awl", addr1)
54+
assertAddr(name1+".AWL", addr1)
55+
assertAddr(name1Capitalized+".awl", addr1)
56+
57+
assertAddr(name2+".awl", addr2)
58+
assertAddr(name2Capitalized+".awl", addr2)
59+
60+
addrs, err := client.LookupHost(ctx, "unknown.awl")
61+
a.Error(err)
62+
a.Empty(addrs)
63+
dnsErr := err.(*net.DNSError)
64+
// TODO: investigate why macos and linux in CI return `lookup unknown.awl on 127.0.0.53:53: server misbehaving`
65+
// it should use only our resolver, but somehow it tries to use system resolver afterwards
66+
if runtime.GOOS == "windows" {
67+
a.Equalf(true, dnsErr.IsNotFound, "actual error: %v", err)
68+
}
69+
}
70+
71+
func NewResolverClient(address string) *net.Resolver {
72+
dialer := &net.Dialer{Timeout: time.Second}
73+
return &net.Resolver{
74+
StrictErrors: true,
75+
PreferGo: true,
76+
Dial: func(ctx context.Context, network, _ string) (net.Conn, error) {
77+
return dialer.DialContext(ctx, network, address)
78+
},
79+
}
80+
}
81+
82+
func FindFreePort() int {
83+
l, err := net.Listen("tcp", "127.0.0.1:0")
84+
if err != nil {
85+
panic(fmt.Sprintf("failed to listen on a port: %v", err))
86+
}
87+
defer l.Close()
88+
89+
port := l.Addr().(*net.TCPAddr).Port
90+
91+
u, err := net.ListenPacket("udp", l.Addr().String())
92+
if err != nil {
93+
panic(fmt.Sprintf("failed to listen on a udp port: %v", err))
94+
}
95+
defer u.Close()
96+
97+
return port
98+
}

0 commit comments

Comments
 (0)