github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/route_linux.go (about)

     1  package netlink
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"net"
     8  	"strconv"
     9  	"strings"
    10  	"syscall"
    11  
    12  	"github.com/sagernet/netlink/nl"
    13  	"github.com/vishvananda/netns"
    14  	"golang.org/x/sys/unix"
    15  )
    16  
    17  // RtAttr is shared so it is in netlink_linux.go
    18  
    19  const (
    20  	SCOPE_UNIVERSE Scope = unix.RT_SCOPE_UNIVERSE
    21  	SCOPE_SITE     Scope = unix.RT_SCOPE_SITE
    22  	SCOPE_LINK     Scope = unix.RT_SCOPE_LINK
    23  	SCOPE_HOST     Scope = unix.RT_SCOPE_HOST
    24  	SCOPE_NOWHERE  Scope = unix.RT_SCOPE_NOWHERE
    25  )
    26  
    27  func (s Scope) String() string {
    28  	switch s {
    29  	case SCOPE_UNIVERSE:
    30  		return "universe"
    31  	case SCOPE_SITE:
    32  		return "site"
    33  	case SCOPE_LINK:
    34  		return "link"
    35  	case SCOPE_HOST:
    36  		return "host"
    37  	case SCOPE_NOWHERE:
    38  		return "nowhere"
    39  	default:
    40  		return "unknown"
    41  	}
    42  }
    43  
    44  const (
    45  	FLAG_ONLINK    NextHopFlag = unix.RTNH_F_ONLINK
    46  	FLAG_PERVASIVE NextHopFlag = unix.RTNH_F_PERVASIVE
    47  )
    48  
    49  var testFlags = []flagString{
    50  	{f: FLAG_ONLINK, s: "onlink"},
    51  	{f: FLAG_PERVASIVE, s: "pervasive"},
    52  }
    53  
    54  func listFlags(flag int) []string {
    55  	var flags []string
    56  	for _, tf := range testFlags {
    57  		if flag&int(tf.f) != 0 {
    58  			flags = append(flags, tf.s)
    59  		}
    60  	}
    61  	return flags
    62  }
    63  
    64  func (r *Route) ListFlags() []string {
    65  	return listFlags(r.Flags)
    66  }
    67  
    68  func (n *NexthopInfo) ListFlags() []string {
    69  	return listFlags(n.Flags)
    70  }
    71  
    72  type MPLSDestination struct {
    73  	Labels []int
    74  }
    75  
    76  func (d *MPLSDestination) Family() int {
    77  	return nl.FAMILY_MPLS
    78  }
    79  
    80  func (d *MPLSDestination) Decode(buf []byte) error {
    81  	d.Labels = nl.DecodeMPLSStack(buf)
    82  	return nil
    83  }
    84  
    85  func (d *MPLSDestination) Encode() ([]byte, error) {
    86  	return nl.EncodeMPLSStack(d.Labels...), nil
    87  }
    88  
    89  func (d *MPLSDestination) String() string {
    90  	s := make([]string, 0, len(d.Labels))
    91  	for _, l := range d.Labels {
    92  		s = append(s, fmt.Sprintf("%d", l))
    93  	}
    94  	return strings.Join(s, "/")
    95  }
    96  
    97  func (d *MPLSDestination) Equal(x Destination) bool {
    98  	o, ok := x.(*MPLSDestination)
    99  	if !ok {
   100  		return false
   101  	}
   102  	if d == nil && o == nil {
   103  		return true
   104  	}
   105  	if d == nil || o == nil {
   106  		return false
   107  	}
   108  	if d.Labels == nil && o.Labels == nil {
   109  		return true
   110  	}
   111  	if d.Labels == nil || o.Labels == nil {
   112  		return false
   113  	}
   114  	if len(d.Labels) != len(o.Labels) {
   115  		return false
   116  	}
   117  	for i := range d.Labels {
   118  		if d.Labels[i] != o.Labels[i] {
   119  			return false
   120  		}
   121  	}
   122  	return true
   123  }
   124  
   125  type MPLSEncap struct {
   126  	Labels []int
   127  }
   128  
   129  func (e *MPLSEncap) Type() int {
   130  	return nl.LWTUNNEL_ENCAP_MPLS
   131  }
   132  
   133  func (e *MPLSEncap) Decode(buf []byte) error {
   134  	if len(buf) < 4 {
   135  		return fmt.Errorf("lack of bytes")
   136  	}
   137  	l := native.Uint16(buf)
   138  	if len(buf) < int(l) {
   139  		return fmt.Errorf("lack of bytes")
   140  	}
   141  	buf = buf[:l]
   142  	typ := native.Uint16(buf[2:])
   143  	if typ != nl.MPLS_IPTUNNEL_DST {
   144  		return fmt.Errorf("unknown MPLS Encap Type: %d", typ)
   145  	}
   146  	e.Labels = nl.DecodeMPLSStack(buf[4:])
   147  	return nil
   148  }
   149  
   150  func (e *MPLSEncap) Encode() ([]byte, error) {
   151  	s := nl.EncodeMPLSStack(e.Labels...)
   152  	hdr := make([]byte, 4)
   153  	native.PutUint16(hdr, uint16(len(s)+4))
   154  	native.PutUint16(hdr[2:], nl.MPLS_IPTUNNEL_DST)
   155  	return append(hdr, s...), nil
   156  }
   157  
   158  func (e *MPLSEncap) String() string {
   159  	s := make([]string, 0, len(e.Labels))
   160  	for _, l := range e.Labels {
   161  		s = append(s, fmt.Sprintf("%d", l))
   162  	}
   163  	return strings.Join(s, "/")
   164  }
   165  
   166  func (e *MPLSEncap) Equal(x Encap) bool {
   167  	o, ok := x.(*MPLSEncap)
   168  	if !ok {
   169  		return false
   170  	}
   171  	if e == nil && o == nil {
   172  		return true
   173  	}
   174  	if e == nil || o == nil {
   175  		return false
   176  	}
   177  	if e.Labels == nil && o.Labels == nil {
   178  		return true
   179  	}
   180  	if e.Labels == nil || o.Labels == nil {
   181  		return false
   182  	}
   183  	if len(e.Labels) != len(o.Labels) {
   184  		return false
   185  	}
   186  	for i := range e.Labels {
   187  		if e.Labels[i] != o.Labels[i] {
   188  			return false
   189  		}
   190  	}
   191  	return true
   192  }
   193  
   194  // SEG6 definitions
   195  type SEG6Encap struct {
   196  	Mode     int
   197  	Segments []net.IP
   198  }
   199  
   200  func (e *SEG6Encap) Type() int {
   201  	return nl.LWTUNNEL_ENCAP_SEG6
   202  }
   203  func (e *SEG6Encap) Decode(buf []byte) error {
   204  	if len(buf) < 4 {
   205  		return fmt.Errorf("lack of bytes")
   206  	}
   207  	// Get Length(l) & Type(typ) : 2 + 2 bytes
   208  	l := native.Uint16(buf)
   209  	if len(buf) < int(l) {
   210  		return fmt.Errorf("lack of bytes")
   211  	}
   212  	buf = buf[:l] // make sure buf size upper limit is Length
   213  	typ := native.Uint16(buf[2:])
   214  	// LWTUNNEL_ENCAP_SEG6 has only one attr type SEG6_IPTUNNEL_SRH
   215  	if typ != nl.SEG6_IPTUNNEL_SRH {
   216  		return fmt.Errorf("unknown SEG6 Type: %d", typ)
   217  	}
   218  
   219  	var err error
   220  	e.Mode, e.Segments, err = nl.DecodeSEG6Encap(buf[4:])
   221  
   222  	return err
   223  }
   224  func (e *SEG6Encap) Encode() ([]byte, error) {
   225  	s, err := nl.EncodeSEG6Encap(e.Mode, e.Segments)
   226  	hdr := make([]byte, 4)
   227  	native.PutUint16(hdr, uint16(len(s)+4))
   228  	native.PutUint16(hdr[2:], nl.SEG6_IPTUNNEL_SRH)
   229  	return append(hdr, s...), err
   230  }
   231  func (e *SEG6Encap) String() string {
   232  	segs := make([]string, 0, len(e.Segments))
   233  	// append segment backwards (from n to 0) since seg#0 is the last segment.
   234  	for i := len(e.Segments); i > 0; i-- {
   235  		segs = append(segs, e.Segments[i-1].String())
   236  	}
   237  	str := fmt.Sprintf("mode %s segs %d [ %s ]", nl.SEG6EncapModeString(e.Mode),
   238  		len(e.Segments), strings.Join(segs, " "))
   239  	return str
   240  }
   241  func (e *SEG6Encap) Equal(x Encap) bool {
   242  	o, ok := x.(*SEG6Encap)
   243  	if !ok {
   244  		return false
   245  	}
   246  	if e == o {
   247  		return true
   248  	}
   249  	if e == nil || o == nil {
   250  		return false
   251  	}
   252  	if e.Mode != o.Mode {
   253  		return false
   254  	}
   255  	if len(e.Segments) != len(o.Segments) {
   256  		return false
   257  	}
   258  	for i := range e.Segments {
   259  		if !e.Segments[i].Equal(o.Segments[i]) {
   260  			return false
   261  		}
   262  	}
   263  	return true
   264  }
   265  
   266  // SEG6LocalEncap definitions
   267  type SEG6LocalEncap struct {
   268  	Flags    [nl.SEG6_LOCAL_MAX]bool
   269  	Action   int
   270  	Segments []net.IP // from SRH in seg6_local_lwt
   271  	Table    int      // table id for End.T and End.DT6
   272  	InAddr   net.IP
   273  	In6Addr  net.IP
   274  	Iif      int
   275  	Oif      int
   276  }
   277  
   278  func (e *SEG6LocalEncap) Type() int {
   279  	return nl.LWTUNNEL_ENCAP_SEG6_LOCAL
   280  }
   281  func (e *SEG6LocalEncap) Decode(buf []byte) error {
   282  	attrs, err := nl.ParseRouteAttr(buf)
   283  	if err != nil {
   284  		return err
   285  	}
   286  	for _, attr := range attrs {
   287  		switch attr.Attr.Type {
   288  		case nl.SEG6_LOCAL_ACTION:
   289  			e.Action = int(native.Uint32(attr.Value[0:4]))
   290  			e.Flags[nl.SEG6_LOCAL_ACTION] = true
   291  		case nl.SEG6_LOCAL_SRH:
   292  			e.Segments, err = nl.DecodeSEG6Srh(attr.Value[:])
   293  			e.Flags[nl.SEG6_LOCAL_SRH] = true
   294  		case nl.SEG6_LOCAL_TABLE:
   295  			e.Table = int(native.Uint32(attr.Value[0:4]))
   296  			e.Flags[nl.SEG6_LOCAL_TABLE] = true
   297  		case nl.SEG6_LOCAL_NH4:
   298  			e.InAddr = net.IP(attr.Value[0:4])
   299  			e.Flags[nl.SEG6_LOCAL_NH4] = true
   300  		case nl.SEG6_LOCAL_NH6:
   301  			e.In6Addr = net.IP(attr.Value[0:16])
   302  			e.Flags[nl.SEG6_LOCAL_NH6] = true
   303  		case nl.SEG6_LOCAL_IIF:
   304  			e.Iif = int(native.Uint32(attr.Value[0:4]))
   305  			e.Flags[nl.SEG6_LOCAL_IIF] = true
   306  		case nl.SEG6_LOCAL_OIF:
   307  			e.Oif = int(native.Uint32(attr.Value[0:4]))
   308  			e.Flags[nl.SEG6_LOCAL_OIF] = true
   309  		}
   310  	}
   311  	return err
   312  }
   313  func (e *SEG6LocalEncap) Encode() ([]byte, error) {
   314  	var err error
   315  	res := make([]byte, 8)
   316  	native.PutUint16(res, 8) // length
   317  	native.PutUint16(res[2:], nl.SEG6_LOCAL_ACTION)
   318  	native.PutUint32(res[4:], uint32(e.Action))
   319  	if e.Flags[nl.SEG6_LOCAL_SRH] {
   320  		srh, err := nl.EncodeSEG6Srh(e.Segments)
   321  		if err != nil {
   322  			return nil, err
   323  		}
   324  		attr := make([]byte, 4)
   325  		native.PutUint16(attr, uint16(len(srh)+4))
   326  		native.PutUint16(attr[2:], nl.SEG6_LOCAL_SRH)
   327  		attr = append(attr, srh...)
   328  		res = append(res, attr...)
   329  	}
   330  	if e.Flags[nl.SEG6_LOCAL_TABLE] {
   331  		attr := make([]byte, 8)
   332  		native.PutUint16(attr, 8)
   333  		native.PutUint16(attr[2:], nl.SEG6_LOCAL_TABLE)
   334  		native.PutUint32(attr[4:], uint32(e.Table))
   335  		res = append(res, attr...)
   336  	}
   337  	if e.Flags[nl.SEG6_LOCAL_NH4] {
   338  		attr := make([]byte, 4)
   339  		native.PutUint16(attr, 8)
   340  		native.PutUint16(attr[2:], nl.SEG6_LOCAL_NH4)
   341  		ipv4 := e.InAddr.To4()
   342  		if ipv4 == nil {
   343  			err = fmt.Errorf("SEG6_LOCAL_NH4 has invalid IPv4 address")
   344  			return nil, err
   345  		}
   346  		attr = append(attr, ipv4...)
   347  		res = append(res, attr...)
   348  	}
   349  	if e.Flags[nl.SEG6_LOCAL_NH6] {
   350  		attr := make([]byte, 4)
   351  		native.PutUint16(attr, 20)
   352  		native.PutUint16(attr[2:], nl.SEG6_LOCAL_NH6)
   353  		attr = append(attr, e.In6Addr...)
   354  		res = append(res, attr...)
   355  	}
   356  	if e.Flags[nl.SEG6_LOCAL_IIF] {
   357  		attr := make([]byte, 8)
   358  		native.PutUint16(attr, 8)
   359  		native.PutUint16(attr[2:], nl.SEG6_LOCAL_IIF)
   360  		native.PutUint32(attr[4:], uint32(e.Iif))
   361  		res = append(res, attr...)
   362  	}
   363  	if e.Flags[nl.SEG6_LOCAL_OIF] {
   364  		attr := make([]byte, 8)
   365  		native.PutUint16(attr, 8)
   366  		native.PutUint16(attr[2:], nl.SEG6_LOCAL_OIF)
   367  		native.PutUint32(attr[4:], uint32(e.Oif))
   368  		res = append(res, attr...)
   369  	}
   370  	return res, err
   371  }
   372  func (e *SEG6LocalEncap) String() string {
   373  	strs := make([]string, 0, nl.SEG6_LOCAL_MAX)
   374  	strs = append(strs, fmt.Sprintf("action %s", nl.SEG6LocalActionString(e.Action)))
   375  
   376  	if e.Flags[nl.SEG6_LOCAL_TABLE] {
   377  		strs = append(strs, fmt.Sprintf("table %d", e.Table))
   378  	}
   379  	if e.Flags[nl.SEG6_LOCAL_NH4] {
   380  		strs = append(strs, fmt.Sprintf("nh4 %s", e.InAddr))
   381  	}
   382  	if e.Flags[nl.SEG6_LOCAL_NH6] {
   383  		strs = append(strs, fmt.Sprintf("nh6 %s", e.In6Addr))
   384  	}
   385  	if e.Flags[nl.SEG6_LOCAL_IIF] {
   386  		link, err := LinkByIndex(e.Iif)
   387  		if err != nil {
   388  			strs = append(strs, fmt.Sprintf("iif %d", e.Iif))
   389  		} else {
   390  			strs = append(strs, fmt.Sprintf("iif %s", link.Attrs().Name))
   391  		}
   392  	}
   393  	if e.Flags[nl.SEG6_LOCAL_OIF] {
   394  		link, err := LinkByIndex(e.Oif)
   395  		if err != nil {
   396  			strs = append(strs, fmt.Sprintf("oif %d", e.Oif))
   397  		} else {
   398  			strs = append(strs, fmt.Sprintf("oif %s", link.Attrs().Name))
   399  		}
   400  	}
   401  	if e.Flags[nl.SEG6_LOCAL_SRH] {
   402  		segs := make([]string, 0, len(e.Segments))
   403  		//append segment backwards (from n to 0) since seg#0 is the last segment.
   404  		for i := len(e.Segments); i > 0; i-- {
   405  			segs = append(segs, e.Segments[i-1].String())
   406  		}
   407  		strs = append(strs, fmt.Sprintf("segs %d [ %s ]", len(e.Segments), strings.Join(segs, " ")))
   408  	}
   409  	return strings.Join(strs, " ")
   410  }
   411  func (e *SEG6LocalEncap) Equal(x Encap) bool {
   412  	o, ok := x.(*SEG6LocalEncap)
   413  	if !ok {
   414  		return false
   415  	}
   416  	if e == o {
   417  		return true
   418  	}
   419  	if e == nil || o == nil {
   420  		return false
   421  	}
   422  	// compare all arrays first
   423  	for i := range e.Flags {
   424  		if e.Flags[i] != o.Flags[i] {
   425  			return false
   426  		}
   427  	}
   428  	if len(e.Segments) != len(o.Segments) {
   429  		return false
   430  	}
   431  	for i := range e.Segments {
   432  		if !e.Segments[i].Equal(o.Segments[i]) {
   433  			return false
   434  		}
   435  	}
   436  	// compare values
   437  	if !e.InAddr.Equal(o.InAddr) || !e.In6Addr.Equal(o.In6Addr) {
   438  		return false
   439  	}
   440  	if e.Action != o.Action || e.Table != o.Table || e.Iif != o.Iif || e.Oif != o.Oif {
   441  		return false
   442  	}
   443  	return true
   444  }
   445  
   446  // Encap BPF definitions
   447  type bpfObj struct {
   448  	progFd   int
   449  	progName string
   450  }
   451  type BpfEncap struct {
   452  	progs    [nl.LWT_BPF_MAX]bpfObj
   453  	headroom int
   454  }
   455  
   456  // SetProg adds a bpf function to the route via netlink RTA_ENCAP. The fd must be a bpf
   457  // program loaded with bpf(type=BPF_PROG_TYPE_LWT_*) matching the direction the program should
   458  // be applied to (LWT_BPF_IN, LWT_BPF_OUT, LWT_BPF_XMIT).
   459  func (e *BpfEncap) SetProg(mode, progFd int, progName string) error {
   460  	if progFd <= 0 {
   461  		return fmt.Errorf("lwt bpf SetProg: invalid fd")
   462  	}
   463  	if mode <= nl.LWT_BPF_UNSPEC || mode >= nl.LWT_BPF_XMIT_HEADROOM {
   464  		return fmt.Errorf("lwt bpf SetProg:invalid mode")
   465  	}
   466  	e.progs[mode].progFd = progFd
   467  	e.progs[mode].progName = fmt.Sprintf("%s[fd:%d]", progName, progFd)
   468  	return nil
   469  }
   470  
   471  // SetXmitHeadroom sets the xmit headroom (LWT_BPF_MAX_HEADROOM) via netlink RTA_ENCAP.
   472  // maximum headroom is LWT_BPF_MAX_HEADROOM
   473  func (e *BpfEncap) SetXmitHeadroom(headroom int) error {
   474  	if headroom > nl.LWT_BPF_MAX_HEADROOM || headroom < 0 {
   475  		return fmt.Errorf("invalid headroom size. range is 0 - %d", nl.LWT_BPF_MAX_HEADROOM)
   476  	}
   477  	e.headroom = headroom
   478  	return nil
   479  }
   480  
   481  func (e *BpfEncap) Type() int {
   482  	return nl.LWTUNNEL_ENCAP_BPF
   483  }
   484  func (e *BpfEncap) Decode(buf []byte) error {
   485  	if len(buf) < 4 {
   486  		return fmt.Errorf("lwt bpf decode: lack of bytes")
   487  	}
   488  	native := nl.NativeEndian()
   489  	attrs, err := nl.ParseRouteAttr(buf)
   490  	if err != nil {
   491  		return fmt.Errorf("lwt bpf decode: failed parsing attribute. err: %v", err)
   492  	}
   493  	for _, attr := range attrs {
   494  		if int(attr.Attr.Type) < 1 {
   495  			// nl.LWT_BPF_UNSPEC
   496  			continue
   497  		}
   498  		if int(attr.Attr.Type) > nl.LWT_BPF_MAX {
   499  			return fmt.Errorf("lwt bpf decode: received unknown attribute type: %d", attr.Attr.Type)
   500  		}
   501  		switch int(attr.Attr.Type) {
   502  		case nl.LWT_BPF_MAX_HEADROOM:
   503  			e.headroom = int(native.Uint32(attr.Value))
   504  		default:
   505  			bpfO := bpfObj{}
   506  			parsedAttrs, err := nl.ParseRouteAttr(attr.Value)
   507  			if err != nil {
   508  				return fmt.Errorf("lwt bpf decode: failed parsing route attribute")
   509  			}
   510  			for _, parsedAttr := range parsedAttrs {
   511  				switch int(parsedAttr.Attr.Type) {
   512  				case nl.LWT_BPF_PROG_FD:
   513  					bpfO.progFd = int(native.Uint32(parsedAttr.Value))
   514  				case nl.LWT_BPF_PROG_NAME:
   515  					bpfO.progName = string(parsedAttr.Value)
   516  				default:
   517  					return fmt.Errorf("lwt bpf decode: received unknown attribute: type: %d, len: %d", parsedAttr.Attr.Type, parsedAttr.Attr.Len)
   518  				}
   519  			}
   520  			e.progs[attr.Attr.Type] = bpfO
   521  		}
   522  	}
   523  	return nil
   524  }
   525  
   526  func (e *BpfEncap) Encode() ([]byte, error) {
   527  	buf := make([]byte, 0)
   528  	native = nl.NativeEndian()
   529  	for index, attr := range e.progs {
   530  		nlMsg := nl.NewRtAttr(index, []byte{})
   531  		if attr.progFd != 0 {
   532  			nlMsg.AddRtAttr(nl.LWT_BPF_PROG_FD, nl.Uint32Attr(uint32(attr.progFd)))
   533  		}
   534  		if attr.progName != "" {
   535  			nlMsg.AddRtAttr(nl.LWT_BPF_PROG_NAME, nl.ZeroTerminated(attr.progName))
   536  		}
   537  		if nlMsg.Len() > 4 {
   538  			buf = append(buf, nlMsg.Serialize()...)
   539  		}
   540  	}
   541  	if len(buf) <= 4 {
   542  		return nil, fmt.Errorf("lwt bpf encode: bpf obj definitions returned empty buffer")
   543  	}
   544  	if e.headroom > 0 {
   545  		hRoom := nl.NewRtAttr(nl.LWT_BPF_XMIT_HEADROOM, nl.Uint32Attr(uint32(e.headroom)))
   546  		buf = append(buf, hRoom.Serialize()...)
   547  	}
   548  	return buf, nil
   549  }
   550  
   551  func (e *BpfEncap) String() string {
   552  	progs := make([]string, 0)
   553  	for index, obj := range e.progs {
   554  		empty := bpfObj{}
   555  		switch index {
   556  		case nl.LWT_BPF_IN:
   557  			if obj != empty {
   558  				progs = append(progs, fmt.Sprintf("in: %s", obj.progName))
   559  			}
   560  		case nl.LWT_BPF_OUT:
   561  			if obj != empty {
   562  				progs = append(progs, fmt.Sprintf("out: %s", obj.progName))
   563  			}
   564  		case nl.LWT_BPF_XMIT:
   565  			if obj != empty {
   566  				progs = append(progs, fmt.Sprintf("xmit: %s", obj.progName))
   567  			}
   568  		}
   569  	}
   570  	if e.headroom > 0 {
   571  		progs = append(progs, fmt.Sprintf("xmit headroom: %d", e.headroom))
   572  	}
   573  	return strings.Join(progs, " ")
   574  }
   575  
   576  func (e *BpfEncap) Equal(x Encap) bool {
   577  	o, ok := x.(*BpfEncap)
   578  	if !ok {
   579  		return false
   580  	}
   581  	if e.headroom != o.headroom {
   582  		return false
   583  	}
   584  	for i := range o.progs {
   585  		if o.progs[i] != e.progs[i] {
   586  			return false
   587  		}
   588  	}
   589  	return true
   590  }
   591  
   592  type Via struct {
   593  	AddrFamily int
   594  	Addr       net.IP
   595  }
   596  
   597  func (v *Via) Equal(x Destination) bool {
   598  	o, ok := x.(*Via)
   599  	if !ok {
   600  		return false
   601  	}
   602  	if v.AddrFamily == x.Family() && v.Addr.Equal(o.Addr) {
   603  		return true
   604  	}
   605  	return false
   606  }
   607  
   608  func (v *Via) String() string {
   609  	return fmt.Sprintf("Family: %d, Address: %s", v.AddrFamily, v.Addr.String())
   610  }
   611  
   612  func (v *Via) Family() int {
   613  	return v.AddrFamily
   614  }
   615  
   616  func (v *Via) Encode() ([]byte, error) {
   617  	buf := &bytes.Buffer{}
   618  	err := binary.Write(buf, native, uint16(v.AddrFamily))
   619  	if err != nil {
   620  		return nil, err
   621  	}
   622  	err = binary.Write(buf, native, v.Addr)
   623  	if err != nil {
   624  		return nil, err
   625  	}
   626  	return buf.Bytes(), nil
   627  }
   628  
   629  func (v *Via) Decode(b []byte) error {
   630  	if len(b) < 6 {
   631  		return fmt.Errorf("decoding failed: buffer too small (%d bytes)", len(b))
   632  	}
   633  	v.AddrFamily = int(native.Uint16(b[0:2]))
   634  	if v.AddrFamily == nl.FAMILY_V4 {
   635  		v.Addr = net.IP(b[2:6])
   636  		return nil
   637  	} else if v.AddrFamily == nl.FAMILY_V6 {
   638  		if len(b) < 18 {
   639  			return fmt.Errorf("decoding failed: buffer too small (%d bytes)", len(b))
   640  		}
   641  		v.Addr = net.IP(b[2:])
   642  		return nil
   643  	}
   644  	return fmt.Errorf("decoding failed: address family %d unknown", v.AddrFamily)
   645  }
   646  
   647  // RouteAdd will add a route to the system.
   648  // Equivalent to: `ip route add $route`
   649  func RouteAdd(route *Route) error {
   650  	return pkgHandle.RouteAdd(route)
   651  }
   652  
   653  // RouteAdd will add a route to the system.
   654  // Equivalent to: `ip route add $route`
   655  func (h *Handle) RouteAdd(route *Route) error {
   656  	flags := unix.NLM_F_CREATE | unix.NLM_F_EXCL | unix.NLM_F_ACK
   657  	req := h.newNetlinkRequest(unix.RTM_NEWROUTE, flags)
   658  	_, err := h.routeHandle(route, req, nl.NewRtMsg())
   659  	return err
   660  }
   661  
   662  // RouteAppend will append a route to the system.
   663  // Equivalent to: `ip route append $route`
   664  func RouteAppend(route *Route) error {
   665  	return pkgHandle.RouteAppend(route)
   666  }
   667  
   668  // RouteAppend will append a route to the system.
   669  // Equivalent to: `ip route append $route`
   670  func (h *Handle) RouteAppend(route *Route) error {
   671  	flags := unix.NLM_F_CREATE | unix.NLM_F_APPEND | unix.NLM_F_ACK
   672  	req := h.newNetlinkRequest(unix.RTM_NEWROUTE, flags)
   673  	_, err := h.routeHandle(route, req, nl.NewRtMsg())
   674  	return err
   675  }
   676  
   677  // RouteAddEcmp will add a route to the system.
   678  func RouteAddEcmp(route *Route) error {
   679  	return pkgHandle.RouteAddEcmp(route)
   680  }
   681  
   682  // RouteAddEcmp will add a route to the system.
   683  func (h *Handle) RouteAddEcmp(route *Route) error {
   684  	flags := unix.NLM_F_CREATE | unix.NLM_F_ACK
   685  	req := h.newNetlinkRequest(unix.RTM_NEWROUTE, flags)
   686  	_, err := h.routeHandle(route, req, nl.NewRtMsg())
   687  	return err
   688  }
   689  
   690  // RouteReplace will add a route to the system.
   691  // Equivalent to: `ip route replace $route`
   692  func RouteReplace(route *Route) error {
   693  	return pkgHandle.RouteReplace(route)
   694  }
   695  
   696  // RouteReplace will add a route to the system.
   697  // Equivalent to: `ip route replace $route`
   698  func (h *Handle) RouteReplace(route *Route) error {
   699  	flags := unix.NLM_F_CREATE | unix.NLM_F_REPLACE | unix.NLM_F_ACK
   700  	req := h.newNetlinkRequest(unix.RTM_NEWROUTE, flags)
   701  	_, err := h.routeHandle(route, req, nl.NewRtMsg())
   702  	return err
   703  }
   704  
   705  // RouteDel will delete a route from the system.
   706  // Equivalent to: `ip route del $route`
   707  func RouteDel(route *Route) error {
   708  	return pkgHandle.RouteDel(route)
   709  }
   710  
   711  // RouteDel will delete a route from the system.
   712  // Equivalent to: `ip route del $route`
   713  func (h *Handle) RouteDel(route *Route) error {
   714  	req := h.newNetlinkRequest(unix.RTM_DELROUTE, unix.NLM_F_ACK)
   715  	_, err := h.routeHandle(route, req, nl.NewRtDelMsg())
   716  	return err
   717  }
   718  
   719  func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) ([][]byte, error) {
   720  	if req.NlMsghdr.Type != unix.RTM_GETROUTE && (route.Dst == nil || route.Dst.IP == nil) && route.Src == nil && route.Gw == nil && route.MPLSDst == nil {
   721  		return nil, fmt.Errorf("Either Dst.IP, Src.IP or Gw must be set")
   722  	}
   723  
   724  	family := -1
   725  	var rtAttrs []*nl.RtAttr
   726  
   727  	if route.Dst != nil && route.Dst.IP != nil {
   728  		dstLen, _ := route.Dst.Mask.Size()
   729  		msg.Dst_len = uint8(dstLen)
   730  		dstFamily := nl.GetIPFamily(route.Dst.IP)
   731  		family = dstFamily
   732  		var dstData []byte
   733  		if dstFamily == FAMILY_V4 {
   734  			dstData = route.Dst.IP.To4()
   735  		} else {
   736  			dstData = route.Dst.IP.To16()
   737  		}
   738  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, dstData))
   739  	} else if route.MPLSDst != nil {
   740  		family = nl.FAMILY_MPLS
   741  		msg.Dst_len = uint8(20)
   742  		msg.Type = unix.RTN_UNICAST
   743  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_DST, nl.EncodeMPLSStack(*route.MPLSDst)))
   744  	}
   745  
   746  	if route.NewDst != nil {
   747  		if family != -1 && family != route.NewDst.Family() {
   748  			return nil, fmt.Errorf("new destination and destination are not the same address family")
   749  		}
   750  		buf, err := route.NewDst.Encode()
   751  		if err != nil {
   752  			return nil, err
   753  		}
   754  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_NEWDST, buf))
   755  	}
   756  
   757  	if route.Encap != nil {
   758  		buf := make([]byte, 2)
   759  		native.PutUint16(buf, uint16(route.Encap.Type()))
   760  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_ENCAP_TYPE, buf))
   761  		buf, err := route.Encap.Encode()
   762  		if err != nil {
   763  			return nil, err
   764  		}
   765  		switch route.Encap.Type() {
   766  		case nl.LWTUNNEL_ENCAP_BPF:
   767  			rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_ENCAP|unix.NLA_F_NESTED, buf))
   768  		default:
   769  			rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_ENCAP, buf))
   770  		}
   771  
   772  	}
   773  
   774  	if route.Src != nil {
   775  		srcFamily := nl.GetIPFamily(route.Src)
   776  		if family != -1 && family != srcFamily {
   777  			return nil, fmt.Errorf("source and destination ip are not the same IP family")
   778  		}
   779  		family = srcFamily
   780  		var srcData []byte
   781  		if srcFamily == FAMILY_V4 {
   782  			srcData = route.Src.To4()
   783  		} else {
   784  			srcData = route.Src.To16()
   785  		}
   786  		// The commonly used src ip for routes is actually PREFSRC
   787  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_PREFSRC, srcData))
   788  	}
   789  
   790  	if route.Gw != nil {
   791  		gwFamily := nl.GetIPFamily(route.Gw)
   792  		if family != -1 && family != gwFamily {
   793  			return nil, fmt.Errorf("gateway, source, and destination ip are not the same IP family")
   794  		}
   795  		family = gwFamily
   796  		var gwData []byte
   797  		if gwFamily == FAMILY_V4 {
   798  			gwData = route.Gw.To4()
   799  		} else {
   800  			gwData = route.Gw.To16()
   801  		}
   802  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_GATEWAY, gwData))
   803  	}
   804  
   805  	if route.Via != nil {
   806  		buf, err := route.Via.Encode()
   807  		if err != nil {
   808  			return nil, fmt.Errorf("failed to encode RTA_VIA: %v", err)
   809  		}
   810  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_VIA, buf))
   811  	}
   812  
   813  	if len(route.MultiPath) > 0 {
   814  		buf := []byte{}
   815  		for _, nh := range route.MultiPath {
   816  			rtnh := &nl.RtNexthop{
   817  				RtNexthop: unix.RtNexthop{
   818  					Hops:    uint8(nh.Hops),
   819  					Ifindex: int32(nh.LinkIndex),
   820  					Flags:   uint8(nh.Flags),
   821  				},
   822  			}
   823  			children := []nl.NetlinkRequestData{}
   824  			if nh.Gw != nil {
   825  				gwFamily := nl.GetIPFamily(nh.Gw)
   826  				if family != -1 && family != gwFamily {
   827  					return nil, fmt.Errorf("gateway, source, and destination ip are not the same IP family")
   828  				}
   829  				if gwFamily == FAMILY_V4 {
   830  					children = append(children, nl.NewRtAttr(unix.RTA_GATEWAY, []byte(nh.Gw.To4())))
   831  				} else {
   832  					children = append(children, nl.NewRtAttr(unix.RTA_GATEWAY, []byte(nh.Gw.To16())))
   833  				}
   834  			}
   835  			if nh.NewDst != nil {
   836  				if family != -1 && family != nh.NewDst.Family() {
   837  					return nil, fmt.Errorf("new destination and destination are not the same address family")
   838  				}
   839  				buf, err := nh.NewDst.Encode()
   840  				if err != nil {
   841  					return nil, err
   842  				}
   843  				children = append(children, nl.NewRtAttr(unix.RTA_NEWDST, buf))
   844  			}
   845  			if nh.Encap != nil {
   846  				buf := make([]byte, 2)
   847  				native.PutUint16(buf, uint16(nh.Encap.Type()))
   848  				children = append(children, nl.NewRtAttr(unix.RTA_ENCAP_TYPE, buf))
   849  				buf, err := nh.Encap.Encode()
   850  				if err != nil {
   851  					return nil, err
   852  				}
   853  				children = append(children, nl.NewRtAttr(unix.RTA_ENCAP, buf))
   854  			}
   855  			if nh.Via != nil {
   856  				buf, err := nh.Via.Encode()
   857  				if err != nil {
   858  					return nil, err
   859  				}
   860  				children = append(children, nl.NewRtAttr(unix.RTA_VIA, buf))
   861  			}
   862  			rtnh.Children = children
   863  			buf = append(buf, rtnh.Serialize()...)
   864  		}
   865  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_MULTIPATH, buf))
   866  	}
   867  
   868  	if route.Table > 0 {
   869  		if route.Table >= 256 {
   870  			msg.Table = unix.RT_TABLE_UNSPEC
   871  			b := make([]byte, 4)
   872  			native.PutUint32(b, uint32(route.Table))
   873  			rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_TABLE, b))
   874  		} else {
   875  			msg.Table = uint8(route.Table)
   876  		}
   877  	}
   878  
   879  	if route.Priority > 0 {
   880  		b := make([]byte, 4)
   881  		native.PutUint32(b, uint32(route.Priority))
   882  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_PRIORITY, b))
   883  	}
   884  	if route.Realm > 0 {
   885  		b := make([]byte, 4)
   886  		native.PutUint32(b, uint32(route.Realm))
   887  		rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_FLOW, b))
   888  	}
   889  	if route.Tos > 0 {
   890  		msg.Tos = uint8(route.Tos)
   891  	}
   892  	if route.Protocol > 0 {
   893  		msg.Protocol = uint8(route.Protocol)
   894  	}
   895  	if route.Type > 0 {
   896  		msg.Type = uint8(route.Type)
   897  	}
   898  
   899  	var metrics []*nl.RtAttr
   900  	if route.MTU > 0 {
   901  		b := nl.Uint32Attr(uint32(route.MTU))
   902  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_MTU, b))
   903  	}
   904  	if route.Window > 0 {
   905  		b := nl.Uint32Attr(uint32(route.Window))
   906  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_WINDOW, b))
   907  	}
   908  	if route.Rtt > 0 {
   909  		b := nl.Uint32Attr(uint32(route.Rtt))
   910  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_RTT, b))
   911  	}
   912  	if route.RttVar > 0 {
   913  		b := nl.Uint32Attr(uint32(route.RttVar))
   914  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_RTTVAR, b))
   915  	}
   916  	if route.Ssthresh > 0 {
   917  		b := nl.Uint32Attr(uint32(route.Ssthresh))
   918  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_SSTHRESH, b))
   919  	}
   920  	if route.Cwnd > 0 {
   921  		b := nl.Uint32Attr(uint32(route.Cwnd))
   922  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_CWND, b))
   923  	}
   924  	if route.AdvMSS > 0 {
   925  		b := nl.Uint32Attr(uint32(route.AdvMSS))
   926  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_ADVMSS, b))
   927  	}
   928  	if route.Reordering > 0 {
   929  		b := nl.Uint32Attr(uint32(route.Reordering))
   930  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_REORDERING, b))
   931  	}
   932  	if route.Hoplimit > 0 {
   933  		b := nl.Uint32Attr(uint32(route.Hoplimit))
   934  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_HOPLIMIT, b))
   935  	}
   936  	if route.InitCwnd > 0 {
   937  		b := nl.Uint32Attr(uint32(route.InitCwnd))
   938  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_INITCWND, b))
   939  	}
   940  	if route.Features > 0 {
   941  		b := nl.Uint32Attr(uint32(route.Features))
   942  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_FEATURES, b))
   943  	}
   944  	if route.RtoMin > 0 {
   945  		b := nl.Uint32Attr(uint32(route.RtoMin))
   946  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_RTO_MIN, b))
   947  	}
   948  	if route.InitRwnd > 0 {
   949  		b := nl.Uint32Attr(uint32(route.InitRwnd))
   950  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_INITRWND, b))
   951  	}
   952  	if route.QuickACK > 0 {
   953  		b := nl.Uint32Attr(uint32(route.QuickACK))
   954  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_QUICKACK, b))
   955  	}
   956  	if route.Congctl != "" {
   957  		b := nl.ZeroTerminated(route.Congctl)
   958  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_CC_ALGO, b))
   959  	}
   960  	if route.FastOpenNoCookie > 0 {
   961  		b := nl.Uint32Attr(uint32(route.FastOpenNoCookie))
   962  		metrics = append(metrics, nl.NewRtAttr(unix.RTAX_FASTOPEN_NO_COOKIE, b))
   963  	}
   964  
   965  	if metrics != nil {
   966  		attr := nl.NewRtAttr(unix.RTA_METRICS, nil)
   967  		for _, metric := range metrics {
   968  			attr.AddChild(metric)
   969  		}
   970  		rtAttrs = append(rtAttrs, attr)
   971  	}
   972  
   973  	msg.Flags = uint32(route.Flags)
   974  	msg.Scope = uint8(route.Scope)
   975  	// only overwrite family if it was not set in msg
   976  	if msg.Family == 0 {
   977  		msg.Family = uint8(family)
   978  	}
   979  	req.AddData(msg)
   980  	for _, attr := range rtAttrs {
   981  		req.AddData(attr)
   982  	}
   983  
   984  	if (req.NlMsghdr.Type != unix.RTM_GETROUTE) || (req.NlMsghdr.Type == unix.RTM_GETROUTE && route.LinkIndex > 0) {
   985  		b := make([]byte, 4)
   986  		native.PutUint32(b, uint32(route.LinkIndex))
   987  		req.AddData(nl.NewRtAttr(unix.RTA_OIF, b))
   988  	}
   989  
   990  	return req.Execute(unix.NETLINK_ROUTE, 0)
   991  }
   992  
   993  // RouteList gets a list of routes in the system.
   994  // Equivalent to: `ip route show`.
   995  // The list can be filtered by link and ip family.
   996  func RouteList(link Link, family int) ([]Route, error) {
   997  	return pkgHandle.RouteList(link, family)
   998  }
   999  
  1000  // RouteList gets a list of routes in the system.
  1001  // Equivalent to: `ip route show`.
  1002  // The list can be filtered by link and ip family.
  1003  func (h *Handle) RouteList(link Link, family int) ([]Route, error) {
  1004  	routeFilter := &Route{}
  1005  	if link != nil {
  1006  		routeFilter.LinkIndex = link.Attrs().Index
  1007  	}
  1008  	return h.RouteListFiltered(family, routeFilter, RT_FILTER_OIF)
  1009  }
  1010  
  1011  // RouteListFiltered gets a list of routes in the system filtered with specified rules.
  1012  // All rules must be defined in RouteFilter struct
  1013  func RouteListFiltered(family int, filter *Route, filterMask uint64) ([]Route, error) {
  1014  	return pkgHandle.RouteListFiltered(family, filter, filterMask)
  1015  }
  1016  
  1017  // RouteListFiltered gets a list of routes in the system filtered with specified rules.
  1018  // All rules must be defined in RouteFilter struct
  1019  func (h *Handle) RouteListFiltered(family int, filter *Route, filterMask uint64) ([]Route, error) {
  1020  	req := h.newNetlinkRequest(unix.RTM_GETROUTE, unix.NLM_F_DUMP)
  1021  	rtmsg := &nl.RtMsg{}
  1022  	rtmsg.Family = uint8(family)
  1023  	msgs, err := h.routeHandle(filter, req, rtmsg)
  1024  	if err != nil {
  1025  		return nil, err
  1026  	}
  1027  
  1028  	var res []Route
  1029  	for _, m := range msgs {
  1030  		msg := nl.DeserializeRtMsg(m)
  1031  		if msg.Flags&unix.RTM_F_CLONED != 0 {
  1032  			// Ignore cloned routes
  1033  			continue
  1034  		}
  1035  		if msg.Table != unix.RT_TABLE_MAIN {
  1036  			if filter == nil || filter != nil && filterMask&RT_FILTER_TABLE == 0 {
  1037  				// Ignore non-main tables
  1038  				continue
  1039  			}
  1040  		}
  1041  		route, err := deserializeRoute(m)
  1042  		if err != nil {
  1043  			return nil, err
  1044  		}
  1045  		if filter != nil {
  1046  			switch {
  1047  			case filterMask&RT_FILTER_TABLE != 0 && filter.Table != unix.RT_TABLE_UNSPEC && route.Table != filter.Table:
  1048  				continue
  1049  			case filterMask&RT_FILTER_PROTOCOL != 0 && route.Protocol != filter.Protocol:
  1050  				continue
  1051  			case filterMask&RT_FILTER_SCOPE != 0 && route.Scope != filter.Scope:
  1052  				continue
  1053  			case filterMask&RT_FILTER_TYPE != 0 && route.Type != filter.Type:
  1054  				continue
  1055  			case filterMask&RT_FILTER_TOS != 0 && route.Tos != filter.Tos:
  1056  				continue
  1057  			case filterMask&RT_FILTER_REALM != 0 && route.Realm != filter.Realm:
  1058  				continue
  1059  			case filterMask&RT_FILTER_OIF != 0 && route.LinkIndex != filter.LinkIndex:
  1060  				continue
  1061  			case filterMask&RT_FILTER_IIF != 0 && route.ILinkIndex != filter.ILinkIndex:
  1062  				continue
  1063  			case filterMask&RT_FILTER_GW != 0 && !route.Gw.Equal(filter.Gw):
  1064  				continue
  1065  			case filterMask&RT_FILTER_SRC != 0 && !route.Src.Equal(filter.Src):
  1066  				continue
  1067  			case filterMask&RT_FILTER_DST != 0:
  1068  				if filter.MPLSDst == nil || route.MPLSDst == nil || (*filter.MPLSDst) != (*route.MPLSDst) {
  1069  					if !ipNetEqual(route.Dst, filter.Dst) {
  1070  						continue
  1071  					}
  1072  				}
  1073  			case filterMask&RT_FILTER_HOPLIMIT != 0 && route.Hoplimit != filter.Hoplimit:
  1074  				continue
  1075  			}
  1076  		}
  1077  		res = append(res, route)
  1078  	}
  1079  	return res, nil
  1080  }
  1081  
  1082  // deserializeRoute decodes a binary netlink message into a Route struct
  1083  func deserializeRoute(m []byte) (Route, error) {
  1084  	msg := nl.DeserializeRtMsg(m)
  1085  	attrs, err := nl.ParseRouteAttr(m[msg.Len():])
  1086  	if err != nil {
  1087  		return Route{}, err
  1088  	}
  1089  	route := Route{
  1090  		Scope:    Scope(msg.Scope),
  1091  		Protocol: RouteProtocol(int(msg.Protocol)),
  1092  		Table:    int(msg.Table),
  1093  		Type:     int(msg.Type),
  1094  		Tos:      int(msg.Tos),
  1095  		Flags:    int(msg.Flags),
  1096  		Family:   int(msg.Family),
  1097  	}
  1098  
  1099  	var encap, encapType syscall.NetlinkRouteAttr
  1100  	for _, attr := range attrs {
  1101  		switch attr.Attr.Type {
  1102  		case unix.RTA_GATEWAY:
  1103  			route.Gw = net.IP(attr.Value)
  1104  		case unix.RTA_PREFSRC:
  1105  			route.Src = net.IP(attr.Value)
  1106  		case unix.RTA_DST:
  1107  			if msg.Family == nl.FAMILY_MPLS {
  1108  				stack := nl.DecodeMPLSStack(attr.Value)
  1109  				if len(stack) == 0 || len(stack) > 1 {
  1110  					return route, fmt.Errorf("invalid MPLS RTA_DST")
  1111  				}
  1112  				route.MPLSDst = &stack[0]
  1113  			} else {
  1114  				route.Dst = &net.IPNet{
  1115  					IP:   attr.Value,
  1116  					Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attr.Value)),
  1117  				}
  1118  			}
  1119  		case unix.RTA_OIF:
  1120  			route.LinkIndex = int(native.Uint32(attr.Value[0:4]))
  1121  		case unix.RTA_IIF:
  1122  			route.ILinkIndex = int(native.Uint32(attr.Value[0:4]))
  1123  		case unix.RTA_PRIORITY:
  1124  			route.Priority = int(native.Uint32(attr.Value[0:4]))
  1125  		case unix.RTA_FLOW:
  1126  			route.Realm = int(native.Uint32(attr.Value[0:4]))
  1127  		case unix.RTA_TABLE:
  1128  			route.Table = int(native.Uint32(attr.Value[0:4]))
  1129  		case unix.RTA_MULTIPATH:
  1130  			parseRtNexthop := func(value []byte) (*NexthopInfo, []byte, error) {
  1131  				if len(value) < unix.SizeofRtNexthop {
  1132  					return nil, nil, fmt.Errorf("lack of bytes")
  1133  				}
  1134  				nh := nl.DeserializeRtNexthop(value)
  1135  				if len(value) < int(nh.RtNexthop.Len) {
  1136  					return nil, nil, fmt.Errorf("lack of bytes")
  1137  				}
  1138  				info := &NexthopInfo{
  1139  					LinkIndex: int(nh.RtNexthop.Ifindex),
  1140  					Hops:      int(nh.RtNexthop.Hops),
  1141  					Flags:     int(nh.RtNexthop.Flags),
  1142  				}
  1143  				attrs, err := nl.ParseRouteAttr(value[unix.SizeofRtNexthop:int(nh.RtNexthop.Len)])
  1144  				if err != nil {
  1145  					return nil, nil, err
  1146  				}
  1147  				var encap, encapType syscall.NetlinkRouteAttr
  1148  				for _, attr := range attrs {
  1149  					switch attr.Attr.Type {
  1150  					case unix.RTA_GATEWAY:
  1151  						info.Gw = net.IP(attr.Value)
  1152  					case unix.RTA_NEWDST:
  1153  						var d Destination
  1154  						switch msg.Family {
  1155  						case nl.FAMILY_MPLS:
  1156  							d = &MPLSDestination{}
  1157  						}
  1158  						if err := d.Decode(attr.Value); err != nil {
  1159  							return nil, nil, err
  1160  						}
  1161  						info.NewDst = d
  1162  					case unix.RTA_ENCAP_TYPE:
  1163  						encapType = attr
  1164  					case unix.RTA_ENCAP:
  1165  						encap = attr
  1166  					case unix.RTA_VIA:
  1167  						d := &Via{}
  1168  						if err := d.Decode(attr.Value); err != nil {
  1169  							return nil, nil, err
  1170  						}
  1171  						info.Via = d
  1172  					}
  1173  				}
  1174  
  1175  				if len(encap.Value) != 0 && len(encapType.Value) != 0 {
  1176  					typ := int(native.Uint16(encapType.Value[0:2]))
  1177  					var e Encap
  1178  					switch typ {
  1179  					case nl.LWTUNNEL_ENCAP_MPLS:
  1180  						e = &MPLSEncap{}
  1181  						if err := e.Decode(encap.Value); err != nil {
  1182  							return nil, nil, err
  1183  						}
  1184  					}
  1185  					info.Encap = e
  1186  				}
  1187  
  1188  				return info, value[int(nh.RtNexthop.Len):], nil
  1189  			}
  1190  			rest := attr.Value
  1191  			for len(rest) > 0 {
  1192  				info, buf, err := parseRtNexthop(rest)
  1193  				if err != nil {
  1194  					return route, err
  1195  				}
  1196  				route.MultiPath = append(route.MultiPath, info)
  1197  				rest = buf
  1198  			}
  1199  		case unix.RTA_NEWDST:
  1200  			var d Destination
  1201  			switch msg.Family {
  1202  			case nl.FAMILY_MPLS:
  1203  				d = &MPLSDestination{}
  1204  			}
  1205  			if err := d.Decode(attr.Value); err != nil {
  1206  				return route, err
  1207  			}
  1208  			route.NewDst = d
  1209  		case unix.RTA_VIA:
  1210  			v := &Via{}
  1211  			if err := v.Decode(attr.Value); err != nil {
  1212  				return route, err
  1213  			}
  1214  			route.Via = v
  1215  		case unix.RTA_ENCAP_TYPE:
  1216  			encapType = attr
  1217  		case unix.RTA_ENCAP:
  1218  			encap = attr
  1219  		case unix.RTA_METRICS:
  1220  			metrics, err := nl.ParseRouteAttr(attr.Value)
  1221  			if err != nil {
  1222  				return route, err
  1223  			}
  1224  			for _, metric := range metrics {
  1225  				switch metric.Attr.Type {
  1226  				case unix.RTAX_MTU:
  1227  					route.MTU = int(native.Uint32(metric.Value[0:4]))
  1228  				case unix.RTAX_WINDOW:
  1229  					route.Window = int(native.Uint32(metric.Value[0:4]))
  1230  				case unix.RTAX_RTT:
  1231  					route.Rtt = int(native.Uint32(metric.Value[0:4]))
  1232  				case unix.RTAX_RTTVAR:
  1233  					route.RttVar = int(native.Uint32(metric.Value[0:4]))
  1234  				case unix.RTAX_SSTHRESH:
  1235  					route.Ssthresh = int(native.Uint32(metric.Value[0:4]))
  1236  				case unix.RTAX_CWND:
  1237  					route.Cwnd = int(native.Uint32(metric.Value[0:4]))
  1238  				case unix.RTAX_ADVMSS:
  1239  					route.AdvMSS = int(native.Uint32(metric.Value[0:4]))
  1240  				case unix.RTAX_REORDERING:
  1241  					route.Reordering = int(native.Uint32(metric.Value[0:4]))
  1242  				case unix.RTAX_HOPLIMIT:
  1243  					route.Hoplimit = int(native.Uint32(metric.Value[0:4]))
  1244  				case unix.RTAX_INITCWND:
  1245  					route.InitCwnd = int(native.Uint32(metric.Value[0:4]))
  1246  				case unix.RTAX_FEATURES:
  1247  					route.Features = int(native.Uint32(metric.Value[0:4]))
  1248  				case unix.RTAX_RTO_MIN:
  1249  					route.RtoMin = int(native.Uint32(metric.Value[0:4]))
  1250  				case unix.RTAX_INITRWND:
  1251  					route.InitRwnd = int(native.Uint32(metric.Value[0:4]))
  1252  				case unix.RTAX_QUICKACK:
  1253  					route.QuickACK = int(native.Uint32(metric.Value[0:4]))
  1254  				case unix.RTAX_CC_ALGO:
  1255  					route.Congctl = nl.BytesToString(metric.Value)
  1256  				case unix.RTAX_FASTOPEN_NO_COOKIE:
  1257  					route.FastOpenNoCookie = int(native.Uint32(metric.Value[0:4]))
  1258  				}
  1259  			}
  1260  		}
  1261  	}
  1262  
  1263  	if len(encap.Value) != 0 && len(encapType.Value) != 0 {
  1264  		typ := int(native.Uint16(encapType.Value[0:2]))
  1265  		var e Encap
  1266  		switch typ {
  1267  		case nl.LWTUNNEL_ENCAP_MPLS:
  1268  			e = &MPLSEncap{}
  1269  			if err := e.Decode(encap.Value); err != nil {
  1270  				return route, err
  1271  			}
  1272  		case nl.LWTUNNEL_ENCAP_SEG6:
  1273  			e = &SEG6Encap{}
  1274  			if err := e.Decode(encap.Value); err != nil {
  1275  				return route, err
  1276  			}
  1277  		case nl.LWTUNNEL_ENCAP_SEG6_LOCAL:
  1278  			e = &SEG6LocalEncap{}
  1279  			if err := e.Decode(encap.Value); err != nil {
  1280  				return route, err
  1281  			}
  1282  		case nl.LWTUNNEL_ENCAP_BPF:
  1283  			e = &BpfEncap{}
  1284  			if err := e.Decode(encap.Value); err != nil {
  1285  				return route, err
  1286  			}
  1287  		}
  1288  		route.Encap = e
  1289  	}
  1290  
  1291  	return route, nil
  1292  }
  1293  
  1294  // RouteGetOptions contains a set of options to use with
  1295  // RouteGetWithOptions
  1296  type RouteGetOptions struct {
  1297  	Iif     string
  1298  	Oif     string
  1299  	VrfName string
  1300  	SrcAddr net.IP
  1301  	UID     *uint32
  1302  }
  1303  
  1304  // RouteGetWithOptions gets a route to a specific destination from the host system.
  1305  // Equivalent to: 'ip route get <> vrf <VrfName>'.
  1306  func RouteGetWithOptions(destination net.IP, options *RouteGetOptions) ([]Route, error) {
  1307  	return pkgHandle.RouteGetWithOptions(destination, options)
  1308  }
  1309  
  1310  // RouteGet gets a route to a specific destination from the host system.
  1311  // Equivalent to: 'ip route get'.
  1312  func RouteGet(destination net.IP) ([]Route, error) {
  1313  	return pkgHandle.RouteGet(destination)
  1314  }
  1315  
  1316  // RouteGetWithOptions gets a route to a specific destination from the host system.
  1317  // Equivalent to: 'ip route get <> vrf <VrfName>'.
  1318  func (h *Handle) RouteGetWithOptions(destination net.IP, options *RouteGetOptions) ([]Route, error) {
  1319  	req := h.newNetlinkRequest(unix.RTM_GETROUTE, unix.NLM_F_REQUEST)
  1320  	family := nl.GetIPFamily(destination)
  1321  	var destinationData []byte
  1322  	var bitlen uint8
  1323  	if family == FAMILY_V4 {
  1324  		destinationData = destination.To4()
  1325  		bitlen = 32
  1326  	} else {
  1327  		destinationData = destination.To16()
  1328  		bitlen = 128
  1329  	}
  1330  	msg := &nl.RtMsg{}
  1331  	msg.Family = uint8(family)
  1332  	msg.Dst_len = bitlen
  1333  	if options != nil && options.SrcAddr != nil {
  1334  		msg.Src_len = bitlen
  1335  	}
  1336  	msg.Flags = unix.RTM_F_LOOKUP_TABLE
  1337  	req.AddData(msg)
  1338  
  1339  	rtaDst := nl.NewRtAttr(unix.RTA_DST, destinationData)
  1340  	req.AddData(rtaDst)
  1341  
  1342  	if options != nil {
  1343  		if options.VrfName != "" {
  1344  			link, err := LinkByName(options.VrfName)
  1345  			if err != nil {
  1346  				return nil, err
  1347  			}
  1348  			b := make([]byte, 4)
  1349  			native.PutUint32(b, uint32(link.Attrs().Index))
  1350  
  1351  			req.AddData(nl.NewRtAttr(unix.RTA_OIF, b))
  1352  		}
  1353  
  1354  		if len(options.Iif) > 0 {
  1355  			link, err := LinkByName(options.Iif)
  1356  			if err != nil {
  1357  				return nil, err
  1358  			}
  1359  
  1360  			b := make([]byte, 4)
  1361  			native.PutUint32(b, uint32(link.Attrs().Index))
  1362  
  1363  			req.AddData(nl.NewRtAttr(unix.RTA_IIF, b))
  1364  		}
  1365  
  1366  		if len(options.Oif) > 0 {
  1367  			link, err := LinkByName(options.Oif)
  1368  			if err != nil {
  1369  				return nil, err
  1370  			}
  1371  
  1372  			b := make([]byte, 4)
  1373  			native.PutUint32(b, uint32(link.Attrs().Index))
  1374  
  1375  			req.AddData(nl.NewRtAttr(unix.RTA_OIF, b))
  1376  		}
  1377  
  1378  		if options.SrcAddr != nil {
  1379  			var srcAddr []byte
  1380  			if family == FAMILY_V4 {
  1381  				srcAddr = options.SrcAddr.To4()
  1382  			} else {
  1383  				srcAddr = options.SrcAddr.To16()
  1384  			}
  1385  
  1386  			req.AddData(nl.NewRtAttr(unix.RTA_SRC, srcAddr))
  1387  		}
  1388  
  1389  		if options.UID != nil {
  1390  			uid := *options.UID
  1391  			b := make([]byte, 4)
  1392  			native.PutUint32(b, uid)
  1393  			req.AddData(nl.NewRtAttr(unix.RTA_UID, b))
  1394  		}
  1395  	}
  1396  
  1397  	msgs, err := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWROUTE)
  1398  	if err != nil {
  1399  		return nil, err
  1400  	}
  1401  
  1402  	var res []Route
  1403  	for _, m := range msgs {
  1404  		route, err := deserializeRoute(m)
  1405  		if err != nil {
  1406  			return nil, err
  1407  		}
  1408  		res = append(res, route)
  1409  	}
  1410  	return res, nil
  1411  }
  1412  
  1413  // RouteGet gets a route to a specific destination from the host system.
  1414  // Equivalent to: 'ip route get'.
  1415  func (h *Handle) RouteGet(destination net.IP) ([]Route, error) {
  1416  	return h.RouteGetWithOptions(destination, nil)
  1417  }
  1418  
  1419  // RouteSubscribe takes a chan down which notifications will be sent
  1420  // when routes are added or deleted. Close the 'done' chan to stop subscription.
  1421  func RouteSubscribe(ch chan<- RouteUpdate, done <-chan struct{}) error {
  1422  	return routeSubscribeAt(netns.None(), netns.None(), ch, done, nil, false)
  1423  }
  1424  
  1425  // RouteSubscribeAt works like RouteSubscribe plus it allows the caller
  1426  // to choose the network namespace in which to subscribe (ns).
  1427  func RouteSubscribeAt(ns netns.NsHandle, ch chan<- RouteUpdate, done <-chan struct{}) error {
  1428  	return routeSubscribeAt(ns, netns.None(), ch, done, nil, false)
  1429  }
  1430  
  1431  // RouteSubscribeOptions contains a set of options to use with
  1432  // RouteSubscribeWithOptions.
  1433  type RouteSubscribeOptions struct {
  1434  	Namespace     *netns.NsHandle
  1435  	ErrorCallback func(error)
  1436  	ListExisting  bool
  1437  }
  1438  
  1439  // RouteSubscribeWithOptions work like RouteSubscribe but enable to
  1440  // provide additional options to modify the behavior. Currently, the
  1441  // namespace can be provided as well as an error callback.
  1442  func RouteSubscribeWithOptions(ch chan<- RouteUpdate, done <-chan struct{}, options RouteSubscribeOptions) error {
  1443  	if options.Namespace == nil {
  1444  		none := netns.None()
  1445  		options.Namespace = &none
  1446  	}
  1447  	return routeSubscribeAt(*options.Namespace, netns.None(), ch, done, options.ErrorCallback, options.ListExisting)
  1448  }
  1449  
  1450  func routeSubscribeAt(newNs, curNs netns.NsHandle, ch chan<- RouteUpdate, done <-chan struct{}, cberr func(error), listExisting bool) error {
  1451  	s, err := nl.SubscribeAt(newNs, curNs, unix.NETLINK_ROUTE, unix.RTNLGRP_IPV4_ROUTE, unix.RTNLGRP_IPV6_ROUTE)
  1452  	if err != nil {
  1453  		return err
  1454  	}
  1455  	if done != nil {
  1456  		go func() {
  1457  			<-done
  1458  			s.Close()
  1459  		}()
  1460  	}
  1461  	if listExisting {
  1462  		req := pkgHandle.newNetlinkRequest(unix.RTM_GETROUTE,
  1463  			unix.NLM_F_DUMP)
  1464  		infmsg := nl.NewIfInfomsg(unix.AF_UNSPEC)
  1465  		req.AddData(infmsg)
  1466  		if err := s.Send(req); err != nil {
  1467  			return err
  1468  		}
  1469  	}
  1470  	go func() {
  1471  		defer close(ch)
  1472  		for {
  1473  			msgs, from, err := s.Receive()
  1474  			if err != nil {
  1475  				if cberr != nil {
  1476  					cberr(fmt.Errorf("Receive failed: %v",
  1477  						err))
  1478  				}
  1479  				return
  1480  			}
  1481  			if from.Pid != nl.PidKernel {
  1482  				if cberr != nil {
  1483  					cberr(fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel))
  1484  				}
  1485  				continue
  1486  			}
  1487  			for _, m := range msgs {
  1488  				if m.Header.Type == unix.NLMSG_DONE {
  1489  					continue
  1490  				}
  1491  				if m.Header.Type == unix.NLMSG_ERROR {
  1492  					error := int32(native.Uint32(m.Data[0:4]))
  1493  					if error == 0 {
  1494  						continue
  1495  					}
  1496  					if cberr != nil {
  1497  						cberr(fmt.Errorf("error message: %v",
  1498  							syscall.Errno(-error)))
  1499  					}
  1500  					continue
  1501  				}
  1502  				route, err := deserializeRoute(m.Data)
  1503  				if err != nil {
  1504  					if cberr != nil {
  1505  						cberr(err)
  1506  					}
  1507  					continue
  1508  				}
  1509  				ch <- RouteUpdate{Type: m.Header.Type, Route: route}
  1510  			}
  1511  		}
  1512  	}()
  1513  
  1514  	return nil
  1515  }
  1516  
  1517  func (p RouteProtocol) String() string {
  1518  	switch int(p) {
  1519  	case unix.RTPROT_BABEL:
  1520  		return "babel"
  1521  	case unix.RTPROT_BGP:
  1522  		return "bgp"
  1523  	case unix.RTPROT_BIRD:
  1524  		return "bird"
  1525  	case unix.RTPROT_BOOT:
  1526  		return "boot"
  1527  	case unix.RTPROT_DHCP:
  1528  		return "dhcp"
  1529  	case unix.RTPROT_DNROUTED:
  1530  		return "dnrouted"
  1531  	case unix.RTPROT_EIGRP:
  1532  		return "eigrp"
  1533  	case unix.RTPROT_GATED:
  1534  		return "gated"
  1535  	case unix.RTPROT_ISIS:
  1536  		return "isis"
  1537  	//case unix.RTPROT_KEEPALIVED:
  1538  	//	return "keepalived"
  1539  	case unix.RTPROT_KERNEL:
  1540  		return "kernel"
  1541  	case unix.RTPROT_MROUTED:
  1542  		return "mrouted"
  1543  	case unix.RTPROT_MRT:
  1544  		return "mrt"
  1545  	case unix.RTPROT_NTK:
  1546  		return "ntk"
  1547  	case unix.RTPROT_OSPF:
  1548  		return "ospf"
  1549  	case unix.RTPROT_RA:
  1550  		return "ra"
  1551  	case unix.RTPROT_REDIRECT:
  1552  		return "redirect"
  1553  	case unix.RTPROT_RIP:
  1554  		return "rip"
  1555  	case unix.RTPROT_STATIC:
  1556  		return "static"
  1557  	case unix.RTPROT_UNSPEC:
  1558  		return "unspec"
  1559  	case unix.RTPROT_XORP:
  1560  		return "xorp"
  1561  	case unix.RTPROT_ZEBRA:
  1562  		return "zebra"
  1563  	default:
  1564  		return strconv.Itoa(int(p))
  1565  	}
  1566  }