github.com/d4l3k/go@v0.0.0-20151015000803-65fc379daeda/src/net/lookup_windows.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package net
     6  
     7  import (
     8  	"os"
     9  	"runtime"
    10  	"syscall"
    11  	"unsafe"
    12  )
    13  
    14  var (
    15  	lookupPort = oldLookupPort
    16  	lookupIP   = oldLookupIP
    17  )
    18  
    19  func getprotobyname(name string) (proto int, err error) {
    20  	p, err := syscall.GetProtoByName(name)
    21  	if err != nil {
    22  		return 0, os.NewSyscallError("getorotobyname", err)
    23  	}
    24  	return int(p.Proto), nil
    25  }
    26  
    27  // lookupProtocol looks up IP protocol name and returns correspondent protocol number.
    28  func lookupProtocol(name string) (int, error) {
    29  	// GetProtoByName return value is stored in thread local storage.
    30  	// Start new os thread before the call to prevent races.
    31  	type result struct {
    32  		proto int
    33  		err   error
    34  	}
    35  	ch := make(chan result)
    36  	go func() {
    37  		acquireThread()
    38  		defer releaseThread()
    39  		runtime.LockOSThread()
    40  		defer runtime.UnlockOSThread()
    41  		proto, err := getprotobyname(name)
    42  		ch <- result{proto: proto, err: err}
    43  	}()
    44  	r := <-ch
    45  	if r.err != nil {
    46  		if proto, ok := protocols[name]; ok {
    47  			return proto, nil
    48  		}
    49  		r.err = &DNSError{Err: r.err.Error(), Name: name}
    50  	}
    51  	return r.proto, r.err
    52  }
    53  
    54  func lookupHost(name string) ([]string, error) {
    55  	ips, err := LookupIP(name)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  	addrs := make([]string, 0, len(ips))
    60  	for _, ip := range ips {
    61  		addrs = append(addrs, ip.String())
    62  	}
    63  	return addrs, nil
    64  }
    65  
    66  func gethostbyname(name string) (addrs []IPAddr, err error) {
    67  	// caller already acquired thread
    68  	h, err := syscall.GetHostByName(name)
    69  	if err != nil {
    70  		return nil, os.NewSyscallError("gethostbyname", err)
    71  	}
    72  	switch h.AddrType {
    73  	case syscall.AF_INET:
    74  		i := 0
    75  		addrs = make([]IPAddr, 100) // plenty of room to grow
    76  		for p := (*[100](*[4]byte))(unsafe.Pointer(h.AddrList)); i < cap(addrs) && p[i] != nil; i++ {
    77  			addrs[i] = IPAddr{IP: IPv4(p[i][0], p[i][1], p[i][2], p[i][3])}
    78  		}
    79  		addrs = addrs[0:i]
    80  	default: // TODO(vcc): Implement non IPv4 address lookups.
    81  		return nil, syscall.EWINDOWS
    82  	}
    83  	return addrs, nil
    84  }
    85  
    86  func oldLookupIP(name string) ([]IPAddr, error) {
    87  	// GetHostByName return value is stored in thread local storage.
    88  	// Start new os thread before the call to prevent races.
    89  	type result struct {
    90  		addrs []IPAddr
    91  		err   error
    92  	}
    93  	ch := make(chan result)
    94  	go func() {
    95  		acquireThread()
    96  		defer releaseThread()
    97  		runtime.LockOSThread()
    98  		defer runtime.UnlockOSThread()
    99  		addrs, err := gethostbyname(name)
   100  		ch <- result{addrs: addrs, err: err}
   101  	}()
   102  	r := <-ch
   103  	if r.err != nil {
   104  		r.err = &DNSError{Err: r.err.Error(), Name: name}
   105  	}
   106  	return r.addrs, r.err
   107  }
   108  
   109  func newLookupIP(name string) ([]IPAddr, error) {
   110  	acquireThread()
   111  	defer releaseThread()
   112  	hints := syscall.AddrinfoW{
   113  		Family:   syscall.AF_UNSPEC,
   114  		Socktype: syscall.SOCK_STREAM,
   115  		Protocol: syscall.IPPROTO_IP,
   116  	}
   117  	var result *syscall.AddrinfoW
   118  	e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
   119  	if e != nil {
   120  		return nil, &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: name}
   121  	}
   122  	defer syscall.FreeAddrInfoW(result)
   123  	addrs := make([]IPAddr, 0, 5)
   124  	for ; result != nil; result = result.Next {
   125  		addr := unsafe.Pointer(result.Addr)
   126  		switch result.Family {
   127  		case syscall.AF_INET:
   128  			a := (*syscall.RawSockaddrInet4)(addr).Addr
   129  			addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])})
   130  		case syscall.AF_INET6:
   131  			a := (*syscall.RawSockaddrInet6)(addr).Addr
   132  			zone := zoneToString(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
   133  			addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone})
   134  		default:
   135  			return nil, &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}
   136  		}
   137  	}
   138  	return addrs, nil
   139  }
   140  
   141  func getservbyname(network, service string) (int, error) {
   142  	acquireThread()
   143  	defer releaseThread()
   144  	switch network {
   145  	case "tcp4", "tcp6":
   146  		network = "tcp"
   147  	case "udp4", "udp6":
   148  		network = "udp"
   149  	}
   150  	s, err := syscall.GetServByName(service, network)
   151  	if err != nil {
   152  		return 0, os.NewSyscallError("getservbyname", err)
   153  	}
   154  	return int(syscall.Ntohs(s.Port)), nil
   155  }
   156  
   157  func oldLookupPort(network, service string) (int, error) {
   158  	// GetServByName return value is stored in thread local storage.
   159  	// Start new os thread before the call to prevent races.
   160  	type result struct {
   161  		port int
   162  		err  error
   163  	}
   164  	ch := make(chan result)
   165  	go func() {
   166  		acquireThread()
   167  		defer releaseThread()
   168  		runtime.LockOSThread()
   169  		defer runtime.UnlockOSThread()
   170  		port, err := getservbyname(network, service)
   171  		ch <- result{port: port, err: err}
   172  	}()
   173  	r := <-ch
   174  	if r.err != nil {
   175  		r.err = &DNSError{Err: r.err.Error(), Name: network + "/" + service}
   176  	}
   177  	return r.port, r.err
   178  }
   179  
   180  func newLookupPort(network, service string) (int, error) {
   181  	acquireThread()
   182  	defer releaseThread()
   183  	var stype int32
   184  	switch network {
   185  	case "tcp4", "tcp6":
   186  		stype = syscall.SOCK_STREAM
   187  	case "udp4", "udp6":
   188  		stype = syscall.SOCK_DGRAM
   189  	}
   190  	hints := syscall.AddrinfoW{
   191  		Family:   syscall.AF_UNSPEC,
   192  		Socktype: stype,
   193  		Protocol: syscall.IPPROTO_IP,
   194  	}
   195  	var result *syscall.AddrinfoW
   196  	e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
   197  	if e != nil {
   198  		return 0, &DNSError{Err: os.NewSyscallError("getaddrinfow", e).Error(), Name: network + "/" + service}
   199  	}
   200  	defer syscall.FreeAddrInfoW(result)
   201  	if result == nil {
   202  		return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
   203  	}
   204  	addr := unsafe.Pointer(result.Addr)
   205  	switch result.Family {
   206  	case syscall.AF_INET:
   207  		a := (*syscall.RawSockaddrInet4)(addr)
   208  		return int(syscall.Ntohs(a.Port)), nil
   209  	case syscall.AF_INET6:
   210  		a := (*syscall.RawSockaddrInet6)(addr)
   211  		return int(syscall.Ntohs(a.Port)), nil
   212  	}
   213  	return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
   214  }
   215  
   216  func lookupCNAME(name string) (string, error) {
   217  	acquireThread()
   218  	defer releaseThread()
   219  	var r *syscall.DNSRecord
   220  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
   221  	// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
   222  	if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
   223  		// if there are no aliases, the canonical name is the input name
   224  		if name == "" || name[len(name)-1] != '.' {
   225  			return name + ".", nil
   226  		}
   227  		return name, nil
   228  	}
   229  	if e != nil {
   230  		return "", &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
   231  	}
   232  	defer syscall.DnsRecordListFree(r, 1)
   233  
   234  	resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
   235  	cname := syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:]) + "."
   236  	return cname, nil
   237  }
   238  
   239  func lookupSRV(service, proto, name string) (string, []*SRV, error) {
   240  	acquireThread()
   241  	defer releaseThread()
   242  	var target string
   243  	if service == "" && proto == "" {
   244  		target = name
   245  	} else {
   246  		target = "_" + service + "._" + proto + "." + name
   247  	}
   248  	var r *syscall.DNSRecord
   249  	e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil)
   250  	if e != nil {
   251  		return "", nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: target}
   252  	}
   253  	defer syscall.DnsRecordListFree(r, 1)
   254  
   255  	srvs := make([]*SRV, 0, 10)
   256  	for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
   257  		v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
   258  		srvs = append(srvs, &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight})
   259  	}
   260  	byPriorityWeight(srvs).sort()
   261  	return name, srvs, nil
   262  }
   263  
   264  func lookupMX(name string) ([]*MX, error) {
   265  	acquireThread()
   266  	defer releaseThread()
   267  	var r *syscall.DNSRecord
   268  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil)
   269  	if e != nil {
   270  		return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
   271  	}
   272  	defer syscall.DnsRecordListFree(r, 1)
   273  
   274  	mxs := make([]*MX, 0, 10)
   275  	for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
   276  		v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
   277  		mxs = append(mxs, &MX{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]) + ".", v.Preference})
   278  	}
   279  	byPref(mxs).sort()
   280  	return mxs, nil
   281  }
   282  
   283  func lookupNS(name string) ([]*NS, error) {
   284  	acquireThread()
   285  	defer releaseThread()
   286  	var r *syscall.DNSRecord
   287  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil)
   288  	if e != nil {
   289  		return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
   290  	}
   291  	defer syscall.DnsRecordListFree(r, 1)
   292  
   293  	nss := make([]*NS, 0, 10)
   294  	for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
   295  		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
   296  		nss = append(nss, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."})
   297  	}
   298  	return nss, nil
   299  }
   300  
   301  func lookupTXT(name string) ([]string, error) {
   302  	acquireThread()
   303  	defer releaseThread()
   304  	var r *syscall.DNSRecord
   305  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil)
   306  	if e != nil {
   307  		return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
   308  	}
   309  	defer syscall.DnsRecordListFree(r, 1)
   310  
   311  	txts := make([]string, 0, 10)
   312  	for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) {
   313  		d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
   314  		for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] {
   315  			s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:])
   316  			txts = append(txts, s)
   317  		}
   318  	}
   319  	return txts, nil
   320  }
   321  
   322  func lookupAddr(addr string) ([]string, error) {
   323  	acquireThread()
   324  	defer releaseThread()
   325  	arpa, err := reverseaddr(addr)
   326  	if err != nil {
   327  		return nil, err
   328  	}
   329  	var r *syscall.DNSRecord
   330  	e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil)
   331  	if e != nil {
   332  		return nil, &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: addr}
   333  	}
   334  	defer syscall.DnsRecordListFree(r, 1)
   335  
   336  	ptrs := make([]string, 0, 10)
   337  	for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
   338  		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
   339  		ptrs = append(ptrs, syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))
   340  	}
   341  	return ptrs, nil
   342  }
   343  
   344  const dnsSectionMask = 0x0003
   345  
   346  // returns only results applicable to name and resolves CNAME entries
   347  func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
   348  	cname := syscall.StringToUTF16Ptr(name)
   349  	if dnstype != syscall.DNS_TYPE_CNAME {
   350  		cname = resolveCNAME(cname, r)
   351  	}
   352  	rec := make([]*syscall.DNSRecord, 0, 10)
   353  	for p := r; p != nil; p = p.Next {
   354  		if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
   355  			continue
   356  		}
   357  		if p.Type != dnstype {
   358  			continue
   359  		}
   360  		if !syscall.DnsNameCompare(cname, p.Name) {
   361  			continue
   362  		}
   363  		rec = append(rec, p)
   364  	}
   365  	return rec
   366  }
   367  
   368  // returns the last CNAME in chain
   369  func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
   370  	// limit cname resolving to 10 in case of a infinite CNAME loop
   371  Cname:
   372  	for cnameloop := 0; cnameloop < 10; cnameloop++ {
   373  		for p := r; p != nil; p = p.Next {
   374  			if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
   375  				continue
   376  			}
   377  			if p.Type != syscall.DNS_TYPE_CNAME {
   378  				continue
   379  			}
   380  			if !syscall.DnsNameCompare(name, p.Name) {
   381  				continue
   382  			}
   383  			name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
   384  			continue Cname
   385  		}
   386  		break
   387  	}
   388  	return name
   389  }