github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/scripts/juju-txn-helper/txn_helper.py (about)

     1  #!/usr/bin/env python3
     2  # Copyright Canonical Ltd.
     3  # Licensed under the GNU General Affero Public License version 3.0.
     4  
     5  """txn_helper.py - A model transaction queue analysis tool"""
     6  
     7  import argparse
     8  import pprint
     9  import re
    10  import sys
    11  import textwrap
    12  from enum import Enum
    13  
    14  from bson import ObjectId
    15  from pymongo import MongoClient
    16  
    17  
    18  class OpState(Enum):
    19      """Juju transaction operation states."""
    20      PREPARING = 1
    21      PREPARED = 2
    22      ABORTING = 3
    23      APPLYING = 4
    24      ABORTED = 5
    25      APPLIED = 6
    26  
    27  
    28  class AssertionType(Enum):
    29      """Transaction assertion types.
    30  
    31      Note that these are simply names used for convenience in this script;
    32      DOC_EXISTS/DOC_MISSING refer to the d+/d- assertion codes, while QUERY_DOC
    33      refers to assertions which specify MongoDB query documents.
    34      """
    35      DOC_EXISTS = 1
    36      DOC_MISSING = 2
    37      QUERY_DOC = 3
    38  
    39  
    40  STATE_MAP = {enum_entry.value: enum_entry.name.lower() for enum_entry in OpState}
    41  
    42  DOC_EXISTS = "d+"
    43  DOC_MISSING = "d-"
    44  
    45  UUID_REGEX = re.compile(
    46      r"^[\dA-Fa-f]{8}-[\dA-Fa-f]{4}-[\dA-Fa-f]{4}-[\dA-Fa-f]{4}-[\dA-Fa-f]{12}$")
    47  
    48  
    49  def main():
    50      """The main program entry point."""
    51      args = parse_args()
    52      client_args = create_client_args(args)
    53      client = MongoClient(**client_args)
    54      if args.model:
    55          # Target all transactions from the specified model's txn-queue.
    56          model_uuid = get_model_uuid(client, args)
    57          txn_queue = get_model_transaction_queue(client, model_uuid)
    58          state_filter = None
    59      else:
    60          # Default behavior is to filter by transaction state, usually on the "aborted" state.
    61          txn_queue = None
    62          state_filter = args.state_filter
    63      walk_transaction_queue(client, txn_queue, state_filter, args.dump_transaction, args.include_passes, args.count)
    64  
    65  
    66  def parse_args():
    67      """Parse the command line arguments."""
    68      ap = argparse.ArgumentParser()    # pylint: disable=invalid-name
    69      ap.add_argument('model', nargs='?',
    70                      help='Name or UUID of model to examine.  If not specified, the full '
    71                           'transaction collection will be examined instead.')
    72      ap.add_argument('-H', '--host', help='Mongo hostname or URI to use, if required.')
    73      ap.add_argument('-u', '--user', help='Mongo username to use, if required.')
    74      ap.add_argument('-p', '--password', help='Mongo password to use, if required.')
    75      ap.add_argument('--auth-database', default='admin',
    76                      help='Mongo auth database to use, if required.  (Default: %(default)s)')
    77      ap.add_argument('--state', dest='state_filter', type=int, default=OpState.ABORTED.value,
    78                      help="Filter by state number.  This is used when querying the full "
    79                           "transaction queue; it is ignored if querying a specific object's "
    80                           "transaction queue.  (Default: %(default)s)")
    81      ap.add_argument('-c', '--count', type=int, help='Count of entries to examine')
    82      ap.add_argument('-d', '--dump-transaction', action='store_true',
    83                      help='Additionally pretty-print entire transactions to stdout')
    84      ap.add_argument('-P', '--include-passes', action='store_true', help='Include pass details')
    85      ap.add_argument('-s', '--ssl', '--tls', dest='tls', action='store_true', help='Enable TLS')
    86      return ap.parse_args()
    87  
    88  
    89  def create_client_args(args):
    90      """Create a set of client arguments suitable for talking to the target Mongo instance."""
    91      client_args = {}
    92      if args.host:
    93          client_args['host'] = args.host
    94      if args.user:
    95          client_args['username'] = args.user
    96          client_args['authSource'] = args.auth_database
    97      if args.password:
    98          client_args['password'] = args.password
    99      if args.tls:
   100          client_args['tls'] = True
   101          client_args['tlsAllowInvalidCertificates'] = True
   102      return client_args
   103  
   104  
   105  def get_model_uuid(client, args):
   106      """Given a model argument, convert it (if necessary) to a model UUID."""
   107      model = args.model
   108      print("Examining supplied model:", model)
   109      if UUID_REGEX.match(model):
   110          model_uuid = model
   111          print('Supplied model is a valid UUID; using as-is.')
   112      else:
   113          model_doc = client.juju.models.find_one({'name': model})
   114          if model_doc:
   115              model_uuid = model_doc['_id']
   116              print('Found matching UUID:', model_uuid)
   117          else:
   118              sys.exit('Could not find the specified model ({})'.format(model))
   119      return model_uuid
   120  
   121  
   122  def get_model_transaction_queue(client, model_uuid):
   123      """Retrieves a list of transaction IDs from the specified model's document."""
   124      model_doc = client.juju.models.find_one({"_id": model_uuid})
   125      if not model_doc:
   126          sys.exit('Could not find model with specified UUID ({})'.format(model_uuid))
   127      txn_queue = model_doc['txn-queue']
   128      print('Retrieved model transaction queue:')
   129      for ref in txn_queue:
   130          print('- {}'.format(ref))
   131      print('Converting to transaction IDs:')
   132      txn_ids = [ref.split("_")[0] for ref in txn_queue]
   133      for id_ in txn_ids:
   134          print('- {}'.format(id_))
   135      print()
   136      return txn_ids
   137  
   138  
   139  def walk_transaction_queue(client, txn_queue, state_filter, dump_transaction, include_passes, max_transaction_count):
   140      """Examine part or all of the transactions collection."""
   141      db_client = client.juju
   142      state_filter_args = {} if state_filter is None else {"s": state_filter}
   143      if txn_queue:
   144          # We'll perform one query for each specific transaction ID
   145          query_docs = [dict(state_filter_args, _id=ObjectId(txn)) for txn in txn_queue]
   146      else:
   147          # We'll use a single cursor and iterate over the full collection of transactions
   148          query_docs = [state_filter_args]
   149  
   150      txn_counter = 0
   151      for query_doc in query_docs:
   152          matches = db_client.txns.find(query_doc)
   153          for match in matches:
   154              print('Transaction {} (state: {}):'.format(str(match['_id']), get_state_as_string(match['s'])))
   155              if dump_transaction:
   156                  print('  Transaction dump:\n{}'.format(textwrap.indent(pprint.pformat(match), '    ')))
   157                  print()
   158              for op_index, op in enumerate(match['o']):                      # pylint: disable=invalid-name
   159                  _print_op_details(db_client, op_index, op, include_passes)  # pylint: disable=invalid-name
   160              txn_counter += 1
   161              print()
   162  
   163              if max_transaction_count and txn_counter >= max_transaction_count:
   164                  break
   165          if max_transaction_count and txn_counter >= max_transaction_count:
   166              break
   167  
   168  
   169  def get_state_as_string(i):
   170      """Given an integer transaction state, return its meaning as a string."""
   171      return STATE_MAP[i]
   172  
   173  
   174  def _print_op_details(db_client, op_index, op, include_passes):  # pylint: disable=invalid-name
   175      collection = getattr(db_client, op['c'])
   176      match_doc_id = op['d']
   177      find_filter = {"_id": match_doc_id}
   178      should_print = False
   179      if 'a' not in op:
   180          if include_passes:
   181              print('  Op {}: no assertion present; passes'.format(op_index))
   182      else:
   183          if op['a'] == DOC_MISSING:
   184              assertion_type = AssertionType.DOC_MISSING
   185              existing_doc = collection.find_one(find_filter)
   186              failed = existing_doc
   187          elif op['a'] == DOC_EXISTS:
   188              assertion_type = AssertionType.DOC_EXISTS
   189              existing_doc = collection.find_one(find_filter)
   190              failed = not existing_doc
   191          else:
   192              # Standard assertion.  Uses a query doc.
   193              assertion_type = AssertionType.QUERY_DOC
   194              find_filter.update(op['a'])
   195              existing_doc = collection.find_one(find_filter)
   196              failed = not existing_doc
   197  
   198          should_print = failed or include_passes
   199          if should_print:
   200              print('  Op {}: {} assertion {}'.format(op_index, assertion_type.name, 'FAILED' if failed else 'passed'))
   201              print("  Collection '{}', ID '{}'".format(op['c'], match_doc_id))
   202              if assertion_type == AssertionType.QUERY_DOC:
   203                  print(
   204                      '  Query doc tested was:\n{}'.format(textwrap.indent(pprint.pformat(find_filter), '    ')))
   205              if existing_doc:
   206                  print('  Existing doc is:\n{}'.format(textwrap.indent(pprint.pformat(existing_doc), '    ')))
   207              if 'i' in op:
   208                  print('  Insert doc is:\n{}'.format(textwrap.indent(pprint.pformat(op['i']), '    ')))
   209              if 'u' in op:
   210                  print('  Update doc is:\n{}'.format(textwrap.indent(pprint.pformat(op['u']), '    ')))
   211      if should_print:
   212          print()
   213  
   214  
   215  if __name__ == "__main__":
   216      main()