code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/utils.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package sqlstore
    17  
    18  import (
    19  	"fmt"
    20  	"reflect"
    21  	"strconv"
    22  	"strings"
    23  
    24  	"code.vegaprotocol.io/vega/datanode/entities"
    25  
    26  	"github.com/georgysavva/scany/dbscan"
    27  )
    28  
    29  // A handy little helper function for building queries. Appends 'value'
    30  // to the 'args' slice and returns a string '$N' referring to the index
    31  // of the value in args. For example:
    32  //
    33  //	var args []interface{}
    34  //	query = "select * from foo where id=" + nextBindVar(&args, 100)
    35  //	db.Query(query, args...)
    36  func nextBindVar(args *[]interface{}, value interface{}) string {
    37  	*args = append(*args, value)
    38  	return "$" + strconv.Itoa(len(*args))
    39  }
    40  
    41  func orderAndPaginateWithCursor(query string, pagination entities.CursorPagination, cursors CursorQueryParameters,
    42  	args ...interface{}) (string, []interface{},
    43  ) {
    44  	var order string
    45  
    46  	whereOrAnd := "WHERE"
    47  
    48  	if strings.Contains(strings.ToUpper(query), "WHERE") {
    49  		whereOrAnd = "AND"
    50  	}
    51  
    52  	var cursor string
    53  	cursor, args = cursors.Where(args...)
    54  	if cursor != "" {
    55  		query = fmt.Sprintf("%s %s %s", query, whereOrAnd, cursor)
    56  	}
    57  
    58  	limit := calculateLimit(pagination)
    59  
    60  	if limit == 0 {
    61  		// return everything ordered by the cursor column ordered ascending
    62  		order = cursors.OrderBy()
    63  		query = fmt.Sprintf("%s ORDER BY %s", query, order)
    64  		return query, args
    65  	}
    66  
    67  	order = cursors.OrderBy()
    68  	query = fmt.Sprintf("%s ORDER BY %s", query, order)
    69  	query = fmt.Sprintf("%s LIMIT %d", query, limit)
    70  
    71  	return query, args
    72  }
    73  
    74  func calculateLimit(pagination entities.CursorPagination) int {
    75  	var limit int32
    76  	if pagination.HasForward() && pagination.Forward.Limit != nil {
    77  		limit = *pagination.Forward.Limit + 1
    78  		if pagination.Forward.HasCursor() {
    79  			limit = *pagination.Forward.Limit + 2 // +2 to make sure we get the previous and next cursor
    80  		}
    81  	} else if pagination.HasBackward() && pagination.Backward.Limit != nil {
    82  		limit = *pagination.Backward.Limit + 1
    83  		if pagination.Backward.HasCursor() {
    84  			limit = *pagination.Backward.Limit + 2 // +2 to make sure we get the previous and next cursor
    85  		}
    86  	}
    87  
    88  	return int(limit)
    89  }
    90  
    91  func extractPaginationInfo(pagination entities.CursorPagination) (Sorting, Compare, string) {
    92  	var cmp Compare
    93  	var value string
    94  
    95  	sort := ASC
    96  
    97  	if pagination.NewestFirst {
    98  		sort = DESC
    99  	}
   100  
   101  	if pagination.HasForward() {
   102  		if pagination.Forward.HasCursor() {
   103  			cmp = GE
   104  			if pagination.NewestFirst {
   105  				cmp = LE
   106  			}
   107  			value = pagination.Forward.Cursor.Value()
   108  		}
   109  	} else if pagination.HasBackward() {
   110  		sort = DESC
   111  
   112  		if pagination.NewestFirst {
   113  			sort = ASC
   114  		}
   115  
   116  		if pagination.Backward.HasCursor() {
   117  			cmp = LE
   118  			if pagination.NewestFirst {
   119  				cmp = GE
   120  			}
   121  			value = pagination.Backward.Cursor.Value()
   122  		}
   123  	}
   124  
   125  	return sort, cmp, value
   126  }
   127  
   128  func extractCursorFromPagination(pagination entities.CursorPagination) (cursor string) {
   129  	if pagination.HasForward() && pagination.Forward.HasCursor() {
   130  		cursor = pagination.Forward.Cursor.Value()
   131  	} else if pagination.HasBackward() && pagination.Backward.HasCursor() {
   132  		cursor = pagination.Backward.Cursor.Value()
   133  	}
   134  	return
   135  }
   136  
   137  // StructValueForColumn replicates some of the unexported functionality from Scanny. You pass a
   138  // struct (or pointer to a struct), and a column name. It converts the struct field names into
   139  // database column names in a similar way to scanny and if one matches colName, that field value
   140  // is returned. For example
   141  //
   142  //	type Foo struct {
   143  //		Thingy        int `db:"wotsit"`
   144  //		SomethingElse int
   145  //	}
   146  //
   147  //	val, err := StructValueForColumn(foo, "wotsit")             -> 1
   148  //	val, err := StructValueForColumn(&foo, "something_else")    -> 2
   149  //
   150  // NB - not all functionality of scanny is supported (but could be added if needed)
   151  //   - we don't support embedded structs
   152  //   - assumes the 'dbTag' is the default 'db'
   153  func StructValueForColumn(obj any, colName string) (interface{}, error) {
   154  	structType := reflect.TypeOf(obj)
   155  	structValue := reflect.ValueOf(obj)
   156  
   157  	if structType.Kind() == reflect.Pointer {
   158  		structType = structType.Elem()
   159  		structValue = structValue.Elem()
   160  	}
   161  
   162  	if structType.Kind() != reflect.Struct {
   163  		return nil, fmt.Errorf("obj must be struct")
   164  	}
   165  
   166  	for i := 0; i < structType.NumField(); i++ {
   167  		field := structType.Field(i)
   168  		thisColName := field.Tag.Get("db")
   169  		if thisColName == "" {
   170  			thisColName = dbscan.SnakeCaseMapper(field.Name)
   171  		}
   172  		if thisColName == colName {
   173  			fieldValue := structValue.Field(i)
   174  			return fieldValue.Interface(), nil
   175  		}
   176  	}
   177  	return nil, fmt.Errorf("no field matching column name %s", colName)
   178  }
   179  
   180  func filterDateRange(query, dateColumn string, dateRange entities.DateRange, isFirstCondition bool, args ...interface{}) (string, []interface{}) {
   181  	conditions := []string{}
   182  
   183  	if dateRange.Start != nil {
   184  		conditions = append(conditions, fmt.Sprintf("%s >= %s", dateColumn, nextBindVar(&args, *dateRange.Start)))
   185  	}
   186  
   187  	if dateRange.End != nil {
   188  		conditions = append(conditions, fmt.Sprintf("%s < %s", dateColumn, nextBindVar(&args, *dateRange.End)))
   189  	}
   190  
   191  	if len(conditions) <= 0 {
   192  		return query, args
   193  	}
   194  
   195  	finalConditions := strings.Join(conditions, " AND ")
   196  	if isFirstCondition {
   197  		query = fmt.Sprintf("%s where %s", query, finalConditions)
   198  	} else {
   199  		query = fmt.Sprintf("%s AND %s", query, finalConditions)
   200  	}
   201  
   202  	return query, args
   203  }