github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/gcp/tests/pubsub_matcher.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """PubSub verifier used for end-to-end test."""
    19  
    20  # pytype: skip-file
    21  
    22  import logging
    23  import time
    24  from collections import Counter
    25  
    26  from hamcrest.core.base_matcher import BaseMatcher
    27  
    28  from apache_beam.io.gcp.pubsub import PubsubMessage
    29  
    30  __all__ = ['PubSubMessageMatcher']
    31  
    32  # Protect against environments where pubsub library is not available.
    33  try:
    34    from google.cloud import pubsub
    35  except ImportError:
    36    pubsub = None
    37  
    38  DEFAULT_TIMEOUT = 5 * 60
    39  DEFAULT_SLEEP_TIME = 1
    40  DEFAULT_MAX_MESSAGES_IN_ONE_PULL = 50
    41  DEFAULT_PULL_TIMEOUT = 30.0
    42  
    43  _LOGGER = logging.getLogger(__name__)
    44  
    45  
    46  class PubSubMessageMatcher(BaseMatcher):
    47    """Matcher that verifies messages from given subscription.
    48  
    49    This matcher can block the test and keep pulling messages from given
    50    subscription until all expected messages are shown or timeout.
    51    """
    52    def __init__(
    53        self,
    54        project,
    55        sub_name,
    56        expected_msg=None,
    57        expected_msg_len=None,
    58        timeout=DEFAULT_TIMEOUT,
    59        with_attributes=False,
    60        strip_attributes=None,
    61        sleep_time=DEFAULT_SLEEP_TIME,
    62        max_messages_in_one_pull=DEFAULT_MAX_MESSAGES_IN_ONE_PULL,
    63        pull_timeout=DEFAULT_PULL_TIMEOUT):
    64      """Initialize PubSubMessageMatcher object.
    65  
    66      Args:
    67        project: A name string of project.
    68        sub_name: A name string of subscription which is attached to output.
    69        expected_msg: A string list that contains expected message data pulled
    70          from the subscription. See also: with_attributes.
    71        expected_msg_len: Number of expected messages pulled from the
    72          subscription.
    73        timeout: Timeout in seconds to wait for all expected messages appears.
    74        with_attributes: If True, will match against both message data and
    75          attributes. If True, expected_msg should be a list of ``PubsubMessage``
    76          objects. Otherwise, it should be a list of ``bytes``.
    77        strip_attributes: List of strings. If with_attributes==True, strip the
    78          attributes keyed by these values from incoming messages.
    79          If a key is missing, will add an attribute with an error message as
    80          value to prevent a successful match.
    81        sleep_time: Time in seconds between which the pulls from pubsub are done.
    82        max_messages_in_one_pull: Maximum number of messages pulled from pubsub
    83          at once.
    84        pull_timeout: Time in seconds after which the pull from pubsub is repeated
    85      """
    86      if pubsub is None:
    87        raise ImportError('PubSub dependencies are not installed.')
    88      if not project:
    89        raise ValueError('Invalid project %s.' % project)
    90      if not sub_name:
    91        raise ValueError('Invalid subscription %s.' % sub_name)
    92      if not expected_msg_len and not expected_msg:
    93        raise ValueError(
    94            'Required expected_msg: {} or expected_msg_len: {}.'.format(
    95                expected_msg, expected_msg_len))
    96      if expected_msg and not isinstance(expected_msg, list):
    97        raise ValueError('Invalid expected messages %s.' % expected_msg)
    98      if expected_msg_len and not isinstance(expected_msg_len, int):
    99        raise ValueError('Invalid expected messages %s.' % expected_msg_len)
   100  
   101      self.project = project
   102      self.sub_name = sub_name
   103      self.expected_msg = expected_msg
   104      self.expected_msg_len = expected_msg_len or len(self.expected_msg)
   105      self.timeout = timeout
   106      self.messages = None
   107      self.messages_all_details = None
   108      self.with_attributes = with_attributes
   109      self.strip_attributes = strip_attributes
   110      self.sleep_time = sleep_time
   111      self.max_messages_in_one_pull = max_messages_in_one_pull
   112      self.pull_timeout = pull_timeout
   113  
   114    def _matches(self, _):
   115      if self.messages is None:
   116        self.messages, self.messages_all_details = self._wait_for_messages(
   117            self.expected_msg_len, self.timeout)
   118      if self.expected_msg:
   119        return Counter(self.messages) == Counter(self.expected_msg)
   120      else:
   121        return len(self.messages) == self.expected_msg_len
   122  
   123    def _wait_for_messages(self, expected_num, timeout):
   124      """Wait for messages from given subscription."""
   125      total_messages = []
   126      total_messages_all_details = []
   127  
   128      sub_client = pubsub.SubscriberClient()
   129      start_time = time.time()
   130      while time.time() - start_time <= timeout:
   131        response = sub_client.pull(
   132            subscription=self.sub_name,
   133            max_messages=self.max_messages_in_one_pull,
   134            timeout=self.pull_timeout)
   135        for rm in response.received_messages:
   136          msg = PubsubMessage._from_message(rm.message)
   137          full_message = (
   138              msg.data,
   139              msg.attributes,
   140              msg.attributes,
   141              msg.publish_time,
   142              msg.ordering_key)
   143          if not self.with_attributes:
   144            total_messages.append(msg.data)
   145            total_messages_all_details.append(full_message)
   146            continue
   147  
   148          if self.strip_attributes:
   149            for attr in self.strip_attributes:
   150              try:
   151                del msg.attributes[attr]
   152              except KeyError:
   153                msg.attributes[attr] = (
   154                    'PubSubMessageMatcher error: '
   155                    'expected attribute not found.')
   156          total_messages.append(msg)
   157          total_messages_all_details.append(full_message)
   158  
   159        ack_ids = [rm.ack_id for rm in response.received_messages]
   160        if ack_ids:
   161          sub_client.acknowledge(subscription=self.sub_name, ack_ids=ack_ids)
   162        if len(total_messages) >= expected_num:
   163          break
   164        time.sleep(self.sleep_time)
   165  
   166      if time.time() - start_time > timeout:
   167        _LOGGER.error(
   168            'Timeout after %d sec. Received %d messages from %s.',
   169            timeout,
   170            len(total_messages),
   171            self.sub_name)
   172      return total_messages, total_messages_all_details
   173  
   174    def describe_to(self, description):
   175      description.append_text('Expected %d messages.' % self.expected_msg_len)
   176  
   177    def describe_mismatch(self, _, mismatch_description):
   178      c_expected = Counter(self.expected_msg)
   179      c_actual = Counter(self.messages)
   180      mismatch_description.append_text("Got %d messages. " % (len(self.messages)))
   181      if self.expected_msg:
   182        expected = (c_expected - c_actual).items()
   183        unexpected = (c_actual - c_expected).items()
   184        unexpected_keys = [repr(item[0]) for item in unexpected]
   185        if self.with_attributes:
   186          unexpected_all_details = [
   187              x for x in self.messages_all_details
   188              if 'PubsubMessage(%s, %s)' % (repr(x[0]), x[1]) in unexpected_keys
   189          ]
   190        else:
   191          unexpected_all_details = [
   192              x for x in self.messages_all_details
   193              if repr(x[0]) in unexpected_keys
   194          ]
   195        mismatch_description.append_text(
   196            "Diffs (item, count):\n"
   197            "  Expected but not in actual: %s\n"
   198            "  Unexpected: %s\n"
   199            "  Unexpected (with all details): %s" %
   200            (expected, unexpected, unexpected_all_details))
   201      if self.with_attributes and self.strip_attributes:
   202        mismatch_description.append_text(
   203            '\n  Stripped attributes: %r' % self.strip_attributes)