github.com/GuanceCloud/cliutils@v1.1.21/dialtesting/websocket.go (about)

     1  // Unless explicitly stated otherwise all files in this repository are licensed
     2  // under the MIT License.
     3  // This product includes software developed at Guance Cloud (https://www.guance.com/).
     4  // Copyright 2021-present Guance, Inc.
     5  
     6  package dialtesting
     7  
     8  import (
     9  	"context"
    10  	"crypto/tls"
    11  	"encoding/base64"
    12  	"encoding/json"
    13  	"fmt"
    14  	"net"
    15  	"net/http"
    16  	"net/url"
    17  	"strings"
    18  	"time"
    19  
    20  	"github.com/GuanceCloud/cliutils"
    21  	"github.com/gorilla/websocket"
    22  )
    23  
    24  type WebsocketResponseTime struct {
    25  	IsContainDNS bool   `json:"is_contain_dns"`
    26  	Target       string `json:"target"`
    27  
    28  	targetTime time.Duration
    29  }
    30  
    31  type WebsocketSuccess struct {
    32  	ResponseTime    []*WebsocketResponseTime    `json:"response_time,omitempty"`
    33  	ResponseMessage []*SuccessOption            `json:"response_message,omitempty"`
    34  	Header          map[string][]*SuccessOption `json:"header,omitempty"`
    35  }
    36  
    37  type WebsocketOptRequest struct {
    38  	Timeout string            `json:"timeout,omitempty"`
    39  	Headers map[string]string `json:"headers,omitempty"`
    40  }
    41  
    42  type WebsocketOptAuth struct {
    43  	// basic auth
    44  	Username string `json:"username,omitempty"`
    45  	Password string `json:"password,omitempty"`
    46  }
    47  
    48  type WebsocketAdvanceOption struct {
    49  	RequestOptions *WebsocketOptRequest `json:"request_options,omitempty"`
    50  	Auth           *WebsocketOptAuth    `json:"auth,omitempty"`
    51  }
    52  
    53  type WebsocketTask struct {
    54  	URL               string                  `json:"url"`
    55  	Message           string                  `json:"message"`
    56  	SuccessWhen       []*WebsocketSuccess     `json:"success_when"`
    57  	AdvanceOptions    *WebsocketAdvanceOption `json:"advance_options,omitempty"`
    58  	SuccessWhenLogic  string                  `json:"success_when_logic"`
    59  	ExternalID        string                  `json:"external_id"`
    60  	Name              string                  `json:"name"`
    61  	AK                string                  `json:"access_key"`
    62  	PostURL           string                  `json:"post_url"`
    63  	CurStatus         string                  `json:"status"`
    64  	Frequency         string                  `json:"frequency"`
    65  	Region            string                  `json:"region"`
    66  	OwnerExternalID   string                  `json:"owner_external_id"`
    67  	Tags              map[string]string       `json:"tags,omitempty"`
    68  	Labels            []string                `json:"labels,omitempty"`
    69  	UpdateTime        int64                   `json:"update_time,omitempty"`
    70  	WorkspaceLanguage string                  `json:"workspace_language,omitempty"`
    71  	TagsInfo          string                  `json:"tags_info,omitempty"`
    72  
    73  	reqCost         time.Duration
    74  	reqDNSCost      time.Duration
    75  	responseMessage string
    76  	resp            *http.Response
    77  	parsedURL       *url.URL
    78  	hostname        string
    79  	reqError        string
    80  	timeout         time.Duration
    81  	ticker          *time.Ticker
    82  }
    83  
    84  func (t *WebsocketTask) init(debug bool) error {
    85  	t.timeout = 30 * time.Second
    86  	if t.AdvanceOptions != nil {
    87  		if t.AdvanceOptions.RequestOptions != nil && len(t.AdvanceOptions.RequestOptions.Timeout) > 0 {
    88  			if timeout, err := time.ParseDuration(t.AdvanceOptions.RequestOptions.Timeout); err != nil {
    89  				return err
    90  			} else {
    91  				t.timeout = timeout
    92  			}
    93  		}
    94  	}
    95  
    96  	if !debug {
    97  		du, err := time.ParseDuration(t.Frequency)
    98  		if err != nil {
    99  			return err
   100  		}
   101  		if t.ticker != nil {
   102  			t.ticker.Stop()
   103  		}
   104  		t.ticker = time.NewTicker(du)
   105  	}
   106  
   107  	if strings.EqualFold(t.CurStatus, StatusStop) {
   108  		return nil
   109  	}
   110  
   111  	if len(t.SuccessWhen) == 0 {
   112  		return fmt.Errorf(`no any check rule`)
   113  	}
   114  
   115  	for _, checker := range t.SuccessWhen {
   116  		if checker.ResponseTime != nil {
   117  			for _, v := range checker.ResponseTime {
   118  				du, err := time.ParseDuration(v.Target)
   119  				if err != nil {
   120  					return err
   121  				}
   122  				v.targetTime = du
   123  			}
   124  		}
   125  
   126  		for _, vs := range checker.Header {
   127  			for _, v := range vs {
   128  				err := genReg(v)
   129  				if err != nil {
   130  					return err
   131  				}
   132  			}
   133  		}
   134  
   135  		for _, v := range checker.ResponseMessage {
   136  			err := genReg(v)
   137  			if err != nil {
   138  				return err
   139  			}
   140  		}
   141  	}
   142  
   143  	if parsedURL, err := url.Parse(t.URL); err != nil {
   144  		return err
   145  	} else {
   146  		if parsedURL.Port() == "" {
   147  			port := ""
   148  			if parsedURL.Scheme == "wss" {
   149  				port = "443"
   150  			} else if parsedURL.Scheme == "ws" {
   151  				port = "80"
   152  			}
   153  			parsedURL.Host = net.JoinHostPort(parsedURL.Host, port)
   154  		}
   155  		t.parsedURL = parsedURL
   156  		t.hostname = parsedURL.Hostname()
   157  	}
   158  
   159  	return nil
   160  }
   161  
   162  func (t *WebsocketTask) InitDebug() error {
   163  	return t.init(true)
   164  }
   165  
   166  func (t *WebsocketTask) Init() error {
   167  	return t.init(false)
   168  }
   169  
   170  func (t *WebsocketTask) Check() error {
   171  	if t.ExternalID == "" {
   172  		return fmt.Errorf("external ID missing")
   173  	}
   174  
   175  	if len(t.URL) == 0 {
   176  		return fmt.Errorf("URL should not be empty")
   177  	}
   178  
   179  	return t.Init()
   180  }
   181  
   182  func (t *WebsocketTask) CheckResult() (reasons []string, succFlag bool) {
   183  	for _, chk := range t.SuccessWhen {
   184  		// check response time
   185  		if chk.ResponseTime != nil {
   186  			for _, v := range chk.ResponseTime {
   187  				reqCost := t.reqCost
   188  				if v.IsContainDNS {
   189  					reqCost += t.reqDNSCost
   190  				}
   191  
   192  				if reqCost > v.targetTime && v.targetTime > 0 {
   193  					reasons = append(reasons,
   194  						fmt.Sprintf("response time(%v) larger than %v", reqCost, v.targetTime))
   195  				} else if v.targetTime > 0 {
   196  					succFlag = true
   197  				}
   198  			}
   199  		}
   200  
   201  		// check message
   202  		if chk.ResponseMessage != nil {
   203  			for _, v := range chk.ResponseMessage {
   204  				if err := v.check(t.responseMessage, "response message"); err != nil {
   205  					reasons = append(reasons, err.Error())
   206  				} else {
   207  					succFlag = true
   208  				}
   209  			}
   210  		}
   211  
   212  		// check header
   213  		if t.resp != nil {
   214  			for k, vs := range chk.Header {
   215  				for _, v := range vs {
   216  					if err := v.check(t.resp.Header.Get(k), fmt.Sprintf("Websocket header `%s'", k)); err != nil {
   217  						reasons = append(reasons, err.Error())
   218  					} else {
   219  						succFlag = true
   220  					}
   221  				}
   222  			}
   223  		}
   224  	}
   225  
   226  	return reasons, succFlag
   227  }
   228  
   229  func (t *WebsocketTask) GetResults() (tags map[string]string, fields map[string]interface{}) {
   230  	tags = map[string]string{
   231  		"name":   t.Name,
   232  		"url":    t.URL,
   233  		"status": "FAIL",
   234  		"proto":  "websocket",
   235  	}
   236  
   237  	responseTime := int64(t.reqCost+t.reqDNSCost) / 1000        // us
   238  	responseTimeWithDNS := int64(t.reqCost+t.reqDNSCost) / 1000 // us
   239  
   240  	fields = map[string]interface{}{
   241  		"response_time":          responseTime,
   242  		"response_time_with_dns": responseTimeWithDNS,
   243  		"response_message":       t.responseMessage,
   244  		"sent_message":           t.Message,
   245  		"success":                int64(-1),
   246  	}
   247  
   248  	for k, v := range t.Tags {
   249  		tags[k] = v
   250  	}
   251  
   252  	message := map[string]interface{}{}
   253  
   254  	reasons, succFlag := t.CheckResult()
   255  	if t.reqError != "" {
   256  		reasons = append(reasons, t.reqError)
   257  	}
   258  
   259  	switch t.SuccessWhenLogic {
   260  	case "or":
   261  		if succFlag && t.reqError == "" {
   262  			tags["status"] = "OK"
   263  			fields["success"] = int64(1)
   264  			message["response_time"] = responseTime
   265  		} else {
   266  			message[`fail_reason`] = strings.Join(reasons, `;`)
   267  			fields[`fail_reason`] = strings.Join(reasons, `;`)
   268  		}
   269  	default:
   270  		if len(reasons) != 0 {
   271  			message[`fail_reason`] = strings.Join(reasons, `;`)
   272  			fields[`fail_reason`] = strings.Join(reasons, `;`)
   273  		} else {
   274  			message["response_time"] = responseTime
   275  		}
   276  
   277  		if t.reqError == "" && len(reasons) == 0 {
   278  			tags["status"] = "OK"
   279  			fields["success"] = int64(1)
   280  		}
   281  	}
   282  
   283  	if v, ok := fields[`fail_reason`]; ok && len(v.(string)) != 0 && t.resp != nil {
   284  		message[`response_header`] = t.resp.Header
   285  	}
   286  
   287  	data, err := json.Marshal(message)
   288  	if err != nil {
   289  		fields[`message`] = err.Error()
   290  	}
   291  
   292  	if len(data) > MaxMsgSize {
   293  		fields[`message`] = string(data[:MaxMsgSize])
   294  	} else {
   295  		fields[`message`] = string(data)
   296  	}
   297  
   298  	return tags, fields
   299  }
   300  
   301  func (t *WebsocketTask) MetricName() string {
   302  	return `websocket_dial_testing`
   303  }
   304  
   305  func (t *WebsocketTask) Clear() {
   306  	t.reqCost = 0
   307  	t.reqError = ""
   308  }
   309  
   310  func (t *WebsocketTask) Run() error {
   311  	t.Clear()
   312  
   313  	ctx, cancel := context.WithTimeout(context.Background(), t.timeout)
   314  	defer cancel()
   315  
   316  	hostIP := net.ParseIP(t.hostname)
   317  
   318  	if hostIP == nil { // host name
   319  		start := time.Now()
   320  		if ips, err := net.LookupIP(t.hostname); err != nil {
   321  			t.reqError = err.Error()
   322  			return err
   323  		} else {
   324  			if len(ips) == 0 {
   325  				err := fmt.Errorf("invalid host: %s, found no ip record", t.hostname)
   326  				t.reqError = err.Error()
   327  				return err
   328  			} else {
   329  				t.reqDNSCost = time.Since(start)
   330  				hostIP = ips[0] // TODO: support mutiple ip for one host
   331  			}
   332  		}
   333  	}
   334  
   335  	header := t.getHeader()
   336  
   337  	if len(header.Get("Host")) == 0 {
   338  		// set default Host
   339  		header.Add("Host", t.hostname)
   340  	}
   341  
   342  	t.parsedURL.Host = net.JoinHostPort(hostIP.String(), t.parsedURL.Port())
   343  
   344  	if t.parsedURL.Scheme == "wss" {
   345  		websocket.DefaultDialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // nolint:gosec
   346  	}
   347  
   348  	start := time.Now()
   349  
   350  	c, resp, err := websocket.DefaultDialer.DialContext(ctx, t.parsedURL.String(), header)
   351  	if err != nil {
   352  		t.reqError = err.Error()
   353  		t.reqDNSCost = 0
   354  		return err
   355  	}
   356  
   357  	t.reqCost = time.Since(start)
   358  	defer func() {
   359  		if err := c.Close(); err != nil {
   360  			_ = err // pass
   361  		}
   362  	}()
   363  
   364  	t.resp = resp
   365  
   366  	t.getMessage(c)
   367  	return nil
   368  }
   369  
   370  func (t *WebsocketTask) getMessage(c *websocket.Conn) {
   371  	err := c.WriteMessage(websocket.TextMessage, []byte(t.Message))
   372  	if err != nil {
   373  		t.reqError = err.Error()
   374  		return
   375  	}
   376  
   377  	if _, message, err := c.ReadMessage(); err != nil {
   378  		t.reqError = err.Error()
   379  		return
   380  	} else {
   381  		t.responseMessage = string(message)
   382  	}
   383  
   384  	// close error ignore
   385  	_ = c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
   386  }
   387  
   388  func (t *WebsocketTask) getHeader() http.Header {
   389  	var header http.Header = make(http.Header)
   390  
   391  	if t.AdvanceOptions != nil {
   392  		if t.AdvanceOptions.RequestOptions != nil {
   393  			for k, v := range t.AdvanceOptions.RequestOptions.Headers {
   394  				header[k] = []string{v}
   395  			}
   396  
   397  			if t.AdvanceOptions.Auth != nil && len(t.AdvanceOptions.Auth.Username) > 0 && len(t.AdvanceOptions.Auth.Password) > 0 {
   398  				header["Authorization"] = []string{"Basic " + basicAuth(t.AdvanceOptions.Auth.Username, t.AdvanceOptions.Auth.Password)}
   399  			}
   400  		}
   401  	}
   402  
   403  	return header
   404  }
   405  
   406  func (t *WebsocketTask) Stop() error {
   407  	return nil
   408  }
   409  
   410  func (t *WebsocketTask) UpdateTimeUs() int64 {
   411  	return t.UpdateTime
   412  }
   413  
   414  func (t *WebsocketTask) ID() string {
   415  	if t.ExternalID == `` {
   416  		return cliutils.XID("dtst_")
   417  	}
   418  	return fmt.Sprintf("%s_%s", t.AK, t.ExternalID)
   419  }
   420  
   421  func (t *WebsocketTask) GetOwnerExternalID() string {
   422  	return t.OwnerExternalID
   423  }
   424  
   425  func (t *WebsocketTask) SetOwnerExternalID(exid string) {
   426  	t.OwnerExternalID = exid
   427  }
   428  
   429  func (t *WebsocketTask) SetRegionID(regionID string) {
   430  	t.Region = regionID
   431  }
   432  
   433  func (t *WebsocketTask) SetAk(ak string) {
   434  	t.AK = ak
   435  }
   436  
   437  func (t *WebsocketTask) SetStatus(status string) {
   438  	t.CurStatus = status
   439  }
   440  
   441  func (t *WebsocketTask) SetUpdateTime(ts int64) {
   442  	t.UpdateTime = ts
   443  }
   444  
   445  func (t *WebsocketTask) Status() string {
   446  	return t.CurStatus
   447  }
   448  
   449  func (t *WebsocketTask) Ticker() *time.Ticker {
   450  	return t.ticker
   451  }
   452  
   453  func (t *WebsocketTask) Class() string {
   454  	return ClassWebsocket
   455  }
   456  
   457  func (t *WebsocketTask) GetFrequency() string {
   458  	return t.Frequency
   459  }
   460  
   461  func (t *WebsocketTask) GetLineData() string {
   462  	return ""
   463  }
   464  
   465  func (t *WebsocketTask) RegionName() string {
   466  	return t.Region
   467  }
   468  
   469  func (t *WebsocketTask) PostURLStr() string {
   470  	return t.PostURL
   471  }
   472  
   473  func (t *WebsocketTask) AccessKey() string {
   474  	return t.AK
   475  }
   476  
   477  func (t *WebsocketTask) GetHostName() (string, error) {
   478  	return getHostName(t.URL)
   479  }
   480  
   481  func basicAuth(username, password string) string {
   482  	auth := username + ":" + password
   483  	return base64.StdEncoding.EncodeToString([]byte(auth))
   484  }
   485  
   486  func (t *WebsocketTask) GetWorkspaceLanguage() string {
   487  	if t.WorkspaceLanguage == "en" {
   488  		return "en"
   489  	}
   490  	return "zh"
   491  }
   492  
   493  func (t *WebsocketTask) GetTagsInfo() string {
   494  	return t.TagsInfo
   495  }