github.com/zhuohuang-hust/src-cbuild@v0.0.0-20230105071821-c7aab3e7c840/mergeCode/libnetwork/ipvs/netlink.go (about)

     1  // +build linux
     2  
     3  package ipvs
     4  
     5  import (
     6  	"bytes"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"net"
    10  	"os/exec"
    11  	"strings"
    12  	"sync"
    13  	"syscall"
    14  	"unsafe"
    15  
    16  	"github.com/Sirupsen/logrus"
    17  	"github.com/vishvananda/netlink/nl"
    18  	"github.com/vishvananda/netns"
    19  )
    20  
    21  var (
    22  	native     = nl.NativeEndian()
    23  	ipvsFamily int
    24  	ipvsOnce   sync.Once
    25  )
    26  
    27  type genlMsgHdr struct {
    28  	cmd      uint8
    29  	version  uint8
    30  	reserved uint16
    31  }
    32  
    33  type ipvsFlags struct {
    34  	flags uint32
    35  	mask  uint32
    36  }
    37  
    38  func deserializeGenlMsg(b []byte) (hdr *genlMsgHdr) {
    39  	return (*genlMsgHdr)(unsafe.Pointer(&b[0:unsafe.Sizeof(*hdr)][0]))
    40  }
    41  
    42  func (hdr *genlMsgHdr) Serialize() []byte {
    43  	return (*(*[unsafe.Sizeof(*hdr)]byte)(unsafe.Pointer(hdr)))[:]
    44  }
    45  
    46  func (hdr *genlMsgHdr) Len() int {
    47  	return int(unsafe.Sizeof(*hdr))
    48  }
    49  
    50  func (f *ipvsFlags) Serialize() []byte {
    51  	return (*(*[unsafe.Sizeof(*f)]byte)(unsafe.Pointer(f)))[:]
    52  }
    53  
    54  func (f *ipvsFlags) Len() int {
    55  	return int(unsafe.Sizeof(*f))
    56  }
    57  
    58  func setup() {
    59  	ipvsOnce.Do(func() {
    60  		var err error
    61  		if out, err := exec.Command("modprobe", "-va", "ip_vs").CombinedOutput(); err != nil {
    62  			logrus.Warnf("Running modprobe ip_vs failed with message: `%s`, error: %v", strings.TrimSpace(string(out)), err)
    63  		}
    64  
    65  		ipvsFamily, err = getIPVSFamily()
    66  		if err != nil {
    67  			logrus.Errorf("Could not get ipvs family information from the kernel. It is possible that ipvs is not enabled in your kernel. Native loadbalancing will not work until this is fixed.")
    68  		}
    69  	})
    70  }
    71  
    72  func fillService(s *Service) nl.NetlinkRequestData {
    73  	cmdAttr := nl.NewRtAttr(ipvsCmdAttrService, nil)
    74  	nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddressFamily, nl.Uint16Attr(s.AddressFamily))
    75  	if s.FWMark != 0 {
    76  		nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFWMark, nl.Uint32Attr(s.FWMark))
    77  	} else {
    78  		nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrProtocol, nl.Uint16Attr(s.Protocol))
    79  		nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrAddress, rawIPData(s.Address))
    80  
    81  		// Port needs to be in network byte order.
    82  		portBuf := new(bytes.Buffer)
    83  		binary.Write(portBuf, binary.BigEndian, s.Port)
    84  		nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPort, portBuf.Bytes())
    85  	}
    86  
    87  	nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrSchedName, nl.ZeroTerminated(s.SchedName))
    88  	if s.PEName != "" {
    89  		nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrPEName, nl.ZeroTerminated(s.PEName))
    90  	}
    91  
    92  	f := &ipvsFlags{
    93  		flags: s.Flags,
    94  		mask:  0xFFFFFFFF,
    95  	}
    96  	nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrFlags, f.Serialize())
    97  	nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrTimeout, nl.Uint32Attr(s.Timeout))
    98  	nl.NewRtAttrChild(cmdAttr, ipvsSvcAttrNetmask, nl.Uint32Attr(s.Netmask))
    99  	return cmdAttr
   100  }
   101  
   102  func fillDestinaton(d *Destination) nl.NetlinkRequestData {
   103  	cmdAttr := nl.NewRtAttr(ipvsCmdAttrDest, nil)
   104  
   105  	nl.NewRtAttrChild(cmdAttr, ipvsDestAttrAddress, rawIPData(d.Address))
   106  	// Port needs to be in network byte order.
   107  	portBuf := new(bytes.Buffer)
   108  	binary.Write(portBuf, binary.BigEndian, d.Port)
   109  	nl.NewRtAttrChild(cmdAttr, ipvsDestAttrPort, portBuf.Bytes())
   110  
   111  	nl.NewRtAttrChild(cmdAttr, ipvsDestAttrForwardingMethod, nl.Uint32Attr(d.ConnectionFlags&ConnectionFlagFwdMask))
   112  	nl.NewRtAttrChild(cmdAttr, ipvsDestAttrWeight, nl.Uint32Attr(uint32(d.Weight)))
   113  	nl.NewRtAttrChild(cmdAttr, ipvsDestAttrUpperThreshold, nl.Uint32Attr(d.UpperThreshold))
   114  	nl.NewRtAttrChild(cmdAttr, ipvsDestAttrLowerThreshold, nl.Uint32Attr(d.LowerThreshold))
   115  
   116  	return cmdAttr
   117  }
   118  
   119  func (i *Handle) doCmd(s *Service, d *Destination, cmd uint8) error {
   120  	req := newIPVSRequest(cmd)
   121  	req.AddData(fillService(s))
   122  
   123  	if d != nil {
   124  		req.AddData(fillDestinaton(d))
   125  	}
   126  
   127  	if _, err := execute(i.sock, req, 0); err != nil {
   128  		return err
   129  	}
   130  
   131  	return nil
   132  }
   133  
   134  func getIPVSFamily() (int, error) {
   135  	sock, err := nl.GetNetlinkSocketAt(netns.None(), netns.None(), syscall.NETLINK_GENERIC)
   136  	if err != nil {
   137  		return 0, err
   138  	}
   139  
   140  	req := newGenlRequest(genlCtrlID, genlCtrlCmdGetFamily)
   141  	req.AddData(nl.NewRtAttr(genlCtrlAttrFamilyName, nl.ZeroTerminated("IPVS")))
   142  
   143  	msgs, err := execute(sock, req, 0)
   144  	if err != nil {
   145  		return 0, err
   146  	}
   147  
   148  	for _, m := range msgs {
   149  		hdr := deserializeGenlMsg(m)
   150  		attrs, err := nl.ParseRouteAttr(m[hdr.Len():])
   151  		if err != nil {
   152  			return 0, err
   153  		}
   154  
   155  		for _, attr := range attrs {
   156  			switch int(attr.Attr.Type) {
   157  			case genlCtrlAttrFamilyID:
   158  				return int(native.Uint16(attr.Value[0:2])), nil
   159  			}
   160  		}
   161  	}
   162  
   163  	return 0, fmt.Errorf("no family id in the netlink response")
   164  }
   165  
   166  func rawIPData(ip net.IP) []byte {
   167  	family := nl.GetIPFamily(ip)
   168  	if family == nl.FAMILY_V4 {
   169  		return ip.To4()
   170  	}
   171  
   172  	return ip
   173  }
   174  
   175  func newIPVSRequest(cmd uint8) *nl.NetlinkRequest {
   176  	return newGenlRequest(ipvsFamily, cmd)
   177  }
   178  
   179  func newGenlRequest(familyID int, cmd uint8) *nl.NetlinkRequest {
   180  	req := nl.NewNetlinkRequest(familyID, syscall.NLM_F_ACK)
   181  	req.AddData(&genlMsgHdr{cmd: cmd, version: 1})
   182  	return req
   183  }
   184  
   185  func execute(s *nl.NetlinkSocket, req *nl.NetlinkRequest, resType uint16) ([][]byte, error) {
   186  	var (
   187  		err error
   188  	)
   189  
   190  	if err := s.Send(req); err != nil {
   191  		return nil, err
   192  	}
   193  
   194  	pid, err := s.GetPid()
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  
   199  	var res [][]byte
   200  
   201  done:
   202  	for {
   203  		msgs, err := s.Receive()
   204  		if err != nil {
   205  			return nil, err
   206  		}
   207  		for _, m := range msgs {
   208  			if m.Header.Seq != req.Seq {
   209  				return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
   210  			}
   211  			if m.Header.Pid != pid {
   212  				return nil, fmt.Errorf("Wrong pid %d, expected %d", m.Header.Pid, pid)
   213  			}
   214  			if m.Header.Type == syscall.NLMSG_DONE {
   215  				break done
   216  			}
   217  			if m.Header.Type == syscall.NLMSG_ERROR {
   218  				error := int32(native.Uint32(m.Data[0:4]))
   219  				if error == 0 {
   220  					break done
   221  				}
   222  				return nil, syscall.Errno(-error)
   223  			}
   224  			if resType != 0 && m.Header.Type != resType {
   225  				continue
   226  			}
   227  			res = append(res, m.Data)
   228  			if m.Header.Flags&syscall.NLM_F_MULTI == 0 {
   229  				break done
   230  			}
   231  		}
   232  	}
   233  	return res, nil
   234  }