
     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   *
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    18  package main
    20  import (
    21  	"bytes"
    22  	"context"
    23  	"fmt"
    24  	"io"
    25  	"log"
    26  	"net/http"
    27  	"os"
    28  	"os/signal"
    29  	"strconv"
    30  	"strings"
    31  	"syscall"
    32  	"time"
    34  	""
    35  	""
    36  	""
    37  	""
    38  	""
    39  	""
    40  	""
    41  )
    43  var (
    44  	// ServerAddr 北极星服务端接入地址IP,默认为
    45  	ServerAddr = os.Getenv("SERVER_IP")
    46  	// GRPCPort 北极星服务端接入地址 GRPC 协议端口,默认为 8091
    47  	GRPCPort, _ = strconv.ParseInt(os.Getenv("SERVER_PORT"), 10, 64)
    48  	// HttpPort 北极星服务端接入地址 HTTP 协议端口,默认为 8090
    49  	HttpPort, _ = strconv.ParseInt(os.Getenv("SERVER_PORT"), 10, 64)
    50  	// RunMode 运行模式,内容为 VERIFY(验证模式)/BENCHMARK(压测模式)
    51  	RunMode = os.Getenv("RUN_MODE")
    52  	// Service 服务名
    53  	Service = os.Getenv("SERVICE")
    54  	// Namespace 命名空间
    55  	Namespace = os.Getenv("NAMESPACE")
    56  	// BasePort 端口起始
    57  	BasePort, _ = strconv.ParseInt(os.Getenv("BASE_PORT"), 10, 64)
    58  	// PortNum 单个 POD 注册多少个端口
    59  	PortNum, _ = strconv.ParseInt(os.Getenv("PORT_NUM"), 10, 64)
    60  	// BeatInterval 心跳默认周期, 单位为秒
    61  	BeatInterval, _ = strconv.ParseInt(os.Getenv("BEAT_INTERVAL"), 10, 64)
    62  	// CheckInterval 检查任务执行周期
    63  	CheckInterval, _ = time.ParseDuration(os.Getenv("CHECK_INTERVAL"))
    64  	// PodIP 实例注册 IP
    65  	PodIP = os.Getenv("POD_IP")
    66  	// metricsRegistry .
    67  	metricsRegistry = prometheus.NewRegistry()
    68  	// heartbeatCount 客户端心跳上报次数
    69  	heartbeatCount = prometheus.NewCounter(prometheus.CounterOpts{
    70  		Name: "client_beat_count",
    71  	})
    72  )
    74  const (
    75  	defaultSeverIP       = ""
    76  	defaultGrpcPort      = 8091
    77  	defaultHttpPort      = 8090
    78  	defaultBeatInterval  = 5
    79  	defaultBasePort      = 8080
    80  	defaultPortNum       = 1
    81  	defaultService       = "benchmark-heartbeat"
    82  	defaultNamesapce     = "benchmark"
    83  	metricsPort          = 9090
    84  	defaultCheckInterval = time.Minute
    85  )
    87  func setDefault() {
    88  	if ServerAddr == "" {
    89  		ServerAddr = defaultSeverIP
    90  	}
    91  	if GRPCPort == 0 {
    92  		GRPCPort = defaultGrpcPort
    93  	}
    94  	if HttpPort == 0 {
    95  		HttpPort = defaultHttpPort
    96  	}
    97  	if Service == "" {
    98  		Service = defaultService
    99  	}
   100  	if Namespace == "" {
   101  		Namespace = defaultNamesapce
   102  	}
   103  	if BasePort == 0 {
   104  		BasePort = defaultBasePort
   105  	}
   106  	if PortNum == 0 {
   107  		PortNum = 1
   108  	}
   109  	if CheckInterval == 0 {
   110  		CheckInterval = defaultCheckInterval
   111  	}
   112  	log.Printf("run_mode(%s)", RunMode)
   113  	log.Printf("server_addr(%s)", ServerAddr)
   114  	log.Printf("grpc_port(%d)", GRPCPort)
   115  	log.Printf("http_port(%d)", HttpPort)
   116  	log.Printf("namespace(%s)", Namespace)
   117  	log.Printf("service(%s)", Service)
   118  	log.Printf("base_port(%d)", BasePort)
   119  	log.Printf("port_num(%d)", PortNum)
   120  	log.Printf("check_interval(%v)", CheckInterval)
   121  }
   123  func setMetrics() {
   124  	_ = metricsRegistry.Register(heartbeatCount)
   125  }
   127  func main() {
   128  	setDefault()
   129  	setMetrics()
   130  	switch strings.ToLower(RunMode) {
   131  	case "verify":
   132  		go runVerifyMode()
   133  	case "benchmark":
   134  		go runBenchmarkMode()
   135  	default:
   136  		panic("unknown run mode, please export RUN_MODE=verify or RUN_MODE=benchmark")
   137  	}
   138  	go func() {
   139  		_ = http.ListenAndServe(fmt.Sprintf("", metricsPort),
   140  			promhttp.HandlerFor(metricsRegistry, promhttp.HandlerOpts{EnableOpenMetrics: true}))
   141  	}()
   142  	mainLoop()
   143  }
   145  func runVerifyMode() {
   146  	ticker := time.NewTicker(CheckInterval)
   148  	checkInstanceHealth := func() {
   149  		req := &service_manage.DiscoverRequest{
   150  			Type: service_manage.DiscoverRequest_INSTANCE,
   151  			Service: &service_manage.Service{
   152  				Namespace: wrapperspb.String(Namespace),
   153  				Name:      wrapperspb.String(Service),
   154  			},
   155  		}
   157  		marshaler := jsonpb.Marshaler{}
   158  		body, _ := marshaler.MarshalToString(req)
   159  		rsp, err := http.Post(fmt.Sprintf("http://%s:%d/v1/Discover", ServerAddr, HttpPort), "application/json", bytes.NewBufferString(body))
   160  		if err != nil {
   161  			log.Printf("[ERROR] send discover to server fail: %s", err.Error())
   162  			return
   163  		}
   165  		defer func() {
   166  			_ = rsp.Body.Close()
   167  		}()
   168  		data, _ := io.ReadAll(rsp.Body)
   169  		discoverRsp := &service_manage.DiscoverResponse{}
   170  		if err := jsonpb.Unmarshal(bytes.NewBuffer(data), discoverRsp); err != nil {
   171  			log.Printf("[ERROR] unmarshaler discover resp fail: %s", err.Error())
   172  			return
   173  		}
   174  		if discoverRsp.GetCode().GetValue() != uint32(model.Code_ExecuteSuccess) {
   175  			log.Printf("[ERROR] receive discover resp fail: %s", discoverRsp.GetInfo().GetValue())
   176  			return
   177  		}
   178  		// 检查实例健康状态
   179  	}
   181  	checkInstanceBeatTimestamp := func() {
   183  	}
   185  	for {
   186  		select {
   187  		case <-ticker.C:
   188  			checkInstanceHealth()
   189  			checkInstanceBeatTimestamp()
   190  		}
   191  	}
   192  }
   194  func runBenchmarkMode() {
   195  	conn, err := grpc.DialContext(context.Background(), fmt.Sprintf("%s:%d", ServerAddr, GRPCPort),
   196  		grpc.WithBlock(),
   197  		grpc.WithInsecure(),
   198  	)
   199  	if err != nil {
   200  		panic(err)
   201  	}
   203  	client := service_manage.NewPolarisGRPCClient(conn)
   205  	// 先注册
   206  	for i := 0; i < int(PortNum); i++ {
   207  		instance := &service_manage.Instance{
   208  			Namespace:         wrapperspb.String(Namespace),
   209  			Service:           wrapperspb.String(Service),
   210  			Host:              wrapperspb.String(PodIP),
   211  			Port:              wrapperspb.UInt32(uint32(int(BasePort) + i)),
   212  			EnableHealthCheck: wrapperspb.Bool(true),
   213  			HealthCheck: &service_manage.HealthCheck{
   214  				Type: service_manage.HealthCheck_HEARTBEAT,
   215  				Heartbeat: &service_manage.HeartbeatHealthCheck{
   216  					Ttl: wrapperspb.UInt32(uint32(BeatInterval)),
   217  				},
   218  			},
   219  		}
   221  		resp, err := client.RegisterInstance(context.Background(), instance)
   222  		if err != nil {
   223  			panic(err)
   224  		}
   225  		if resp.GetCode().GetValue() != uint32(model.Code_ExecuteSuccess) {
   226  			panic(resp.GetInfo().GetValue())
   227  		}
   228  		log.Printf("[INFO] instance register success id: %s", instance.GetId().GetValue())
   229  		instance.Id = resp.GetInstance().GetId()
   230  		go func(instance *service_manage.Instance) {
   231  			ticker := time.NewTicker(time.Duration(BeatInterval) * time.Second)
   232  			defer ticker.Stop()
   234  			for range ticker.C {
   235  				heartbeatCount.Inc()
   236  				resp, err := client.Heartbeat(context.Background(), instance)
   237  				if err != nil {
   238  					log.Printf("[ERROR] instance(%s) beat fail error: %s", instance.GetId().GetValue(), err.Error())
   239  				}
   240  				if resp.GetCode().GetValue() != uint32(model.Code_ExecuteSuccess) {
   241  					log.Printf("[ERROR] instance(%s) beat fail info: %s", instance.GetId().GetValue(), resp.GetInfo().GetValue())
   242  				}
   243  			}
   244  		}(instance)
   245  	}
   247  	// 发起任务开始定期 心跳上报
   248  }
   250  // mainLoop 等待信号量执行退出
   251  func mainLoop() {
   252  	ch := make(chan os.Signal, 1)
   254  	// 监听信号量
   255  	signal.Notify(ch, []os.Signal{
   256  		syscall.SIGINT, syscall.SIGTERM,
   257  		syscall.SIGSEGV, syscall.SIGUSR1, syscall.SIGUSR2,
   258  	}...)
   260  	for {
   261  		select {
   262  		case <-ch:
   263  			log.Printf("[INFO] catch signal, stop benchmark server")
   264  			return
   265  		}
   266  	}
   267  }