github.com/siglens/siglens@v0.0.0-20240328180423-f7ce9ae441ed/pkg/ast/pipesearch/wsSearchHandler.go (about)

     1  /*
     2  Copyright 2023.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package pipesearch
    18  
    19  import (
    20  	"fmt"
    21  	"math"
    22  	"time"
    23  
    24  	"github.com/dustin/go-humanize"
    25  	"github.com/fasthttp/websocket"
    26  	rutils "github.com/siglens/siglens/pkg/readerUtils"
    27  	"github.com/siglens/siglens/pkg/segment"
    28  	"github.com/siglens/siglens/pkg/segment/query"
    29  	"github.com/siglens/siglens/pkg/segment/structs"
    30  	segutils "github.com/siglens/siglens/pkg/segment/utils"
    31  	"github.com/siglens/siglens/pkg/utils"
    32  	log "github.com/sirupsen/logrus"
    33  	"github.com/valyala/fasthttp"
    34  )
    35  
    36  func ProcessPipeSearchWebsocket(conn *websocket.Conn, orgid uint64, ctx *fasthttp.RequestCtx) {
    37  
    38  	qid := rutils.GetNextQid()
    39  	event, err := readInitialEvent(qid, conn)
    40  	defer utils.DeferableAddAccessLogEntry(
    41  		time.Now(),
    42  		func() time.Time { return time.Now() },
    43  		"No-user", // TODO : Add logged in user when user auth is implemented
    44  		ctx.Request.URI().String(),
    45  		fmt.Sprintf("%+v", event),
    46  		func() int { return ctx.Response.StatusCode() },
    47  		true, // Log this even though it's a websocket connection
    48  		"access.log",
    49  	)
    50  
    51  	if err != nil {
    52  		log.Errorf("qid=%d, ProcessPipeSearchWebsocket: Failed to read initial event %+v!", qid, err)
    53  		wErr := conn.WriteJSON(createErrorResponse(err.Error()))
    54  		if wErr != nil {
    55  			log.Errorf("qid=%d, ProcessPipeSearchWebsocket: failed to write error response to websocket! %+v", qid, wErr)
    56  		}
    57  		return
    58  	}
    59  	eventState, ok := event["state"]
    60  	if !ok {
    61  		log.Errorf("qid=%d, ProcessPipeSearchWebsocket: first request does not have 'state' as a key!", qid)
    62  		wErr := conn.WriteJSON(createErrorResponse("request missing required key 'state'"))
    63  		if wErr != nil {
    64  			log.Errorf("qid=%d, ProcessPipeSearchWebsocket: failed to write error response to websocket! %+v", qid, wErr)
    65  		}
    66  		return
    67  	}
    68  	if eventState != "query" {
    69  		log.Errorf("qid=%d, ProcessPipeSearchWebsocket: first request is not a query 'state'!", qid)
    70  		wErr := conn.WriteJSON(createErrorResponse("first request should have 'state':'query'"))
    71  		if wErr != nil {
    72  			log.Errorf("qid=%d, ProcessPipeSearchWebsocket: failed to write error response to websocket! %+v", qid, wErr)
    73  		}
    74  		return
    75  	}
    76  
    77  	nowTs := utils.GetCurrentTimeInMs()
    78  	searchText, startEpoch, endEpoch, sizeLimit, indexNameIn, scrollFrom := ParseSearchBody(event, nowTs)
    79  
    80  	if scrollFrom > 10_000 {
    81  		processMaxScrollComplete(conn, qid)
    82  		return
    83  	}
    84  
    85  	ti := structs.InitTableInfo(indexNameIn, orgid, false)
    86  	log.Infof("qid=%v, ProcessPipeSearchWebsocket: index=[%v] searchString=[%v] scrollFrom=[%v]",
    87  		qid, ti.String(), searchText, scrollFrom)
    88  
    89  	queryLanguageType := event["queryLanguage"]
    90  	var simpleNode *structs.ASTNode
    91  	var aggs *structs.QueryAggregators
    92  
    93  	if queryLanguageType == "SQL" {
    94  		simpleNode, aggs, err = ParseRequest(searchText, startEpoch, endEpoch, qid, "SQL", indexNameIn)
    95  	} else if queryLanguageType == "Pipe QL" {
    96  		simpleNode, aggs, err = ParseRequest(searchText, startEpoch, endEpoch, qid, "Pipe QL", indexNameIn)
    97  	} else if queryLanguageType == "Log QL" {
    98  		simpleNode, aggs, err = ParseRequest(searchText, startEpoch, endEpoch, qid, "Log QL", indexNameIn)
    99  	} else if queryLanguageType == "Splunk QL" {
   100  		simpleNode, aggs, err = ParseRequest(searchText, startEpoch, endEpoch, qid, "Splunk QL", indexNameIn)
   101  	} else {
   102  		log.Infof("ProcessPipeSearchWebsocket: unknown queryLanguageType: %v; using Pipe QL instead", queryLanguageType)
   103  		simpleNode, aggs, err = ParseRequest(searchText, startEpoch, endEpoch, qid, "Pipe QL", indexNameIn)
   104  	}
   105  
   106  	if err != nil {
   107  		log.Errorf("qid=%d, ProcessPipeSearchWebsocket: failed to parse query err=%v", qid, err)
   108  		wErr := conn.WriteJSON(createErrorResponse(err.Error()))
   109  		if wErr != nil {
   110  			log.Errorf("qid=%d, ProcessPipeSearchWebsocket: failed to write error response to websocket! %+v", qid, wErr)
   111  		}
   112  		return
   113  	}
   114  
   115  	if queryLanguageType == "SQL" && aggs != nil && aggs.TableName != "*" {
   116  		indexNameIn = aggs.TableName
   117  		ti = structs.InitTableInfo(indexNameIn, orgid, false) // Re-initialize ti with the updated indexNameIn
   118  	}
   119  
   120  	if aggs != nil && (aggs.GroupByRequest != nil || aggs.MeasureOperations != nil) {
   121  		sizeLimit = 0
   122  	} else if aggs.HasDedupBlockInChain() || aggs.HasSortBlockInChain() || aggs.HasRexBlockInChainWithStats() || aggs.HasTransactionArgumentsInChain() {
   123  		// 1. Dedup needs state information about the previous records, so we can
   124  		// run into an issue if we show some records, then the user scrolls
   125  		// down to see more and we run dedup on just the new records and add
   126  		// them to the existing ones. To get around this, we can run the query
   127  		// on all of the records initially so that scrolling down doesn't cause
   128  		// another query to run.
   129  		// 2. Sort cmd is similar to Dedup cmd; we need to process all the records at once and extract those with top/rare priority based on requirements.
   130  		// 3. If there's a Rex block in the chain followed by a Stats block, we need to
   131  		// see all the matched records before we apply or calculate the stats.
   132  		sizeLimit = math.MaxUint64
   133  	}
   134  
   135  	// If MaxRows is used to limit the number of returned results, set `sizeLimit`
   136  	// to it. Currently MaxRows is only valid as the root QueryAggregators.
   137  	if aggs != nil && aggs.Limit != 0 {
   138  		sizeLimit = uint64(aggs.Limit)
   139  	}
   140  
   141  	qc := structs.InitQueryContextWithTableInfo(ti, sizeLimit, scrollFrom, orgid, false)
   142  	eventC, err := segment.ExecuteAsyncQuery(simpleNode, aggs, qid, qc)
   143  	if err != nil {
   144  		log.Errorf("qid=%d, ProcessPipeSearchWebsocket: failed to execute query err=%v", qid, err)
   145  		wErr := conn.WriteJSON(createErrorResponse(err.Error()))
   146  		if wErr != nil {
   147  			log.Errorf("qid=%d, ProcessPipeSearchWebsocket: failed to write error response to websocket! %+v", qid, wErr)
   148  		}
   149  		return
   150  	}
   151  	websocketR := make(chan map[string]interface{})
   152  	go listenToConnection(qid, websocketR, conn)
   153  	for {
   154  		select {
   155  		case qscd, ok := <-eventC:
   156  			switch qscd.StateName {
   157  			case query.RUNNING:
   158  				processRunningUpdate(conn, qid)
   159  			case query.QUERY_UPDATE:
   160  				processQueryUpdate(conn, qid, sizeLimit, scrollFrom, qscd, aggs)
   161  			case query.TIMEOUT:
   162  				processTimeoutUpdate(conn, qid)
   163  				return
   164  			case query.COMPLETE:
   165  				processCompleteUpdate(conn, sizeLimit, qid, aggs)
   166  				query.DeleteQuery(qid)
   167  				return
   168  			default:
   169  				log.Errorf("qid=%v, Got unknown state %v", qid, qscd.StateName)
   170  			}
   171  			if !ok {
   172  				log.Errorf("qid=%v, ProcessPipeSearchWebsocket: Got non ok, state: %v", qid, qscd.StateName)
   173  				return
   174  			}
   175  		case readMsg := <-websocketR:
   176  			log.Infof("qid=%d, Got message from websocket %+v", qid, readMsg)
   177  			if readMsg["state"] == "cancel" {
   178  				query.CancelQuery(qid)
   179  			}
   180  		}
   181  	}
   182  }
   183  
   184  func listenToConnection(qid uint64, e chan map[string]interface{}, conn *websocket.Conn) {
   185  	for {
   186  		readEvent := make(map[string]interface{})
   187  		err := conn.ReadJSON(&readEvent)
   188  		if err != nil {
   189  			if websocket.IsUnexpectedCloseError(err,
   190  				websocket.CloseGoingAway, websocket.CloseNormalClosure) {
   191  				log.Errorf("qid=%d, listenToConnection unexpected error: %+v", qid, err.Error())
   192  			}
   193  			cancelEvent := map[string]interface{}{"state": "cancel", "message": "websocket connection is closed"}
   194  			e <- cancelEvent
   195  			return
   196  		}
   197  		e <- readEvent
   198  	}
   199  }
   200  
   201  func readInitialEvent(qid uint64, conn *websocket.Conn) (map[string]interface{}, error) {
   202  	readEvent := make(map[string]interface{})
   203  	err := conn.ReadJSON(&readEvent)
   204  	if err != nil {
   205  		log.Errorf("qid=%d, readInitialEvent: Failed to read initial event from websocket!: %v", qid, err)
   206  		return readEvent, err
   207  	}
   208  
   209  	log.Infof("qid=%d, Read initial event from websocket %+v", qid, readEvent)
   210  	return readEvent, nil
   211  }
   212  
   213  func createErrorResponse(errMsg string) map[string]interface{} {
   214  	e := map[string]interface{}{
   215  		"state":   "error",
   216  		"message": errMsg,
   217  	}
   218  	return e
   219  }
   220  
   221  func processTimeoutUpdate(conn *websocket.Conn, qid uint64) {
   222  	e := map[string]interface{}{
   223  		"state":          query.TIMEOUT.String(),
   224  		"qid":            qid,
   225  		"timeoutSeconds": fmt.Sprintf("%v", query.CANCEL_QUERY_AFTER_SECONDS),
   226  	}
   227  	err := conn.WriteJSON(e)
   228  	if err != nil {
   229  		log.Errorf("qid=%d, processTimeoutUpdate: failed to write to websocket! %+v", qid, err)
   230  	}
   231  }
   232  
   233  func processRunningUpdate(conn *websocket.Conn, qid uint64) {
   234  
   235  	e := map[string]interface{}{
   236  		"state": query.RUNNING.String(),
   237  		"qid":   qid,
   238  	}
   239  	wErr := conn.WriteJSON(e)
   240  	if wErr != nil {
   241  		log.Errorf("qid=%d, processRunningUpdate: failed to write error response to websocket! %+v", qid, wErr)
   242  	}
   243  }
   244  
   245  func processQueryUpdate(conn *websocket.Conn, qid uint64, sizeLimit uint64, scrollFrom int, qscd *query.QueryStateChanData,
   246  	aggs *structs.QueryAggregators) {
   247  	searchPercent := qscd.PercentComplete
   248  	totalEventsSearched, err := query.GetTotalsRecsSearchedForQid(qid)
   249  	if err != nil {
   250  		log.Errorf("qid=%d, processQueryUpdate: failed to get total records searched: %+v", qid, err)
   251  		wErr := conn.WriteJSON(createErrorResponse(err.Error()))
   252  		if wErr != nil {
   253  			log.Errorf("qid=%d, processQueryUpdate: failed to write error response to websocket! %+v", qid, wErr)
   254  		}
   255  		return
   256  	}
   257  
   258  	var wsResponse *PipeSearchWSUpdateResponse
   259  	if qscd.QueryUpdate == nil {
   260  		log.Errorf("qid=%d, processQueryUpdate: got nil query update!", qid)
   261  		wErr := conn.WriteJSON(createErrorResponse(err.Error()))
   262  		if wErr != nil {
   263  			log.Errorf("qid=%d, processQueryUpdate: failed to write RRC response to websocket! %+v", qid, wErr)
   264  		}
   265  		return
   266  	}
   267  
   268  	wsResponse, err = createRecsWsResp(qid, sizeLimit, searchPercent, scrollFrom, totalEventsSearched, qscd.QueryUpdate, aggs)
   269  	if err != nil {
   270  		wErr := conn.WriteJSON(createErrorResponse(err.Error()))
   271  		if wErr != nil {
   272  			log.Errorf("qid=%d, processQueryUpdate: failed to write RRC response to websocket! %+v", qid, wErr)
   273  		}
   274  		return
   275  	}
   276  
   277  	wErr := conn.WriteJSON(wsResponse)
   278  	if wErr != nil {
   279  		log.Errorf("qid=%d, processQueryUpdate: failed to write update response to websocket! %+v", qid, wErr)
   280  	}
   281  }
   282  
   283  func processCompleteUpdate(conn *websocket.Conn, sizeLimit, qid uint64, aggs *structs.QueryAggregators) {
   284  	queryC := query.GetQueryCountInfoForQid(qid)
   285  	totalEventsSearched, err := query.GetTotalsRecsSearchedForQid(qid)
   286  	if err != nil {
   287  		log.Errorf("qid=%d, processCompleteUpdate: failed to get total records searched: %+v", qid, err)
   288  	}
   289  	numRRCs, err := query.GetNumMatchedRRCs(qid)
   290  	if err != nil {
   291  		log.Errorf("qid=%d, processCompleteUpdate: failed to get number of RRCs for qid! Error: %v", qid, err)
   292  	}
   293  
   294  	aggMeasureRes, aggMeasureFunctions, aggGroupByCols, bucketCount := query.GetMeasureResultsForQid(qid, true, 0, aggs.BucketLimit) //aggs.BucketLimit
   295  
   296  	var canScrollMore bool
   297  	if numRRCs == sizeLimit {
   298  		// if the number of RRCs is exactly equal to the requested size, there may be more than size matches. Hence, we can scroll more
   299  		canScrollMore = true
   300  	}
   301  	queryType := query.GetQueryType(qid)
   302  	resp := &PipeSearchCompleteResponse{
   303  		TotalMatched:        convertQueryCountToTotalResponse(queryC),
   304  		State:               query.COMPLETE.String(),
   305  		TotalEventsSearched: humanize.Comma(int64(totalEventsSearched)),
   306  		CanScrollMore:       canScrollMore,
   307  		TotalRRCCount:       numRRCs,
   308  		MeasureResults:      aggMeasureRes,
   309  		MeasureFunctions:    aggMeasureFunctions,
   310  		GroupByCols:         aggGroupByCols,
   311  		Qtype:               queryType.String(),
   312  		BucketCount:         bucketCount,
   313  		IsTimechart:         aggs.UsedByTimechart(),
   314  	}
   315  	searchErrors, err := query.GetUniqueSearchErrors(qid)
   316  	if err != nil {
   317  		log.Errorf("qid=%d, processCompleteUpdate: failed to get search Errors for qid! Error: %v", qid, err)
   318  	} else if searchErrors == "" {
   319  		wErr := conn.WriteJSON(resp)
   320  		if wErr != nil {
   321  			log.Errorf("qid=%d, processCompleteUpdate: failed to write complete response to websocket! %+v", qid, wErr)
   322  		}
   323  	} else {
   324  		wErr := conn.WriteJSON(createErrorResponse(searchErrors))
   325  		if wErr != nil {
   326  			log.Errorf("qid=%d, processCompleteUpdate: failed to write error response to websocket! %+v", qid, wErr)
   327  		}
   328  	}
   329  }
   330  
   331  func processMaxScrollComplete(conn *websocket.Conn, qid uint64) {
   332  	resp := &PipeSearchCompleteResponse{
   333  		CanScrollMore: false,
   334  	}
   335  	qType := query.GetQueryType(qid)
   336  	resp.Qtype = qType.String()
   337  	wErr := conn.WriteJSON(resp)
   338  	if wErr != nil {
   339  		log.Errorf("qid=%d, processMaxScrollComplete: failed to write complete response to websocket! %+v", qid, wErr)
   340  	}
   341  }
   342  
   343  func createRecsWsResp(qid uint64, sizeLimit uint64, searchPercent float64, scrollFrom int,
   344  	totalEventsSearched uint64, qUpdate *query.QueryUpdate, aggs *structs.QueryAggregators) (*PipeSearchWSUpdateResponse, error) {
   345  
   346  	qType := query.GetQueryType(qid)
   347  	wsResponse := &PipeSearchWSUpdateResponse{
   348  		Completion:               searchPercent,
   349  		State:                    query.QUERY_UPDATE.String(),
   350  		TotalEventsSearched:      humanize.Comma(int64(totalEventsSearched)),
   351  		Qtype:                    qType.String(),
   352  		SortByTimestampAtDefault: !aggs.HasSortBlockInChain(),
   353  	}
   354  
   355  	switch qType {
   356  	case structs.SegmentStatsCmd, structs.GroupByCmd:
   357  		if aggs.Next == nil { // We'll do chained aggs after all segments are searched.
   358  			var doPull bool
   359  			if qUpdate.RemoteID != "" {
   360  				doPull = true
   361  			}
   362  			aggMeasureRes, aggMeasureFunctions, aggGroupByCols, bucketCount := query.GetMeasureResultsForQid(qid, doPull, qUpdate.SegKeyEnc, aggs.BucketLimit)
   363  			wsResponse.MeasureResults = aggMeasureRes
   364  			wsResponse.MeasureFunctions = aggMeasureFunctions
   365  			wsResponse.GroupByCols = aggGroupByCols
   366  			wsResponse.Qtype = qType.String()
   367  			wsResponse.BucketCount = bucketCount
   368  		}
   369  	case structs.RRCCmd:
   370  		useAnySegKey := false
   371  		if aggs.OutputTransforms != nil && aggs.OutputTransforms.MaxRows != 0 {
   372  			// For only getting MaxRows rows, don't show any rows until the
   373  			// search has completed (so that we don't show a row and later in
   374  			// the search find out another row has higher priority and the row
   375  			// we displayed is no longer in the top MaxRows rows.)
   376  			if searchPercent < 100 {
   377  				break
   378  			}
   379  
   380  			sizeLimit = uint64(aggs.OutputTransforms.MaxRows)
   381  
   382  			useAnySegKey = true
   383  		}
   384  
   385  		inrrcs, qc, segencmap, err := query.GetRawRecordInfoForQid(scrollFrom, qid)
   386  		if err != nil {
   387  			log.Errorf("qid=%d, createRecsWsResp: failed to get rrcs %v", qid, err)
   388  			return nil, err
   389  		}
   390  
   391  		// filter out the rrcs that don't match the segkey
   392  		var allJson []map[string]interface{}
   393  		var allCols []string
   394  		if qUpdate.QUpdate == query.QUERY_UPDATE_REMOTE {
   395  			// handle remote
   396  			allJson, allCols, err = query.GetRemoteRawLogInfo(qUpdate.RemoteID, inrrcs, qid)
   397  			if err != nil {
   398  				log.Errorf("qid=%d, createRecsWsResp: failed to get remote raw logs and columns: %+v", qid, err)
   399  				return nil, err
   400  			}
   401  		} else {
   402  			// handle local
   403  			allJson, allCols, err = getRawLogsAndColumns(inrrcs, qUpdate.SegKeyEnc, useAnySegKey, sizeLimit, segencmap, aggs, qid)
   404  			if err != nil {
   405  				log.Errorf("qid=%d, createRecsWsResp: failed to get raw logs and columns: %+v", qid, err)
   406  				return nil, err
   407  			}
   408  		}
   409  		if err != nil {
   410  			log.Errorf("qid=%d, createRecsWsResp: failed to convert rrcs to json: %+v", qid, err)
   411  			return nil, err
   412  		}
   413  
   414  		wsResponse.Hits = PipeSearchResponse{
   415  			Hits:         allJson,
   416  			TotalMatched: qc,
   417  		}
   418  		wsResponse.AllPossibleColumns = allCols
   419  		wsResponse.Qtype = qType.String()
   420  	}
   421  	return wsResponse, nil
   422  }
   423  
   424  func getRawLogsAndColumns(inrrcs []*segutils.RecordResultContainer, skEnc uint16, anySegKey bool, sizeLimit uint64,
   425  	segencmap map[uint16]string, aggs *structs.QueryAggregators, qid uint64) ([]map[string]interface{}, []string, error) {
   426  	found := uint64(0)
   427  	rrcs := make([]*segutils.RecordResultContainer, len(inrrcs))
   428  	for i := 0; i < len(inrrcs); i++ {
   429  		if !inrrcs[i].SegKeyInfo.IsRemote && (anySegKey || inrrcs[i].SegKeyInfo.SegKeyEnc == skEnc) {
   430  			rrcs[found] = inrrcs[i]
   431  			found++
   432  		}
   433  	}
   434  	rrcs = rrcs[:found]
   435  	allJson, allCols, err := convertRRCsToJSONResponse(rrcs, sizeLimit, qid, segencmap, aggs)
   436  	if err != nil {
   437  		log.Errorf("qid=%d, getRawLogsAndColumns: failed to convert rrcs to json: %+v", qid, err)
   438  		return nil, nil, err
   439  	}
   440  	return allJson, allCols, nil
   441  }