github.com/meulengracht/snapd@v0.0.0-20210719210640-8bde69bcc84e/asserts/header_checks.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2015-2020 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package asserts 21 22 import ( 23 "crypto" 24 "encoding/base64" 25 "fmt" 26 "regexp" 27 "strconv" 28 "strings" 29 "time" 30 ) 31 32 // common checks used when decoding/assembling assertions 33 34 func checkExistsString(headers map[string]interface{}, name string) (string, error) { 35 return checkExistsStringWhat(headers, name, "header") 36 } 37 38 func checkExistsStringWhat(m map[string]interface{}, name, what string) (string, error) { 39 value, ok := m[name] 40 if !ok { 41 return "", fmt.Errorf("%q %s is mandatory", name, what) 42 } 43 s, ok := value.(string) 44 if !ok { 45 return "", fmt.Errorf("%q %s must be a string", name, what) 46 } 47 return s, nil 48 } 49 50 func checkNotEmptyString(headers map[string]interface{}, name string) (string, error) { 51 return checkNotEmptyStringWhat(headers, name, "header") 52 } 53 54 func checkNotEmptyStringWhat(m map[string]interface{}, name, what string) (string, error) { 55 s, err := checkExistsStringWhat(m, name, what) 56 if err != nil { 57 return "", err 58 } 59 if len(s) == 0 { 60 return "", fmt.Errorf("%q %s should not be empty", name, what) 61 } 62 return s, nil 63 } 64 65 func checkOptionalStringWhat(headers map[string]interface{}, name, what string) (string, error) { 66 value, ok := headers[name] 67 if !ok { 68 return "", nil 69 } 70 s, ok := value.(string) 71 if !ok { 72 return "", fmt.Errorf("%q %s must be a string", name, what) 73 } 74 return s, nil 75 } 76 77 func checkOptionalString(headers map[string]interface{}, name string) (string, error) { 78 return checkOptionalStringWhat(headers, name, "header") 79 } 80 81 func checkPrimaryKey(headers map[string]interface{}, primKey string) (string, error) { 82 value, err := checkNotEmptyString(headers, primKey) 83 if err != nil { 84 return "", err 85 } 86 if strings.Contains(value, "/") { 87 return "", fmt.Errorf("%q primary key header cannot contain '/'", primKey) 88 } 89 return value, nil 90 } 91 92 func checkAssertType(assertType *AssertionType) error { 93 if assertType == nil { 94 return fmt.Errorf("internal error: assertion type cannot be nil") 95 } 96 // sanity check against known canonical 97 sanity := typeRegistry[assertType.Name] 98 switch sanity { 99 case assertType: 100 // fine, matches canonical 101 return nil 102 case nil: 103 return fmt.Errorf("internal error: unknown assertion type: %q", assertType.Name) 104 default: 105 return fmt.Errorf("internal error: unpredefined assertion type for name %q used (unexpected address %p)", assertType.Name, assertType) 106 } 107 } 108 109 // use 'defl' default if missing 110 func checkIntWithDefault(headers map[string]interface{}, name string, defl int) (int, error) { 111 value, ok := headers[name] 112 if !ok { 113 return defl, nil 114 } 115 s, ok := value.(string) 116 if !ok { 117 return -1, fmt.Errorf("%q header is not an integer: %v", name, value) 118 } 119 m, err := atoi(s, "%q %s", name, "header") 120 if err != nil { 121 return -1, err 122 } 123 return m, nil 124 } 125 126 func checkInt(headers map[string]interface{}, name string) (int, error) { 127 return checkIntWhat(headers, name, "header") 128 } 129 130 func checkIntWhat(headers map[string]interface{}, name, what string) (int, error) { 131 valueStr, err := checkNotEmptyStringWhat(headers, name, what) 132 if err != nil { 133 return -1, err 134 } 135 value, err := atoi(valueStr, "%q %s", name, what) 136 if err != nil { 137 return -1, err 138 } 139 return value, nil 140 } 141 142 type intSyntaxError string 143 144 func (e intSyntaxError) Error() string { 145 return string(e) 146 } 147 148 func atoi(valueStr, whichFmt string, whichArgs ...interface{}) (int, error) { 149 value, err := strconv.Atoi(valueStr) 150 if err != nil { 151 which := fmt.Sprintf(whichFmt, whichArgs...) 152 if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange { 153 return -1, fmt.Errorf("%s is out of range: %v", which, valueStr) 154 } 155 return -1, intSyntaxError(fmt.Sprintf("%s is not an integer: %v", which, valueStr)) 156 } 157 if prefixZeros(valueStr) { 158 return -1, fmt.Errorf("%s has invalid prefix zeros: %s", fmt.Sprintf(whichFmt, whichArgs...), valueStr) 159 } 160 return value, nil 161 } 162 163 func prefixZeros(s string) bool { 164 return strings.HasPrefix(s, "0") && s != "0" 165 } 166 167 func checkRFC3339Date(headers map[string]interface{}, name string) (time.Time, error) { 168 return checkRFC3339DateWhat(headers, name, "header") 169 } 170 171 func checkRFC3339DateWhat(m map[string]interface{}, name, what string) (time.Time, error) { 172 dateStr, err := checkNotEmptyStringWhat(m, name, what) 173 if err != nil { 174 return time.Time{}, err 175 } 176 date, err := time.Parse(time.RFC3339, dateStr) 177 if err != nil { 178 return time.Time{}, fmt.Errorf("%q %s is not a RFC3339 date: %v", name, what, err) 179 } 180 return date, nil 181 } 182 183 func checkRFC3339DateWithDefault(headers map[string]interface{}, name string, defl time.Time) (time.Time, error) { 184 return checkRFC3339DateWithDefaultWhat(headers, name, "header", defl) 185 } 186 187 func checkRFC3339DateWithDefaultWhat(m map[string]interface{}, name, what string, defl time.Time) (time.Time, error) { 188 value, ok := m[name] 189 if !ok { 190 return defl, nil 191 } 192 dateStr, ok := value.(string) 193 if !ok { 194 return time.Time{}, fmt.Errorf("%q %s must be a string", name, what) 195 } 196 date, err := time.Parse(time.RFC3339, dateStr) 197 if err != nil { 198 return time.Time{}, fmt.Errorf("%q %s is not a RFC3339 date: %v", name, what, err) 199 } 200 return date, nil 201 } 202 203 func checkUint(headers map[string]interface{}, name string, bitSize int) (uint64, error) { 204 valueStr, err := checkNotEmptyString(headers, name) 205 if err != nil { 206 return 0, err 207 } 208 value, err := strconv.ParseUint(valueStr, 10, bitSize) 209 if err != nil { 210 if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange { 211 return 0, fmt.Errorf("%q header is out of range: %v", name, valueStr) 212 } 213 return 0, fmt.Errorf("%q header is not an unsigned integer: %v", name, valueStr) 214 } 215 if prefixZeros(valueStr) { 216 return 0, fmt.Errorf("%q header has invalid prefix zeros: %s", name, valueStr) 217 } 218 return value, nil 219 } 220 221 func checkDigest(headers map[string]interface{}, name string, h crypto.Hash) ([]byte, error) { 222 digestStr, err := checkNotEmptyString(headers, name) 223 if err != nil { 224 return nil, err 225 } 226 b, err := base64.RawURLEncoding.DecodeString(digestStr) 227 if err != nil { 228 return nil, fmt.Errorf("%q header cannot be decoded: %v", name, err) 229 } 230 if len(b) != h.Size() { 231 return nil, fmt.Errorf("%q header does not have the expected bit length: %d", name, len(b)*8) 232 } 233 234 return b, nil 235 } 236 237 // checkStringListInMap returns the `name` entry in the `m` map as a (possibly nil) `[]string` 238 // if `m` has an entry for `name` and it isn't a `[]string`, an error is returned 239 // if pattern is not nil, all the strings must match that pattern, otherwise an error is returned 240 // `what` is a descriptor, used for error messages 241 func checkStringListInMap(m map[string]interface{}, name, what string, pattern *regexp.Regexp) ([]string, error) { 242 value, ok := m[name] 243 if !ok { 244 return nil, nil 245 } 246 lst, ok := value.([]interface{}) 247 if !ok { 248 return nil, fmt.Errorf("%s must be a list of strings", what) 249 } 250 if len(lst) == 0 { 251 return nil, nil 252 } 253 res := make([]string, len(lst)) 254 for i, v := range lst { 255 s, ok := v.(string) 256 if !ok { 257 return nil, fmt.Errorf("%s must be a list of strings", what) 258 } 259 if pattern != nil && !pattern.MatchString(s) { 260 return nil, fmt.Errorf("%s contains an invalid element: %q", what, s) 261 } 262 res[i] = s 263 } 264 return res, nil 265 } 266 267 func checkStringList(headers map[string]interface{}, name string) ([]string, error) { 268 return checkStringListMatches(headers, name, nil) 269 } 270 271 func checkStringListMatches(headers map[string]interface{}, name string, pattern *regexp.Regexp) ([]string, error) { 272 return checkStringListInMap(headers, name, fmt.Sprintf("%q header", name), pattern) 273 } 274 275 func checkStringMatches(headers map[string]interface{}, name string, pattern *regexp.Regexp) (string, error) { 276 return checkStringMatchesWhat(headers, name, "header", pattern) 277 } 278 279 func checkStringMatchesWhat(headers map[string]interface{}, name, what string, pattern *regexp.Regexp) (string, error) { 280 s, err := checkNotEmptyStringWhat(headers, name, what) 281 if err != nil { 282 return "", err 283 } 284 if !pattern.MatchString(s) { 285 return "", fmt.Errorf("%q %s contains invalid characters: %q", name, what, s) 286 } 287 return s, nil 288 } 289 290 func checkOptionalBool(headers map[string]interface{}, name string) (bool, error) { 291 value, ok := headers[name] 292 if !ok { 293 return false, nil 294 } 295 s, ok := value.(string) 296 if !ok || (s != "true" && s != "false") { 297 return false, fmt.Errorf("%q header must be 'true' or 'false'", name) 298 } 299 return s == "true", nil 300 } 301 302 func checkMap(headers map[string]interface{}, name string) (map[string]interface{}, error) { 303 value, ok := headers[name] 304 if !ok { 305 return nil, nil 306 } 307 m, ok := value.(map[string]interface{}) 308 if !ok { 309 return nil, fmt.Errorf("%q header must be a map", name) 310 } 311 return m, nil 312 }