github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/util.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Utilities for testing Beam pipelines."""
    19  
    20  # pytype: skip-file
    21  
    22  import collections
    23  import glob
    24  import io
    25  import tempfile
    26  from typing import Iterable
    27  
    28  from apache_beam import pvalue
    29  from apache_beam.transforms import window
    30  from apache_beam.transforms.core import Create
    31  from apache_beam.transforms.core import DoFn
    32  from apache_beam.transforms.core import Map
    33  from apache_beam.transforms.core import ParDo
    34  from apache_beam.transforms.core import WindowInto
    35  from apache_beam.transforms.ptransform import PTransform
    36  from apache_beam.transforms.util import CoGroupByKey
    37  
    38  __all__ = [
    39      'assert_that',
    40      'equal_to',
    41      'equal_to_per_window',
    42      'is_empty',
    43      'is_not_empty',
    44      'matches_all',
    45      # open_shards is internal and has no backwards compatibility guarantees.
    46      'open_shards',
    47      'TestWindowedValue',
    48  ]
    49  
    50  
    51  class BeamAssertException(Exception):
    52    """Exception raised by matcher classes used by assert_that transform."""
    53  
    54    pass
    55  
    56  
    57  # Used for reifying timestamps and windows for assert_that matchers.
    58  TestWindowedValue = collections.namedtuple(
    59      'TestWindowedValue', 'value timestamp windows')
    60  
    61  
    62  def contains_in_any_order(iterable):
    63    """Creates an object that matches another iterable if they both have the
    64    same count of items.
    65  
    66    Arguments:
    67      iterable: An iterable of hashable objects.
    68    """
    69    class InAnyOrder(object):
    70      def __init__(self, iterable):
    71        self._counter = collections.Counter(iterable)
    72  
    73      def __eq__(self, other):
    74        return self._counter == collections.Counter(other)
    75  
    76      def __hash__(self):
    77        return hash(self._counter)
    78  
    79      def __repr__(self):
    80        return "InAnyOrder(%s)" % self._counter
    81  
    82    return InAnyOrder(iterable)
    83  
    84  
    85  class _EqualToPerWindowMatcher(object):
    86    def __init__(self, expected_window_to_elements):
    87      self._expected_window_to_elements = expected_window_to_elements
    88  
    89    def __call__(self, value):
    90      # Short-hand.
    91      _expected = self._expected_window_to_elements
    92  
    93      # Match the given windowed value to an expected window. Fails if the window
    94      # doesn't exist or the element wasn't found in the window.
    95      def match(windowed_value):
    96        actual = windowed_value.value
    97        window_key = windowed_value.windows[0]
    98        try:
    99          _expected[window_key]
   100        except KeyError:
   101          raise BeamAssertException(
   102              'Failed assert: window {} not found in any expected ' \
   103              'windows {}'.format(window_key, list(_expected.keys())))\
   104  
   105        # Remove any matched elements from the window. This is used later on to
   106        # assert that all elements in the window were matched with actual
   107        # elements.
   108        try:
   109          _expected[window_key].remove(actual)
   110        except ValueError:
   111          raise BeamAssertException(
   112              'Failed assert: element {} not found in window ' \
   113              '{}:{}'.format(actual, window_key, _expected[window_key]))\
   114  
   115      # Run the matcher for each window and value pair. Fails if the
   116      # windowed_value is not a TestWindowedValue.
   117      for windowed_value in value:
   118        if not isinstance(windowed_value, TestWindowedValue):
   119          raise BeamAssertException(
   120              'Failed assert: Received element {} is not of type ' \
   121              'TestWindowedValue. Did you forget to set reify_windows=True ' \
   122              'on the assertion?'.format(windowed_value))
   123        match(windowed_value)
   124  
   125      # Finally, some elements may not have been matched. Assert that we removed
   126      # all the elements that we received from the expected list. If the list is
   127      # non-empty, then there are unmatched elements.
   128      for win in _expected:
   129        if _expected[win]:
   130          raise BeamAssertException(
   131              'Failed assert: unmatched elements {} in window {}'.format(
   132                  _expected[win], win))
   133  
   134  
   135  def equal_to_per_window(expected_window_to_elements):
   136    """Matcher used by assert_that to check to assert expected windows.
   137  
   138    The 'assert_that' statement must have reify_windows=True. This assertion works
   139    when elements are emitted and are finally checked at the end of the window.
   140  
   141    Arguments:
   142      expected_window_to_elements: A dictionary where the keys are the windows
   143        to check and the values are the elements associated with each window.
   144    """
   145  
   146    return _EqualToPerWindowMatcher(expected_window_to_elements)
   147  
   148  
   149  # Note that equal_to checks if expected and actual are permutations of each
   150  # other. However, only permutations of the top level are checked. Therefore
   151  # [1,2] and [2,1] are considered equal and [[1,2]] and [[2,1]] are not.
   152  def equal_to(expected, equals_fn=None):
   153    def _equal(actual, equals_fn=equals_fn):
   154      expected_list = list(expected)
   155  
   156      # Try to compare actual and expected by sorting. This fails with a
   157      # TypeError in Python 3 if different types are present in the same
   158      # collection. It can also raise false negatives for types that don't have
   159      # a deterministic sort order, like pyarrow Tables as of 0.14.1
   160      if not equals_fn:
   161        equals_fn = lambda e, a: e == a
   162        try:
   163          sorted_expected = sorted(expected)
   164          sorted_actual = sorted(actual)
   165          if sorted_expected == sorted_actual:
   166            return
   167        except TypeError:
   168          pass
   169      # Slower method, used in two cases:
   170      # 1) If sorted expected != actual, use this method to verify the inequality.
   171      #    This ensures we don't raise any false negatives for types that don't
   172      #    have a deterministic sort order.
   173      # 2) As a fallback if we encounter a TypeError in python 3. this method
   174      #    works on collections that have different types.
   175      unexpected = []
   176      for element in actual:
   177        found = False
   178        for i, v in enumerate(expected_list):
   179          if equals_fn(v, element):
   180            found = True
   181            expected_list.pop(i)
   182            break
   183        if not found:
   184          unexpected.append(element)
   185      if unexpected or expected_list:
   186        msg = 'Failed assert: %r == %r' % (expected, actual)
   187        if unexpected:
   188          msg = msg + ', unexpected elements %r' % unexpected
   189        if expected_list:
   190          msg = msg + ', missing elements %r' % expected_list
   191        raise BeamAssertException(msg)
   192  
   193    return _equal
   194  
   195  
   196  def matches_all(expected):
   197    """Matcher used by assert_that to check a set of matchers.
   198  
   199    Args:
   200      expected: A list of elements or hamcrest matchers to be used to match
   201        the elements of a single PCollection.
   202    """
   203    def _matches(actual):
   204      from hamcrest.core import assert_that as hamcrest_assert
   205      from hamcrest.library.collection import contains_inanyorder
   206      expected_list = list(expected)
   207  
   208      hamcrest_assert(actual, contains_inanyorder(*expected_list))
   209  
   210    return _matches
   211  
   212  
   213  def is_empty():
   214    def _empty(actual):
   215      actual = list(actual)
   216      if actual:
   217        raise BeamAssertException('Failed assert: [] == %r' % actual)
   218  
   219    return _empty
   220  
   221  
   222  def is_not_empty():
   223    """
   224    This is test method which makes sure that the pcol is not empty and it has
   225    some data in it.
   226    :return:
   227    """
   228    def _not_empty(actual):
   229      actual = list(actual)
   230      if not actual:
   231        raise BeamAssertException('Failed assert: pcol is empty')
   232  
   233    return _not_empty
   234  
   235  
   236  def assert_that(
   237      actual,
   238      matcher,
   239      label='assert_that',
   240      reify_windows=False,
   241      use_global_window=True):
   242    """A PTransform that checks a PCollection has an expected value.
   243  
   244    Note that assert_that should be used only for testing pipelines since the
   245    check relies on materializing the entire PCollection being checked.
   246  
   247    Args:
   248      actual: A PCollection.
   249      matcher: A matcher function taking as argument the actual value of a
   250        materialized PCollection. The matcher validates this actual value against
   251        expectations and raises BeamAssertException if they are not met.
   252      label: Optional string label. This is needed in case several assert_that
   253        transforms are introduced in the same pipeline.
   254      reify_windows: If True, matcher is passed a list of TestWindowedValue.
   255      use_global_window: If False, matcher is passed a dictionary of
   256        (k, v) = (window, elements in the window).
   257  
   258    Returns:
   259      Ignored.
   260    """
   261    assert isinstance(actual, pvalue.PCollection), (
   262        '%s is not a supported type for Beam assert' % type(actual))
   263  
   264    if isinstance(matcher, _EqualToPerWindowMatcher):
   265      reify_windows = True
   266      use_global_window = True
   267  
   268    class ReifyTimestampWindow(DoFn):
   269      def process(
   270          self, element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam):
   271        # This returns TestWindowedValue instead of
   272        # beam.utils.windowed_value.WindowedValue because ParDo will extract
   273        # the timestamp and window out of the latter.
   274        return [TestWindowedValue(element, timestamp, [window])]
   275  
   276    class AddWindow(DoFn):
   277      def process(self, element, window=DoFn.WindowParam):
   278        yield element, window
   279  
   280    class AssertThat(PTransform):
   281      def expand(self, pcoll):
   282        if reify_windows:
   283          pcoll = pcoll | ParDo(ReifyTimestampWindow())
   284  
   285        keyed_singleton = pcoll.pipeline | Create([(None, None)])
   286        keyed_singleton.is_bounded = True
   287  
   288        if use_global_window:
   289          pcoll = pcoll | WindowInto(window.GlobalWindows())
   290  
   291        keyed_actual = pcoll | 'ToVoidKey' >> Map(lambda v: (None, v))
   292        keyed_actual.is_bounded = True
   293  
   294        # This is a CoGroupByKey so that the matcher always runs, even if the
   295        # PCollection is empty.
   296        plain_actual = ((keyed_singleton, keyed_actual)
   297                        | 'Group' >> CoGroupByKey()
   298                        | 'Unkey' >> Map(lambda k_values: k_values[1][1]))
   299  
   300        if not use_global_window:
   301          plain_actual = plain_actual | 'AddWindow' >> ParDo(AddWindow())
   302  
   303        plain_actual = plain_actual | 'Match' >> Map(matcher)
   304  
   305      def default_label(self):
   306        return label
   307  
   308    actual | AssertThat()  # pylint: disable=expression-not-assigned
   309  
   310  
   311  def open_shards(glob_pattern, mode='rt', encoding='utf-8'):
   312    """Returns a composite file of all shards matching the given glob pattern.
   313  
   314    Args:
   315      glob_pattern (str): Pattern used to match files which should be opened.
   316      mode (str): Specify the mode in which the file should be opened. For
   317                  available modes, check io.open() documentation.
   318      encoding (str): Name of the encoding used to decode or encode the file.
   319                      This should only be used in text mode.
   320  
   321    Returns:
   322      A stream with the contents of the opened files.
   323    """
   324    if 'b' in mode:
   325      encoding = None
   326  
   327    with tempfile.NamedTemporaryFile(delete=False) as out_file:
   328      for shard in glob.glob(glob_pattern):
   329        with open(shard, 'rb') as in_file:
   330          out_file.write(in_file.read())
   331      concatenated_file_name = out_file.name
   332    return io.open(concatenated_file_name, mode, encoding=encoding)
   333  
   334  
   335  def _sort_lists(result):
   336    if isinstance(result, list):
   337      return sorted(result)
   338    elif isinstance(result, tuple):
   339      return tuple(_sort_lists(e) for e in result)
   340    elif isinstance(result, dict):
   341      return {k: _sort_lists(v) for k, v in result.items()}
   342    elif isinstance(result, Iterable) and not isinstance(result, str):
   343      return sorted(result)
   344    else:
   345      return result
   346  
   347  
   348  # A utility transform that recursively sorts lists for easier testing.
   349  SortLists = Map(_sort_lists)