github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/util.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Utilities for testing Beam pipelines.""" 19 20 # pytype: skip-file 21 22 import collections 23 import glob 24 import io 25 import tempfile 26 from typing import Iterable 27 28 from apache_beam import pvalue 29 from apache_beam.transforms import window 30 from apache_beam.transforms.core import Create 31 from apache_beam.transforms.core import DoFn 32 from apache_beam.transforms.core import Map 33 from apache_beam.transforms.core import ParDo 34 from apache_beam.transforms.core import WindowInto 35 from apache_beam.transforms.ptransform import PTransform 36 from apache_beam.transforms.util import CoGroupByKey 37 38 __all__ = [ 39 'assert_that', 40 'equal_to', 41 'equal_to_per_window', 42 'is_empty', 43 'is_not_empty', 44 'matches_all', 45 # open_shards is internal and has no backwards compatibility guarantees. 46 'open_shards', 47 'TestWindowedValue', 48 ] 49 50 51 class BeamAssertException(Exception): 52 """Exception raised by matcher classes used by assert_that transform.""" 53 54 pass 55 56 57 # Used for reifying timestamps and windows for assert_that matchers. 58 TestWindowedValue = collections.namedtuple( 59 'TestWindowedValue', 'value timestamp windows') 60 61 62 def contains_in_any_order(iterable): 63 """Creates an object that matches another iterable if they both have the 64 same count of items. 65 66 Arguments: 67 iterable: An iterable of hashable objects. 68 """ 69 class InAnyOrder(object): 70 def __init__(self, iterable): 71 self._counter = collections.Counter(iterable) 72 73 def __eq__(self, other): 74 return self._counter == collections.Counter(other) 75 76 def __hash__(self): 77 return hash(self._counter) 78 79 def __repr__(self): 80 return "InAnyOrder(%s)" % self._counter 81 82 return InAnyOrder(iterable) 83 84 85 class _EqualToPerWindowMatcher(object): 86 def __init__(self, expected_window_to_elements): 87 self._expected_window_to_elements = expected_window_to_elements 88 89 def __call__(self, value): 90 # Short-hand. 91 _expected = self._expected_window_to_elements 92 93 # Match the given windowed value to an expected window. Fails if the window 94 # doesn't exist or the element wasn't found in the window. 95 def match(windowed_value): 96 actual = windowed_value.value 97 window_key = windowed_value.windows[0] 98 try: 99 _expected[window_key] 100 except KeyError: 101 raise BeamAssertException( 102 'Failed assert: window {} not found in any expected ' \ 103 'windows {}'.format(window_key, list(_expected.keys())))\ 104 105 # Remove any matched elements from the window. This is used later on to 106 # assert that all elements in the window were matched with actual 107 # elements. 108 try: 109 _expected[window_key].remove(actual) 110 except ValueError: 111 raise BeamAssertException( 112 'Failed assert: element {} not found in window ' \ 113 '{}:{}'.format(actual, window_key, _expected[window_key]))\ 114 115 # Run the matcher for each window and value pair. Fails if the 116 # windowed_value is not a TestWindowedValue. 117 for windowed_value in value: 118 if not isinstance(windowed_value, TestWindowedValue): 119 raise BeamAssertException( 120 'Failed assert: Received element {} is not of type ' \ 121 'TestWindowedValue. Did you forget to set reify_windows=True ' \ 122 'on the assertion?'.format(windowed_value)) 123 match(windowed_value) 124 125 # Finally, some elements may not have been matched. Assert that we removed 126 # all the elements that we received from the expected list. If the list is 127 # non-empty, then there are unmatched elements. 128 for win in _expected: 129 if _expected[win]: 130 raise BeamAssertException( 131 'Failed assert: unmatched elements {} in window {}'.format( 132 _expected[win], win)) 133 134 135 def equal_to_per_window(expected_window_to_elements): 136 """Matcher used by assert_that to check to assert expected windows. 137 138 The 'assert_that' statement must have reify_windows=True. This assertion works 139 when elements are emitted and are finally checked at the end of the window. 140 141 Arguments: 142 expected_window_to_elements: A dictionary where the keys are the windows 143 to check and the values are the elements associated with each window. 144 """ 145 146 return _EqualToPerWindowMatcher(expected_window_to_elements) 147 148 149 # Note that equal_to checks if expected and actual are permutations of each 150 # other. However, only permutations of the top level are checked. Therefore 151 # [1,2] and [2,1] are considered equal and [[1,2]] and [[2,1]] are not. 152 def equal_to(expected, equals_fn=None): 153 def _equal(actual, equals_fn=equals_fn): 154 expected_list = list(expected) 155 156 # Try to compare actual and expected by sorting. This fails with a 157 # TypeError in Python 3 if different types are present in the same 158 # collection. It can also raise false negatives for types that don't have 159 # a deterministic sort order, like pyarrow Tables as of 0.14.1 160 if not equals_fn: 161 equals_fn = lambda e, a: e == a 162 try: 163 sorted_expected = sorted(expected) 164 sorted_actual = sorted(actual) 165 if sorted_expected == sorted_actual: 166 return 167 except TypeError: 168 pass 169 # Slower method, used in two cases: 170 # 1) If sorted expected != actual, use this method to verify the inequality. 171 # This ensures we don't raise any false negatives for types that don't 172 # have a deterministic sort order. 173 # 2) As a fallback if we encounter a TypeError in python 3. this method 174 # works on collections that have different types. 175 unexpected = [] 176 for element in actual: 177 found = False 178 for i, v in enumerate(expected_list): 179 if equals_fn(v, element): 180 found = True 181 expected_list.pop(i) 182 break 183 if not found: 184 unexpected.append(element) 185 if unexpected or expected_list: 186 msg = 'Failed assert: %r == %r' % (expected, actual) 187 if unexpected: 188 msg = msg + ', unexpected elements %r' % unexpected 189 if expected_list: 190 msg = msg + ', missing elements %r' % expected_list 191 raise BeamAssertException(msg) 192 193 return _equal 194 195 196 def matches_all(expected): 197 """Matcher used by assert_that to check a set of matchers. 198 199 Args: 200 expected: A list of elements or hamcrest matchers to be used to match 201 the elements of a single PCollection. 202 """ 203 def _matches(actual): 204 from hamcrest.core import assert_that as hamcrest_assert 205 from hamcrest.library.collection import contains_inanyorder 206 expected_list = list(expected) 207 208 hamcrest_assert(actual, contains_inanyorder(*expected_list)) 209 210 return _matches 211 212 213 def is_empty(): 214 def _empty(actual): 215 actual = list(actual) 216 if actual: 217 raise BeamAssertException('Failed assert: [] == %r' % actual) 218 219 return _empty 220 221 222 def is_not_empty(): 223 """ 224 This is test method which makes sure that the pcol is not empty and it has 225 some data in it. 226 :return: 227 """ 228 def _not_empty(actual): 229 actual = list(actual) 230 if not actual: 231 raise BeamAssertException('Failed assert: pcol is empty') 232 233 return _not_empty 234 235 236 def assert_that( 237 actual, 238 matcher, 239 label='assert_that', 240 reify_windows=False, 241 use_global_window=True): 242 """A PTransform that checks a PCollection has an expected value. 243 244 Note that assert_that should be used only for testing pipelines since the 245 check relies on materializing the entire PCollection being checked. 246 247 Args: 248 actual: A PCollection. 249 matcher: A matcher function taking as argument the actual value of a 250 materialized PCollection. The matcher validates this actual value against 251 expectations and raises BeamAssertException if they are not met. 252 label: Optional string label. This is needed in case several assert_that 253 transforms are introduced in the same pipeline. 254 reify_windows: If True, matcher is passed a list of TestWindowedValue. 255 use_global_window: If False, matcher is passed a dictionary of 256 (k, v) = (window, elements in the window). 257 258 Returns: 259 Ignored. 260 """ 261 assert isinstance(actual, pvalue.PCollection), ( 262 '%s is not a supported type for Beam assert' % type(actual)) 263 264 if isinstance(matcher, _EqualToPerWindowMatcher): 265 reify_windows = True 266 use_global_window = True 267 268 class ReifyTimestampWindow(DoFn): 269 def process( 270 self, element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam): 271 # This returns TestWindowedValue instead of 272 # beam.utils.windowed_value.WindowedValue because ParDo will extract 273 # the timestamp and window out of the latter. 274 return [TestWindowedValue(element, timestamp, [window])] 275 276 class AddWindow(DoFn): 277 def process(self, element, window=DoFn.WindowParam): 278 yield element, window 279 280 class AssertThat(PTransform): 281 def expand(self, pcoll): 282 if reify_windows: 283 pcoll = pcoll | ParDo(ReifyTimestampWindow()) 284 285 keyed_singleton = pcoll.pipeline | Create([(None, None)]) 286 keyed_singleton.is_bounded = True 287 288 if use_global_window: 289 pcoll = pcoll | WindowInto(window.GlobalWindows()) 290 291 keyed_actual = pcoll | 'ToVoidKey' >> Map(lambda v: (None, v)) 292 keyed_actual.is_bounded = True 293 294 # This is a CoGroupByKey so that the matcher always runs, even if the 295 # PCollection is empty. 296 plain_actual = ((keyed_singleton, keyed_actual) 297 | 'Group' >> CoGroupByKey() 298 | 'Unkey' >> Map(lambda k_values: k_values[1][1])) 299 300 if not use_global_window: 301 plain_actual = plain_actual | 'AddWindow' >> ParDo(AddWindow()) 302 303 plain_actual = plain_actual | 'Match' >> Map(matcher) 304 305 def default_label(self): 306 return label 307 308 actual | AssertThat() # pylint: disable=expression-not-assigned 309 310 311 def open_shards(glob_pattern, mode='rt', encoding='utf-8'): 312 """Returns a composite file of all shards matching the given glob pattern. 313 314 Args: 315 glob_pattern (str): Pattern used to match files which should be opened. 316 mode (str): Specify the mode in which the file should be opened. For 317 available modes, check io.open() documentation. 318 encoding (str): Name of the encoding used to decode or encode the file. 319 This should only be used in text mode. 320 321 Returns: 322 A stream with the contents of the opened files. 323 """ 324 if 'b' in mode: 325 encoding = None 326 327 with tempfile.NamedTemporaryFile(delete=False) as out_file: 328 for shard in glob.glob(glob_pattern): 329 with open(shard, 'rb') as in_file: 330 out_file.write(in_file.read()) 331 concatenated_file_name = out_file.name 332 return io.open(concatenated_file_name, mode, encoding=encoding) 333 334 335 def _sort_lists(result): 336 if isinstance(result, list): 337 return sorted(result) 338 elif isinstance(result, tuple): 339 return tuple(_sort_lists(e) for e in result) 340 elif isinstance(result, dict): 341 return {k: _sort_lists(v) for k, v in result.items()} 342 elif isinstance(result, Iterable) and not isinstance(result, str): 343 return sorted(result) 344 else: 345 return result 346 347 348 # A utility transform that recursively sorts lists for easier testing. 349 SortLists = Map(_sort_lists)