github.com/sean-/go@v0.0.0-20151219100004-97f854cd7bb6/src/net/lookup_windows.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package net
     6  
     7  import (
     8  	"os"
     9  	"runtime"
    10  	"syscall"
    11  	"unsafe"
    12  )
    13  
    14  var (
    15  	lookupPort = oldLookupPort
    16  	lookupIP   = oldLookupIP
    17  )
    18  
    19  func getprotobyname(name string) (proto int, err error) {
    20  	p, err := syscall.GetProtoByName(name)
    21  	if err != nil {
    22  		return 0, os.NewSyscallError("getorotobyname", err)
    23  	}
    24  	return int(p.Proto), nil
    25  }
    26  
    27  // lookupProtocol looks up IP protocol name and returns correspondent protocol number.
    28  func lookupProtocol(name string) (int, error) {
    29  	// GetProtoByName return value is stored in thread local storage.
    30  	// Start new os thread before the call to prevent races.
    31  	type result struct {
    32  		proto int
    33  		err   error
    34  	}
    35  	ch := make(chan result)
    36  	go func() {
    37  		acquireThread()
    38  		defer releaseThread()
    39  		runtime.LockOSThread()
    40  		defer runtime.UnlockOSThread()
    41  		proto, err := getprotobyname(name)
    42  		ch <- result{proto: proto, err: err}
    43  	}()
    44  	r := <-ch
    45  	if r.err != nil {
    46  		if proto, ok := protocols[name]; ok {
    47  			return proto, nil
    48  		}
    49  		r.err = &DNSError{Err: r.err.Error(), Name: name}
    50  	}
    51  	return r.proto, r.err
    52  }
    53  
    54  func lookupHost(name string) ([]string, error) {
    55  	ips, err := LookupIP(name)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  	addrs := make([]string, 0, len(ips))
    60  	for _, ip := range ips {
    61  		addrs = append(addrs, ip.String())
    62  	}
    63  	return addrs, nil
    64  }
    65  
    66  func gethostbyname(name string) (addrs []IPAddr, err error) {
    67  	// caller already acquired thread
    68  	h, err := syscall.GetHostByName(name)
    69  	if err != nil {
    70  		return nil, os.NewSyscallError("gethostbyname", err)
    71  	}
    72  	switch h.AddrType {
    73  	case syscall.AF_INET:
    74  		i := 0
    75  		addrs = make([]IPAddr, 100) // plenty of room to grow
    76  		for p := (*[100](*[4]byte))(unsafe.Pointer(h.AddrList)); i < cap(addrs) && p[i] != nil; i++ {
    77  			addrs[i] = IPAddr{IP: IPv4(p[i][0], p[i][1], p[i][2], p[i][3])}
    78  		}
    79  		addrs = addrs[0:i]
    80  	default: // TODO(vcc): Implement non IPv4 address lookups.
    81  		return nil, syscall.EWINDOWS
    82  	}
    83  	return addrs, nil
    84  }
    85  
    86  func oldLookupIP(name string) ([]IPAddr, error) {
    87  	// GetHostByName return value is stored in thread local storage.
    88  	// Start new os thread before the call to prevent races.
    89  	type result struct {
    90  		addrs []IPAddr
    91  		err   error
    92  	}
    93  	ch := make(chan result)
    94  	go func() {
    95  		acquireThread()
    96  		defer releaseThread()
    97  		runtime.LockOSThread()
    98  		defer runtime.UnlockOSThread()
    99  		addrs, err := gethostbyname(name)
   100  		ch <- result{addrs: addrs, err: err}
   101  	}()
   102  	r := <-ch
   103  	if r.err != nil {
   104  		r.err = &DNSError{Err: r.err.Error(), Name: name}
   105  	}
   106  	return r.addrs, r.err
   107  }
   108  
   109  func newLookupIP(name string) ([]IPAddr, error) {
   110  	acquireThread()
   111  	defer releaseThread()
   112  	hints := syscall.AddrinfoW{
   113  		Family:   syscall.AF_UNSPEC,
   114  		Socktype: syscall.SOCK_STREAM,
   115  		Protocol: syscall.IPPROTO_IP,
   116  	}
   117  	var result *syscall.AddrinfoW
   118  	e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
   119  	if e != nil {
   120  		return nil, &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: name}
   121  	}
   122  	defer syscall.FreeAddrInfoW(result)
   123  	addrs := make([]IPAddr, 0, 5)
   124  	for ; result != nil; result = result.Next {
   125  		addr := unsafe.Pointer(result.Addr)
   126  		switch result.Family {
   127  		case syscall.AF_INET:
   128  			a := (*syscall.RawSockaddrInet4)(addr).Addr
   129  			addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])})
   130  		case syscall.AF_INET6:
   131  			a := (*syscall.RawSockaddrInet6)(addr).Addr
   132  			zone := zoneToString(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
   133  			addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone})
   134  		default:
   135  			return nil, &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}
   136  		}
   137  	}
   138  	return addrs, nil
   139  }
   140  
   141  func getservbyname(network, service string) (int, error) {
   142  	acquireThread()
   143  	defer releaseThread()
   144  	switch network {
   145  	case "tcp4", "tcp6":
   146  		network = "tcp"
   147  	case "udp4", "udp6":
   148  		network = "udp"
   149  	}
   150  	s, err := syscall.GetServByName(service, network)
   151  	if err != nil {
   152  		return 0, os.NewSyscallError("getservbyname", err)
   153  	}
   154  	return int(syscall.Ntohs(s.Port)), nil
   155  }
   156  
   157  func oldLookupPort(network, service string) (int, error) {
   158  	// GetServByName return value is stored in thread local storage.
   159  	// Start new os thread before the call to prevent races.
   160  	type result struct {
   161  		port int
   162  		err  error
   163  	}
   164  	ch := make(chan result)
   165  	go func() {
   166  		acquireThread()
   167  		defer releaseThread()
   168  		runtime.LockOSThread()
   169  		defer runtime.UnlockOSThread()
   170  		port, err := getservbyname(network, service)
   171  		ch <- result{port: port, err: err}
   172  	}()
   173  	r := <-ch
   174  	if r.err != nil {
   175  		r.err = &DNSError{Err: r.err.Error(), Name: network + "/" + service}
   176  	}
   177  	return r.port, r.err
   178  }
   179  
   180  func newLookupPort(network, service string) (int, error) {
   181  	acquireThread()
   182  	defer releaseThread()
   183  	var stype int32
   184  	switch network {
   185  	case "tcp4", "tcp6":
   186  		stype = syscall.SOCK_STREAM
   187  	case "udp4", "udp6":
   188  		stype = syscall.SOCK_DGRAM
   189  	}
   190  	hints := syscall.AddrinfoW{
   191  		Family:   syscall.AF_UNSPEC,
   192  		Socktype: stype,
   193  		Protocol: syscall.IPPROTO_IP,
   194  	}
   195  	var result *syscall.AddrinfoW
   196  	e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
   197  	if e != nil {
   198  		return 0, &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: network + "/" + service}
   199  	}
   200  	defer syscall.FreeAddrInfoW(result)
   201  	if result == nil {
   202  		return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
   203  	}
   204  	addr := unsafe.Pointer(result.Addr)
   205  	switch result.Family {
   206  	case syscall.AF_INET:
   207  		a := (*syscall.RawSockaddrInet4)(addr)
   208  		return int(syscall.Ntohs(a.Port)), nil
   209  	case syscall.AF_INET6:
   210  		a := (*syscall.RawSockaddrInet6)(addr)
   211  		return int(syscall.Ntohs(a.Port)), nil
   212  	}
   213  	return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
   214  }
   215  
   216  // ensureEndDot adds '.' at the end of name unless it is already there.
   217  func ensureEndDot(name string) string {
   218  	if name == "" {
   219  		return "."
   220  	}
   221  	if name[len(name)-1] == '.' {
   222  		return name
   223  	}
   224  	return name + "."
   225  }
   226  
   227  func lookupCNAME(name string) (string, error) {
   228  	acquireThread()
   229  	defer releaseThread()
   230  	var r *syscall.DNSRecord
   231  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
   232  	// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
   233  	if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
   234  		// if there are no aliases, the canonical name is the input name
   235  		return ensureEndDot(name), nil
   236  	}
   237  	if e != nil {
   238  		return "", &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
   239  	}
   240  	defer syscall.DnsRecordListFree(r, 1)
   241  
   242  	resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
   243  	cname := syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:])
   244  	return ensureEndDot(cname), nil
   245  }
   246  
   247  func lookupSRV(service, proto, name string) (string, []*SRV, error) {
   248  	acquireThread()
   249  	defer releaseThread()
   250  	var target string
   251  	if service == "" && proto == "" {
   252  		target = name
   253  	} else {
   254  		target = "_" + service + "._" + proto + "." + name
   255  	}
   256  	var r *syscall.DNSRecord
   257  	e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil)
   258  	if e != nil {
   259  		return "", nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: target}
   260  	}
   261  	defer syscall.DnsRecordListFree(r, 1)
   262  
   263  	srvs := make([]*SRV, 0, 10)
   264  	for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
   265  		v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
   266  		srvs = append(srvs, &SRV{ensureEndDot(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:])), v.Port, v.Priority, v.Weight})
   267  	}
   268  	byPriorityWeight(srvs).sort()
   269  	return ensureEndDot(target), srvs, nil
   270  }
   271  
   272  func lookupMX(name string) ([]*MX, error) {
   273  	acquireThread()
   274  	defer releaseThread()
   275  	var r *syscall.DNSRecord
   276  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil)
   277  	if e != nil {
   278  		return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
   279  	}
   280  	defer syscall.DnsRecordListFree(r, 1)
   281  
   282  	mxs := make([]*MX, 0, 10)
   283  	for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
   284  		v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
   285  		mxs = append(mxs, &MX{ensureEndDot(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:])), v.Preference})
   286  	}
   287  	byPref(mxs).sort()
   288  	return mxs, nil
   289  }
   290  
   291  func lookupNS(name string) ([]*NS, error) {
   292  	acquireThread()
   293  	defer releaseThread()
   294  	var r *syscall.DNSRecord
   295  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil)
   296  	if e != nil {
   297  		return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
   298  	}
   299  	defer syscall.DnsRecordListFree(r, 1)
   300  
   301  	nss := make([]*NS, 0, 10)
   302  	for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
   303  		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
   304  		nss = append(nss, &NS{ensureEndDot(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))})
   305  	}
   306  	return nss, nil
   307  }
   308  
   309  func lookupTXT(name string) ([]string, error) {
   310  	acquireThread()
   311  	defer releaseThread()
   312  	var r *syscall.DNSRecord
   313  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil)
   314  	if e != nil {
   315  		return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
   316  	}
   317  	defer syscall.DnsRecordListFree(r, 1)
   318  
   319  	txts := make([]string, 0, 10)
   320  	for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) {
   321  		d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
   322  		for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] {
   323  			s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:])
   324  			txts = append(txts, s)
   325  		}
   326  	}
   327  	return txts, nil
   328  }
   329  
   330  func lookupAddr(addr string) ([]string, error) {
   331  	acquireThread()
   332  	defer releaseThread()
   333  	arpa, err := reverseaddr(addr)
   334  	if err != nil {
   335  		return nil, err
   336  	}
   337  	var r *syscall.DNSRecord
   338  	e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil)
   339  	if e != nil {
   340  		return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: addr}
   341  	}
   342  	defer syscall.DnsRecordListFree(r, 1)
   343  
   344  	ptrs := make([]string, 0, 10)
   345  	for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
   346  		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
   347  		ptrs = append(ptrs, ensureEndDot(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:])))
   348  	}
   349  	return ptrs, nil
   350  }
   351  
   352  const dnsSectionMask = 0x0003
   353  
   354  // returns only results applicable to name and resolves CNAME entries
   355  func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
   356  	cname := syscall.StringToUTF16Ptr(name)
   357  	if dnstype != syscall.DNS_TYPE_CNAME {
   358  		cname = resolveCNAME(cname, r)
   359  	}
   360  	rec := make([]*syscall.DNSRecord, 0, 10)
   361  	for p := r; p != nil; p = p.Next {
   362  		if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
   363  			continue
   364  		}
   365  		if p.Type != dnstype {
   366  			continue
   367  		}
   368  		if !syscall.DnsNameCompare(cname, p.Name) {
   369  			continue
   370  		}
   371  		rec = append(rec, p)
   372  	}
   373  	return rec
   374  }
   375  
   376  // returns the last CNAME in chain
   377  func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
   378  	// limit cname resolving to 10 in case of a infinite CNAME loop
   379  Cname:
   380  	for cnameloop := 0; cnameloop < 10; cnameloop++ {
   381  		for p := r; p != nil; p = p.Next {
   382  			if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
   383  				continue
   384  			}
   385  			if p.Type != syscall.DNS_TYPE_CNAME {
   386  				continue
   387  			}
   388  			if !syscall.DnsNameCompare(name, p.Name) {
   389  				continue
   390  			}
   391  			name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
   392  			continue Cname
   393  		}
   394  		break
   395  	}
   396  	return name
   397  }