gitee.com/aurawing/surguard-go@v0.3.1-0.20240409071558-96509a61ecf3/stat.go (about)

     1  package main
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"net/http"
     9  	"net/http/httputil"
    10  	"net/url"
    11  	"os"
    12  	"os/exec"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  
    18  	"gitee.com/aurawing/surguard-go/device"
    19  	"github.com/labstack/echo"
    20  	"github.com/labstack/echo/middleware"
    21  	"github.com/shirou/gopsutil/v3/net"
    22  	"golang.org/x/net/http2"
    23  	"golang.org/x/net/http2/h2c"
    24  )
    25  
    26  type StatResp struct {
    27  	Code int                    `json:"code"`
    28  	Msg  string                 `json:"msg"`
    29  	Data map[string]interface{} `json:"data"`
    30  }
    31  
    32  var statSrv http.Server
    33  var statWg sync.WaitGroup
    34  var statInterfaceName string
    35  var statGroupName string
    36  var statDev *device.Device
    37  
    38  func startStatServer(interfaceName, portStr string, logger *device.Logger, dev *device.Device) error {
    39  	statWg.Add(1)
    40  	statDev = dev
    41  	groupNameStr := os.Getenv(device.ENV_SG_GROUP_NAME)
    42  	if groupNameStr == "" {
    43  		statGroupName = device.DEFAULT_GROUP_NAME
    44  	} else {
    45  		statGroupName = groupNameStr
    46  	}
    47  	if strings.TrimSpace(interfaceName) == "" {
    48  		logger.Errorf("interface name is empty")
    49  		return errors.New("interface name is empty")
    50  	}
    51  	statInterfaceName = interfaceName
    52  	port := int64(device.DEFAULT_LISTEN_PORT)
    53  	if strings.TrimSpace(portStr) != "" {
    54  		var err error
    55  		port, err = strconv.ParseInt(strings.TrimSpace(portStr), 10, 64)
    56  		if err != nil {
    57  			logger.Errorf("parse stat server port failed %v", err)
    58  			return err
    59  		}
    60  	}
    61  	server := echo.New()
    62  	server.Use(middleware.Logger())
    63  	server.Use(middleware.Recover())
    64  	server.Use(middleware.GzipWithConfig(middleware.GzipConfig{
    65  		Level: 5,
    66  	}))
    67  	server.GET("/surguard/netstat", sgNetStat)
    68  	server.POST("/surguard/bypass", bypass)
    69  	server.POST("/surguard/disableBypass", disableBypass)
    70  
    71  	apiStatPort := port + 1
    72  	apiStatPortStr := os.Getenv("SG_STAT_PORT")
    73  	if apiStatPortStr != "" {
    74  		var err error
    75  		apiStatPort, err = strconv.ParseInt(strings.TrimSpace(apiStatPortStr), 10, 64)
    76  		if err != nil {
    77  			logger.Errorf("parse API stat server port failed %v", err)
    78  			return err
    79  		}
    80  	}
    81  	proxy := httputil.NewSingleHostReverseProxy(&url.URL{
    82  		Scheme: "http",
    83  		Host:   fmt.Sprintf("localhost:%d", apiStatPort),
    84  	})
    85  	server.Any("/surwall/*", echo.WrapHandler(proxy))
    86  
    87  	h2s := &http2.Server{}
    88  	statSrv = http.Server{
    89  		Addr:    fmt.Sprintf(":%d", port),
    90  		Handler: h2c.NewHandler(server, h2s),
    91  		//ReadTimeout: 30 * time.Second, // customize http.Server timeouts
    92  	}
    93  	go func() {
    94  		if err := statSrv.ListenAndServe(); err != http.ErrServerClosed {
    95  			logger.Errorf("stat server error: %v", err)
    96  		}
    97  		statWg.Done()
    98  	}()
    99  	return nil
   100  }
   101  
   102  func getNetStats(interfaceName string) (*net.IOCountersStat, error) {
   103  	netStats, err := net.IOCounters(true)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  
   108  	for _, stats := range netStats {
   109  		if stats.Name == interfaceName {
   110  			return &stats, nil
   111  		}
   112  	}
   113  
   114  	return nil, fmt.Errorf("interface '%s' not found", interfaceName)
   115  }
   116  
   117  func sgNetStat(c echo.Context) error {
   118  	resp := new(StatResp)
   119  	stats, err := getNetStats(statInterfaceName)
   120  	if err != nil {
   121  		resp.Code = -1
   122  		resp.Msg = err.Error()
   123  		resp.Data = map[string]interface{}{}
   124  	} else {
   125  		resp.Code = 0
   126  		resp.Msg = "ok"
   127  		resp.Data = map[string]interface{}{
   128  			"InterfaceName":   stats.Name,
   129  			"BytesSent":       stats.BytesSent,
   130  			"BytesReceived":   stats.BytesRecv,
   131  			"PacketsSent":     stats.PacketsSent,
   132  			"PacketsReceived": stats.PacketsRecv,
   133  			"ErrorsSent":      stats.Errout,
   134  			"ErrorsReceived":  stats.Errin,
   135  			"DropsSent":       stats.Dropout,
   136  			"DropReceived":    stats.Dropin,
   137  		}
   138  	}
   139  	b, err := json.Marshal(resp)
   140  	if err != nil {
   141  		resp.Code = -2
   142  		resp.Msg = err.Error()
   143  		resp.Data = map[string]interface{}{}
   144  	}
   145  	return c.JSONBlob(http.StatusOK, b)
   146  }
   147  
   148  func bypass(c echo.Context) error {
   149  	resp := new(StatResp)
   150  	if !device.IfBindInterface() {
   151  		err := runCmd("ip", "route", "flush", "table", statGroupName)
   152  		if err != nil {
   153  			resp.Code = -3
   154  			resp.Msg = err.Error()
   155  			resp.Data = map[string]interface{}{}
   156  		}
   157  	} else {
   158  		statDev.IterPeerEndpoint(func(s string) {
   159  			runCmd("route", "delete", s)
   160  		})
   161  	}
   162  	b, err := json.Marshal(resp)
   163  	if err != nil {
   164  		resp.Code = -2
   165  		resp.Msg = err.Error()
   166  		resp.Data = map[string]interface{}{}
   167  	}
   168  	return c.JSONBlob(http.StatusOK, b)
   169  }
   170  
   171  func disableBypass(c echo.Context) error {
   172  	resp := new(StatResp)
   173  	if !device.IfBindInterface() {
   174  		err := runCmd("ip", "route", "add", "default", "dev", statInterfaceName, "table", statGroupName)
   175  		if err != nil {
   176  			resp.Code = -4
   177  			resp.Msg = err.Error()
   178  			resp.Data = map[string]interface{}{}
   179  		}
   180  	} else {
   181  		statDev.IterPeerEndpoint(func(s string) {
   182  			runCmd("route", "add", s, "mask", "255.255.255.255", "0.0.0.0", "IF", strconv.Itoa(device.GetDeviceIndex()))
   183  		})
   184  	}
   185  	b, err := json.Marshal(resp)
   186  	if err != nil {
   187  		resp.Code = -2
   188  		resp.Msg = err.Error()
   189  		resp.Data = map[string]interface{}{}
   190  	}
   191  	return c.JSONBlob(http.StatusOK, b)
   192  }
   193  
   194  func stopStatServer(logger *device.Logger) {
   195  	timeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   196  	defer cancel()
   197  	defer statWg.Wait()
   198  	// 关闭 HTTP 服务器
   199  	if err := statSrv.Shutdown(timeoutCtx); err != nil {
   200  		logger.Errorf("Stat server shutdown error: %v", err)
   201  	}
   202  }
   203  
   204  func runCmd(args ...string) error {
   205  	cmd := exec.Command(args[0], args[1:]...)
   206  	cmd.Stderr = os.Stderr
   207  	cmd.Stdout = os.Stdout
   208  	cmd.Stdin = os.Stdin
   209  	err := cmd.Run()
   210  	return err
   211  }