github.com/TeaOSLab/EdgeNode@v1.3.8/internal/nodes/api_stream.go (about)

     1  package nodes
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs"
     8  	"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
     9  	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
    10  	"github.com/TeaOSLab/EdgeNode/internal/caches"
    11  	"github.com/TeaOSLab/EdgeNode/internal/configs"
    12  	teaconst "github.com/TeaOSLab/EdgeNode/internal/const"
    13  	"github.com/TeaOSLab/EdgeNode/internal/errors"
    14  	"github.com/TeaOSLab/EdgeNode/internal/events"
    15  	"github.com/TeaOSLab/EdgeNode/internal/firewalls"
    16  	"github.com/TeaOSLab/EdgeNode/internal/goman"
    17  	"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
    18  	"github.com/TeaOSLab/EdgeNode/internal/rpc"
    19  	executils "github.com/TeaOSLab/EdgeNode/internal/utils/exec"
    20  	"github.com/iwind/TeaGo/Tea"
    21  	"github.com/iwind/TeaGo/maps"
    22  	"net/url"
    23  	"regexp"
    24  	"runtime"
    25  	"strconv"
    26  	"time"
    27  )
    28  
    29  type APIStream struct {
    30  	stream pb.NodeService_NodeStreamClient
    31  
    32  	isQuiting  bool
    33  	cancelFunc context.CancelFunc
    34  }
    35  
    36  func NewAPIStream() *APIStream {
    37  	return &APIStream{}
    38  }
    39  
    40  func (this *APIStream) Start() {
    41  	events.OnKey(events.EventQuit, this, func() {
    42  		this.isQuiting = true
    43  		if this.cancelFunc != nil {
    44  			this.cancelFunc()
    45  		}
    46  	})
    47  	for {
    48  		if this.isQuiting {
    49  			return
    50  		}
    51  		err := this.loop()
    52  		if err != nil {
    53  			if rpc.IsConnError(err) {
    54  				remotelogs.Debug("API_STREAM", err.Error())
    55  			} else {
    56  				remotelogs.Warn("API_STREAM", err.Error())
    57  			}
    58  			time.Sleep(10 * time.Second)
    59  			continue
    60  		}
    61  		time.Sleep(1 * time.Second)
    62  	}
    63  }
    64  
    65  func (this *APIStream) loop() error {
    66  	rpcClient, err := rpc.SharedRPC()
    67  	if err != nil {
    68  		return errors.Wrap(err)
    69  	}
    70  
    71  	ctx, cancelFunc := context.WithCancel(rpcClient.Context())
    72  	this.cancelFunc = cancelFunc
    73  
    74  	defer func() {
    75  		cancelFunc()
    76  	}()
    77  
    78  	nodeStream, err := rpcClient.NodeRPC.NodeStream(ctx)
    79  	if err != nil {
    80  		if this.isQuiting {
    81  			return nil
    82  		}
    83  		return err
    84  	}
    85  	this.stream = nodeStream
    86  
    87  	for {
    88  		if this.isQuiting {
    89  			remotelogs.Println("API_STREAM", "quit")
    90  			break
    91  		}
    92  
    93  		message, streamErr := nodeStream.Recv()
    94  		if streamErr != nil {
    95  			if this.isQuiting {
    96  				remotelogs.Println("API_STREAM", "quit")
    97  				return nil
    98  			}
    99  			return streamErr
   100  		}
   101  
   102  		// 处理消息
   103  		switch message.Code {
   104  		case messageconfigs.MessageCodeConnectedAPINode: // 连接API节点成功
   105  			err = this.handleConnectedAPINode(message)
   106  		case messageconfigs.MessageCodeWriteCache: // 写入缓存
   107  			err = this.handleWriteCache(message)
   108  		case messageconfigs.MessageCodeReadCache: // 读取缓存
   109  			err = this.handleReadCache(message)
   110  		case messageconfigs.MessageCodeStatCache: // 统计缓存
   111  			err = this.handleStatCache(message)
   112  		case messageconfigs.MessageCodeCleanCache: // 清理缓存
   113  			err = this.handleCleanCache(message)
   114  		case messageconfigs.MessageCodeNewNodeTask: // 有新的任务
   115  			err = this.handleNewNodeTask(message)
   116  		case messageconfigs.MessageCodeCheckSystemdService: // 检查Systemd服务
   117  			err = this.handleCheckSystemdService(message)
   118  		case messageconfigs.MessageCodeCheckLocalFirewall: // 检查本地防火墙
   119  			err = this.handleCheckLocalFirewall(message)
   120  		case messageconfigs.MessageCodeChangeAPINode: // 修改API节点地址
   121  			err = this.handleChangeAPINode(message)
   122  		default:
   123  			err = this.handleUnknownMessage(message)
   124  		}
   125  		if err != nil {
   126  			remotelogs.Error("API_STREAM", "handle message failed: "+err.Error())
   127  		}
   128  	}
   129  
   130  	return nil
   131  }
   132  
   133  // 连接API节点成功
   134  func (this *APIStream) handleConnectedAPINode(message *pb.NodeStreamMessage) error {
   135  	// 更改连接的APINode信息
   136  	if len(message.DataJSON) == 0 {
   137  		return nil
   138  	}
   139  	msg := &messageconfigs.ConnectedAPINodeMessage{}
   140  	err := json.Unmarshal(message.DataJSON, msg)
   141  	if err != nil {
   142  		return errors.Wrap(err)
   143  	}
   144  
   145  	_, err = rpc.SharedRPC()
   146  	if err != nil {
   147  		return errors.Wrap(err)
   148  	}
   149  
   150  	remotelogs.Println("API_STREAM", "connected to api node '"+strconv.FormatInt(msg.APINodeId, 10)+"'")
   151  
   152  	// 重新读取配置
   153  	if nodeConfigUpdatedAt == 0 {
   154  		select {
   155  		case nodeConfigChangedNotify <- true:
   156  		default:
   157  
   158  		}
   159  	}
   160  
   161  	return nil
   162  }
   163  
   164  // 写入缓存
   165  func (this *APIStream) handleWriteCache(message *pb.NodeStreamMessage) error {
   166  	msg := &messageconfigs.WriteCacheMessage{}
   167  	err := json.Unmarshal(message.DataJSON, msg)
   168  	if err != nil {
   169  		this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
   170  		return err
   171  	}
   172  
   173  	storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
   174  	if err != nil {
   175  		return err
   176  	}
   177  	if shouldStop {
   178  		defer func() {
   179  			storage.Stop()
   180  		}()
   181  	}
   182  
   183  	expiredAt := time.Now().Unix() + msg.LifeSeconds
   184  	writer, err := storage.OpenWriter(msg.Key, expiredAt, 200, -1, int64(len(msg.Value)), -1, false)
   185  	if err != nil {
   186  		this.replyFail(message.RequestId, "prepare writing failed: "+err.Error())
   187  		return err
   188  	}
   189  
   190  	// 写入一个空的Header
   191  	_, err = writer.WriteHeader([]byte(":"))
   192  	if err != nil {
   193  		this.replyFail(message.RequestId, "write failed: "+err.Error())
   194  		return err
   195  	}
   196  
   197  	// 写入数据
   198  	_, err = writer.Write(msg.Value)
   199  	if err != nil {
   200  		this.replyFail(message.RequestId, "write failed: "+err.Error())
   201  		return err
   202  	}
   203  
   204  	err = writer.Close()
   205  	if err != nil {
   206  		this.replyFail(message.RequestId, "write failed: "+err.Error())
   207  		return err
   208  	}
   209  	storage.AddToList(&caches.Item{
   210  		Type:       writer.ItemType(),
   211  		Key:        msg.Key,
   212  		ExpiresAt:  expiredAt,
   213  		HeaderSize: writer.HeaderSize(),
   214  		BodySize:   writer.BodySize(),
   215  	})
   216  
   217  	this.replyOk(message.RequestId, "write ok")
   218  
   219  	return nil
   220  }
   221  
   222  // 读取缓存
   223  func (this *APIStream) handleReadCache(message *pb.NodeStreamMessage) error {
   224  	msg := &messageconfigs.ReadCacheMessage{}
   225  	err := json.Unmarshal(message.DataJSON, msg)
   226  	if err != nil {
   227  		this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
   228  		return err
   229  	}
   230  
   231  	storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
   232  	if err != nil {
   233  		return err
   234  	}
   235  	if shouldStop {
   236  		defer func() {
   237  			storage.Stop()
   238  		}()
   239  	}
   240  
   241  	reader, err := storage.OpenReader(msg.Key, false, false)
   242  	if err != nil {
   243  		if err == caches.ErrNotFound {
   244  			this.replyFail(message.RequestId, "key not found")
   245  			return nil
   246  		}
   247  		this.replyFail(message.RequestId, "read key failed: "+err.Error())
   248  		return nil
   249  	}
   250  	defer func() {
   251  		_ = reader.Close()
   252  	}()
   253  
   254  	this.replyOk(message.RequestId, "value "+strconv.FormatInt(reader.BodySize(), 10)+" bytes")
   255  
   256  	return nil
   257  }
   258  
   259  // 统计缓存
   260  func (this *APIStream) handleStatCache(message *pb.NodeStreamMessage) error {
   261  	msg := &messageconfigs.ReadCacheMessage{}
   262  	err := json.Unmarshal(message.DataJSON, msg)
   263  	if err != nil {
   264  		this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
   265  		return err
   266  	}
   267  
   268  	storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
   269  	if err != nil {
   270  		return err
   271  	}
   272  	if shouldStop {
   273  		defer func() {
   274  			storage.Stop()
   275  		}()
   276  	}
   277  
   278  	stat, err := storage.Stat()
   279  	if err != nil {
   280  		this.replyFail(message.RequestId, "stat failed: "+err.Error())
   281  		return err
   282  	}
   283  
   284  	sizeFormat := ""
   285  	if stat.Size < (1 << 10) {
   286  		sizeFormat = strconv.FormatInt(stat.Size, 10) + " Bytes"
   287  	} else if stat.Size < (1 << 20) {
   288  		sizeFormat = fmt.Sprintf("%.2f KiB", float64(stat.Size)/(1<<10))
   289  	} else if stat.Size < (1 << 30) {
   290  		sizeFormat = fmt.Sprintf("%.2f MiB", float64(stat.Size)/(1<<20))
   291  	} else {
   292  		sizeFormat = fmt.Sprintf("%.2f GiB", float64(stat.Size)/(1<<30))
   293  	}
   294  	this.replyOk(message.RequestId, "size:"+sizeFormat+", count:"+strconv.Itoa(stat.Count))
   295  
   296  	return nil
   297  }
   298  
   299  // 清理缓存
   300  func (this *APIStream) handleCleanCache(message *pb.NodeStreamMessage) error {
   301  	msg := &messageconfigs.ReadCacheMessage{}
   302  	err := json.Unmarshal(message.DataJSON, msg)
   303  	if err != nil {
   304  		this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
   305  		return err
   306  	}
   307  
   308  	storage, shouldStop, err := this.cacheStorage(message, msg.CachePolicyJSON)
   309  	if err != nil {
   310  		return err
   311  	}
   312  	if shouldStop {
   313  		defer func() {
   314  			storage.Stop()
   315  		}()
   316  	}
   317  
   318  	err = storage.CleanAll()
   319  	if err != nil {
   320  		this.replyFail(message.RequestId, "clean cache failed: "+err.Error())
   321  		return err
   322  	}
   323  
   324  	this.replyOk(message.RequestId, "ok")
   325  
   326  	return nil
   327  }
   328  
   329  // 处理配置变化
   330  func (this *APIStream) handleNewNodeTask(message *pb.NodeStreamMessage) error {
   331  	select {
   332  	case nodeTaskNotify <- true:
   333  	default:
   334  
   335  	}
   336  	this.replyOk(message.RequestId, "ok")
   337  	return nil
   338  }
   339  
   340  // 检查Systemd服务
   341  func (this *APIStream) handleCheckSystemdService(message *pb.NodeStreamMessage) error {
   342  	systemctl, err := executils.LookPath("systemctl")
   343  	if err != nil {
   344  		this.replyFail(message.RequestId, "'systemctl' not found")
   345  		return nil
   346  	}
   347  	if len(systemctl) == 0 {
   348  		this.replyFail(message.RequestId, "'systemctl' not found")
   349  		return nil
   350  	}
   351  
   352  	var shortName = teaconst.SystemdServiceName
   353  	var cmd = executils.NewTimeoutCmd(10*time.Second, systemctl, "is-enabled", shortName)
   354  	cmd.WithStdout()
   355  	err = cmd.Run()
   356  	if err != nil {
   357  		this.replyFail(message.RequestId, "'systemctl' command error: "+err.Error())
   358  		return nil
   359  	}
   360  	if cmd.Stdout() == "enabled" {
   361  		this.replyOk(message.RequestId, "ok")
   362  	} else {
   363  		this.replyFail(message.RequestId, "not installed")
   364  	}
   365  	return nil
   366  }
   367  
   368  // 检查本地防火墙
   369  func (this *APIStream) handleCheckLocalFirewall(message *pb.NodeStreamMessage) error {
   370  	var dataMessage = &messageconfigs.CheckLocalFirewallMessage{}
   371  	err := json.Unmarshal(message.DataJSON, dataMessage)
   372  	if err != nil {
   373  		this.replyFail(message.RequestId, "decode message data failed: "+err.Error())
   374  		return nil
   375  	}
   376  
   377  	// nft
   378  	if dataMessage.Name == "nftables" {
   379  		if runtime.GOOS != "linux" {
   380  			this.replyFail(message.RequestId, "not Linux system")
   381  			return nil
   382  		}
   383  
   384  		nft, err := executils.LookPath("nft")
   385  		if err != nil {
   386  			this.replyFail(message.RequestId, "'nft' not found: "+err.Error())
   387  			return nil
   388  		}
   389  
   390  		var cmd = executils.NewTimeoutCmd(10*time.Second, nft, "--version")
   391  		cmd.WithStdout()
   392  		err = cmd.Run()
   393  		if err != nil {
   394  			this.replyFail(message.RequestId, "get version failed: "+err.Error())
   395  			return nil
   396  		}
   397  
   398  		var outputString = cmd.Stdout()
   399  		var versionMatches = regexp.MustCompile(`nftables v([\d.]+)`).FindStringSubmatch(outputString)
   400  		if len(versionMatches) <= 1 {
   401  			this.replyFail(message.RequestId, "can not get nft version")
   402  			return nil
   403  		}
   404  		var version = versionMatches[1]
   405  
   406  		var result = maps.Map{
   407  			"version": version,
   408  		}
   409  
   410  		var protectionConfig = sharedNodeConfig.DDoSProtection
   411  		err = firewalls.SharedDDoSProtectionManager.Apply(protectionConfig)
   412  		if err != nil {
   413  			this.replyFail(message.RequestId, dataMessage.Name+" was installed, but apply DDoS protection config failed: "+err.Error())
   414  		} else {
   415  			this.replyOk(message.RequestId, string(result.AsJSON()))
   416  		}
   417  	} else {
   418  		this.replyFail(message.RequestId, "invalid firewall name '"+dataMessage.Name+"'")
   419  	}
   420  
   421  	return nil
   422  }
   423  
   424  // 修改API地址
   425  func (this *APIStream) handleChangeAPINode(message *pb.NodeStreamMessage) error {
   426  	config, err := configs.LoadAPIConfig()
   427  	if err != nil {
   428  		this.replyFail(message.RequestId, "read config error: "+err.Error())
   429  		return nil
   430  	}
   431  
   432  	var messageData = &messageconfigs.ChangeAPINodeMessage{}
   433  	err = json.Unmarshal(message.DataJSON, messageData)
   434  	if err != nil {
   435  		this.replyFail(message.RequestId, "unmarshal message failed: "+err.Error())
   436  		return nil
   437  	}
   438  
   439  	_, err = url.Parse(messageData.Addr)
   440  	if err != nil {
   441  		this.replyFail(message.RequestId, "invalid new api node address: '"+messageData.Addr+"'")
   442  		return nil
   443  	}
   444  
   445  	config.RPCEndpoints = []string{messageData.Addr}
   446  
   447  	// 保存到文件
   448  	err = config.WriteFile(Tea.ConfigFile(configs.ConfigFileName))
   449  	if err != nil {
   450  		this.replyFail(message.RequestId, "save config file failed: "+err.Error())
   451  		return nil
   452  	}
   453  
   454  	this.replyOk(message.RequestId, "")
   455  
   456  	goman.New(func() {
   457  		// 延后生效,防止变更前的API无法读取到状态
   458  		time.Sleep(1 * time.Second)
   459  
   460  		rpcClient, err := rpc.SharedRPC()
   461  		if err != nil {
   462  			remotelogs.Error("API_STREAM", "change rpc endpoint to '"+
   463  				messageData.Addr+"' failed: "+err.Error())
   464  			return
   465  		}
   466  
   467  		rpcClient.Close()
   468  
   469  		err = rpcClient.UpdateConfig(config)
   470  		if err != nil {
   471  			remotelogs.Error("API_STREAM", "change rpc endpoint to '"+
   472  				messageData.Addr+"' failed: "+err.Error())
   473  			return
   474  		}
   475  
   476  		remotelogs.Println("API_STREAM", "change rpc endpoint to '"+
   477  			messageData.Addr+"' successfully")
   478  	})
   479  
   480  	return nil
   481  }
   482  
   483  // 处理未知消息
   484  func (this *APIStream) handleUnknownMessage(message *pb.NodeStreamMessage) error {
   485  	this.replyFail(message.RequestId, "unknown message code '"+message.Code+"'")
   486  	return nil
   487  }
   488  
   489  // 回复失败
   490  func (this *APIStream) replyFail(requestId int64, message string) {
   491  	_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: false, Message: message})
   492  }
   493  
   494  // 回复成功
   495  func (this *APIStream) replyOk(requestId int64, message string) {
   496  	_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message})
   497  }
   498  
   499  // 回复成功并包含数据
   500  func (this *APIStream) replyOkData(requestId int64, message string, dataJSON []byte) {
   501  	_ = this.stream.Send(&pb.NodeStreamMessage{RequestId: requestId, IsOk: true, Message: message, DataJSON: dataJSON})
   502  }
   503  
   504  // 获取缓存存取对象
   505  func (this *APIStream) cacheStorage(message *pb.NodeStreamMessage, cachePolicyJSON []byte) (storage caches.StorageInterface, shouldStop bool, err error) {
   506  	cachePolicy := &serverconfigs.HTTPCachePolicy{}
   507  	err = json.Unmarshal(cachePolicyJSON, cachePolicy)
   508  	if err != nil {
   509  		this.replyFail(message.RequestId, "decode cache policy config failed: "+err.Error())
   510  		return nil, false, err
   511  	}
   512  
   513  	storage = caches.SharedManager.FindStorageWithPolicy(cachePolicy.Id)
   514  	if storage == nil {
   515  		storage = caches.SharedManager.NewStorageWithPolicy(cachePolicy)
   516  		if storage == nil {
   517  			this.replyFail(message.RequestId, "invalid storage type '"+cachePolicy.Type+"'")
   518  			return nil, false, errors.New("invalid storage type '" + cachePolicy.Type + "'")
   519  		}
   520  		err = storage.Init()
   521  		if err != nil {
   522  			this.replyFail(message.RequestId, "storage init failed: "+err.Error())
   523  			return nil, false, err
   524  		}
   525  		shouldStop = true
   526  	}
   527  
   528  	return
   529  }