gitee.com/mysnapcore/mysnapd@v0.1.0/asserts/header_checks.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2015-2022 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package asserts
    21  
    22  import (
    23  	"crypto"
    24  	"encoding/base64"
    25  	"fmt"
    26  	"regexp"
    27  	"strconv"
    28  	"strings"
    29  	"time"
    30  )
    31  
    32  // common checks used when decoding/assembling assertions
    33  
    34  func checkExistsString(headers map[string]interface{}, name string) (string, error) {
    35  	return checkExistsStringWhat(headers, name, "header")
    36  }
    37  
    38  func checkExistsStringWhat(m map[string]interface{}, name, what string) (string, error) {
    39  	value, ok := m[name]
    40  	if !ok {
    41  		return "", fmt.Errorf("%q %s is mandatory", name, what)
    42  	}
    43  	s, ok := value.(string)
    44  	if !ok {
    45  		return "", fmt.Errorf("%q %s must be a string", name, what)
    46  	}
    47  	return s, nil
    48  }
    49  
    50  func checkNotEmptyString(headers map[string]interface{}, name string) (string, error) {
    51  	return checkNotEmptyStringWhat(headers, name, "header")
    52  }
    53  
    54  func checkNotEmptyStringWhat(m map[string]interface{}, name, what string) (string, error) {
    55  	s, err := checkExistsStringWhat(m, name, what)
    56  	if err != nil {
    57  		return "", err
    58  	}
    59  	if len(s) == 0 {
    60  		return "", fmt.Errorf("%q %s should not be empty", name, what)
    61  	}
    62  	return s, nil
    63  }
    64  
    65  func checkOptionalStringWhat(headers map[string]interface{}, name, what string) (string, error) {
    66  	value, ok := headers[name]
    67  	if !ok {
    68  		return "", nil
    69  	}
    70  	s, ok := value.(string)
    71  	if !ok {
    72  		return "", fmt.Errorf("%q %s must be a string", name, what)
    73  	}
    74  	return s, nil
    75  }
    76  
    77  func checkOptionalString(headers map[string]interface{}, name string) (string, error) {
    78  	return checkOptionalStringWhat(headers, name, "header")
    79  }
    80  
    81  func checkPrimaryKey(headers map[string]interface{}, primKey string) (string, error) {
    82  	value, err := checkNotEmptyString(headers, primKey)
    83  	if err != nil {
    84  		return "", err
    85  	}
    86  	if strings.Contains(value, "/") {
    87  		return "", fmt.Errorf("%q primary key header cannot contain '/'", primKey)
    88  	}
    89  	return value, nil
    90  }
    91  
    92  func checkAssertType(assertType *AssertionType) error {
    93  	if assertType == nil {
    94  		return fmt.Errorf("internal error: assertion type cannot be nil")
    95  	}
    96  	// validity check against known canonical
    97  	validity := typeRegistry[assertType.Name]
    98  	switch validity {
    99  	case assertType:
   100  		// fine, matches canonical
   101  		return nil
   102  	case nil:
   103  		return fmt.Errorf("internal error: unknown assertion type: %q", assertType.Name)
   104  	default:
   105  		return fmt.Errorf("internal error: unpredefined assertion type for name %q used (unexpected address %p)", assertType.Name, assertType)
   106  	}
   107  }
   108  
   109  // use 'defl' default if missing
   110  func checkIntWithDefault(headers map[string]interface{}, name string, defl int) (int, error) {
   111  	value, ok := headers[name]
   112  	if !ok {
   113  		return defl, nil
   114  	}
   115  	s, ok := value.(string)
   116  	if !ok {
   117  		return -1, fmt.Errorf("%q header is not an integer: %v", name, value)
   118  	}
   119  	m, err := atoi(s, "%q %s", name, "header")
   120  	if err != nil {
   121  		return -1, err
   122  	}
   123  	return m, nil
   124  }
   125  
   126  func checkInt(headers map[string]interface{}, name string) (int, error) {
   127  	return checkIntWhat(headers, name, "header")
   128  }
   129  
   130  func checkIntWhat(headers map[string]interface{}, name, what string) (int, error) {
   131  	valueStr, err := checkNotEmptyStringWhat(headers, name, what)
   132  	if err != nil {
   133  		return -1, err
   134  	}
   135  	value, err := atoi(valueStr, "%q %s", name, what)
   136  	if err != nil {
   137  		return -1, err
   138  	}
   139  	return value, nil
   140  }
   141  
   142  type intSyntaxError string
   143  
   144  func (e intSyntaxError) Error() string {
   145  	return string(e)
   146  }
   147  
   148  func atoi(valueStr, whichFmt string, whichArgs ...interface{}) (int, error) {
   149  	value, err := strconv.Atoi(valueStr)
   150  	if err != nil {
   151  		which := fmt.Sprintf(whichFmt, whichArgs...)
   152  		if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange {
   153  			return -1, fmt.Errorf("%s is out of range: %v", which, valueStr)
   154  		}
   155  		return -1, intSyntaxError(fmt.Sprintf("%s is not an integer: %v", which, valueStr))
   156  	}
   157  	if prefixZeros(valueStr) {
   158  		return -1, fmt.Errorf("%s has invalid prefix zeros: %s", fmt.Sprintf(whichFmt, whichArgs...), valueStr)
   159  	}
   160  	return value, nil
   161  }
   162  
   163  func prefixZeros(s string) bool {
   164  	return strings.HasPrefix(s, "0") && s != "0"
   165  }
   166  
   167  func checkRFC3339Date(headers map[string]interface{}, name string) (time.Time, error) {
   168  	return checkRFC3339DateWhat(headers, name, "header")
   169  }
   170  
   171  func checkRFC3339DateWhat(m map[string]interface{}, name, what string) (time.Time, error) {
   172  	dateStr, err := checkNotEmptyStringWhat(m, name, what)
   173  	if err != nil {
   174  		return time.Time{}, err
   175  	}
   176  	date, err := time.Parse(time.RFC3339, dateStr)
   177  	if err != nil {
   178  		return time.Time{}, fmt.Errorf("%q %s is not a RFC3339 date: %v", name, what, err)
   179  	}
   180  	return date, nil
   181  }
   182  
   183  func checkRFC3339DateWithDefaultWhat(m map[string]interface{}, name, what string, defl time.Time) (time.Time, error) {
   184  	value, ok := m[name]
   185  	if !ok {
   186  		return defl, nil
   187  	}
   188  	dateStr, ok := value.(string)
   189  	if !ok {
   190  		return time.Time{}, fmt.Errorf("%q %s must be a string", name, what)
   191  	}
   192  	date, err := time.Parse(time.RFC3339, dateStr)
   193  	if err != nil {
   194  		return time.Time{}, fmt.Errorf("%q %s is not a RFC3339 date: %v", name, what, err)
   195  	}
   196  	return date, nil
   197  }
   198  
   199  func checkUint(headers map[string]interface{}, name string, bitSize int) (uint64, error) {
   200  	valueStr, err := checkNotEmptyString(headers, name)
   201  	if err != nil {
   202  		return 0, err
   203  	}
   204  	value, err := strconv.ParseUint(valueStr, 10, bitSize)
   205  	if err != nil {
   206  		if ne, ok := err.(*strconv.NumError); ok && ne.Err == strconv.ErrRange {
   207  			return 0, fmt.Errorf("%q header is out of range: %v", name, valueStr)
   208  		}
   209  		return 0, fmt.Errorf("%q header is not an unsigned integer: %v", name, valueStr)
   210  	}
   211  	if prefixZeros(valueStr) {
   212  		return 0, fmt.Errorf("%q header has invalid prefix zeros: %s", name, valueStr)
   213  	}
   214  	return value, nil
   215  }
   216  
   217  func checkDigest(headers map[string]interface{}, name string, h crypto.Hash) ([]byte, error) {
   218  	digestStr, err := checkNotEmptyString(headers, name)
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  	b, err := base64.RawURLEncoding.DecodeString(digestStr)
   223  	if err != nil {
   224  		return nil, fmt.Errorf("%q header cannot be decoded: %v", name, err)
   225  	}
   226  	if len(b) != h.Size() {
   227  		return nil, fmt.Errorf("%q header does not have the expected bit length: %d", name, len(b)*8)
   228  	}
   229  
   230  	return b, nil
   231  }
   232  
   233  // checkStringListInMap returns the `name` entry in the `m` map as a (possibly nil) `[]string`
   234  // if `m` has an entry for `name` and it isn't a `[]string`, an error is returned
   235  // if pattern is not nil, all the strings must match that pattern, otherwise an error is returned
   236  // `what` is a descriptor, used for error messages
   237  func checkStringListInMap(m map[string]interface{}, name, what string, pattern *regexp.Regexp) ([]string, error) {
   238  	value, ok := m[name]
   239  	if !ok {
   240  		return nil, nil
   241  	}
   242  	lst, ok := value.([]interface{})
   243  	if !ok {
   244  		return nil, fmt.Errorf("%s must be a list of strings", what)
   245  	}
   246  	if len(lst) == 0 {
   247  		return nil, nil
   248  	}
   249  	res := make([]string, len(lst))
   250  	for i, v := range lst {
   251  		s, ok := v.(string)
   252  		if !ok {
   253  			return nil, fmt.Errorf("%s must be a list of strings", what)
   254  		}
   255  		if pattern != nil && !pattern.MatchString(s) {
   256  			return nil, fmt.Errorf("%s contains an invalid element: %q", what, s)
   257  		}
   258  		res[i] = s
   259  	}
   260  	return res, nil
   261  }
   262  
   263  func checkStringList(headers map[string]interface{}, name string) ([]string, error) {
   264  	return checkStringListMatches(headers, name, nil)
   265  }
   266  
   267  func checkStringListMatches(headers map[string]interface{}, name string, pattern *regexp.Regexp) ([]string, error) {
   268  	return checkStringListInMap(headers, name, fmt.Sprintf("%q header", name), pattern)
   269  }
   270  
   271  func checkStringMatches(headers map[string]interface{}, name string, pattern *regexp.Regexp) (string, error) {
   272  	return checkStringMatchesWhat(headers, name, "header", pattern)
   273  }
   274  
   275  func checkStringMatchesWhat(headers map[string]interface{}, name, what string, pattern *regexp.Regexp) (string, error) {
   276  	s, err := checkNotEmptyStringWhat(headers, name, what)
   277  	if err != nil {
   278  		return "", err
   279  	}
   280  	if !pattern.MatchString(s) {
   281  		return "", fmt.Errorf("%q %s contains invalid characters: %q", name, what, s)
   282  	}
   283  	return s, nil
   284  }
   285  
   286  func checkOptionalBool(headers map[string]interface{}, name string) (bool, error) {
   287  	return checkOptionalBoolWhat(headers, name, "header")
   288  }
   289  
   290  func checkOptionalBoolWhat(headers map[string]interface{}, name, what string) (bool, error) {
   291  	value, ok := headers[name]
   292  	if !ok {
   293  		return false, nil
   294  	}
   295  	s, ok := value.(string)
   296  	if !ok || (s != "true" && s != "false") {
   297  		return false, fmt.Errorf("%q %s must be 'true' or 'false'", name, what)
   298  	}
   299  	return s == "true", nil
   300  }
   301  
   302  func checkMap(headers map[string]interface{}, name string) (map[string]interface{}, error) {
   303  	return checkMapWhat(headers, name, "header")
   304  }
   305  
   306  func checkMapWhat(m map[string]interface{}, name, what string) (map[string]interface{}, error) {
   307  	value, ok := m[name]
   308  	if !ok {
   309  		return nil, nil
   310  	}
   311  	mv, ok := value.(map[string]interface{})
   312  	if !ok {
   313  		return nil, fmt.Errorf("%q %s must be a map", name, what)
   314  	}
   315  	return mv, nil
   316  }