github.com/polarismesh/polaris@v1.17.8/store/mysql/common.go (about)

     1  /**
     2   * Tencent is pleased to support the open source community by making Polaris available.
     3   *
     4   * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved.
     5   *
     6   * Licensed under the BSD 3-Clause License (the "License");
     7   * you may not use this file except in compliance with the License.
     8   * You may obtain a copy of the License at
     9   *
    10   * https://opensource.org/licenses/BSD-3-Clause
    11   *
    12   * Unless required by applicable law or agreed to in writing, software distributed
    13   * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
    14   * CONDITIONS OF ANY KIND, either express or implied. See the License for the
    15   * specific language governing permissions and limitations under the License.
    16   */
    17  
    18  package sqldb
    19  
    20  import (
    21  	"bytes"
    22  	"database/sql"
    23  	"strings"
    24  	"time"
    25  	"unicode"
    26  
    27  	"github.com/polarismesh/polaris/store"
    28  )
    29  
    30  // QueryHandler is the interface that wraps the basic Query method.
    31  type QueryHandler func(query string, args ...interface{}) (*sql.Rows, error)
    32  
    33  // BatchHandler 批量查询数据的回调函数
    34  type BatchHandler func(objects []interface{}) error
    35  
    36  // BatchQuery 批量查询数据的对外接口
    37  // 每次最多查询200个
    38  func BatchQuery(label string, data []interface{}, handler BatchHandler) error {
    39  	// start := time.Now()
    40  	maxCount := 200
    41  	beg := 0
    42  	remain := len(data)
    43  	if remain == 0 {
    44  		return nil
    45  	}
    46  
    47  	progress := 0
    48  	for {
    49  		if remain > maxCount {
    50  			if err := handler(data[beg : beg+maxCount]); err != nil {
    51  				return err
    52  			}
    53  
    54  			beg += maxCount
    55  			remain -= maxCount
    56  			progress += maxCount
    57  			if progress%20000 == 0 {
    58  				log.Infof("[Store][database][Batch] query (%s) progress(%d / %d)", label, progress, len(data))
    59  			}
    60  		} else {
    61  			if err := handler(data[beg : beg+remain]); err != nil {
    62  				return err
    63  			}
    64  			break
    65  		}
    66  	}
    67  	// log.Infof("[Store][database][Batch] consume time: %v", time.Now().Sub(start))
    68  	return nil
    69  }
    70  
    71  // BatchOperation 批量操作
    72  // @note 每次最多操作100个
    73  func BatchOperation(label string, data []interface{}, handler BatchHandler) error {
    74  	if data == nil {
    75  		return nil
    76  	}
    77  	maxCount := 100
    78  	progress := 0
    79  	for begin := 0; begin < len(data); begin += maxCount {
    80  		end := begin + maxCount
    81  		if end > len(data) {
    82  			end = len(data)
    83  		}
    84  		if err := handler(data[begin:end]); err != nil {
    85  			return err
    86  		}
    87  		progress += end - begin
    88  		if progress%maxCount == 0 {
    89  			log.Infof("[Store][database][Batch] operation (%s) progress(%d/%d)", label, progress, len(data))
    90  		}
    91  	}
    92  	return nil
    93  }
    94  
    95  // queryEntryCount 单独查询count个数的执行函数
    96  func queryEntryCount(conn *BaseDB, str string, args []interface{}) (uint32, error) {
    97  	var count uint32
    98  	var err error
    99  	Retry("queryRow", func() error {
   100  		err = conn.QueryRow(str, args...).Scan(&count)
   101  		return err
   102  	})
   103  	switch {
   104  	case err == sql.ErrNoRows:
   105  		log.Errorf("[Store][database] not found any entry(%s)", str)
   106  		return 0, err
   107  	case err != nil:
   108  		log.Errorf("[Store][database] query entry count(%s) err: %s", str, err.Error())
   109  		return 0, err
   110  	default:
   111  		return count, nil
   112  	}
   113  }
   114  
   115  // aliasFilter2Where 别名查询转换
   116  var aliasFilter2Where = map[string]string{
   117  	"service":         "source.name",
   118  	"namespace":       "source.namespace",
   119  	"alias":           "alias.name",
   120  	"alias_namespace": "alias.namespace",
   121  	"owner":           "alias.owner",
   122  }
   123  
   124  // serviceAliasFilter2Where 别名查询字段转换函数
   125  func serviceAliasFilter2Where(filter map[string]string) map[string]string {
   126  	out := make(map[string]string)
   127  	for k, v := range filter {
   128  		if d, ok := aliasFilter2Where[k]; ok {
   129  			out[d] = v
   130  		} else {
   131  			out[k] = v
   132  		}
   133  	}
   134  
   135  	return out
   136  }
   137  
   138  // checkDataBaseAffectedRows 检查数据库处理返回的行数
   139  func checkDataBaseAffectedRows(result sql.Result, counts ...int64) error {
   140  	n, err := result.RowsAffected()
   141  	if err != nil {
   142  		log.Errorf("[Store][Database] get rows affected err: %s", err.Error())
   143  		return err
   144  	}
   145  
   146  	for _, c := range counts {
   147  		if n == c {
   148  			return nil
   149  		}
   150  	}
   151  
   152  	log.Errorf("[Store][Database] get rows affected result(%d) is not match expect(%+v)", n, counts)
   153  	return store.NewStatusError(store.AffectedRowsNotMatch, "affected rows not matched")
   154  }
   155  
   156  // timeToTimestamp 转时间戳(秒)
   157  // 由于 FROM_UNIXTIME 不支持负数,所以小于0的情况赋值为 1
   158  func timeToTimestamp(t time.Time) int64 {
   159  	ts := t.Unix()
   160  	if ts < 0 {
   161  		ts = 1
   162  	}
   163  	return ts
   164  }
   165  
   166  func toUnderscoreName(name string) string {
   167  	var buf bytes.Buffer
   168  	for i, token := range name {
   169  		if unicode.IsUpper(token) && i > 0 {
   170  			buf.WriteString("_")
   171  		}
   172  		buf.WriteString(strings.ToLower(string(token)))
   173  	}
   174  	return buf.String()
   175  }