gitee.com/h79/goutils@v1.22.10/dao/es/es.go (about)

     1  package es
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"github.com/olivere/elastic/v7"
     9  )
    10  
    11  func NotFoundErr(err error) bool {
    12  	if elastic.IsNotFound(err) {
    13  		return true
    14  	}
    15  	e, ok := err.(*elastic.Error)
    16  	if !ok {
    17  		return false
    18  	}
    19  	return e.Details.Reason == "all shards failed"
    20  }
    21  
    22  func defaultMapper[T any](sh *elastic.SearchHit) (t *T, err error) {
    23  	err = json.Unmarshal(sh.Source, &t)
    24  	if err != nil {
    25  		return nil, err
    26  	}
    27  	return t, err
    28  }
    29  
    30  // Search ES搜索
    31  func Search(ctx context.Context, client *elastic.Client, filter *Query, valueFunc func(*elastic.SearchHit)) (int64, error) {
    32  	service, err := BuildQuery(ctx, client, filter)
    33  	if err != nil {
    34  		return 0, err
    35  	}
    36  	if len(filter.Index) == 0 {
    37  		return 0, fmt.Errorf("索引不能为空")
    38  	}
    39  	resp, err := service.Do(ctx)
    40  	if err != nil {
    41  		if NotFoundErr(err) {
    42  			return 0, nil
    43  		}
    44  		return 0, err
    45  	}
    46  
    47  	if resp.TotalHits() == 0 {
    48  		return 0, nil
    49  	}
    50  	for _, hit := range resp.Hits.Hits {
    51  		valueFunc(hit)
    52  	}
    53  	return resp.TotalHits(), nil
    54  }
    55  
    56  // SearchAny ES搜索
    57  func SearchAny[T any](ctx context.Context, client *elastic.Client, filter *Query) ([]*T, int64, error) {
    58  	var list []*T
    59  	total, err := Search(ctx, client, filter, func(eh *elastic.SearchHit) {
    60  		v, err := defaultMapper[T](eh)
    61  		if err != nil {
    62  			return
    63  		}
    64  		list = append(list, v)
    65  	})
    66  	return list, total, err
    67  }
    68  
    69  func BatchAdd[T any](ctx context.Context, client *elastic.Client, index string, datas map[string]T, refresh string) error {
    70  	var bulk = client.Bulk().Index(index)
    71  	for k, v := range datas {
    72  		bulk.Add(elastic.NewBulkCreateRequest().Id(k).Doc(v))
    73  	}
    74  	_, err := bulk.Refresh(refresh).Do(ctx)
    75  	return err
    76  }
    77  
    78  func BuildQuery(ctx context.Context, client *elastic.Client, filter *Query) (*elastic.SearchService, error) {
    79  	boolQuery := elastic.NewBoolQuery()
    80  	boolQuery.Must(filter.MustQuery...)
    81  	boolQuery.MustNot(filter.MustNotQuery...)
    82  	boolQuery.Should(filter.ShouldQuery...)
    83  	boolQuery.Filter(filter.Filters...)
    84  
    85  	// 当should不为空时,保证至少匹配should中的一项
    86  	if len(filter.MustQuery) == 0 && len(filter.MustNotQuery) == 0 && len(filter.ShouldQuery) > 0 {
    87  		boolQuery.MinimumShouldMatch("1")
    88  	}
    89  	service := client.Search().Index(filter.Index).Query(boolQuery)
    90  	if len(filter.Sorters) > 0 {
    91  		service = service.SortBy(filter.Sorters...)
    92  	}
    93  	if filter.Size > 0 {
    94  		if filter.PageIndex < 1 {
    95  			filter.PageIndex = 1
    96  		}
    97  		var from = (filter.PageIndex - 1) * filter.Size
    98  		service = service.From(from).Size(filter.Size)
    99  	} else {
   100  		if filter.MaxWindows == 0 {
   101  			filter.MaxWindows = 5000000
   102  		}
   103  		service = service.Size(filter.MaxWindows)
   104  	}
   105  	return service, nil
   106  }
   107  
   108  func BatchUpdate(ctx context.Context, client *elastic.Client, index string, datas map[string]map[string]interface{}, refresh string) error {
   109  	var bulk = client.Bulk().Index(index)
   110  	for k, v := range datas {
   111  		bulk.Add(elastic.NewBulkUpdateRequest().Id(k).Doc(v))
   112  	}
   113  	_, err := bulk.Refresh(refresh).Do(ctx)
   114  	return err
   115  }
   116  
   117  // UpdateWhere 根据查询更新
   118  func UpdateWhere(ctx context.Context, client *elastic.Client, csr *Query, data map[string]interface{}, refresh string) error {
   119  	if len(csr.Index) == 0 {
   120  		return errors.New("index 不能为空")
   121  	}
   122  	if len(csr.Index) == 0 {
   123  		return errors.New("index 不能为空")
   124  	}
   125  	//ctx._source['name']=params['name']
   126  	var i = 0
   127  	var script = ""
   128  	for k := range data {
   129  		if i > 0 {
   130  			script += ";"
   131  		}
   132  		script += fmt.Sprintf("ctx._source['%s']=params['%s']", k, k)
   133  		i++
   134  	}
   135  	esScript := elastic.NewScript(script).Params(data)
   136  	res, err := client.UpdateByQuery(csr.Index).Query(csr.ToQuery()).Script(esScript).Refresh(refresh).Do(ctx)
   137  	if err != nil {
   138  		return err
   139  	}
   140  	if res.Updated >= 1 && res.Total >= 1 {
   141  		return nil
   142  	}
   143  	return errors.New("not update")
   144  }
   145  
   146  func BatchDelete(ctx context.Context, client *elastic.Client, index string, ids []string, refresh string) error {
   147  	bulk := client.Bulk().Index(index)
   148  	for _, v := range ids {
   149  		bulk.Add(elastic.NewBulkDeleteRequest().Id(v))
   150  	}
   151  	_, err := bulk.Refresh(refresh).Do(ctx)
   152  	if err != nil {
   153  		if elastic.IsNotFound(err) {
   154  			return nil
   155  		}
   156  		return err
   157  	}
   158  	return err
   159  }
   160  
   161  // Agg 统计
   162  func Agg(ctx context.Context, client *elastic.Client, csr *Query, agg elastic.Aggregation, valueFunc func(json.RawMessage) error) (err error) {
   163  	searchService, err := BuildQuery(ctx, client, csr)
   164  	if err != nil {
   165  		return err
   166  	}
   167  	searchResult, err := searchService.Size(0).Aggregation("agg", agg).Do(ctx)
   168  	if err != nil {
   169  		if NotFoundErr(err) {
   170  			return nil
   171  		}
   172  		return err
   173  	}
   174  	v, ok := searchResult.Aggregations["agg"]
   175  	if !ok {
   176  		return nil
   177  	}
   178  	return valueFunc(v)
   179  }
   180  
   181  func BuildIntTermsQuery(name string, list []int64) (query *elastic.TermsQuery) {
   182  	if len(list) == 0 {
   183  		return nil
   184  	}
   185  	l := make([]interface{}, len(list))
   186  	for index, value := range list {
   187  		l[index] = value
   188  	}
   189  	return elastic.NewTermsQuery(name, l...)
   190  }
   191  
   192  func BuildTermsQuery[T any](name string, list []T) (query *elastic.TermsQuery) {
   193  	if len(list) == 0 {
   194  		return nil
   195  	}
   196  	l := make([]interface{}, len(list))
   197  	for index, value := range list {
   198  		l[index] = value
   199  	}
   200  	return elastic.NewTermsQuery(name, l...)
   201  }
   202  
   203  func BuildTermQuery[T any](name string, val T) (query *elastic.TermQuery) {
   204  	return elastic.NewTermQuery(name, val)
   205  }
   206  
   207  func BuildExistsQuery(name string) (query *elastic.ExistsQuery) {
   208  	return elastic.NewExistsQuery(name)
   209  }
   210  
   211  func MustTermsQueryIf[T any](ok bool, search *Query, name string, list []T) *Query {
   212  	if !ok {
   213  		return search
   214  	}
   215  	if len(list) == 0 {
   216  		return search
   217  	}
   218  	search.MustQuery = append(search.MustQuery, BuildTermsQuery(name, list))
   219  	return search
   220  }
   221  
   222  func BuildWildcardQuery(names []string, term string) (query *elastic.BoolQuery) {
   223  	var should = make([]elastic.Query, len(names))
   224  	for i := range names {
   225  		should[i] = elastic.NewWildcardQuery(names[i], fmt.Sprintf("*%v*", term))
   226  	}
   227  	return elastic.NewBoolQuery().Should(should...)
   228  }
   229  
   230  func BuildNestedWildcardQuery(path string, name string, card string) (query *elastic.NestedQuery) {
   231  	return elastic.NewNestedQuery(path, BuildWildcardQuery([]string{name}, card))
   232  }
   233  
   234  func BuildNestedTermsQuery[T any](path string, name string, list []T) (query *elastic.NestedQuery) {
   235  	if len(list) == 0 {
   236  		return nil
   237  	}
   238  	return elastic.NewNestedQuery(path, BuildTermsQuery(name, list))
   239  }
   240  
   241  type SortField struct {
   242  	OrderField string `json:"order_field"`
   243  	Desc       bool   `json:"desc"`
   244  }
   245  
   246  type Query struct {
   247  	Index        string
   248  	MustQuery    []elastic.Query
   249  	MustNotQuery []elastic.Query
   250  	ShouldQuery  []elastic.Query
   251  	Filters      []elastic.Query
   252  	Sorters      []elastic.Sorter
   253  	PageIndex    int
   254  	Size         int
   255  	MaxWindows   int
   256  }
   257  
   258  func (cs *Query) ToQuery() elastic.Query {
   259  	boolQuery := elastic.NewBoolQuery()
   260  	boolQuery.Must(cs.MustQuery...)
   261  	boolQuery.MustNot(cs.MustNotQuery...)
   262  	boolQuery.Should(cs.ShouldQuery...)
   263  	boolQuery.Filter(cs.Filters...)
   264  
   265  	// 当should不为空时,保证至少匹配should中的一项
   266  	if len(cs.MustQuery) == 0 && len(cs.MustNotQuery) == 0 && len(cs.ShouldQuery) > 0 {
   267  		boolQuery.MinimumShouldMatch("1")
   268  	}
   269  	return boolQuery
   270  }
   271  
   272  // FilterDeleteQuery 过滤删除的
   273  func (cs *Query) FilterDeleteQuery() *Query {
   274  	if cs.Filters == nil {
   275  		cs.Filters = []elastic.Query{}
   276  	}
   277  	cs.Filters = append(cs.Filters, elastic.NewTermQuery("is_delete", 0))
   278  	return cs
   279  }
   280  
   281  func (cs *Query) FilterDeleteQueryIf(filterDel bool) *Query {
   282  	if cs.Filters == nil {
   283  		cs.Filters = []elastic.Query{}
   284  	}
   285  	if !filterDel {
   286  		return cs
   287  	}
   288  	cs.Filters = append(cs.Filters, elastic.NewTermQuery("is_delete", 0))
   289  	return cs
   290  }
   291  
   292  func (cs *Query) MustWildcardQueryIf(ok bool, name string, term string) *Query {
   293  	if !ok {
   294  		return cs
   295  	}
   296  	cs.MustQuery = append(cs.MustQuery, elastic.NewWildcardQuery(name, fmt.Sprintf("*%v*", term)))
   297  	return cs
   298  }
   299  
   300  func (cs *Query) MustTermQueryIf(ok bool, name string, term interface{}) *Query {
   301  	if !ok {
   302  		return cs
   303  	}
   304  	cs.MustQuery = append(cs.MustQuery, elastic.NewTermQuery(name, term))
   305  	return cs
   306  }
   307  
   308  func (cs *Query) MustNotTermQueryIf(ok bool, name string, term interface{}) *Query {
   309  	if !ok {
   310  		return cs
   311  	}
   312  	cs.MustNotQuery = append(cs.MustNotQuery, elastic.NewTermQuery(name, term))
   313  	return cs
   314  }
   315  
   316  // MustBitScriptQueryIf 位运算
   317  func (cs *Query) MustBitScriptQueryIf(ok bool, name string, val int32) *Query {
   318  	if !ok {
   319  		return cs
   320  	}
   321  	cs.MustQuery = append(cs.MustNotQuery, elastic.NewScriptQuery(elastic.NewScript(fmt.Sprintf("(doc['%v'].value&%d)==%d", name, val, val))))
   322  	return cs
   323  }
   324  
   325  func (cs *Query) TryMustTermsQuery(name string, list []string) *Query {
   326  	if len(list) == 0 {
   327  		return cs
   328  	}
   329  	cs.MustQuery = append(cs.MustQuery, BuildTermsQuery(name, list))
   330  	return cs
   331  }
   332  
   333  func (cs *Query) MustTermsQueryIf(ok bool, name string, list []string) *Query {
   334  	return MustTermsQueryIf(ok, cs, name, list)
   335  }
   336  
   337  func (cs *Query) MustTermsUint8QueryIf(ok bool, name string, list []uint8) *Query {
   338  	return MustTermsQueryIf(ok, cs, name, list)
   339  }
   340  
   341  func (cs *Query) MustTermsQueryInt32If(ok bool, name string, list []int32) *Query {
   342  	return MustTermsQueryIf(ok, cs, name, list)
   343  }
   344  
   345  func (cs *Query) MustTermsQueryIntIf(ok bool, name string, list []int) *Query {
   346  	return MustTermsQueryIf(ok, cs, name, list)
   347  }
   348  
   349  func (cs *Query) TryMustNotTermsQuery(name string, list []string) *Query {
   350  	if len(list) == 0 {
   351  		return cs
   352  	}
   353  	cs.MustNotQuery = append(cs.MustNotQuery, BuildTermsQuery(name, list))
   354  	return cs
   355  }
   356  
   357  func (cs *Query) TryMustIntTermsQuery(name string, list []int64) *Query {
   358  	if len(list) == 0 {
   359  		return cs
   360  	}
   361  	cs.MustQuery = append(cs.MustQuery, BuildIntTermsQuery(name, list))
   362  	return cs
   363  }
   364  
   365  func (cs *Query) ExistsQueryIf(ok bool, name string) *Query {
   366  	if !ok {
   367  		return cs
   368  	}
   369  	cs.MustQuery = append(cs.MustQuery, BuildExistsQuery(name))
   370  	return cs
   371  }
   372  
   373  // TryMustNestedStringTermsQuery 如 (path:to,name:to.id)
   374  func (cs *Query) TryMustNestedStringTermsQuery(path string, name string, list []string) *Query {
   375  	if len(list) == 0 {
   376  		return cs
   377  	}
   378  	cs.MustQuery = append(cs.MustQuery, BuildNestedTermsQuery(path, name, list))
   379  	return cs
   380  }
   381  
   382  // TryMustNestedWildcardQuery 如 (path:to,name:to.id)
   383  func (cs *Query) TryMustNestedWildcardQuery(path string, name string, card string) *Query {
   384  	cs.MustQuery = append(cs.MustQuery, BuildNestedWildcardQuery(path, name, card))
   385  	return cs
   386  }
   387  
   388  // TryMustNestedInt32TermsQuery 如 (path:to,name:to.id)
   389  func (cs *Query) TryMustNestedInt32TermsQuery(path string, name string, list []int32) *Query {
   390  	if len(list) == 0 {
   391  		return cs
   392  	}
   393  	cs.MustQuery = append(cs.MustQuery, BuildNestedTermsQuery(path, name, list))
   394  	return cs
   395  }
   396  
   397  func (cs *Query) TryMustRangeQuery(name string, from int64, to int64) *Query {
   398  	if from <= 0 && to <= 0 {
   399  		return cs
   400  	}
   401  	var q = elastic.NewRangeQuery(name)
   402  	if from > 0 {
   403  		q = q.Gte(from)
   404  	}
   405  	if to > 0 {
   406  		q = q.Lte(to)
   407  	}
   408  	cs.MustQuery = append(cs.MustQuery, q)
   409  	return cs
   410  }
   411  
   412  // MustRangeFromQueryIf optType 1 gt 2 gte 3 lt 4 lte
   413  func (cs *Query) MustRangeFromQueryIf(ok bool, name string, from int64, optType int64) *Query {
   414  	if !ok {
   415  		return cs
   416  	}
   417  	var q = elastic.NewRangeQuery(name)
   418  	switch optType {
   419  	case 1:
   420  		q = q.Gt(from)
   421  	case 2:
   422  		q = q.Gte(from)
   423  	case 3:
   424  		q = q.Lt(from)
   425  	case 4:
   426  		q = q.Lte(from)
   427  	}
   428  	cs.MustQuery = append(cs.MustQuery, q)
   429  	return cs
   430  }
   431  
   432  // MustWildcardOrQueryIf 多字段模糊匹配, 用bool query.should 拼接or 条件
   433  func (cs *Query) MustWildcardOrQueryIf(ok bool, names []string, term string) *Query {
   434  	if !ok {
   435  		return cs
   436  	}
   437  	var should = make([]elastic.Query, len(names))
   438  	for i := range names {
   439  		should[i] = elastic.NewWildcardQuery(names[i], fmt.Sprintf("*%v*", term))
   440  	}
   441  	cs.MustQuery = append(cs.MustQuery, elastic.NewBoolQuery().Should(should...))
   442  	return cs
   443  }
   444  
   445  func (cs *Query) Sort(list []*SortField, dft *SortField) *Query {
   446  	if len(list) == 0 {
   447  		if dft == nil {
   448  			return cs
   449  		}
   450  		return cs.buildSort(dft)
   451  	}
   452  	for _, v := range list {
   453  		cs.buildSort(v)
   454  	}
   455  	return cs
   456  }
   457  
   458  func (cs *Query) buildSort(field *SortField) *Query {
   459  	if field == nil {
   460  		return cs
   461  	}
   462  	var fsort = elastic.NewFieldSort(field.OrderField).Order(!field.Desc)
   463  	cs.Sorters = append(cs.Sorters, fsort)
   464  	return cs
   465  }