github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/caching/streaming_cache_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import unittest 21 22 from apache_beam import coders 23 from apache_beam.options.pipeline_options import StandardOptions 24 from apache_beam.portability.api import beam_interactive_api_pb2 25 from apache_beam.portability.api import beam_runner_api_pb2 26 from apache_beam.runners.interactive import interactive_beam as ib 27 from apache_beam.runners.interactive.cache_manager import SafeFastPrimitivesCoder 28 from apache_beam.runners.interactive.caching.cacheable import CacheKey 29 from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache 30 from apache_beam.runners.interactive.testing.test_cache_manager import FileRecordsBuilder 31 from apache_beam.testing.test_pipeline import TestPipeline 32 from apache_beam.testing.test_stream import TestStream 33 from apache_beam.testing.util import * 34 from apache_beam.transforms.window import TimestampedValue 35 36 # Nose automatically detects tests if they match a regex. Here, it mistakens 37 # these protos as tests. For more info see the Nose docs at: 38 # https://nose.readthedocs.io/en/latest/writing_tests.html 39 beam_runner_api_pb2.TestStreamPayload.__test__ = False # type: ignore[attr-defined] 40 beam_interactive_api_pb2.TestStreamFileHeader.__test__ = False # type: ignore[attr-defined] 41 beam_interactive_api_pb2.TestStreamFileRecord.__test__ = False # type: ignore[attr-defined] 42 43 44 class StreamingCacheTest(unittest.TestCase): 45 def setUp(self): 46 pass 47 48 def test_exists(self): 49 cache = StreamingCache(cache_dir=None) 50 self.assertFalse(cache.exists('my_label')) 51 cache.write([beam_interactive_api_pb2.TestStreamFileRecord()], 'my_label') 52 self.assertTrue(cache.exists('my_label')) 53 54 # '' shouldn't be treated as a wildcard to match everything. 55 self.assertFalse(cache.exists('')) 56 57 def test_empty(self): 58 CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', '')) 59 60 cache = StreamingCache(cache_dir=None) 61 self.assertFalse(cache.exists(CACHED_PCOLLECTION_KEY)) 62 cache.write([], CACHED_PCOLLECTION_KEY) 63 reader, _ = cache.read(CACHED_PCOLLECTION_KEY) 64 65 # Assert that an empty reader returns an empty list. 66 self.assertFalse([e for e in reader]) 67 68 def test_size(self): 69 cache = StreamingCache(cache_dir=None) 70 cache.write([beam_interactive_api_pb2.TestStreamFileRecord()], 'my_label') 71 coder = cache.load_pcoder('my_label') 72 73 # Add one because of the new-line character that is also written. 74 size = len( 75 coder.encode( 76 beam_interactive_api_pb2.TestStreamFileRecord().SerializeToString()) 77 ) + 1 78 self.assertEqual(cache.size('my_label'), size) 79 80 def test_clear(self): 81 cache = StreamingCache(cache_dir=None) 82 self.assertFalse(cache.exists('my_label')) 83 cache.sink(['my_label'], is_capture=True) 84 cache.write([beam_interactive_api_pb2.TestStreamFileRecord()], 'my_label') 85 self.assertTrue(cache.exists('my_label')) 86 self.assertEqual(cache.capture_keys, set(['my_label'])) 87 self.assertTrue(cache.clear('my_label')) 88 self.assertFalse(cache.exists('my_label')) 89 self.assertFalse(cache.capture_keys) 90 91 def test_single_reader(self): 92 """ 93 Tests that we expect to see all the correctly emitted TestStreamPayloads. 94 """ 95 CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', '')) 96 97 values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY) 98 .add_element(element=0, event_time_secs=0) 99 .advance_processing_time(1) 100 .add_element(element=1, event_time_secs=1) 101 .advance_processing_time(1) 102 .add_element(element=2, event_time_secs=2) 103 .build()) # yapf: disable 104 105 cache = StreamingCache(cache_dir=None) 106 cache.write(values, CACHED_PCOLLECTION_KEY) 107 108 reader, _ = cache.read(CACHED_PCOLLECTION_KEY) 109 coder = coders.FastPrimitivesCoder() 110 events = list(reader) 111 112 # Units here are in microseconds. 113 expected = [ 114 beam_runner_api_pb2.TestStreamPayload.Event( 115 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 116 AddElements( 117 elements=[ 118 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 119 encoded_element=coder.encode(0), timestamp=0) 120 ], 121 tag=CACHED_PCOLLECTION_KEY)), 122 beam_runner_api_pb2.TestStreamPayload.Event( 123 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 124 AdvanceProcessingTime(advance_duration=1 * 10**6)), 125 beam_runner_api_pb2.TestStreamPayload.Event( 126 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 127 AddElements( 128 elements=[ 129 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 130 encoded_element=coder.encode(1), timestamp=1 * 10**6) 131 ], 132 tag=CACHED_PCOLLECTION_KEY)), 133 beam_runner_api_pb2.TestStreamPayload.Event( 134 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 135 AdvanceProcessingTime(advance_duration=1 * 10**6)), 136 beam_runner_api_pb2.TestStreamPayload.Event( 137 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 138 AddElements( 139 elements=[ 140 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 141 encoded_element=coder.encode(2), timestamp=2 * 10**6) 142 ], 143 tag=CACHED_PCOLLECTION_KEY)), 144 ] 145 self.assertSequenceEqual(events, expected) 146 147 def test_multiple_readers(self): 148 """Tests that the service advances the clock with multiple outputs. 149 """ 150 151 CACHED_LETTERS = repr(CacheKey('letters', '', '', '')) 152 CACHED_NUMBERS = repr(CacheKey('numbers', '', '', '')) 153 CACHED_LATE = repr(CacheKey('late', '', '', '')) 154 155 letters = (FileRecordsBuilder(CACHED_LETTERS) 156 .advance_processing_time(1) 157 .advance_watermark(watermark_secs=0) 158 .add_element(element='a', event_time_secs=0) 159 .advance_processing_time(10) 160 .advance_watermark(watermark_secs=10) 161 .add_element(element='b', event_time_secs=10) 162 .build()) # yapf: disable 163 164 numbers = (FileRecordsBuilder(CACHED_NUMBERS) 165 .advance_processing_time(2) 166 .add_element(element=1, event_time_secs=0) 167 .advance_processing_time(1) 168 .add_element(element=2, event_time_secs=0) 169 .advance_processing_time(1) 170 .add_element(element=2, event_time_secs=0) 171 .build()) # yapf: disable 172 173 late = (FileRecordsBuilder(CACHED_LATE) 174 .advance_processing_time(101) 175 .add_element(element='late', event_time_secs=0) 176 .build()) # yapf: disable 177 178 cache = StreamingCache(cache_dir=None) 179 cache.write(letters, CACHED_LETTERS) 180 cache.write(numbers, CACHED_NUMBERS) 181 cache.write(late, CACHED_LATE) 182 183 reader = cache.read_multiple([[CACHED_LETTERS], [CACHED_NUMBERS], 184 [CACHED_LATE]]) 185 coder = coders.FastPrimitivesCoder() 186 events = list(reader) 187 188 # Units here are in microseconds. 189 expected = [ 190 # Advances clock from 0 to 1 191 beam_runner_api_pb2.TestStreamPayload.Event( 192 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 193 AdvanceProcessingTime(advance_duration=1 * 10**6)), 194 beam_runner_api_pb2.TestStreamPayload.Event( 195 watermark_event=beam_runner_api_pb2.TestStreamPayload.Event. 196 AdvanceWatermark(new_watermark=0, tag=CACHED_LETTERS)), 197 beam_runner_api_pb2.TestStreamPayload.Event( 198 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 199 AddElements( 200 elements=[ 201 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 202 encoded_element=coder.encode('a'), timestamp=0) 203 ], 204 tag=CACHED_LETTERS)), 205 206 # Advances clock from 1 to 2 207 beam_runner_api_pb2.TestStreamPayload.Event( 208 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 209 AdvanceProcessingTime(advance_duration=1 * 10**6)), 210 beam_runner_api_pb2.TestStreamPayload.Event( 211 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 212 AddElements( 213 elements=[ 214 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 215 encoded_element=coder.encode(1), timestamp=0) 216 ], 217 tag=CACHED_NUMBERS)), 218 219 # Advances clock from 2 to 3 220 beam_runner_api_pb2.TestStreamPayload.Event( 221 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 222 AdvanceProcessingTime(advance_duration=1 * 10**6)), 223 beam_runner_api_pb2.TestStreamPayload.Event( 224 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 225 AddElements( 226 elements=[ 227 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 228 encoded_element=coder.encode(2), timestamp=0) 229 ], 230 tag=CACHED_NUMBERS)), 231 232 # Advances clock from 3 to 4 233 beam_runner_api_pb2.TestStreamPayload.Event( 234 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 235 AdvanceProcessingTime(advance_duration=1 * 10**6)), 236 beam_runner_api_pb2.TestStreamPayload.Event( 237 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 238 AddElements( 239 elements=[ 240 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 241 encoded_element=coder.encode(2), timestamp=0) 242 ], 243 tag=CACHED_NUMBERS)), 244 245 # Advances clock from 4 to 11 246 beam_runner_api_pb2.TestStreamPayload.Event( 247 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 248 AdvanceProcessingTime(advance_duration=7 * 10**6)), 249 beam_runner_api_pb2.TestStreamPayload.Event( 250 watermark_event=beam_runner_api_pb2.TestStreamPayload.Event. 251 AdvanceWatermark(new_watermark=10 * 10**6, tag=CACHED_LETTERS)), 252 beam_runner_api_pb2.TestStreamPayload.Event( 253 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 254 AddElements( 255 elements=[ 256 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 257 encoded_element=coder.encode('b'), timestamp=10 * 10**6) 258 ], 259 tag=CACHED_LETTERS)), 260 261 # Advances clock from 11 to 101 262 beam_runner_api_pb2.TestStreamPayload.Event( 263 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 264 AdvanceProcessingTime(advance_duration=90 * 10**6)), 265 beam_runner_api_pb2.TestStreamPayload.Event( 266 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 267 AddElements( 268 elements=[ 269 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 270 encoded_element=coder.encode('late'), timestamp=0) 271 ], 272 tag=CACHED_LATE)), 273 ] 274 275 self.assertSequenceEqual(events, expected) 276 277 def test_read_and_write(self): 278 """An integration test between the Sink and Source. 279 280 This ensures that the sink and source speak the same language in terms of 281 coders, protos, order, and units. 282 """ 283 CACHED_RECORDS = repr(CacheKey('records', '', '', '')) 284 285 # Units here are in seconds. 286 test_stream = ( 287 TestStream(output_tags=(CACHED_RECORDS)) 288 .advance_watermark_to(0, tag=CACHED_RECORDS) 289 .advance_processing_time(5) 290 .add_elements(['a', 'b', 'c'], tag=CACHED_RECORDS) 291 .advance_watermark_to(10, tag=CACHED_RECORDS) 292 .advance_processing_time(1) 293 .add_elements( 294 [ 295 TimestampedValue('1', 15), 296 TimestampedValue('2', 15), 297 TimestampedValue('3', 15) 298 ], 299 tag=CACHED_RECORDS)) # yapf: disable 300 301 coder = SafeFastPrimitivesCoder() 302 cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0) 303 304 # Assert that there are no capture keys at first. 305 self.assertEqual(cache.capture_keys, set()) 306 307 options = StandardOptions(streaming=True) 308 with TestPipeline(options=options) as p: 309 records = (p | test_stream)[CACHED_RECORDS] 310 311 # pylint: disable=expression-not-assigned 312 records | cache.sink([CACHED_RECORDS], is_capture=True) 313 314 reader, _ = cache.read(CACHED_RECORDS) 315 actual_events = list(reader) 316 317 # Assert that the capture keys are forwarded correctly. 318 self.assertEqual(cache.capture_keys, set([CACHED_RECORDS])) 319 320 # Units here are in microseconds. 321 expected_events = [ 322 beam_runner_api_pb2.TestStreamPayload.Event( 323 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 324 AdvanceProcessingTime(advance_duration=5 * 10**6)), 325 beam_runner_api_pb2.TestStreamPayload.Event( 326 watermark_event=beam_runner_api_pb2.TestStreamPayload.Event. 327 AdvanceWatermark(new_watermark=0, tag=CACHED_RECORDS)), 328 beam_runner_api_pb2.TestStreamPayload.Event( 329 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 330 AddElements( 331 elements=[ 332 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 333 encoded_element=coder.encode('a'), timestamp=0), 334 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 335 encoded_element=coder.encode('b'), timestamp=0), 336 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 337 encoded_element=coder.encode('c'), timestamp=0), 338 ], 339 tag=CACHED_RECORDS)), 340 beam_runner_api_pb2.TestStreamPayload.Event( 341 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 342 AdvanceProcessingTime(advance_duration=1 * 10**6)), 343 beam_runner_api_pb2.TestStreamPayload.Event( 344 watermark_event=beam_runner_api_pb2.TestStreamPayload.Event. 345 AdvanceWatermark(new_watermark=10 * 10**6, tag=CACHED_RECORDS)), 346 beam_runner_api_pb2.TestStreamPayload.Event( 347 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 348 AddElements( 349 elements=[ 350 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 351 encoded_element=coder.encode('1'), timestamp=15 * 352 10**6), 353 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 354 encoded_element=coder.encode('2'), timestamp=15 * 355 10**6), 356 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 357 encoded_element=coder.encode('3'), timestamp=15 * 358 10**6), 359 ], 360 tag=CACHED_RECORDS)), 361 ] 362 self.assertEqual(actual_events, expected_events) 363 364 def test_read_and_write_multiple_outputs(self): 365 """An integration test between the Sink and Source with multiple outputs. 366 367 This tests the funcionatlity that the StreamingCache reads from multiple 368 files and combines them into a single sorted output. 369 """ 370 LETTERS_TAG = repr(CacheKey('letters', '', '', '')) 371 NUMBERS_TAG = repr(CacheKey('numbers', '', '', '')) 372 373 # Units here are in seconds. 374 test_stream = (TestStream() 375 .advance_watermark_to(0, tag=LETTERS_TAG) 376 .advance_processing_time(5) 377 .add_elements(['a', 'b', 'c'], tag=LETTERS_TAG) 378 .advance_watermark_to(10, tag=NUMBERS_TAG) 379 .advance_processing_time(1) 380 .add_elements( 381 [ 382 TimestampedValue('1', 15), 383 TimestampedValue('2', 15), 384 TimestampedValue('3', 15) 385 ], 386 tag=NUMBERS_TAG)) # yapf: disable 387 388 cache = StreamingCache(cache_dir=None, sample_resolution_sec=1.0) 389 390 coder = SafeFastPrimitivesCoder() 391 392 options = StandardOptions(streaming=True) 393 with TestPipeline(options=options) as p: 394 # pylint: disable=expression-not-assigned 395 events = p | test_stream 396 events[LETTERS_TAG] | 'Letters sink' >> cache.sink([LETTERS_TAG]) 397 events[NUMBERS_TAG] | 'Numbers sink' >> cache.sink([NUMBERS_TAG]) 398 399 reader = cache.read_multiple([[LETTERS_TAG], [NUMBERS_TAG]]) 400 actual_events = list(reader) 401 402 # Units here are in microseconds. 403 expected_events = [ 404 beam_runner_api_pb2.TestStreamPayload.Event( 405 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 406 AdvanceProcessingTime(advance_duration=5 * 10**6)), 407 beam_runner_api_pb2.TestStreamPayload.Event( 408 watermark_event=beam_runner_api_pb2.TestStreamPayload.Event. 409 AdvanceWatermark(new_watermark=0, tag=LETTERS_TAG)), 410 beam_runner_api_pb2.TestStreamPayload.Event( 411 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 412 AddElements( 413 elements=[ 414 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 415 encoded_element=coder.encode('a'), timestamp=0), 416 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 417 encoded_element=coder.encode('b'), timestamp=0), 418 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 419 encoded_element=coder.encode('c'), timestamp=0), 420 ], 421 tag=LETTERS_TAG)), 422 beam_runner_api_pb2.TestStreamPayload.Event( 423 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 424 AdvanceProcessingTime(advance_duration=1 * 10**6)), 425 beam_runner_api_pb2.TestStreamPayload.Event( 426 watermark_event=beam_runner_api_pb2.TestStreamPayload.Event. 427 AdvanceWatermark(new_watermark=10 * 10**6, tag=NUMBERS_TAG)), 428 beam_runner_api_pb2.TestStreamPayload.Event( 429 watermark_event=beam_runner_api_pb2.TestStreamPayload.Event. 430 AdvanceWatermark(new_watermark=0, tag=LETTERS_TAG)), 431 beam_runner_api_pb2.TestStreamPayload.Event( 432 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 433 AddElements( 434 elements=[ 435 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 436 encoded_element=coder.encode('1'), timestamp=15 * 437 10**6), 438 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 439 encoded_element=coder.encode('2'), timestamp=15 * 440 10**6), 441 beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 442 encoded_element=coder.encode('3'), timestamp=15 * 443 10**6), 444 ], 445 tag=NUMBERS_TAG)), 446 ] 447 448 self.assertListEqual(actual_events, expected_events) 449 450 def test_always_default_coder_for_test_stream_records(self): 451 CACHED_NUMBERS = repr(CacheKey('numbers', '', '', '')) 452 numbers = (FileRecordsBuilder(CACHED_NUMBERS) 453 .advance_processing_time(2) 454 .add_element(element=1, event_time_secs=0) 455 .advance_processing_time(1) 456 .add_element(element=2, event_time_secs=0) 457 .advance_processing_time(1) 458 .add_element(element=2, event_time_secs=0) 459 .build()) # yapf: disable 460 cache = StreamingCache(cache_dir=None) 461 cache.write(numbers, CACHED_NUMBERS) 462 self.assertIs( 463 type(cache.load_pcoder(CACHED_NUMBERS)), type(cache._default_pcoder)) 464 465 def test_streaming_cache_does_not_write_non_record_or_header_types(self): 466 cache = StreamingCache(cache_dir=None) 467 self.assertRaises(TypeError, cache.write, 'some value', 'a key') 468 469 def test_streaming_cache_uses_gcs_ib_cache_root(self): 470 """ 471 Checks that StreamingCache._cache_dir is set to the 472 cache_root set under Interactive Beam for a GCS directory. 473 """ 474 # Set Interactive Beam specified cache dir to cloud storage 475 ib.options.cache_root = 'gs://' 476 cache_manager_with_ib_option = StreamingCache( 477 cache_dir=ib.options.cache_root) 478 479 self.assertEqual( 480 ib.options.cache_root, cache_manager_with_ib_option._cache_dir) 481 482 # Reset Interactive Beam setting 483 ib.options.cache_root = None 484 485 def test_streaming_cache_uses_local_ib_cache_root(self): 486 """ 487 Checks that StreamingCache._cache_dir is set to the 488 cache_root set under Interactive Beam for a local directory 489 and that the cached values are the same as the values of a 490 cache using default settings. 491 """ 492 CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', '')) 493 values = (FileRecordsBuilder(CACHED_PCOLLECTION_KEY) 494 .advance_processing_time(1) 495 .advance_watermark(watermark_secs=0) 496 .add_element(element=1, event_time_secs=0) 497 .build()) # yapf: disable 498 499 local_cache = StreamingCache(cache_dir=None) 500 local_cache.write(values, CACHED_PCOLLECTION_KEY) 501 reader_one, _ = local_cache.read(CACHED_PCOLLECTION_KEY) 502 pcoll_list_one = list(reader_one) 503 504 # Set Interactive Beam specified cache dir to cloud storage 505 ib.options.cache_root = '/tmp/it-test/' 506 cache_manager_with_ib_option = StreamingCache( 507 cache_dir=ib.options.cache_root) 508 509 self.assertEqual( 510 ib.options.cache_root, cache_manager_with_ib_option._cache_dir) 511 512 cache_manager_with_ib_option.write(values, CACHED_PCOLLECTION_KEY) 513 reader_two, _ = cache_manager_with_ib_option.read(CACHED_PCOLLECTION_KEY) 514 pcoll_list_two = list(reader_two) 515 516 self.assertEqual(pcoll_list_one, pcoll_list_two) 517 518 # Reset Interactive Beam setting 519 ib.options.cache_root = None 520 521 522 if __name__ == '__main__': 523 unittest.main()