github.com/rigado/snapd@v2.42.5-go-mod+incompatible/asserts/header_checks.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2015-2016 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package asserts
    21  
    22  import (
    23  	"crypto"
    24  	"encoding/base64"
    25  	"fmt"
    26  	"regexp"
    27  	"strconv"
    28  	"strings"
    29  	"time"
    30  )
    31  
    32  // common checks used when decoding/assembling assertions
    33  
    34  func checkExistsString(headers map[string]interface{}, name string) (string, error) {
    35  	return checkExistsStringWhat(headers, name, "header")
    36  }
    37  
    38  func checkExistsStringWhat(m map[string]interface{}, name, what string) (string, error) {
    39  	value, ok := m[name]
    40  	if !ok {
    41  		return "", fmt.Errorf("%q %s is mandatory", name, what)
    42  	}
    43  	s, ok := value.(string)
    44  	if !ok {
    45  		return "", fmt.Errorf("%q %s must be a string", name, what)
    46  	}
    47  	return s, nil
    48  }
    49  
    50  func checkNotEmptyString(headers map[string]interface{}, name string) (string, error) {
    51  	return checkNotEmptyStringWhat(headers, name, "header")
    52  }
    53  
    54  func checkNotEmptyStringWhat(m map[string]interface{}, name, what string) (string, error) {
    55  	s, err := checkExistsStringWhat(m, name, what)
    56  	if err != nil {
    57  		return "", err
    58  	}
    59  	if len(s) == 0 {
    60  		return "", fmt.Errorf("%q %s should not be empty", name, what)
    61  	}
    62  	return s, nil
    63  }
    64  
    65  func checkOptionalStringWhat(headers map[string]interface{}, name, what string) (string, error) {
    66  	value, ok := headers[name]
    67  	if !ok {
    68  		return "", nil
    69  	}
    70  	s, ok := value.(string)
    71  	if !ok {
    72  		return "", fmt.Errorf("%q %s must be a string", name, what)
    73  	}
    74  	return s, nil
    75  }
    76  
    77  func checkOptionalString(headers map[string]interface{}, name string) (string, error) {
    78  	return checkOptionalStringWhat(headers, name, "header")
    79  }
    80  
    81  func checkPrimaryKey(headers map[string]interface{}, primKey string) (string, error) {
    82  	value, err := checkNotEmptyString(headers, primKey)
    83  	if err != nil {
    84  		return "", err
    85  	}
    86  	if strings.Contains(value, "/") {
    87  		return "", fmt.Errorf("%q primary key header cannot contain '/'", primKey)
    88  	}
    89  	return value, nil
    90  }
    91  
    92  func checkAssertType(assertType *AssertionType) error {
    93  	if assertType == nil {
    94  		return fmt.Errorf("internal error: assertion type cannot be nil")
    95  	}
    96  	// sanity check against known canonical
    97  	sanity := typeRegistry[assertType.Name]
    98  	switch sanity {
    99  	case assertType:
   100  		// fine, matches canonical
   101  		return nil
   102  	case nil:
   103  		return fmt.Errorf("internal error: unknown assertion type: %q", assertType.Name)
   104  	default:
   105  		return fmt.Errorf("internal error: unpredefined assertion type for name %q used (unexpected address %p)", assertType.Name, assertType)
   106  	}
   107  }
   108  
   109  // use 'defl' default if missing
   110  func checkIntWithDefault(headers map[string]interface{}, name string, defl int) (int, error) {
   111  	value, ok := headers[name]
   112  	if !ok {
   113  		return defl, nil
   114  	}
   115  	s, ok := value.(string)
   116  	if !ok {
   117  		return -1, fmt.Errorf("%q header is not an integer: %v", name, value)
   118  	}
   119  	m, err := strconv.Atoi(s)
   120  	if err != nil {
   121  		return -1, fmt.Errorf("%q header is not an integer: %v", name, s)
   122  	}
   123  	return m, nil
   124  }
   125  
   126  func checkInt(headers map[string]interface{}, name string) (int, error) {
   127  	valueStr, err := checkNotEmptyString(headers, name)
   128  	if err != nil {
   129  		return -1, err
   130  	}
   131  	value, err := strconv.Atoi(valueStr)
   132  	if err != nil {
   133  		return -1, fmt.Errorf("%q header is not an integer: %v", name, valueStr)
   134  	}
   135  	return value, nil
   136  }
   137  
   138  func checkRFC3339Date(headers map[string]interface{}, name string) (time.Time, error) {
   139  	return checkRFC3339DateWhat(headers, name, "header")
   140  }
   141  
   142  func checkRFC3339DateWhat(m map[string]interface{}, name, what string) (time.Time, error) {
   143  	dateStr, err := checkNotEmptyStringWhat(m, name, what)
   144  	if err != nil {
   145  		return time.Time{}, err
   146  	}
   147  	date, err := time.Parse(time.RFC3339, dateStr)
   148  	if err != nil {
   149  		return time.Time{}, fmt.Errorf("%q %s is not a RFC3339 date: %v", name, what, err)
   150  	}
   151  	return date, nil
   152  }
   153  
   154  func checkRFC3339DateWithDefault(headers map[string]interface{}, name string, defl time.Time) (time.Time, error) {
   155  	return checkRFC3339DateWithDefaultWhat(headers, name, "header", defl)
   156  }
   157  
   158  func checkRFC3339DateWithDefaultWhat(m map[string]interface{}, name, what string, defl time.Time) (time.Time, error) {
   159  	value, ok := m[name]
   160  	if !ok {
   161  		return defl, nil
   162  	}
   163  	dateStr, ok := value.(string)
   164  	if !ok {
   165  		return time.Time{}, fmt.Errorf("%q %s must be a string", name, what)
   166  	}
   167  	date, err := time.Parse(time.RFC3339, dateStr)
   168  	if err != nil {
   169  		return time.Time{}, fmt.Errorf("%q %s is not a RFC3339 date: %v", name, what, err)
   170  	}
   171  	return date, nil
   172  }
   173  
   174  func checkUint(headers map[string]interface{}, name string, bitSize int) (uint64, error) {
   175  	valueStr, err := checkNotEmptyString(headers, name)
   176  	if err != nil {
   177  		return 0, err
   178  	}
   179  
   180  	value, err := strconv.ParseUint(valueStr, 10, bitSize)
   181  	if err != nil {
   182  		return 0, fmt.Errorf("%q header is not an unsigned integer: %v", name, valueStr)
   183  	}
   184  	return value, nil
   185  }
   186  
   187  func checkDigest(headers map[string]interface{}, name string, h crypto.Hash) ([]byte, error) {
   188  	digestStr, err := checkNotEmptyString(headers, name)
   189  	if err != nil {
   190  		return nil, err
   191  	}
   192  	b, err := base64.RawURLEncoding.DecodeString(digestStr)
   193  	if err != nil {
   194  		return nil, fmt.Errorf("%q header cannot be decoded: %v", name, err)
   195  	}
   196  	if len(b) != h.Size() {
   197  		return nil, fmt.Errorf("%q header does not have the expected bit length: %d", name, len(b)*8)
   198  	}
   199  
   200  	return b, nil
   201  }
   202  
   203  // checkStringListInMap returns the `name` entry in the `m` map as a (possibly nil) `[]string`
   204  // if `m` has an entry for `name` and it isn't a `[]string`, an error is returned
   205  // if pattern is not nil, all the strings must match that pattern, otherwise an error is returned
   206  // `what` is a descriptor, used for error messages
   207  func checkStringListInMap(m map[string]interface{}, name, what string, pattern *regexp.Regexp) ([]string, error) {
   208  	value, ok := m[name]
   209  	if !ok {
   210  		return nil, nil
   211  	}
   212  	lst, ok := value.([]interface{})
   213  	if !ok {
   214  		return nil, fmt.Errorf("%s must be a list of strings", what)
   215  	}
   216  	if len(lst) == 0 {
   217  		return nil, nil
   218  	}
   219  	res := make([]string, len(lst))
   220  	for i, v := range lst {
   221  		s, ok := v.(string)
   222  		if !ok {
   223  			return nil, fmt.Errorf("%s must be a list of strings", what)
   224  		}
   225  		if pattern != nil && !pattern.MatchString(s) {
   226  			return nil, fmt.Errorf("%s contains an invalid element: %q", what, s)
   227  		}
   228  		res[i] = s
   229  	}
   230  	return res, nil
   231  }
   232  
   233  func checkStringList(headers map[string]interface{}, name string) ([]string, error) {
   234  	return checkStringListMatches(headers, name, nil)
   235  }
   236  
   237  func checkStringListMatches(headers map[string]interface{}, name string, pattern *regexp.Regexp) ([]string, error) {
   238  	return checkStringListInMap(headers, name, fmt.Sprintf("%q header", name), pattern)
   239  }
   240  
   241  func checkStringMatches(headers map[string]interface{}, name string, pattern *regexp.Regexp) (string, error) {
   242  	return checkStringMatchesWhat(headers, name, "header", pattern)
   243  }
   244  
   245  func checkStringMatchesWhat(headers map[string]interface{}, name, what string, pattern *regexp.Regexp) (string, error) {
   246  	s, err := checkNotEmptyStringWhat(headers, name, what)
   247  	if err != nil {
   248  		return "", err
   249  	}
   250  	if !pattern.MatchString(s) {
   251  		return "", fmt.Errorf("%q %s contains invalid characters: %q", name, what, s)
   252  	}
   253  	return s, nil
   254  }
   255  
   256  func checkOptionalBool(headers map[string]interface{}, name string) (bool, error) {
   257  	value, ok := headers[name]
   258  	if !ok {
   259  		return false, nil
   260  	}
   261  	s, ok := value.(string)
   262  	if !ok || (s != "true" && s != "false") {
   263  		return false, fmt.Errorf("%q header must be 'true' or 'false'", name)
   264  	}
   265  	return s == "true", nil
   266  }
   267  
   268  func checkMap(headers map[string]interface{}, name string) (map[string]interface{}, error) {
   269  	value, ok := headers[name]
   270  	if !ok {
   271  		return nil, nil
   272  	}
   273  	m, ok := value.(map[string]interface{})
   274  	if !ok {
   275  		return nil, fmt.Errorf("%q header must be a map", name)
   276  	}
   277  	return m, nil
   278  }