github.com/gopacket/gopacket@v1.1.0/routing/routing.go (about)

     1  // Copyright 2012 Google, Inc. All rights reserved.
     2  //
     3  // Use of this source code is governed by a BSD-style license
     4  // that can be found in the LICENSE file in the root of the source
     5  // tree.
     6  
     7  //go:build linux
     8  // +build linux
     9  
    10  // Package routing provides a very basic but mostly functional implementation of
    11  // a routing table for IPv4/IPv6 addresses.  It uses a routing table pulled from
    12  // the kernel via netlink to find the correct interface, gateway, and preferred
    13  // source IP address for packets destined to a particular location.
    14  //
    15  // The routing package is meant to be used with applications that are sending
    16  // raw packet data, which don't have the benefit of having the kernel route
    17  // packets for them.
    18  package routing
    19  
    20  import (
    21  	"bytes"
    22  	"errors"
    23  	"fmt"
    24  	"net"
    25  	"sort"
    26  	"strings"
    27  	"syscall"
    28  	"unsafe"
    29  )
    30  
    31  // Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html
    32  // See the section on RTM_NEWROUTE, specifically 'struct rtmsg'.
    33  type routeInfoInMemory struct {
    34  	Family byte
    35  	DstLen byte
    36  	SrcLen byte
    37  	TOS    byte
    38  
    39  	Table    byte
    40  	Protocol byte
    41  	Scope    byte
    42  	Type     byte
    43  
    44  	Flags uint32
    45  }
    46  
    47  // rtInfo contains information on a single route.
    48  type rtInfo struct {
    49  	Src, Dst         *net.IPNet
    50  	Gateway, PrefSrc net.IP
    51  	// We currently ignore the InputIface.
    52  	InputIface, OutputIface uint32
    53  	Priority                uint32
    54  }
    55  
    56  // routeSlice implements sort.Interface to sort routes by Priority.
    57  type routeSlice []*rtInfo
    58  
    59  func (r routeSlice) Len() int {
    60  	return len(r)
    61  }
    62  func (r routeSlice) Less(i, j int) bool {
    63  	return r[i].Priority < r[j].Priority
    64  }
    65  func (r routeSlice) Swap(i, j int) {
    66  	r[i], r[j] = r[j], r[i]
    67  }
    68  
    69  type router struct {
    70  	ifaces map[int]*net.Interface
    71  	addrs  map[int]ipAddrs
    72  	v4, v6 routeSlice
    73  }
    74  
    75  func (r *router) String() string {
    76  	strs := []string{"ROUTER", "--- V4 ---"}
    77  	for _, route := range r.v4 {
    78  		strs = append(strs, fmt.Sprintf("%+v", *route))
    79  	}
    80  	strs = append(strs, "--- V6 ---")
    81  	for _, route := range r.v6 {
    82  		strs = append(strs, fmt.Sprintf("%+v", *route))
    83  	}
    84  	return strings.Join(strs, "\n")
    85  }
    86  
    87  type ipAddrs struct {
    88  	v4, v6 net.IP
    89  }
    90  
    91  func (r *router) Route(dst net.IP) (iface *net.Interface, gateway, preferredSrc net.IP, err error) {
    92  	return r.RouteWithSrc(nil, nil, dst)
    93  }
    94  
    95  func (r *router) RouteWithSrc(input net.HardwareAddr, src, dst net.IP) (iface *net.Interface, gateway, preferredSrc net.IP, err error) {
    96  	var ifaceIndex int
    97  	switch {
    98  	case dst.To4() != nil:
    99  		ifaceIndex, gateway, preferredSrc, err = r.route(r.v4, input, src, dst)
   100  	case dst.To16() != nil:
   101  		ifaceIndex, gateway, preferredSrc, err = r.route(r.v6, input, src, dst)
   102  	default:
   103  		err = errors.New("IP is not valid as IPv4 or IPv6")
   104  	}
   105  
   106  	if err != nil {
   107  		return
   108  	}
   109  
   110  	iface = r.ifaces[ifaceIndex]
   111  
   112  	if preferredSrc == nil {
   113  		switch {
   114  		case dst.To4() != nil:
   115  			preferredSrc = r.addrs[ifaceIndex].v4
   116  		case dst.To16() != nil:
   117  			preferredSrc = r.addrs[ifaceIndex].v6
   118  		}
   119  	}
   120  	return
   121  }
   122  
   123  func (r *router) route(routes routeSlice, input net.HardwareAddr, src, dst net.IP) (iface int, gateway, preferredSrc net.IP, err error) {
   124  	var inputIndex uint32
   125  	if input != nil {
   126  		for i, iface := range r.ifaces {
   127  			if bytes.Equal(input, iface.HardwareAddr) {
   128  				inputIndex = uint32(i)
   129  				break
   130  			}
   131  		}
   132  	}
   133  	var defaultGateway *rtInfo = nil
   134  	for _, rt := range routes {
   135  		if rt.InputIface != 0 && rt.InputIface != inputIndex {
   136  			continue
   137  		}
   138  		if rt.Src == nil && rt.Dst == nil {
   139  			defaultGateway = rt
   140  			continue
   141  		}
   142  		if rt.Src != nil && !rt.Src.Contains(src) {
   143  			continue
   144  		}
   145  		if rt.Dst != nil && !rt.Dst.Contains(dst) {
   146  			continue
   147  		}
   148  		return int(rt.OutputIface), rt.Gateway, rt.PrefSrc, nil
   149  	}
   150  
   151  	if defaultGateway != nil {
   152  		return int(defaultGateway.OutputIface), defaultGateway.Gateway, defaultGateway.PrefSrc, nil
   153  	}
   154  	err = fmt.Errorf("no route found for %v", dst)
   155  	return
   156  }
   157  
   158  // New creates a new router object.  The router returned by New currently does
   159  // not update its routes after construction... care should be taken for
   160  // long-running programs to call New() regularly to take into account any
   161  // changes to the routing table which have occurred since the last New() call.
   162  func New() (Router, error) {
   163  	rtr := &router{
   164  		ifaces: make(map[int]*net.Interface),
   165  		addrs:  make(map[int]ipAddrs),
   166  	}
   167  	tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  	msgs, err := syscall.ParseNetlinkMessage(tab)
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  loop:
   176  	for _, m := range msgs {
   177  		switch m.Header.Type {
   178  		case syscall.NLMSG_DONE:
   179  			break loop
   180  		case syscall.RTM_NEWROUTE:
   181  			rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0]))
   182  			routeInfo := rtInfo{}
   183  			attrs, err := syscall.ParseNetlinkRouteAttr(&m)
   184  			if err != nil {
   185  				return nil, err
   186  			}
   187  			switch rt.Family {
   188  			case syscall.AF_INET:
   189  				rtr.v4 = append(rtr.v4, &routeInfo)
   190  			case syscall.AF_INET6:
   191  				rtr.v6 = append(rtr.v6, &routeInfo)
   192  			default:
   193  				continue loop
   194  			}
   195  			for _, attr := range attrs {
   196  				switch attr.Attr.Type {
   197  				case syscall.RTA_DST:
   198  					routeInfo.Dst = &net.IPNet{
   199  						IP:   net.IP(attr.Value),
   200  						Mask: net.CIDRMask(int(rt.DstLen), len(attr.Value)*8),
   201  					}
   202  				case syscall.RTA_SRC:
   203  					routeInfo.Src = &net.IPNet{
   204  						IP:   net.IP(attr.Value),
   205  						Mask: net.CIDRMask(int(rt.SrcLen), len(attr.Value)*8),
   206  					}
   207  				case syscall.RTA_GATEWAY:
   208  					routeInfo.Gateway = net.IP(attr.Value)
   209  				case syscall.RTA_PREFSRC:
   210  					routeInfo.PrefSrc = net.IP(attr.Value)
   211  				case syscall.RTA_IIF:
   212  					routeInfo.InputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
   213  				case syscall.RTA_OIF:
   214  					routeInfo.OutputIface = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
   215  				case syscall.RTA_PRIORITY:
   216  					routeInfo.Priority = *(*uint32)(unsafe.Pointer(&attr.Value[0]))
   217  				}
   218  			}
   219  		}
   220  	}
   221  	sort.Sort(rtr.v4)
   222  	sort.Sort(rtr.v6)
   223  	ifaces, err := net.Interfaces()
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  	for _, tmp := range ifaces {
   228  		iface := tmp
   229  		rtr.ifaces[iface.Index] = &iface
   230  		var addrs ipAddrs
   231  		ifaceAddrs, err := iface.Addrs()
   232  		if err != nil {
   233  			return nil, err
   234  		}
   235  		for _, addr := range ifaceAddrs {
   236  			if inet, ok := addr.(*net.IPNet); ok {
   237  				// Go has a nasty habit of giving you IPv4s as ::ffff:1.2.3.4 instead of 1.2.3.4.
   238  				// We want to use mapped v4 addresses as v4 preferred addresses, never as v6
   239  				// preferred addresses.
   240  				if v4 := inet.IP.To4(); v4 != nil {
   241  					if addrs.v4 == nil {
   242  						addrs.v4 = v4
   243  					}
   244  				} else if addrs.v6 == nil {
   245  					addrs.v6 = inet.IP
   246  				}
   247  			}
   248  		}
   249  		rtr.addrs[iface.Index] = addrs
   250  	}
   251  	return rtr, nil
   252  }