github.com/canonical/ubuntu-image@v0.0.0-20240430122802-2202fe98b290/tests/lib/external/snapd-testing-tools/utils/spread-shellcheck (about)

     1  #!/usr/bin/env python3
     2  
     3  # Copyright (C) 2022 Canonical Ltd
     4  #
     5  # This program is free software: you can redistribute it and/or modify
     6  # it under the terms of the GNU General Public License version 3 as
     7  # published by the Free Software Foundation.
     8  #
     9  # This program is distributed in the hope that it will be useful,
    10  # but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12  # GNU General Public License for more details.
    13  #
    14  # You should have received a copy of the GNU General Public License
    15  # along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16  
    17  import argparse
    18  import binascii
    19  import hashlib
    20  import itertools
    21  import logging
    22  import os
    23  import re
    24  import subprocess
    25  import yaml
    26  
    27  from collections import namedtuple
    28  from concurrent.futures import ThreadPoolExecutor
    29  from multiprocessing import cpu_count
    30  from pathlib import Path
    31  from threading import Lock
    32  from typing import Dict
    33  
    34  
    35  # default shell for shellcheck
    36  SHELLCHECK_SHELL = os.getenv('SHELLCHECK_SHELL', 'bash')
    37  # set to non-empty to ignore all errors
    38  NO_FAIL = os.getenv('NO_FAIL')
    39  # set to non empty to enable 'set -x'
    40  D = os.getenv('D')
    41  # set to non-empty to enable verbose logging
    42  V = os.getenv('V')
    43  # set to a number to use these many threads
    44  N = int(os.getenv('N') or cpu_count())
    45  # file with list of files that can fail validation
    46  CAN_FAIL = os.getenv('CAN_FAIL')
    47  
    48  # names of sections
    49  SECTIONS = ['prepare', 'prepare-each', 'restore', 'restore-each',
    50              'debug', 'debug-each', 'execute', 'repack']
    51  
    52  
    53  def parse_arguments():
    54      parser = argparse.ArgumentParser(description='spread shellcheck helper')
    55      parser.add_argument('-s', '--shell', default='bash',
    56                          help='shell')
    57      parser.add_argument('-n', '--no-errors', action='store_true',
    58                          default=False, help='ignore all errors ')
    59      parser.add_argument('-v', '--verbose', action='store_true',
    60                          default=False, help='verbose logging')
    61      parser.add_argument('--can-fail', default=None,
    62                          help=('file with list of files that are can fail '
    63                                'validation'))
    64      parser.add_argument('-P', '--max-procs', default=N, type=int, metavar='N',
    65                          help='run these many shellchecks in parallel (default: %(default)s)')
    66      parser.add_argument('-e', '--exclude', default=[], action="append",
    67                          help='path to exclude of the shell check')
    68      parser.add_argument('--no-cache', help='disable caching', action='store_true')
    69      parser.add_argument('paths', nargs='+', help='paths to check')
    70      return parser.parse_args()
    71  
    72  
    73  class ShellcheckRunError(Exception):
    74      def __init__(self, stderr):
    75          super().__init__()
    76          self.stderr = stderr
    77  
    78  
    79  class ShellcheckError(Exception):
    80      def __init__(self, path):
    81          super().__init__()
    82          self.sectionerrors = {}
    83          self.path = path
    84  
    85      def addfailure(self, section, error):
    86          self.sectionerrors[section] = error
    87  
    88      def __len__(self):
    89          return len(self.sectionerrors)
    90  
    91  
    92  class ShellcheckFailures(Exception):
    93      def __init__(self, failures=None):
    94          super().__init__()
    95          self.failures = set()
    96          if failures:
    97              self.failures = set(failures)
    98  
    99      def merge(self, otherfailures):
   100          self.failures = self.failures.union(otherfailures.failures)
   101  
   102      def __len__(self):
   103          return len(self.failures)
   104  
   105      def intersection(self, other):
   106          return self.failures.intersection(other)
   107  
   108      def difference(self, other):
   109          return self.failures.difference(other)
   110  
   111      def __iter__(self):
   112          return iter(self.failures)
   113  
   114  
   115  def checksection(data, env: Dict[str, str]):
   116      # spread shell snippets are executed under 'set -e' shell, make sure
   117      # shellcheck knows about that
   118      script_data = []
   119      script_data.append('set -e')
   120  
   121      for key, value in env.items():
   122          value = str(value)
   123          disabled_warnings = set()
   124          export_disabled_warnings = set()
   125          def replacement(match):
   126              if match.group(0) == '"':
   127                  # SC2089 and SC2090 are about quotes vs arrays
   128                  # We cannot have arrays in environment variables of spread
   129                  # So we do have to use quotes
   130                  disabled_warnings.add('SC2089')
   131                  export_disabled_warnings.add('SC2090')
   132                  return r'\"'
   133              else:
   134                  assert(match.group('command') is not None)
   135                  # "Useless" echo. This is what we get.
   136                  # We cannot just evaluate to please shellcheck.
   137                  disabled_warnings.add('SC2116')
   138                  return '$({})'.format(match.group('command'))
   139          value = re.sub(r'[$][(]HOST:(?P<command>.*)[)]|"', replacement, value)
   140          # converts
   141          # FOO: "$(HOST: echo $foo)"     -> FOO="$(echo $foo)"
   142          # FOO: "$(HOST: echo \"$foo\")" -> FOO="$(echo "$foo")"
   143          # FOO: "foo"                    -> FOO="foo"
   144          # FOO: "\"foo\""                -> FOO="\"foo\""
   145          if disabled_warnings:
   146              script_data.append("# shellcheck disable={}".format(','.join(disabled_warnings)))
   147          script_data.append("{}=\"{}\"".format(key, value))
   148          if export_disabled_warnings:
   149              script_data.append("# shellcheck disable={}".format(','.join(export_disabled_warnings)))
   150          script_data.append("export {}".format(key, value))
   151      script_data.append(data)
   152      proc = subprocess.Popen("shellcheck -s {} -x -".format(SHELLCHECK_SHELL),
   153                              stdout=subprocess.PIPE,
   154                              stdin=subprocess.PIPE,
   155                              shell=True)
   156      stdout, _ = proc.communicate(input='\n'.join(script_data).encode('utf-8'), timeout=60)
   157      if proc.returncode != 0:
   158          raise ShellcheckRunError(stdout)
   159  
   160  
   161  class Cacher:
   162      _instance = None
   163  
   164      def __init__(self):
   165          self._enabled = True
   166          self._lock = Lock()
   167          self._hit =0
   168          self._miss = 0
   169          self._shellcheck_version = None
   170          self._probe_shellcheck_version()
   171  
   172      @classmethod
   173      def init(cls):
   174          cls._instance = Cacher()
   175  
   176      @classmethod
   177      def get(cls):
   178          return cls._instance
   179  
   180      def disable(self):
   181          logging.debug("caching is disabled")
   182          self._enabled = False
   183  
   184      @staticmethod
   185      def _cache_path_for(digest):
   186          prefix = digest[:2]
   187          return Path.home().joinpath(".cache", "spread-shellcheck", prefix, digest)
   188  
   189      def is_cached(self, data, path):
   190          if not self._enabled:
   191              return False, ""
   192          # the digest uses script content and shellcheck versions as inputs, but
   193          # consider other possible inputs: path to the *.yaml file (so moving
   194          # the script around would cause a miss) or even the contents of this
   195          # script
   196          h = hashlib.sha256()
   197          h.update(self._shellcheck_version)
   198          h.update(data)
   199          hdg = binascii.b2a_hex(h.digest()).decode()
   200          cachepath = Cacher._cache_path_for(hdg)
   201          logging.debug("cache stamp %s, exists? %s", cachepath.as_posix(), cachepath.exists())
   202          hit = cachepath.exists()
   203          self._record_cache_event(hit)
   204          return hit, hdg
   205  
   206      def cache_success(self, digest, path):
   207          if not self._enabled:
   208              return
   209          cachepath = Cacher._cache_path_for(digest)
   210          logging.debug("cache success, path %s", cachepath.as_posix())
   211          cachepath.parent.mkdir(parents=True, exist_ok=True)
   212          cachepath.touch()
   213  
   214      def _record_cache_event(self, hit):
   215          with self._lock:
   216              if hit:
   217                  self._hit += 1
   218              else:
   219                  self._miss += 1
   220  
   221      def _probe_shellcheck_version(self):
   222          logging.debug("probing shellcheck version")
   223          out = subprocess.check_output("shellcheck --version", shell=True)
   224          self._shellcheck_version = out
   225  
   226      @property
   227      def stats(self):
   228          return namedtuple('Stats', ['hit', 'miss'])(self._hit, self._miss)
   229  
   230  
   231  def checkfile(path, executor):
   232      logging.debug("checking file %s", path)
   233      with open(path, mode='rb') as inf:
   234          rawdata = inf.read()
   235          cached, digest = Cacher.get().is_cached(rawdata, path)
   236          if cached:
   237              logging.debug("entry %s already cached", digest)
   238              return
   239          data = yaml.safe_load(rawdata)
   240  
   241      errors = ShellcheckError(path)
   242      # TODO: handle stacking of environment from other places that influence it:
   243      # spread.yaml -> global env + backend env + suite env -> task.yaml (task
   244      # env + variant env).
   245      env = {}
   246      for key, value in data.get("environment", {}).items():
   247          if "/" in key:
   248              # TODO: re-check with each variant's value set.
   249              key = key.split('/', 1)[0]
   250          env[key] = value
   251      for section in SECTIONS:
   252          if section not in data:
   253              continue
   254          try:
   255              logging.debug("%s: checking section %s", path, section)
   256              checksection(data[section], env)
   257          except ShellcheckRunError as serr:
   258              errors.addfailure(section, serr.stderr.decode('utf-8'))
   259  
   260      if path.endswith('spread.yaml') and 'suites' in data:
   261          # check suites
   262          suites_sections_and_futures = []
   263          for suite in data['suites'].keys():
   264              for section in SECTIONS:
   265                  if section not in data['suites'][suite]:
   266                      continue
   267                  logging.debug("%s (suite %s): checking section %s", path, suite, section)
   268                  future = executor.submit(checksection, data['suites'][suite][section], env)
   269                  suites_sections_and_futures.append((suite, section, future))
   270          for item in suites_sections_and_futures:
   271              suite, section, future = item
   272              try:
   273                  future.result()
   274              except ShellcheckRunError as serr:
   275                  errors.addfailure('suites/' + suite + '/' + section,
   276                                  serr.stderr.decode('utf-8'))
   277  
   278      if errors:
   279          raise errors
   280      # only stamp the cache when the script was found to be valid
   281      Cacher.get().cache_success(digest, path)
   282  
   283  
   284  def is_file_in_dirs(file, dirs):
   285      for dir in dirs:
   286          if os.path.abspath(file).startswith('{}/'.format(os.path.abspath(dir))):
   287              print('Skipping {}'.format(file))
   288              return True
   289  
   290      return False
   291  
   292  
   293  def findfiles(locations, exclude):
   294      for loc in locations:
   295          if os.path.isdir(loc):
   296              for root, _, files in os.walk(loc, topdown=True):
   297                  for name in files:
   298                      if name in ['spread.yaml', 'task.yaml']:
   299                          full_path = os.path.join(root, name)
   300                          if not is_file_in_dirs(full_path, exclude):
   301                              yield full_path
   302          else:
   303              full_path = os.path.abspath(loc)
   304              if not is_file_in_dirs(full_path, exclude):
   305                  yield full_path
   306  
   307  
   308  def check1path(path, executor):
   309      try:
   310          checkfile(path, executor)
   311      except ShellcheckError as err:
   312          return err
   313      return None
   314  
   315  
   316  def checkpaths(locs, exclude, executor):
   317      # setup iterator
   318      locations = findfiles(locs, exclude)
   319      failed = []
   320      for serr in executor.map(check1path, locations, itertools.repeat(executor)):
   321          if serr is None:
   322              continue
   323          logging.error(('shellcheck failed for file %s in sections: '
   324                         '%s; error log follows'),
   325                        serr.path, ', '.join(serr.sectionerrors.keys()))
   326          for section, error in serr.sectionerrors.items():
   327              logging.error("%s: section '%s':\n%s", serr.path, section, error)
   328          failed.append(serr.path)
   329  
   330      if failed:
   331          raise ShellcheckFailures(failures=failed)
   332  
   333  
   334  def loadfilelist(flistpath):
   335      flist = set()
   336      with open(flistpath) as inf:
   337          for line in inf:
   338              if not line.startswith('#'):
   339                  flist.add(line.strip())
   340      return flist
   341  
   342  
   343  def main(opts):
   344      paths = opts.paths or ['.']
   345      exclude = opts.exclude
   346      failures = ShellcheckFailures()
   347      with ThreadPoolExecutor(max_workers=opts.max_procs) as executor:
   348          try:
   349              checkpaths(paths, exclude, executor)
   350          except ShellcheckFailures as sf:
   351              failures.merge(sf)
   352  
   353      if not opts.no_cache:
   354          stats = Cacher.get().stats
   355          logging.info("cache stats: hit %d miss %d", stats.hit, stats.miss)
   356  
   357      if failures:
   358          if opts.can_fail:
   359              can_fail = loadfilelist(opts.can_fail)
   360  
   361              unexpected = failures.difference(can_fail)
   362              if unexpected:
   363                  logging.error(('validation failed for the following '
   364                                 'non-whitelisted files:\n%s'),
   365                                '\n'.join([' - ' + f for f in
   366                                           sorted(unexpected)]))
   367                  raise SystemExit(1)
   368  
   369              did_not_fail = can_fail - failures.intersection(can_fail)
   370              if did_not_fail:
   371                  logging.error(('the following files are whitelisted '
   372                                 'but validated successfully:\n%s'),
   373                                '\n'.join([' - ' + f for f in
   374                                           sorted(did_not_fail)]))
   375                  raise SystemExit(1)
   376  
   377              # no unexpected failures
   378              return
   379  
   380          logging.error('validation failed for the following files:\n%s',
   381                        '\n'.join([' - ' + f for f in sorted(failures)]))
   382  
   383          if NO_FAIL or opts.no_errors:
   384              logging.warning("ignoring errors")
   385          else:
   386              raise SystemExit(1)
   387  
   388  
   389  if __name__ == '__main__':
   390      opts = parse_arguments()
   391      if opts.verbose or D or V:
   392          lvl = logging.DEBUG
   393      else:
   394          lvl = logging.INFO
   395      logging.basicConfig(level=lvl)
   396  
   397      if CAN_FAIL:
   398          opts.can_fail = CAN_FAIL
   399  
   400      if NO_FAIL:
   401          opts.no_errors = True
   402  
   403      if opts.max_procs == 1:
   404          # TODO: temporary workaround for a deadlock when running with a single
   405          # worker
   406          opts.max_procs += 1
   407          logging.warning('workers count bumped to 2 to workaround a deadlock')
   408  
   409      Cacher.init()
   410      if opts.no_cache:
   411          Cacher.get().disable()
   412  
   413      main(opts)