code.vegaprotocol.io/vega@v0.79.0/datanode/sqlstore/utils.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package sqlstore 17 18 import ( 19 "fmt" 20 "reflect" 21 "strconv" 22 "strings" 23 24 "code.vegaprotocol.io/vega/datanode/entities" 25 26 "github.com/georgysavva/scany/dbscan" 27 ) 28 29 // A handy little helper function for building queries. Appends 'value' 30 // to the 'args' slice and returns a string '$N' referring to the index 31 // of the value in args. For example: 32 // 33 // var args []interface{} 34 // query = "select * from foo where id=" + nextBindVar(&args, 100) 35 // db.Query(query, args...) 36 func nextBindVar(args *[]interface{}, value interface{}) string { 37 *args = append(*args, value) 38 return "$" + strconv.Itoa(len(*args)) 39 } 40 41 func orderAndPaginateWithCursor(query string, pagination entities.CursorPagination, cursors CursorQueryParameters, 42 args ...interface{}) (string, []interface{}, 43 ) { 44 var order string 45 46 whereOrAnd := "WHERE" 47 48 if strings.Contains(strings.ToUpper(query), "WHERE") { 49 whereOrAnd = "AND" 50 } 51 52 var cursor string 53 cursor, args = cursors.Where(args...) 54 if cursor != "" { 55 query = fmt.Sprintf("%s %s %s", query, whereOrAnd, cursor) 56 } 57 58 limit := calculateLimit(pagination) 59 60 if limit == 0 { 61 // return everything ordered by the cursor column ordered ascending 62 order = cursors.OrderBy() 63 query = fmt.Sprintf("%s ORDER BY %s", query, order) 64 return query, args 65 } 66 67 order = cursors.OrderBy() 68 query = fmt.Sprintf("%s ORDER BY %s", query, order) 69 query = fmt.Sprintf("%s LIMIT %d", query, limit) 70 71 return query, args 72 } 73 74 func calculateLimit(pagination entities.CursorPagination) int { 75 var limit int32 76 if pagination.HasForward() && pagination.Forward.Limit != nil { 77 limit = *pagination.Forward.Limit + 1 78 if pagination.Forward.HasCursor() { 79 limit = *pagination.Forward.Limit + 2 // +2 to make sure we get the previous and next cursor 80 } 81 } else if pagination.HasBackward() && pagination.Backward.Limit != nil { 82 limit = *pagination.Backward.Limit + 1 83 if pagination.Backward.HasCursor() { 84 limit = *pagination.Backward.Limit + 2 // +2 to make sure we get the previous and next cursor 85 } 86 } 87 88 return int(limit) 89 } 90 91 func extractPaginationInfo(pagination entities.CursorPagination) (Sorting, Compare, string) { 92 var cmp Compare 93 var value string 94 95 sort := ASC 96 97 if pagination.NewestFirst { 98 sort = DESC 99 } 100 101 if pagination.HasForward() { 102 if pagination.Forward.HasCursor() { 103 cmp = GE 104 if pagination.NewestFirst { 105 cmp = LE 106 } 107 value = pagination.Forward.Cursor.Value() 108 } 109 } else if pagination.HasBackward() { 110 sort = DESC 111 112 if pagination.NewestFirst { 113 sort = ASC 114 } 115 116 if pagination.Backward.HasCursor() { 117 cmp = LE 118 if pagination.NewestFirst { 119 cmp = GE 120 } 121 value = pagination.Backward.Cursor.Value() 122 } 123 } 124 125 return sort, cmp, value 126 } 127 128 func extractCursorFromPagination(pagination entities.CursorPagination) (cursor string) { 129 if pagination.HasForward() && pagination.Forward.HasCursor() { 130 cursor = pagination.Forward.Cursor.Value() 131 } else if pagination.HasBackward() && pagination.Backward.HasCursor() { 132 cursor = pagination.Backward.Cursor.Value() 133 } 134 return 135 } 136 137 // StructValueForColumn replicates some of the unexported functionality from Scanny. You pass a 138 // struct (or pointer to a struct), and a column name. It converts the struct field names into 139 // database column names in a similar way to scanny and if one matches colName, that field value 140 // is returned. For example 141 // 142 // type Foo struct { 143 // Thingy int `db:"wotsit"` 144 // SomethingElse int 145 // } 146 // 147 // val, err := StructValueForColumn(foo, "wotsit") -> 1 148 // val, err := StructValueForColumn(&foo, "something_else") -> 2 149 // 150 // NB - not all functionality of scanny is supported (but could be added if needed) 151 // - we don't support embedded structs 152 // - assumes the 'dbTag' is the default 'db' 153 func StructValueForColumn(obj any, colName string) (interface{}, error) { 154 structType := reflect.TypeOf(obj) 155 structValue := reflect.ValueOf(obj) 156 157 if structType.Kind() == reflect.Pointer { 158 structType = structType.Elem() 159 structValue = structValue.Elem() 160 } 161 162 if structType.Kind() != reflect.Struct { 163 return nil, fmt.Errorf("obj must be struct") 164 } 165 166 for i := 0; i < structType.NumField(); i++ { 167 field := structType.Field(i) 168 thisColName := field.Tag.Get("db") 169 if thisColName == "" { 170 thisColName = dbscan.SnakeCaseMapper(field.Name) 171 } 172 if thisColName == colName { 173 fieldValue := structValue.Field(i) 174 return fieldValue.Interface(), nil 175 } 176 } 177 return nil, fmt.Errorf("no field matching column name %s", colName) 178 } 179 180 func filterDateRange(query, dateColumn string, dateRange entities.DateRange, isFirstCondition bool, args ...interface{}) (string, []interface{}) { 181 conditions := []string{} 182 183 if dateRange.Start != nil { 184 conditions = append(conditions, fmt.Sprintf("%s >= %s", dateColumn, nextBindVar(&args, *dateRange.Start))) 185 } 186 187 if dateRange.End != nil { 188 conditions = append(conditions, fmt.Sprintf("%s < %s", dateColumn, nextBindVar(&args, *dateRange.End))) 189 } 190 191 if len(conditions) <= 0 { 192 return query, args 193 } 194 195 finalConditions := strings.Join(conditions, " AND ") 196 if isFirstCondition { 197 query = fmt.Sprintf("%s where %s", query, finalConditions) 198 } else { 199 query = fmt.Sprintf("%s AND %s", query, finalConditions) 200 } 201 202 return query, args 203 }