gitee.com/aurawing/surguard-go@v0.3.1-0.20240409071558-96509a61ecf3/stat.go (about) 1 package main 2 3 import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "net/http" 9 "net/http/httputil" 10 "net/url" 11 "os" 12 "os/exec" 13 "strconv" 14 "strings" 15 "sync" 16 "time" 17 18 "gitee.com/aurawing/surguard-go/device" 19 "github.com/labstack/echo" 20 "github.com/labstack/echo/middleware" 21 "github.com/shirou/gopsutil/v3/net" 22 "golang.org/x/net/http2" 23 "golang.org/x/net/http2/h2c" 24 ) 25 26 type StatResp struct { 27 Code int `json:"code"` 28 Msg string `json:"msg"` 29 Data map[string]interface{} `json:"data"` 30 } 31 32 var statSrv http.Server 33 var statWg sync.WaitGroup 34 var statInterfaceName string 35 var statGroupName string 36 var statDev *device.Device 37 38 func startStatServer(interfaceName, portStr string, logger *device.Logger, dev *device.Device) error { 39 statWg.Add(1) 40 statDev = dev 41 groupNameStr := os.Getenv(device.ENV_SG_GROUP_NAME) 42 if groupNameStr == "" { 43 statGroupName = device.DEFAULT_GROUP_NAME 44 } else { 45 statGroupName = groupNameStr 46 } 47 if strings.TrimSpace(interfaceName) == "" { 48 logger.Errorf("interface name is empty") 49 return errors.New("interface name is empty") 50 } 51 statInterfaceName = interfaceName 52 port := int64(device.DEFAULT_LISTEN_PORT) 53 if strings.TrimSpace(portStr) != "" { 54 var err error 55 port, err = strconv.ParseInt(strings.TrimSpace(portStr), 10, 64) 56 if err != nil { 57 logger.Errorf("parse stat server port failed %v", err) 58 return err 59 } 60 } 61 server := echo.New() 62 server.Use(middleware.Logger()) 63 server.Use(middleware.Recover()) 64 server.Use(middleware.GzipWithConfig(middleware.GzipConfig{ 65 Level: 5, 66 })) 67 server.GET("/surguard/netstat", sgNetStat) 68 server.POST("/surguard/bypass", bypass) 69 server.POST("/surguard/disableBypass", disableBypass) 70 71 apiStatPort := port + 1 72 apiStatPortStr := os.Getenv("SG_STAT_PORT") 73 if apiStatPortStr != "" { 74 var err error 75 apiStatPort, err = strconv.ParseInt(strings.TrimSpace(apiStatPortStr), 10, 64) 76 if err != nil { 77 logger.Errorf("parse API stat server port failed %v", err) 78 return err 79 } 80 } 81 proxy := httputil.NewSingleHostReverseProxy(&url.URL{ 82 Scheme: "http", 83 Host: fmt.Sprintf("localhost:%d", apiStatPort), 84 }) 85 server.Any("/surwall/*", echo.WrapHandler(proxy)) 86 87 h2s := &http2.Server{} 88 statSrv = http.Server{ 89 Addr: fmt.Sprintf(":%d", port), 90 Handler: h2c.NewHandler(server, h2s), 91 //ReadTimeout: 30 * time.Second, // customize http.Server timeouts 92 } 93 go func() { 94 if err := statSrv.ListenAndServe(); err != http.ErrServerClosed { 95 logger.Errorf("stat server error: %v", err) 96 } 97 statWg.Done() 98 }() 99 return nil 100 } 101 102 func getNetStats(interfaceName string) (*net.IOCountersStat, error) { 103 netStats, err := net.IOCounters(true) 104 if err != nil { 105 return nil, err 106 } 107 108 for _, stats := range netStats { 109 if stats.Name == interfaceName { 110 return &stats, nil 111 } 112 } 113 114 return nil, fmt.Errorf("interface '%s' not found", interfaceName) 115 } 116 117 func sgNetStat(c echo.Context) error { 118 resp := new(StatResp) 119 stats, err := getNetStats(statInterfaceName) 120 if err != nil { 121 resp.Code = -1 122 resp.Msg = err.Error() 123 resp.Data = map[string]interface{}{} 124 } else { 125 resp.Code = 0 126 resp.Msg = "ok" 127 resp.Data = map[string]interface{}{ 128 "InterfaceName": stats.Name, 129 "BytesSent": stats.BytesSent, 130 "BytesReceived": stats.BytesRecv, 131 "PacketsSent": stats.PacketsSent, 132 "PacketsReceived": stats.PacketsRecv, 133 "ErrorsSent": stats.Errout, 134 "ErrorsReceived": stats.Errin, 135 "DropsSent": stats.Dropout, 136 "DropReceived": stats.Dropin, 137 } 138 } 139 b, err := json.Marshal(resp) 140 if err != nil { 141 resp.Code = -2 142 resp.Msg = err.Error() 143 resp.Data = map[string]interface{}{} 144 } 145 return c.JSONBlob(http.StatusOK, b) 146 } 147 148 func bypass(c echo.Context) error { 149 resp := new(StatResp) 150 if !device.IfBindInterface() { 151 err := runCmd("ip", "route", "flush", "table", statGroupName) 152 if err != nil { 153 resp.Code = -3 154 resp.Msg = err.Error() 155 resp.Data = map[string]interface{}{} 156 } 157 } else { 158 statDev.IterPeerEndpoint(func(s string) { 159 runCmd("route", "delete", s) 160 }) 161 } 162 b, err := json.Marshal(resp) 163 if err != nil { 164 resp.Code = -2 165 resp.Msg = err.Error() 166 resp.Data = map[string]interface{}{} 167 } 168 return c.JSONBlob(http.StatusOK, b) 169 } 170 171 func disableBypass(c echo.Context) error { 172 resp := new(StatResp) 173 if !device.IfBindInterface() { 174 err := runCmd("ip", "route", "add", "default", "dev", statInterfaceName, "table", statGroupName) 175 if err != nil { 176 resp.Code = -4 177 resp.Msg = err.Error() 178 resp.Data = map[string]interface{}{} 179 } 180 } else { 181 statDev.IterPeerEndpoint(func(s string) { 182 runCmd("route", "add", s, "mask", "255.255.255.255", "0.0.0.0", "IF", strconv.Itoa(device.GetDeviceIndex())) 183 }) 184 } 185 b, err := json.Marshal(resp) 186 if err != nil { 187 resp.Code = -2 188 resp.Msg = err.Error() 189 resp.Data = map[string]interface{}{} 190 } 191 return c.JSONBlob(http.StatusOK, b) 192 } 193 194 func stopStatServer(logger *device.Logger) { 195 timeoutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 196 defer cancel() 197 defer statWg.Wait() 198 // 关闭 HTTP 服务器 199 if err := statSrv.Shutdown(timeoutCtx); err != nil { 200 logger.Errorf("Stat server shutdown error: %v", err) 201 } 202 } 203 204 func runCmd(args ...string) error { 205 cmd := exec.Command(args[0], args[1:]...) 206 cmd.Stderr = os.Stderr 207 cmd.Stdout = os.Stdout 208 cmd.Stdin = os.Stdin 209 err := cmd.Run() 210 return err 211 }