github.com/dolthub/go-mysql-server@v0.18.0/internal/sockstate/netstat_linux.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  //go:build linux
    16  // +build linux
    17  
    18  package sockstate
    19  
    20  // Taken (simplified and with utility functions added) from https://github.com/cakturk/go-netstat
    21  
    22  import (
    23  	"bufio"
    24  	"bytes"
    25  	"encoding/binary"
    26  	"fmt"
    27  	"io"
    28  	"log"
    29  	"net"
    30  	"os"
    31  	"path"
    32  	"strconv"
    33  	"strings"
    34  
    35  	"github.com/sirupsen/logrus"
    36  )
    37  
    38  const (
    39  	pathTCP4Tab = "/proc/net/tcp"
    40  	pathTCP6Tab = "/proc/net/tcp6"
    41  	ipv4StrLen  = 8
    42  	ipv6StrLen  = 32
    43  )
    44  
    45  type procFd struct {
    46  	base  string
    47  	pid   int
    48  	sktab []sockTabEntry
    49  	p     *process
    50  }
    51  
    52  const sockPrefix = "socket:["
    53  
    54  func getProcName(s []byte) string {
    55  	i := bytes.Index(s, []byte("("))
    56  	if i < 0 {
    57  		return ""
    58  	}
    59  	j := bytes.LastIndex(s, []byte(")"))
    60  	if i < 0 {
    61  		return ""
    62  	}
    63  	if i > j {
    64  		return ""
    65  	}
    66  	return string(s[i+1 : j])
    67  }
    68  
    69  func (p *procFd) iterFdDir() {
    70  	// link name is of the form socket:[5860846]
    71  	fddir := path.Join(p.base, "/fd")
    72  	fi, err := os.ReadDir(fddir)
    73  	if err != nil {
    74  		return
    75  	}
    76  	var buf [128]byte
    77  
    78  	for _, file := range fi {
    79  		fd := path.Join(fddir, file.Name())
    80  		lname, err := os.Readlink(fd)
    81  		if err != nil {
    82  			continue
    83  		}
    84  
    85  		for i := range p.sktab {
    86  			sk := &p.sktab[i]
    87  			ss := sockPrefix + sk.Ino + "]"
    88  			if ss != lname {
    89  				continue
    90  			}
    91  			if p.p == nil {
    92  				stat, err := os.Open(path.Join(p.base, "stat"))
    93  				if err != nil {
    94  					return
    95  				}
    96  				n, err := stat.Read(buf[:])
    97  				_ = stat.Close()
    98  				if err != nil {
    99  					return
   100  				}
   101  				z := bytes.SplitN(buf[:n], []byte(" "), 3)
   102  				name := getProcName(z[1])
   103  				p.p = &process{p.pid, name}
   104  			}
   105  			sk.Process = p.p
   106  		}
   107  	}
   108  }
   109  
   110  func extractProcInfo(sktab []sockTabEntry) {
   111  	const basedir = "/proc"
   112  	fi, err := os.ReadDir(basedir)
   113  	if err != nil {
   114  		return
   115  	}
   116  
   117  	for _, file := range fi {
   118  		if !file.IsDir() {
   119  			continue
   120  		}
   121  		pid, err := strconv.Atoi(file.Name())
   122  		if err != nil {
   123  			continue
   124  		}
   125  		base := path.Join(basedir, file.Name())
   126  		proc := procFd{base: base, pid: pid, sktab: sktab}
   127  		proc.iterFdDir()
   128  	}
   129  }
   130  
   131  func parseIPv4(s string) (net.IP, error) {
   132  	v, err := strconv.ParseUint(s, 16, 32)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  	ip := make(net.IP, net.IPv4len)
   137  	binary.LittleEndian.PutUint32(ip, uint32(v))
   138  	return ip, nil
   139  }
   140  
   141  func parseIPv6(s string) (net.IP, error) {
   142  	ip := make(net.IP, net.IPv6len)
   143  	const grpLen = 4
   144  	i, j := 0, 4
   145  	for len(s) != 0 {
   146  		grp := s[0:8]
   147  		u, err := strconv.ParseUint(grp, 16, 32)
   148  		binary.LittleEndian.PutUint32(ip[i:j], uint32(u))
   149  		if err != nil {
   150  			return nil, err
   151  		}
   152  		i, j = i+grpLen, j+grpLen
   153  		s = s[8:]
   154  	}
   155  	return ip, nil
   156  }
   157  
   158  func parseAddr(s string) (*sockAddr, error) {
   159  	fields := strings.Split(s, ":")
   160  	if len(fields) < 2 {
   161  		return nil, fmt.Errorf("sockstate: not enough fields: %v", s)
   162  	}
   163  	var ip net.IP
   164  	var err error
   165  	switch len(fields[0]) {
   166  	case ipv4StrLen:
   167  		ip, err = parseIPv4(fields[0])
   168  	case ipv6StrLen:
   169  		ip, err = parseIPv6(fields[0])
   170  	default:
   171  		log.Fatal("Badly formatted connection address:", s)
   172  	}
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  	v, err := strconv.ParseUint(fields[1], 16, 16)
   177  	if err != nil {
   178  		return nil, err
   179  	}
   180  	return &sockAddr{IP: ip, Port: uint16(v)}, nil
   181  }
   182  
   183  func parseSocktab(r io.Reader, accept AcceptFn) ([]sockTabEntry, error) {
   184  	br := bufio.NewScanner(r)
   185  	tab := make([]sockTabEntry, 0, 4)
   186  
   187  	// Discard title
   188  	br.Scan()
   189  
   190  	for br.Scan() {
   191  		var e sockTabEntry
   192  		line := br.Text()
   193  		// Skip comments
   194  		if i := strings.Index(line, "#"); i >= 0 {
   195  			line = line[:i]
   196  		}
   197  		fields := strings.Fields(line)
   198  		if len(fields) < 12 {
   199  			return nil, fmt.Errorf("sockstate: not enough fields: %v, %v", len(fields), fields)
   200  		}
   201  		addr, err := parseAddr(fields[1])
   202  		if err != nil {
   203  			return nil, err
   204  		}
   205  		e.LocalAddr = addr
   206  		addr, err = parseAddr(fields[2])
   207  		if err != nil {
   208  			return nil, err
   209  		}
   210  		e.RemoteAddr = addr
   211  		u, err := strconv.ParseUint(fields[3], 16, 8)
   212  		if err != nil {
   213  			return nil, err
   214  		}
   215  		e.State = skState(u)
   216  		u, err = strconv.ParseUint(fields[7], 10, 32)
   217  		if err != nil {
   218  			return nil, err
   219  		}
   220  		e.UID = uint32(u)
   221  		e.Ino = fields[9]
   222  		if accept(&e) {
   223  			tab = append(tab, e)
   224  		}
   225  	}
   226  	return tab, br.Err()
   227  }
   228  
   229  // This net stat code appears to be broken when running a linux binary under the Windows Subsystem for Linux (WSL). If
   230  // we detect we are running on WSL, disable the TCP socket check, as we do on Windows and Darwin.
   231  var isWSL = false
   232  var isProcBlocked = false
   233  
   234  func init() {
   235  	osRelease, err := os.ReadFile("/proc/sys/kernel/osrelease")
   236  	if err == nil {
   237  		osReleaseString := strings.ToLower(string(osRelease))
   238  		if strings.Contains(osReleaseString, "microsoft") {
   239  			isWSL = true
   240  		}
   241  	} else {
   242  		logrus.Warnf("Could not read /proc/sys/kernel/osrelease: %s", err.Error())
   243  		isProcBlocked = true
   244  	}
   245  }
   246  
   247  // tcpSocks returns a slice of active TCP sockets containing only those
   248  // elements that satisfy the accept function
   249  func tcpSocks(accept AcceptFn) ([]sockTabEntry, error) {
   250  	if isWSL || isProcBlocked {
   251  		logrus.Warn("Connection checking not implemented for WSL")
   252  		return nil, ErrSocketCheckNotImplemented.New()
   253  	}
   254  
   255  	paths := [2]string{pathTCP4Tab, pathTCP6Tab}
   256  	var allTabs []sockTabEntry
   257  	for _, p := range paths {
   258  		f, err := os.Open(p)
   259  		defer func() {
   260  			_ = f.Close()
   261  		}()
   262  		if os.IsNotExist(err) {
   263  			continue
   264  		}
   265  		if err != nil {
   266  			return nil, err
   267  		}
   268  
   269  		t, err := parseSocktab(f, accept)
   270  		if err != nil {
   271  			return nil, err
   272  		}
   273  		allTabs = append(allTabs, t...)
   274  
   275  	}
   276  	extractProcInfo(allTabs)
   277  	return allTabs, nil
   278  }
   279  
   280  // GetConnInode returns the inode number of an fd.
   281  func GetConnInode(conn *net.TCPConn) (n uint64, err error) {
   282  	fd, err := getConnFd(conn)
   283  	if err != nil {
   284  		return 0, err
   285  	}
   286  
   287  	if isWSL || isProcBlocked {
   288  		return 0, ErrSocketCheckNotImplemented.New()
   289  	}
   290  
   291  	socketStr := fmt.Sprintf("/proc/%d/fd/%d", os.Getpid(), fd)
   292  	socketLnk, err := os.Readlink(socketStr)
   293  	if err != nil {
   294  		return
   295  	}
   296  
   297  	if strings.HasPrefix(socketLnk, sockPrefix) {
   298  		_, err = fmt.Sscanf(socketLnk, sockPrefix+"%d]", &n)
   299  		if err != nil {
   300  			return
   301  		}
   302  	} else {
   303  		err = ErrNoSocketLink.New()
   304  	}
   305  	return
   306  }