github.com/shijuvar/go@v0.0.0-20141209052335-e8f13700b70c/src/net/lookup_windows.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package net
     6  
     7  import (
     8  	"os"
     9  	"runtime"
    10  	"syscall"
    11  	"unsafe"
    12  )
    13  
    14  var (
    15  	lookupPort = oldLookupPort
    16  	lookupIP   = oldLookupIP
    17  )
    18  
    19  func getprotobyname(name string) (proto int, err error) {
    20  	p, err := syscall.GetProtoByName(name)
    21  	if err != nil {
    22  		return 0, os.NewSyscallError("GetProtoByName", err)
    23  	}
    24  	return int(p.Proto), nil
    25  }
    26  
    27  // lookupProtocol looks up IP protocol name and returns correspondent protocol number.
    28  func lookupProtocol(name string) (proto int, err error) {
    29  	// GetProtoByName return value is stored in thread local storage.
    30  	// Start new os thread before the call to prevent races.
    31  	type result struct {
    32  		proto int
    33  		err   error
    34  	}
    35  	ch := make(chan result)
    36  	go func() {
    37  		acquireThread()
    38  		defer releaseThread()
    39  		runtime.LockOSThread()
    40  		defer runtime.UnlockOSThread()
    41  		proto, err := getprotobyname(name)
    42  		ch <- result{proto: proto, err: err}
    43  	}()
    44  	r := <-ch
    45  	if r.err != nil {
    46  		if proto, ok := protocols[name]; ok {
    47  			return proto, nil
    48  		}
    49  	}
    50  	return r.proto, r.err
    51  }
    52  
    53  func lookupHost(name string) (addrs []string, err error) {
    54  	ips, err := LookupIP(name)
    55  	if err != nil {
    56  		return
    57  	}
    58  	addrs = make([]string, 0, len(ips))
    59  	for _, ip := range ips {
    60  		addrs = append(addrs, ip.String())
    61  	}
    62  	return
    63  }
    64  
    65  func gethostbyname(name string) (addrs []IP, err error) {
    66  	// caller already acquired thread
    67  	h, err := syscall.GetHostByName(name)
    68  	if err != nil {
    69  		return nil, os.NewSyscallError("GetHostByName", err)
    70  	}
    71  	switch h.AddrType {
    72  	case syscall.AF_INET:
    73  		i := 0
    74  		addrs = make([]IP, 100) // plenty of room to grow
    75  		for p := (*[100](*[4]byte))(unsafe.Pointer(h.AddrList)); i < cap(addrs) && p[i] != nil; i++ {
    76  			addrs[i] = IPv4(p[i][0], p[i][1], p[i][2], p[i][3])
    77  		}
    78  		addrs = addrs[0:i]
    79  	default: // TODO(vcc): Implement non IPv4 address lookups.
    80  		return nil, os.NewSyscallError("LookupIP", syscall.EWINDOWS)
    81  	}
    82  	return addrs, nil
    83  }
    84  
    85  func oldLookupIP(name string) (addrs []IP, err error) {
    86  	// GetHostByName return value is stored in thread local storage.
    87  	// Start new os thread before the call to prevent races.
    88  	type result struct {
    89  		addrs []IP
    90  		err   error
    91  	}
    92  	ch := make(chan result)
    93  	go func() {
    94  		acquireThread()
    95  		defer releaseThread()
    96  		runtime.LockOSThread()
    97  		defer runtime.UnlockOSThread()
    98  		addrs, err := gethostbyname(name)
    99  		ch <- result{addrs: addrs, err: err}
   100  	}()
   101  	r := <-ch
   102  	return r.addrs, r.err
   103  }
   104  
   105  func newLookupIP(name string) (addrs []IP, err error) {
   106  	acquireThread()
   107  	defer releaseThread()
   108  	hints := syscall.AddrinfoW{
   109  		Family:   syscall.AF_UNSPEC,
   110  		Socktype: syscall.SOCK_STREAM,
   111  		Protocol: syscall.IPPROTO_IP,
   112  	}
   113  	var result *syscall.AddrinfoW
   114  	e := syscall.GetAddrInfoW(syscall.StringToUTF16Ptr(name), nil, &hints, &result)
   115  	if e != nil {
   116  		return nil, os.NewSyscallError("GetAddrInfoW", e)
   117  	}
   118  	defer syscall.FreeAddrInfoW(result)
   119  	addrs = make([]IP, 0, 5)
   120  	for ; result != nil; result = result.Next {
   121  		addr := unsafe.Pointer(result.Addr)
   122  		switch result.Family {
   123  		case syscall.AF_INET:
   124  			a := (*syscall.RawSockaddrInet4)(addr).Addr
   125  			addrs = append(addrs, IPv4(a[0], a[1], a[2], a[3]))
   126  		case syscall.AF_INET6:
   127  			a := (*syscall.RawSockaddrInet6)(addr).Addr
   128  			addrs = append(addrs, IP{a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]})
   129  		default:
   130  			return nil, os.NewSyscallError("LookupIP", syscall.EWINDOWS)
   131  		}
   132  	}
   133  	return addrs, nil
   134  }
   135  
   136  func getservbyname(network, service string) (port int, err error) {
   137  	acquireThread()
   138  	defer releaseThread()
   139  	switch network {
   140  	case "tcp4", "tcp6":
   141  		network = "tcp"
   142  	case "udp4", "udp6":
   143  		network = "udp"
   144  	}
   145  	s, err := syscall.GetServByName(service, network)
   146  	if err != nil {
   147  		return 0, os.NewSyscallError("GetServByName", err)
   148  	}
   149  	return int(syscall.Ntohs(s.Port)), nil
   150  }
   151  
   152  func oldLookupPort(network, service string) (port int, err error) {
   153  	// GetServByName return value is stored in thread local storage.
   154  	// Start new os thread before the call to prevent races.
   155  	type result struct {
   156  		port int
   157  		err  error
   158  	}
   159  	ch := make(chan result)
   160  	go func() {
   161  		acquireThread()
   162  		defer releaseThread()
   163  		runtime.LockOSThread()
   164  		defer runtime.UnlockOSThread()
   165  		port, err := getservbyname(network, service)
   166  		ch <- result{port: port, err: err}
   167  	}()
   168  	r := <-ch
   169  	return r.port, r.err
   170  }
   171  
   172  func newLookupPort(network, service string) (port int, err error) {
   173  	acquireThread()
   174  	defer releaseThread()
   175  	var stype int32
   176  	switch network {
   177  	case "tcp4", "tcp6":
   178  		stype = syscall.SOCK_STREAM
   179  	case "udp4", "udp6":
   180  		stype = syscall.SOCK_DGRAM
   181  	}
   182  	hints := syscall.AddrinfoW{
   183  		Family:   syscall.AF_UNSPEC,
   184  		Socktype: stype,
   185  		Protocol: syscall.IPPROTO_IP,
   186  	}
   187  	var result *syscall.AddrinfoW
   188  	e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
   189  	if e != nil {
   190  		return 0, os.NewSyscallError("GetAddrInfoW", e)
   191  	}
   192  	defer syscall.FreeAddrInfoW(result)
   193  	if result == nil {
   194  		return 0, os.NewSyscallError("LookupPort", syscall.EINVAL)
   195  	}
   196  	addr := unsafe.Pointer(result.Addr)
   197  	switch result.Family {
   198  	case syscall.AF_INET:
   199  		a := (*syscall.RawSockaddrInet4)(addr)
   200  		return int(syscall.Ntohs(a.Port)), nil
   201  	case syscall.AF_INET6:
   202  		a := (*syscall.RawSockaddrInet6)(addr)
   203  		return int(syscall.Ntohs(a.Port)), nil
   204  	}
   205  	return 0, os.NewSyscallError("LookupPort", syscall.EINVAL)
   206  }
   207  
   208  func lookupCNAME(name string) (cname string, err error) {
   209  	acquireThread()
   210  	defer releaseThread()
   211  	var r *syscall.DNSRecord
   212  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
   213  	// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
   214  	if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
   215  		// if there are no aliases, the canonical name is the input name
   216  		if name == "" || name[len(name)-1] != '.' {
   217  			return name + ".", nil
   218  		}
   219  		return name, nil
   220  	}
   221  	if e != nil {
   222  		return "", os.NewSyscallError("LookupCNAME", e)
   223  	}
   224  	defer syscall.DnsRecordListFree(r, 1)
   225  
   226  	resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
   227  	cname = syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:]) + "."
   228  	return
   229  }
   230  
   231  func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err error) {
   232  	acquireThread()
   233  	defer releaseThread()
   234  	var target string
   235  	if service == "" && proto == "" {
   236  		target = name
   237  	} else {
   238  		target = "_" + service + "._" + proto + "." + name
   239  	}
   240  	var r *syscall.DNSRecord
   241  	e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &r, nil)
   242  	if e != nil {
   243  		return "", nil, os.NewSyscallError("LookupSRV", e)
   244  	}
   245  	defer syscall.DnsRecordListFree(r, 1)
   246  
   247  	addrs = make([]*SRV, 0, 10)
   248  	for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
   249  		v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
   250  		addrs = append(addrs, &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight})
   251  	}
   252  	byPriorityWeight(addrs).sort()
   253  	return name, addrs, nil
   254  }
   255  
   256  func lookupMX(name string) (mx []*MX, err error) {
   257  	acquireThread()
   258  	defer releaseThread()
   259  	var r *syscall.DNSRecord
   260  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &r, nil)
   261  	if e != nil {
   262  		return nil, os.NewSyscallError("LookupMX", e)
   263  	}
   264  	defer syscall.DnsRecordListFree(r, 1)
   265  
   266  	mx = make([]*MX, 0, 10)
   267  	for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
   268  		v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
   269  		mx = append(mx, &MX{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]) + ".", v.Preference})
   270  	}
   271  	byPref(mx).sort()
   272  	return mx, nil
   273  }
   274  
   275  func lookupNS(name string) (ns []*NS, err error) {
   276  	acquireThread()
   277  	defer releaseThread()
   278  	var r *syscall.DNSRecord
   279  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &r, nil)
   280  	if e != nil {
   281  		return nil, os.NewSyscallError("LookupNS", e)
   282  	}
   283  	defer syscall.DnsRecordListFree(r, 1)
   284  
   285  	ns = make([]*NS, 0, 10)
   286  	for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
   287  		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
   288  		ns = append(ns, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."})
   289  	}
   290  	return ns, nil
   291  }
   292  
   293  func lookupTXT(name string) (txt []string, err error) {
   294  	acquireThread()
   295  	defer releaseThread()
   296  	var r *syscall.DNSRecord
   297  	e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &r, nil)
   298  	if e != nil {
   299  		return nil, os.NewSyscallError("LookupTXT", e)
   300  	}
   301  	defer syscall.DnsRecordListFree(r, 1)
   302  
   303  	txt = make([]string, 0, 10)
   304  	for _, p := range validRecs(r, syscall.DNS_TYPE_TEXT, name) {
   305  		d := (*syscall.DNSTXTData)(unsafe.Pointer(&p.Data[0]))
   306  		for _, v := range (*[1 << 10]*uint16)(unsafe.Pointer(&(d.StringArray[0])))[:d.StringCount] {
   307  			s := syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(v))[:])
   308  			txt = append(txt, s)
   309  		}
   310  	}
   311  	return
   312  }
   313  
   314  func lookupAddr(addr string) (name []string, err error) {
   315  	acquireThread()
   316  	defer releaseThread()
   317  	arpa, err := reverseaddr(addr)
   318  	if err != nil {
   319  		return nil, err
   320  	}
   321  	var r *syscall.DNSRecord
   322  	e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &r, nil)
   323  	if e != nil {
   324  		return nil, os.NewSyscallError("LookupAddr", e)
   325  	}
   326  	defer syscall.DnsRecordListFree(r, 1)
   327  
   328  	name = make([]string, 0, 10)
   329  	for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
   330  		v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
   331  		name = append(name, syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))
   332  	}
   333  	return name, nil
   334  }
   335  
   336  const dnsSectionMask = 0x0003
   337  
   338  // returns only results applicable to name and resolves CNAME entries
   339  func validRecs(r *syscall.DNSRecord, dnstype uint16, name string) []*syscall.DNSRecord {
   340  	cname := syscall.StringToUTF16Ptr(name)
   341  	if dnstype != syscall.DNS_TYPE_CNAME {
   342  		cname = resolveCNAME(cname, r)
   343  	}
   344  	rec := make([]*syscall.DNSRecord, 0, 10)
   345  	for p := r; p != nil; p = p.Next {
   346  		if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
   347  			continue
   348  		}
   349  		if p.Type != dnstype {
   350  			continue
   351  		}
   352  		if !syscall.DnsNameCompare(cname, p.Name) {
   353  			continue
   354  		}
   355  		rec = append(rec, p)
   356  	}
   357  	return rec
   358  }
   359  
   360  // returns the last CNAME in chain
   361  func resolveCNAME(name *uint16, r *syscall.DNSRecord) *uint16 {
   362  	// limit cname resolving to 10 in case of a infinite CNAME loop
   363  Cname:
   364  	for cnameloop := 0; cnameloop < 10; cnameloop++ {
   365  		for p := r; p != nil; p = p.Next {
   366  			if p.Dw&dnsSectionMask != syscall.DnsSectionAnswer {
   367  				continue
   368  			}
   369  			if p.Type != syscall.DNS_TYPE_CNAME {
   370  				continue
   371  			}
   372  			if !syscall.DnsNameCompare(name, p.Name) {
   373  				continue
   374  			}
   375  			name = (*syscall.DNSPTRData)(unsafe.Pointer(&r.Data[0])).Host
   376  			continue Cname
   377  		}
   378  		break
   379  	}
   380  	return name
   381  }