rsc.io/go@v0.0.0-20150416155037-e040fd465409/src/net/lookup_windows.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package net
     6  
     7  import (
     8  	"os"
     9  	"runtime"
    10  	"syscall"
    11  	"unsafe"
    12  )
    13  
    14  var (
    15  	lookupPort = oldLookupPort
    16  	lookupIP   = oldLookupIP
    17  )
    18  
    19  func getprotobyname(name string) (proto int, err error) {
    20  	p, err := syscall.GetProtoByName(name)
    21  	if err != nil {
    22  		return 0, os.NewSyscallError("GetProtoByName", err)
    23  	}
    24  	return int(p.Proto), nil
    25  }
    26  
    27  // lookupProtocol looks up IP protocol name and returns correspondent protocol number.
    28  func lookupProtocol(name string) (proto int, err error) {
    29  	// GetProtoByName return value is stored in thread local storage.
    30  	// Start new os thread before the call to prevent races.
    31  	type result struct {
    32  		proto int
    33  		err   error
    34  	}
    35  	ch := make(chan result)
    36  	go func() {
    37  		acquireThread()
    38  		defer releaseThread()
    39  		runtime.LockOSThread()
    40  		defer runtime.UnlockOSThread()
    41  		proto, err := getprotobyname(name)
    42  		ch <- result{proto: proto, err: err}
    43  	}()
    44  	r := <-ch
    45  	if r.err != nil {
    46  		if proto, ok := protocols[name]; ok {
    47  			return proto, nil
    48  		}
    49  	}
    50  	return r.proto, r.err
    51  }
    52  
    53  func lookupHost(name string) (addrs []string, err error) {
    54  	ips, err := LookupIP(name)
    55  	if err != nil {
    56  		return
    57  	}
    58  	addrs = make([]string, 0, len(ips))
    59  	for _, ip := range ips {
    60  		addrs = append(addrs, ip.String())
    61  	}
    62  	return
    63  }
    64  
    65  func gethostbyname(name string) (addrs []IPAddr, err error) {
    66  	// caller already acquired thread
    67  	h, err := syscall.GetHostByName(name)
    68  	if err != nil {
    69  		return nil, os.NewSyscallError("GetHostByName", err)
    70  	}
    71  	switch h.AddrType {
    72  	case syscall.AF_INET:
    73  		i := 0
    74  		addrs = make([]IPAddr, 100) // plenty of room to grow
    75  		for p := (*[100](*[4]byte))(unsafe.Pointer(h.AddrList)); i < cap(addrs) && p[i] != nil; i++ {
    76  			addrs[i] = IPAddr{IP: IPv4(p[i][0], p[i][1], p[i][2], p[i][3])}
    77  		}
    78  		addrs = addrs[0:i]
    79  	default: // TODO(vcc): Implement non IPv4 address lookups.
    80  		return nil, os.NewSyscallError("LookupIP", syscall.EWINDOWS)
    81  	}
    82  	return addrs, nil
    83  }
    84  
    85  func oldLookupIP(name string) (addrs []IPAddr, err error) {
    86  	// GetHostByName return value is stored in thread local storage.
    87  	// Start new os thread before the call to prevent races.
    88  	type result struct {
    89  		addrs []IPAddr
    90  		err   error
    91  	}
    92  	ch := make(chan result)
    93  	go func() {
    94  		acquireThread()
    95  		defer releaseThread()
    96  		runtime.LockOSThread()
    97  		defer runtime.UnlockOSThread()
    98  		addrs, err := gethostbyname(name)
    99  		ch <- result{addrs: addrs, err: err}
   100  	}()
   101  	r := <-ch
   102  	return addrs, r.err
   103  }
   104  
   105  func newLookupIP(name string) (addrs []IPAddr, err error) {
   106  	acquireThread()
   107  	defer releaseThread()
   108  	hints := syscall.AddrinfoW{
   109  		Family:   syscall.AF_UNSPEC,
   110  		Socktype: syscall.SOCK_STREAM,
   111  		Protocol: syscall.IPPROTO_IP,
   112  	}
   113  	var result *syscall.AddrinfoW
   114  	e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
   115  	if e != nil {
   116  		return nil, os.NewSyscallError("GetAddrInfoW", e)
   117  	}
   118  	defer syscall.FreeAddrInfoW(result)
   119  	addrs = make([]IPAddr, 0, 5)
   120  	for ; result != nil; result = result.Next {
   121  		addr := unsafe.Pointer(result.Addr)
   122  		switch result.Family {
   123  		case syscall.AF_INET:
   124  			a := (*syscall.RawSockaddrInet4)(addr).Addr
   125  			addrs = append(addrs, IPAddr{IP: IPv4(a[0], a[1], a[2], a[3])})
   126  		case syscall.AF_INET6:
   127  			a := (*syscall.RawSockaddrInet6)(addr).Addr
   128  			zone := zoneToString(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
   129  			addrs = append(addrs, IPAddr{IP: IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]}, Zone: zone})
   130  		default:
   131  			return nil, os.NewSyscallError("LookupIP", syscall.EWINDOWS)
   132  		}
   133  	}
   134  	return addrs, nil
   135  }
   136  
   137  func getservbyname(network, service string) (port int, err error) {
   138  	acquireThread()
   139  	defer releaseThread()
   140  	switch network {
   141  	case "tcp4", "tcp6":
   142  		network = "tcp"
   143  	case "udp4", "udp6":
   144  		network = "udp"
   145  	}
   146  	s, err := syscall.GetServByName(service, network)
   147  	if err != nil {
   148  		return 0, os.NewSyscallError("GetServByName", err)
   149  	}
   150  	return int(syscall.Ntohs(s.Port)), nil
   151  }
   152  
   153  func oldLookupPort(network, service string) (port int, err error) {
   154  	// GetServByName return value is stored in thread local storage.
   155  	// Start new os thread before the call to prevent races.
   156  	type result struct {
   157  		port int
   158  		err  error
   159  	}
   160  	ch := make(chan result)
   161  	go func() {
   162  		acquireThread()
   163  		defer releaseThread()
   164  		runtime.LockOSThread()
   165  		defer runtime.UnlockOSThread()
   166  		port, err := getservbyname(network, service)
   167  		ch <- result{port: port, err: err}
   168  	}()
   169  	r := <-ch
   170  	return r.port, r.err
   171  }
   172  
   173  func newLookupPort(network, service string) (port int, err error) {
   174  	acquireThread()
   175  	defer releaseThread()
   176  	var stype int32
   177  	switch network {
   178  	case "tcp4", "tcp6":
   179  		stype = syscall.SOCK_STREAM
   180  	case "udp4", "udp6":
   181  		stype = syscall.SOCK_DGRAM
   182  	}
   183  	hints := syscall.AddrinfoW{
   184  		Family:   syscall.AF_UNSPEC,
   185  		Socktype: stype,
   186  		Protocol: syscall.IPPROTO_IP,
   187  	}
   188  	var result *syscall.AddrinfoW
   189  	e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
   190  	if e != nil {
   191  		return 0, os.NewSyscallError("GetAddrInfoW", e)
   192  	}
   193  	defer syscall.FreeAddrInfoW(result)
   194  	if result == nil {
   195  		return 0, os.NewSyscallError("LookupPort", syscall.EINVAL)
   196  	}
   197  	addr := unsafe.Pointer(result.Addr)
   198  	switch result.Family {
   199  	case syscall.AF_INET:
   200  		a := (*syscall.RawSockaddrInet4)(addr)
   201  		return int(syscall.Ntohs(a.Port)), nil
   202  	case syscall.AF_INET6:
   203  		a := (*syscall.RawSockaddrInet6)(addr)
   204  		return int(syscall.Ntohs(a.Port)), nil
   205  	}
   206  	return 0, os.NewSyscallError("LookupPort", syscall.EINVAL)
   207  }
   208  
   209  func lookupCNAME(name string) (cname string, err error) {
   210  	acquireThread()
   211  	defer releaseThread()
   212  	var r *syscall.DNSRecord
   213  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
   214  	// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
   215  	if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
   216  		// if there are no aliases, the canonical name is the input name
   217  		if name == "" || name[len(name)-1] != '.' {
   218  			return name + ".", nil
   219  		}
   220  		return name, nil
   221  	}
   222  	if e != nil {
   223  		return "", os.NewSyscallError("LookupCNAME", e)
   224  	}
   225  	defer syscall.DnsRecordListFree(r, 1)
   226  
   227  	resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
   228  	cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:]) + "."
   229  	return
   230  }
   231  
   232  func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
   233  	acquireThread()
   234  	defer releaseThread()
   235  	var target string
   236  	if service == "" && proto == "" {
   237  		target = name
   238  	} else {
   239  		target = "_" + service + "._" + proto + "." + name
   240  	}
   241  	var r *syscall.DNSRecord
   242  	e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil)
   243  	if e != nil {
   244  		return "", nil, os.NewSyscallError("LookupSRV", e)
   245  	}
   246  	defer syscall.DnsRecordListFree(r, 1)
   247  
   248  	addrs = make([]*SRV, 0, 10)
   249  	for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
   250  		v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
   251  		addrs = append(addrs, &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight})
   252  	}
   253  	byPriorityWeight(addrs).sort()
   254  	return name, addrs, nil
   255  }
   256  
   257  func lookupMX(name string) (mx []*MX, err error) {
   258  	acquireThread()
   259  	defer releaseThread()
   260  	var r *syscall.DNSRecord
   261  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil)
   262  	if e != nil {
   263  		return nil, os.NewSyscallError("LookupMX", e)
   264  	}
   265  	defer syscall.DnsRecordListFree(r, 1)
   266  
   267  	mx = make([]*MX, 0, 10)
   268  	for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
   269  		v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
   270  		mx = append(mx, &MX{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]) + ".", v.Preference})
   271  	}
   272  	byPref(mx).sort()
   273  	return mx, nil
   274  }
   275  
   276  func lookupNS(name string) (ns []*NS, err error) {
   277  	acquireThread()
   278  	defer releaseThread()
   279  	var r *syscall.DNSRecord
   280  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil)
   281  	if e != nil {
   282  		return nil, os.NewSyscallError("LookupNS", e)
   283  	}
   284  	defer syscall.DnsRecordListFree(r, 1)
   285  
   286  	ns = make([]*NS, 0, 10)
   287  	for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
   288  		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
   289  		ns = append(ns, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."})
   290  	}
   291  	return ns, nil
   292  }
   293  
   294  func lookupTXT(name string) (txt []string, err error) {
   295  	acquireThread()
   296  	defer releaseThread()
   297  	var r *syscall.DNSRecord
   298  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil)
   299  	if e != nil {
   300  		return nil, os.NewSyscallError("LookupTXT", e)
   301  	}
   302  	defer syscall.DnsRecordListFree(r, 1)
   303  
   304  	txt = make([]string, 0, 10)
   305  	for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) {
   306  		d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
   307  		for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] {
   308  			s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:])
   309  			txt = append(txt, s)
   310  		}
   311  	}
   312  	return
   313  }
   314  
   315  func lookupAddr(addr string) (name []string, err error) {
   316  	acquireThread()
   317  	defer releaseThread()
   318  	arpa, err := reverseaddr(addr)
   319  	if err != nil {
   320  		return nil, err
   321  	}
   322  	var r *syscall.DNSRecord
   323  	e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil)
   324  	if e != nil {
   325  		return nil, os.NewSyscallError("LookupAddr", e)
   326  	}
   327  	defer syscall.DnsRecordListFree(r, 1)
   328  
   329  	name = make([]string, 0, 10)
   330  	for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
   331  		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
   332  		name = append(name, syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))
   333  	}
   334  	return name, nil
   335  }
   336  
   337  const dnsSectionMask = 0x0003
   338  
   339  // returns only results applicable to name and resolves CNAME entries
   340  func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
   341  	cname := syscall.StringToUTF16Ptr(name)
   342  	if dnstype != syscall.DNS_TYPE_CNAME {
   343  		cname = resolveCNAME(cname, r)
   344  	}
   345  	rec := make([]*syscall.DNSRecord, 0, 10)
   346  	for p := r; p != nil; p = p.Next {
   347  		if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
   348  			continue
   349  		}
   350  		if p.Type != dnstype {
   351  			continue
   352  		}
   353  		if !syscall.DnsNameCompare(cname, p.Name) {
   354  			continue
   355  		}
   356  		rec = append(rec, p)
   357  	}
   358  	return rec
   359  }
   360  
   361  // returns the last CNAME in chain
   362  func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
   363  	// limit cname resolving to 10 in case of a infinite CNAME loop
   364  Cname:
   365  	for cnameloop := 0; cnameloop < 10; cnameloop++ {
   366  		for p := r; p != nil; p = p.Next {
   367  			if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
   368  				continue
   369  			}
   370  			if p.Type != syscall.DNS_TYPE_CNAME {
   371  				continue
   372  			}
   373  			if !syscall.DnsNameCompare(name, p.Name) {
   374  				continue
   375  			}
   376  			name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
   377  			continue Cname
   378  		}
   379  		break
   380  	}
   381  	return name
   382  }