github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/userstate_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the Beam State and Timer API interfaces.""" 19 # pytype: skip-file 20 21 import unittest 22 from typing import Any 23 from typing import List 24 25 import mock 26 import pytest 27 28 import apache_beam as beam 29 from apache_beam.coders import BytesCoder 30 from apache_beam.coders import ListCoder 31 from apache_beam.coders import StrUtf8Coder 32 from apache_beam.coders import VarIntCoder 33 from apache_beam.options.pipeline_options import PipelineOptions 34 from apache_beam.portability import common_urns 35 from apache_beam.portability.api import beam_runner_api_pb2 36 from apache_beam.runners import pipeline_context 37 from apache_beam.runners.common import DoFnSignature 38 from apache_beam.testing.test_pipeline import TestPipeline 39 from apache_beam.testing.test_stream import TestStream 40 from apache_beam.testing.util import assert_that 41 from apache_beam.testing.util import equal_to 42 from apache_beam.transforms import trigger 43 from apache_beam.transforms import userstate 44 from apache_beam.transforms import window 45 from apache_beam.transforms.combiners import ToListCombineFn 46 from apache_beam.transforms.combiners import TopCombineFn 47 from apache_beam.transforms.core import DoFn 48 from apache_beam.transforms.timeutil import TimeDomain 49 from apache_beam.transforms.userstate import BagStateSpec 50 from apache_beam.transforms.userstate import CombiningValueStateSpec 51 from apache_beam.transforms.userstate import ReadModifyWriteStateSpec 52 from apache_beam.transforms.userstate import SetStateSpec 53 from apache_beam.transforms.userstate import TimerSpec 54 from apache_beam.transforms.userstate import get_dofn_specs 55 from apache_beam.transforms.userstate import is_stateful_dofn 56 from apache_beam.transforms.userstate import on_timer 57 from apache_beam.transforms.userstate import validate_stateful_dofn 58 59 60 class TestStatefulDoFn(DoFn): 61 """An example stateful DoFn with state and timers.""" 62 63 BUFFER_STATE_1 = BagStateSpec('buffer', BytesCoder()) 64 BUFFER_STATE_2 = BagStateSpec('buffer2', VarIntCoder()) 65 EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK) 66 EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK) 67 EXPIRY_TIMER_3 = TimerSpec('expiry3', TimeDomain.WATERMARK) 68 EXPIRY_TIMER_FAMILY = TimerSpec('expiry_family', TimeDomain.WATERMARK) 69 70 def process( 71 self, 72 element, 73 t=DoFn.TimestampParam, 74 buffer_1=DoFn.StateParam(BUFFER_STATE_1), 75 buffer_2=DoFn.StateParam(BUFFER_STATE_2), 76 timer_1=DoFn.TimerParam(EXPIRY_TIMER_1), 77 timer_2=DoFn.TimerParam(EXPIRY_TIMER_2), 78 dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)): 79 yield element 80 81 @on_timer(EXPIRY_TIMER_1) 82 def on_expiry_1( 83 self, 84 window=DoFn.WindowParam, 85 timestamp=DoFn.TimestampParam, 86 key=DoFn.KeyParam, 87 buffer=DoFn.StateParam(BUFFER_STATE_1), 88 timer_1=DoFn.TimerParam(EXPIRY_TIMER_1), 89 timer_2=DoFn.TimerParam(EXPIRY_TIMER_2), 90 timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)): 91 yield 'expired1' 92 93 @on_timer(EXPIRY_TIMER_2) 94 def on_expiry_2( 95 self, 96 buffer=DoFn.StateParam(BUFFER_STATE_2), 97 timer_2=DoFn.TimerParam(EXPIRY_TIMER_2), 98 timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)): 99 yield 'expired2' 100 101 @on_timer(EXPIRY_TIMER_3) 102 def on_expiry_3( 103 self, 104 buffer_1=DoFn.StateParam(BUFFER_STATE_1), 105 buffer_2=DoFn.StateParam(BUFFER_STATE_2), 106 timer_3=DoFn.TimerParam(EXPIRY_TIMER_3)): 107 yield 'expired3' 108 109 @on_timer(EXPIRY_TIMER_FAMILY) 110 def on_expiry_family( 111 self, 112 dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY), 113 dynamic_timer_tag=DoFn.DynamicTimerTagParam): 114 yield (dynamic_timer_tag, 'expired_dynamic_timer') 115 116 117 class InterfaceTest(unittest.TestCase): 118 def _validate_dofn(self, dofn): 119 # Construction of DoFnSignature performs validation of the given DoFn. 120 # In particular, it ends up calling userstate._validate_stateful_dofn. 121 # That behavior is explicitly tested below in test_validate_dofn() 122 return DoFnSignature(dofn) 123 124 @mock.patch('apache_beam.transforms.userstate.validate_stateful_dofn') 125 def test_validate_dofn(self, unused_mock): 126 dofn = TestStatefulDoFn() 127 self._validate_dofn(dofn) 128 userstate.validate_stateful_dofn.assert_called_with(dofn) 129 130 def test_spec_construction(self): 131 BagStateSpec('statename', VarIntCoder()) 132 with self.assertRaises(TypeError): 133 BagStateSpec(123, VarIntCoder()) 134 135 CombiningValueStateSpec('statename', VarIntCoder(), TopCombineFn(10)) 136 with self.assertRaises(TypeError): 137 CombiningValueStateSpec(123, VarIntCoder(), TopCombineFn(10)) 138 with self.assertRaises(TypeError): 139 CombiningValueStateSpec('statename', VarIntCoder(), object()) 140 141 SetStateSpec('setstatename', VarIntCoder()) 142 with self.assertRaises(TypeError): 143 SetStateSpec(123, VarIntCoder()) 144 with self.assertRaises(TypeError): 145 SetStateSpec('setstatename', object()) 146 147 ReadModifyWriteStateSpec('valuestatename', VarIntCoder()) 148 with self.assertRaises(TypeError): 149 ReadModifyWriteStateSpec(123, VarIntCoder()) 150 with self.assertRaises(TypeError): 151 ReadModifyWriteStateSpec('valuestatename', object()) 152 153 # TODO: add more spec tests 154 with self.assertRaises(ValueError): 155 DoFn.TimerParam(BagStateSpec('elements', BytesCoder())) 156 157 TimerSpec('timer', TimeDomain.WATERMARK) 158 TimerSpec('timer', TimeDomain.REAL_TIME) 159 with self.assertRaises(ValueError): 160 TimerSpec('timer', 'bogus_time_domain') 161 with self.assertRaises(ValueError): 162 DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK)) 163 164 def test_state_spec_proto_conversion(self): 165 context = pipeline_context.PipelineContext() 166 state = BagStateSpec('statename', VarIntCoder()) 167 state_proto = state.to_runner_api(context) 168 self.assertEqual( 169 beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn), 170 state_proto.protocol) 171 172 context = pipeline_context.PipelineContext() 173 state = CombiningValueStateSpec( 174 'statename', VarIntCoder(), TopCombineFn(10)) 175 state_proto = state.to_runner_api(context) 176 self.assertEqual( 177 beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn), 178 state_proto.protocol) 179 180 context = pipeline_context.PipelineContext() 181 state = SetStateSpec('setstatename', VarIntCoder()) 182 state_proto = state.to_runner_api(context) 183 self.assertEqual( 184 beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn), 185 state_proto.protocol) 186 187 context = pipeline_context.PipelineContext() 188 state = ReadModifyWriteStateSpec('valuestatename', VarIntCoder()) 189 state_proto = state.to_runner_api(context) 190 self.assertEqual( 191 beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn), 192 state_proto.protocol) 193 194 def test_param_construction(self): 195 with self.assertRaises(ValueError): 196 DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK)) 197 with self.assertRaises(ValueError): 198 DoFn.TimerParam(BagStateSpec('elements', BytesCoder())) 199 200 def test_stateful_dofn_detection(self): 201 self.assertFalse(is_stateful_dofn(DoFn())) 202 self.assertTrue(is_stateful_dofn(TestStatefulDoFn())) 203 204 def test_good_signatures(self): 205 class BasicStatefulDoFn(DoFn): 206 BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) 207 EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK) 208 EXPIRY_TIMER_FAMILY = TimerSpec('expiry_family_1', TimeDomain.WATERMARK) 209 210 def process( 211 self, 212 element, 213 buffer=DoFn.StateParam(BUFFER_STATE), 214 timer1=DoFn.TimerParam(EXPIRY_TIMER), 215 dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)): 216 yield element 217 218 @on_timer(EXPIRY_TIMER) 219 def expiry_callback(self, element, timer=DoFn.TimerParam(EXPIRY_TIMER)): 220 yield element 221 222 @on_timer(EXPIRY_TIMER_FAMILY) 223 def expiry_family_callback( 224 self, element, dynamic_timer=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)): 225 yield element 226 227 # Validate get_dofn_specs() and timer callbacks in 228 # DoFnSignature. 229 stateful_dofn = BasicStatefulDoFn() 230 signature = self._validate_dofn(stateful_dofn) 231 expected_specs = ( 232 set([BasicStatefulDoFn.BUFFER_STATE]), 233 set([ 234 BasicStatefulDoFn.EXPIRY_TIMER, 235 BasicStatefulDoFn.EXPIRY_TIMER_FAMILY 236 ]), 237 ) 238 self.assertEqual(expected_specs, get_dofn_specs(stateful_dofn)) 239 self.assertEqual( 240 stateful_dofn.expiry_callback, 241 signature.timer_methods[BasicStatefulDoFn.EXPIRY_TIMER].method_value) 242 self.assertEqual( 243 stateful_dofn.expiry_family_callback, 244 signature.timer_methods[ 245 BasicStatefulDoFn.EXPIRY_TIMER_FAMILY].method_value) 246 247 stateful_dofn = TestStatefulDoFn() 248 signature = self._validate_dofn(stateful_dofn) 249 expected_specs = ( 250 set([TestStatefulDoFn.BUFFER_STATE_1, TestStatefulDoFn.BUFFER_STATE_2]), 251 set([ 252 TestStatefulDoFn.EXPIRY_TIMER_1, 253 TestStatefulDoFn.EXPIRY_TIMER_2, 254 TestStatefulDoFn.EXPIRY_TIMER_3, 255 TestStatefulDoFn.EXPIRY_TIMER_FAMILY 256 ])) 257 self.assertEqual(expected_specs, get_dofn_specs(stateful_dofn)) 258 self.assertEqual( 259 stateful_dofn.on_expiry_1, 260 signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_1].method_value) 261 self.assertEqual( 262 stateful_dofn.on_expiry_2, 263 signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_2].method_value) 264 self.assertEqual( 265 stateful_dofn.on_expiry_3, 266 signature.timer_methods[TestStatefulDoFn.EXPIRY_TIMER_3].method_value) 267 self.assertEqual( 268 stateful_dofn.on_expiry_family, 269 signature.timer_methods[ 270 TestStatefulDoFn.EXPIRY_TIMER_FAMILY].method_value) 271 272 def test_bad_signatures(self): 273 # (1) The same state parameter is duplicated on the process method. 274 class BadStatefulDoFn1(DoFn): 275 BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) 276 277 def process( 278 self, 279 element, 280 b1=DoFn.StateParam(BUFFER_STATE), 281 b2=DoFn.StateParam(BUFFER_STATE)): 282 yield element 283 284 with self.assertRaises(ValueError): 285 self._validate_dofn(BadStatefulDoFn1()) 286 287 # (2) The same timer parameter is duplicated on the process method. 288 class BadStatefulDoFn2(DoFn): 289 TIMER = TimerSpec('timer', TimeDomain.WATERMARK) 290 291 def process( 292 self, element, t1=DoFn.TimerParam(TIMER), t2=DoFn.TimerParam(TIMER)): 293 yield element 294 295 with self.assertRaises(ValueError): 296 self._validate_dofn(BadStatefulDoFn2()) 297 298 # (3) The same state parameter is duplicated on the on_timer method. 299 class BadStatefulDoFn3(DoFn): 300 BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) 301 EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK) 302 EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK) 303 304 @on_timer(EXPIRY_TIMER_1) 305 def expiry_callback( 306 self, 307 element, 308 b1=DoFn.StateParam(BUFFER_STATE), 309 b2=DoFn.StateParam(BUFFER_STATE)): 310 yield element 311 312 with self.assertRaises(ValueError): 313 self._validate_dofn(BadStatefulDoFn3()) 314 315 # (4) The same timer parameter is duplicated on the on_timer method. 316 class BadStatefulDoFn4(DoFn): 317 BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) 318 EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK) 319 EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK) 320 321 @on_timer(EXPIRY_TIMER_1) 322 def expiry_callback( 323 self, 324 element, 325 t1=DoFn.TimerParam(EXPIRY_TIMER_2), 326 t2=DoFn.TimerParam(EXPIRY_TIMER_2)): 327 yield element 328 329 with self.assertRaises(ValueError): 330 self._validate_dofn(BadStatefulDoFn4()) 331 332 # (5) The same timer family parameter is duplicated on the process method. 333 class BadStatefulDoFn5(DoFn): 334 EXPIRY_TIMER_FAMILY = TimerSpec('dynamic_timer', TimeDomain.WATERMARK) 335 336 def process( 337 self, 338 element, 339 dynamic_timer_1=DoFn.TimerParam(EXPIRY_TIMER_FAMILY), 340 dynamic_timer_2=DoFn.TimerParam(EXPIRY_TIMER_FAMILY)): 341 yield element 342 343 with self.assertRaises(ValueError): 344 self._validate_dofn(BadStatefulDoFn5()) 345 346 def test_validation_typos(self): 347 # (1) Here, the user mistakenly used the same timer spec twice for two 348 # different timer callbacks. 349 with self.assertRaisesRegex( 350 ValueError, 351 r'Multiple on_timer callbacks registered for TimerSpec\(.*expiry1\).'): 352 353 class StatefulDoFnWithTimerWithTypo1(DoFn): # pylint: disable=unused-variable 354 BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) 355 EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK) 356 EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK) 357 358 def process(self, element): 359 pass 360 361 @on_timer(EXPIRY_TIMER_1) 362 def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)): 363 yield 'expired1' 364 365 # Note that we mistakenly associate this with the first timer. 366 @on_timer(EXPIRY_TIMER_1) 367 def on_expiry_2(self, buffer_state=DoFn.StateParam(BUFFER_STATE)): 368 yield 'expired2' 369 370 # (2) Here, the user mistakenly used the same callback name and overwrote 371 # the first on_expiry_1 callback. 372 class StatefulDoFnWithTimerWithTypo2(DoFn): 373 BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) 374 EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK) 375 EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK) 376 377 def process( 378 self, 379 element, 380 timer1=DoFn.TimerParam(EXPIRY_TIMER_1), 381 timer2=DoFn.TimerParam(EXPIRY_TIMER_2)): 382 pass 383 384 @on_timer(EXPIRY_TIMER_1) 385 def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)): 386 yield 'expired1' 387 388 # Note that we mistakenly reuse the "on_expiry_1" name; this is valid 389 # syntactically in Python. 390 @on_timer(EXPIRY_TIMER_2) 391 def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)): # pylint: disable=function-redefined 392 yield 'expired2' 393 394 # Use a stable string value for matching. 395 def __repr__(self): 396 return 'StatefulDoFnWithTimerWithTypo2' 397 398 dofn = StatefulDoFnWithTimerWithTypo2() 399 with self.assertRaisesRegex( 400 ValueError, 401 (r'The on_timer callback for TimerSpec\(.*expiry1\) is not the ' 402 r'specified .on_expiry_1 method for DoFn ' 403 r'StatefulDoFnWithTimerWithTypo2 \(perhaps it was overwritten\?\).')): 404 validate_stateful_dofn(dofn) 405 406 # (2) Here, the user forgot to add an on_timer decorator for 'expiry2' 407 class StatefulDoFnWithTimerWithTypo3(DoFn): 408 BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) 409 EXPIRY_TIMER_1 = TimerSpec('expiry1', TimeDomain.WATERMARK) 410 EXPIRY_TIMER_2 = TimerSpec('expiry2', TimeDomain.WATERMARK) 411 412 def process( 413 self, 414 element, 415 timer1=DoFn.TimerParam(EXPIRY_TIMER_1), 416 timer2=DoFn.TimerParam(EXPIRY_TIMER_2)): 417 pass 418 419 @on_timer(EXPIRY_TIMER_1) 420 def on_expiry_1(self, buffer_state=DoFn.StateParam(BUFFER_STATE)): 421 yield 'expired1' 422 423 def on_expiry_2(self, buffer_state=DoFn.StateParam(BUFFER_STATE)): 424 yield 'expired2' 425 426 # Use a stable string value for matching. 427 def __repr__(self): 428 return 'StatefulDoFnWithTimerWithTypo3' 429 430 dofn = StatefulDoFnWithTimerWithTypo3() 431 with self.assertRaisesRegex( 432 ValueError, 433 (r'DoFn StatefulDoFnWithTimerWithTypo3 has a TimerSpec without an ' 434 r'associated on_timer callback: TimerSpec\(.*expiry2\).')): 435 validate_stateful_dofn(dofn) 436 437 438 class StatefulDoFnOnDirectRunnerTest(unittest.TestCase): 439 # pylint: disable=expression-not-assigned 440 all_records = None # type: List[Any] 441 442 def setUp(self): 443 # Use state on the TestCase class, since other references would be pickled 444 # into a closure and not have the desired side effects. 445 # 446 # TODO(https://github.com/apache/beam/issues/18987): Use assert_that after 447 # it works for the cases here in streaming mode. 448 StatefulDoFnOnDirectRunnerTest.all_records = [] 449 450 def record_dofn(self): 451 class RecordDoFn(DoFn): 452 def process(self, element): 453 StatefulDoFnOnDirectRunnerTest.all_records.append(element) 454 455 return RecordDoFn() 456 457 def test_simple_stateful_dofn(self): 458 class SimpleTestStatefulDoFn(DoFn): 459 BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) 460 EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK) 461 462 def process( 463 self, 464 element, 465 buffer=DoFn.StateParam(BUFFER_STATE), 466 timer1=DoFn.TimerParam(EXPIRY_TIMER)): 467 unused_key, value = element 468 buffer.add(b'A' + str(value).encode('latin1')) 469 timer1.set(20) 470 471 @on_timer(EXPIRY_TIMER) 472 def expiry_callback( 473 self, 474 buffer=DoFn.StateParam(BUFFER_STATE), 475 timer=DoFn.TimerParam(EXPIRY_TIMER)): 476 yield b''.join(sorted(buffer.read())) 477 478 with TestPipeline() as p: 479 test_stream = ( 480 TestStream().advance_watermark_to(10).add_elements( 481 [1, 482 2]).add_elements([3]).advance_watermark_to(25).add_elements([4])) 483 ( 484 p 485 | test_stream 486 | beam.Map(lambda x: ('mykey', x)) 487 | beam.ParDo(SimpleTestStatefulDoFn()) 488 | beam.ParDo(self.record_dofn())) 489 490 # Two firings should occur: once after element 3 since the timer should 491 # fire after the watermark passes time 20, and another time after element 492 # 4, since the timer issued at that point should fire immediately. 493 self.assertEqual([b'A1A2A3', b'A1A2A3A4'], 494 StatefulDoFnOnDirectRunnerTest.all_records) 495 496 def test_clearing_bag_state(self): 497 class BagStateClearingStatefulDoFn(beam.DoFn): 498 499 BAG_STATE = BagStateSpec('bag_state', StrUtf8Coder()) 500 EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK) 501 CLEAR_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK) 502 503 def process( 504 self, 505 element, 506 bag_state=beam.DoFn.StateParam(BAG_STATE), 507 emit_timer=beam.DoFn.TimerParam(EMIT_TIMER), 508 clear_timer=beam.DoFn.TimerParam(CLEAR_TIMER)): 509 value = element[1] 510 bag_state.add(value) 511 clear_timer.set(100) 512 emit_timer.set(1000) 513 514 @on_timer(EMIT_TIMER) 515 def emit_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)): 516 for value in bag_state.read(): 517 yield value 518 yield 'extra' 519 520 @on_timer(CLEAR_TIMER) 521 def clear_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)): 522 bag_state.clear() 523 524 with TestPipeline() as p: 525 test_stream = ( 526 TestStream().advance_watermark_to(0).add_elements([ 527 ('key', 'value') 528 ]).advance_watermark_to(100)) 529 530 _ = ( 531 p 532 | test_stream 533 | beam.ParDo(BagStateClearingStatefulDoFn()) 534 | beam.ParDo(self.record_dofn())) 535 536 self.assertEqual(['extra'], StatefulDoFnOnDirectRunnerTest.all_records) 537 538 def test_two_timers_one_function(self): 539 class BagStateClearingStatefulDoFn(beam.DoFn): 540 541 BAG_STATE = BagStateSpec('bag_state', StrUtf8Coder()) 542 EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK) 543 EMIT_TWICE_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK) 544 545 def process( 546 self, 547 element, 548 bag_state=beam.DoFn.StateParam(BAG_STATE), 549 emit_timer=beam.DoFn.TimerParam(EMIT_TIMER), 550 emit_twice_timer=beam.DoFn.TimerParam(EMIT_TWICE_TIMER)): 551 value = element[1] 552 bag_state.add(value) 553 emit_twice_timer.set(100) 554 emit_timer.set(1000) 555 556 @on_timer(EMIT_TWICE_TIMER) 557 @on_timer(EMIT_TIMER) 558 def emit_values(self, bag_state=beam.DoFn.StateParam(BAG_STATE)): 559 for value in bag_state.read(): 560 yield value 561 562 with TestPipeline() as p: 563 test_stream = ( 564 TestStream().advance_watermark_to(0).add_elements([ 565 ('key', 'value') 566 ]).advance_watermark_to(100)) 567 568 _ = ( 569 p 570 | test_stream 571 | beam.ParDo(BagStateClearingStatefulDoFn()) 572 | beam.ParDo(self.record_dofn())) 573 574 self.assertEqual(['value', 'value'], 575 StatefulDoFnOnDirectRunnerTest.all_records) 576 577 def test_simple_read_modify_write_stateful_dofn(self): 578 class SimpleTestReadModifyWriteStatefulDoFn(DoFn): 579 VALUE_STATE = ReadModifyWriteStateSpec('value', StrUtf8Coder()) 580 581 def process(self, element, last_element=DoFn.StateParam(VALUE_STATE)): 582 last_element.write('%s:%s' % element) 583 yield last_element.read() 584 585 with TestPipeline() as p: 586 test_stream = ( 587 TestStream().advance_watermark_to(0).add_elements([ 588 ('a', 1) 589 ]).advance_watermark_to(10).add_elements([ 590 ('a', 3) 591 ]).advance_watermark_to(20).add_elements([('a', 5)])) 592 ( 593 p | test_stream 594 | beam.ParDo(SimpleTestReadModifyWriteStatefulDoFn()) 595 | beam.ParDo(self.record_dofn())) 596 self.assertEqual(['a:1', 'a:3', 'a:5'], 597 StatefulDoFnOnDirectRunnerTest.all_records) 598 599 def test_clearing_read_modify_write_state(self): 600 class SimpleClearingReadModifyWriteStatefulDoFn(DoFn): 601 VALUE_STATE = ReadModifyWriteStateSpec('value', StrUtf8Coder()) 602 603 def process(self, element, last_element=DoFn.StateParam(VALUE_STATE)): 604 value = last_element.read() 605 if value is not None: 606 yield value 607 last_element.clear() 608 last_element.write("%s:%s" % (last_element.read(), element[1])) 609 if element[1] == 5: 610 yield last_element.read() 611 612 with TestPipeline() as p: 613 test_stream = ( 614 TestStream().advance_watermark_to(0).add_elements([ 615 ('a', 1) 616 ]).advance_watermark_to(10).add_elements([ 617 ('a', 3) 618 ]).advance_watermark_to(20).add_elements([('a', 5)])) 619 ( 620 p | test_stream 621 | beam.ParDo(SimpleClearingReadModifyWriteStatefulDoFn()) 622 | beam.ParDo(self.record_dofn())) 623 self.assertEqual(['None:1', 'None:3', 'None:5'], 624 StatefulDoFnOnDirectRunnerTest.all_records) 625 626 def test_simple_set_stateful_dofn(self): 627 class SimpleTestSetStatefulDoFn(DoFn): 628 BUFFER_STATE = SetStateSpec('buffer', VarIntCoder()) 629 EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK) 630 631 def process( 632 self, 633 element, 634 buffer=DoFn.StateParam(BUFFER_STATE), 635 timer1=DoFn.TimerParam(EXPIRY_TIMER)): 636 unused_key, value = element 637 buffer.add(value) 638 timer1.set(20) 639 640 @on_timer(EXPIRY_TIMER) 641 def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE)): 642 yield sorted(buffer.read()) 643 644 with TestPipeline() as p: 645 test_stream = ( 646 TestStream().advance_watermark_to(10).add_elements( 647 [1, 2, 3]).add_elements([2]).advance_watermark_to(24)) 648 ( 649 p 650 | test_stream 651 | beam.Map(lambda x: ('mykey', x)) 652 | beam.ParDo(SimpleTestSetStatefulDoFn()) 653 | beam.ParDo(self.record_dofn())) 654 655 # Two firings should occur: once after element 3 since the timer should 656 # fire after the watermark passes time 20, and another time after element 657 # 4, since the timer issued at that point should fire immediately. 658 self.assertEqual([[1, 2, 3]], StatefulDoFnOnDirectRunnerTest.all_records) 659 660 def test_clearing_set_state(self): 661 class SetStateClearingStatefulDoFn(beam.DoFn): 662 663 SET_STATE = SetStateSpec('buffer', StrUtf8Coder()) 664 EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK) 665 CLEAR_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK) 666 667 def process( 668 self, 669 element, 670 set_state=beam.DoFn.StateParam(SET_STATE), 671 emit_timer=beam.DoFn.TimerParam(EMIT_TIMER), 672 clear_timer=beam.DoFn.TimerParam(CLEAR_TIMER)): 673 value = element[1] 674 set_state.add(value) 675 clear_timer.set(100) 676 emit_timer.set(1000) 677 678 @on_timer(EMIT_TIMER) 679 def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)): 680 for value in set_state.read(): 681 yield value 682 683 @on_timer(CLEAR_TIMER) 684 def clear_values(self, set_state=beam.DoFn.StateParam(SET_STATE)): 685 set_state.clear() 686 set_state.add('different-value') 687 688 with TestPipeline() as p: 689 test_stream = ( 690 TestStream().advance_watermark_to(0).add_elements([ 691 ('key1', 'value1') 692 ]).advance_watermark_to(100)) 693 694 _ = ( 695 p 696 | test_stream 697 | beam.ParDo(SetStateClearingStatefulDoFn()) 698 | beam.ParDo(self.record_dofn())) 699 700 self.assertEqual(['different-value'], 701 StatefulDoFnOnDirectRunnerTest.all_records) 702 703 def test_stateful_set_state_portably(self): 704 class SetStatefulDoFn(beam.DoFn): 705 706 SET_STATE = SetStateSpec('buffer', VarIntCoder()) 707 708 def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)): 709 _, value = element 710 aggregated_value = 0 711 set_state.add(value) 712 for saved_value in set_state.read(): 713 aggregated_value += saved_value 714 yield aggregated_value 715 716 with TestPipeline() as p: 717 values = p | beam.Create([('key', 1), ('key', 2), ('key', 3), ('key', 4), 718 ('key', 3)], 719 reshuffle=False) 720 actual_values = (values | beam.ParDo(SetStatefulDoFn())) 721 assert_that(actual_values, equal_to([1, 3, 6, 10, 10])) 722 723 def test_stateful_set_state_clean_portably(self): 724 class SetStateClearingStatefulDoFn(beam.DoFn): 725 726 SET_STATE = SetStateSpec('buffer', VarIntCoder()) 727 EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK) 728 729 def process( 730 self, 731 element, 732 set_state=beam.DoFn.StateParam(SET_STATE), 733 emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)): 734 _, value = element 735 set_state.add(value) 736 737 all_elements = [element for element in set_state.read()] 738 739 if len(all_elements) == 5: 740 set_state.clear() 741 set_state.add(100) 742 emit_timer.set(1) 743 744 @on_timer(EMIT_TIMER) 745 def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)): 746 yield sorted(set_state.read()) 747 748 with TestPipeline() as p: 749 values = p | beam.Create([('key', 1), ('key', 2), ('key', 3), ('key', 4), 750 ('key', 5)]) 751 actual_values = ( 752 values 753 | beam.Map(lambda t: window.TimestampedValue(t, 1)) 754 | beam.WindowInto(window.FixedWindows(1)) 755 | beam.ParDo(SetStateClearingStatefulDoFn())) 756 757 assert_that(actual_values, equal_to([[100]])) 758 759 def test_stateful_dofn_nonkeyed_input(self): 760 p = TestPipeline() 761 values = p | beam.Create([1, 2, 3]) 762 with self.assertRaisesRegex( 763 ValueError, 764 ('Input elements to the transform .* with stateful DoFn must be ' 765 'key-value pairs.')): 766 values | beam.ParDo(TestStatefulDoFn()) 767 768 def test_generate_sequence_with_realtime_timer(self): 769 from apache_beam.transforms.combiners import CountCombineFn 770 771 class GenerateRecords(beam.DoFn): 772 773 EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.REAL_TIME) 774 COUNT_STATE = CombiningValueStateSpec( 775 'count_state', VarIntCoder(), CountCombineFn()) 776 777 def __init__(self, frequency, total_records): 778 self.total_records = total_records 779 self.frequency = frequency 780 781 def process(self, element, emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)): 782 # Processing time timers should be set on ABSOLUTE TIME. 783 emit_timer.set(self.frequency) 784 yield element[1] 785 786 @on_timer(EMIT_TIMER) 787 def emit_values( 788 self, 789 emit_timer=beam.DoFn.TimerParam(EMIT_TIMER), 790 count_state=beam.DoFn.StateParam(COUNT_STATE)): 791 count = count_state.read() or 0 792 if self.total_records == count: 793 return 794 795 count_state.add(1) 796 # Processing time timers should be set on ABSOLUTE TIME. 797 emit_timer.set(count + 1 + self.frequency) 798 yield 'value' 799 800 TOTAL_RECORDS = 3 801 FREQUENCY = 1 802 803 test_stream = ( 804 TestStream().advance_watermark_to(0).add_elements([ 805 ('key', 0) 806 ]).advance_processing_time(1) # Timestamp: 1 807 .add_elements([('key', 1)]).advance_processing_time(1) # Timestamp: 2 808 .add_elements([('key', 2)]).advance_processing_time(1) # Timestamp: 3 809 .add_elements([('key', 3)])) 810 811 with beam.Pipeline(argv=['--streaming', '--runner=DirectRunner']) as p: 812 _ = ( 813 p 814 | test_stream 815 | beam.ParDo(GenerateRecords(FREQUENCY, TOTAL_RECORDS)) 816 | beam.ParDo(self.record_dofn())) 817 818 self.assertEqual( 819 # 4 RECORDS go through process 820 # 3 values are emitted from timer 821 # Timestamp moves gradually. 822 [0, 'value', 1, 'value', 2, 'value', 3], 823 StatefulDoFnOnDirectRunnerTest.all_records) 824 825 def test_simple_stateful_dofn_combining(self): 826 class SimpleTestStatefulDoFn(DoFn): 827 BUFFER_STATE = CombiningValueStateSpec( 828 'buffer', ListCoder(VarIntCoder()), ToListCombineFn()) 829 EXPIRY_TIMER = TimerSpec('expiry1', TimeDomain.WATERMARK) 830 831 def process( 832 self, 833 element, 834 buffer=DoFn.StateParam(BUFFER_STATE), 835 timer1=DoFn.TimerParam(EXPIRY_TIMER)): 836 unused_key, value = element 837 buffer.add(value) 838 timer1.set(20) 839 840 @on_timer(EXPIRY_TIMER) 841 def expiry_callback( 842 self, 843 buffer=DoFn.StateParam(BUFFER_STATE), 844 timer=DoFn.TimerParam(EXPIRY_TIMER)): 845 yield ''.join(str(x) for x in sorted(buffer.read())) 846 847 with TestPipeline() as p: 848 test_stream = ( 849 TestStream().advance_watermark_to(10).add_elements( 850 [1, 851 2]).add_elements([3]).advance_watermark_to(25).add_elements([4])) 852 ( 853 p 854 | test_stream 855 | beam.Map(lambda x: ('mykey', x)) 856 | beam.ParDo(SimpleTestStatefulDoFn()) 857 | beam.ParDo(self.record_dofn())) 858 859 self.assertEqual(['123', '1234'], 860 StatefulDoFnOnDirectRunnerTest.all_records) 861 862 def test_timer_output_timestamp(self): 863 class TimerEmittingStatefulDoFn(DoFn): 864 EMIT_TIMER_1 = TimerSpec('emit1', TimeDomain.WATERMARK) 865 EMIT_TIMER_2 = TimerSpec('emit2', TimeDomain.WATERMARK) 866 EMIT_TIMER_3 = TimerSpec('emit3', TimeDomain.WATERMARK) 867 868 def process( 869 self, 870 element, 871 timer1=DoFn.TimerParam(EMIT_TIMER_1), 872 timer2=DoFn.TimerParam(EMIT_TIMER_2), 873 timer3=DoFn.TimerParam(EMIT_TIMER_3)): 874 timer1.set(10) 875 timer2.set(20) 876 timer3.set(30) 877 878 @on_timer(EMIT_TIMER_1) 879 def emit_callback_1(self): 880 yield 'timer1' 881 882 @on_timer(EMIT_TIMER_2) 883 def emit_callback_2(self): 884 yield 'timer2' 885 886 @on_timer(EMIT_TIMER_3) 887 def emit_callback_3(self): 888 yield 'timer3' 889 890 class TimestampReifyingDoFn(DoFn): 891 def process(self, element, ts=DoFn.TimestampParam): 892 yield (element, int(ts)) 893 894 with TestPipeline() as p: 895 test_stream = (TestStream().advance_watermark_to(10).add_elements([1])) 896 ( 897 p 898 | test_stream 899 | beam.Map(lambda x: ('mykey', x)) 900 | beam.ParDo(TimerEmittingStatefulDoFn()) 901 | beam.ParDo(TimestampReifyingDoFn()) 902 | beam.ParDo(self.record_dofn())) 903 904 self.assertEqual([('timer1', 10), ('timer2', 20), ('timer3', 30)], 905 sorted(StatefulDoFnOnDirectRunnerTest.all_records)) 906 907 def test_timer_output_timestamp_and_window(self): 908 class TimerEmittingStatefulDoFn(DoFn): 909 EMIT_TIMER_1 = TimerSpec('emit1', TimeDomain.WATERMARK) 910 911 def process(self, element, timer1=DoFn.TimerParam(EMIT_TIMER_1)): 912 timer1.set(10) 913 914 @on_timer(EMIT_TIMER_1) 915 def emit_callback_1( 916 self, 917 window=DoFn.WindowParam, 918 ts=DoFn.TimestampParam, 919 key=DoFn.KeyParam): 920 yield ( 921 'timer1-{key}'.format(key=key), 922 int(ts), 923 int(window.start), 924 int(window.end)) 925 926 pipeline_options = PipelineOptions() 927 with TestPipeline(options=pipeline_options) as p: 928 test_stream = (TestStream().advance_watermark_to(10).add_elements([1])) 929 ( 930 p 931 | test_stream 932 | beam.Map(lambda x: ('mykey', x)) 933 | "window_into" >> beam.WindowInto( 934 window.FixedWindows(5), 935 accumulation_mode=trigger.AccumulationMode.DISCARDING) 936 | beam.ParDo(TimerEmittingStatefulDoFn()) 937 | beam.ParDo(self.record_dofn())) 938 939 self.assertEqual([('timer1-mykey', 10, 10, 15)], 940 sorted(StatefulDoFnOnDirectRunnerTest.all_records)) 941 942 def test_timer_default_tag(self): 943 class DynamicTimerDoFn(DoFn): 944 EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK) 945 946 def process(self, element, emit=DoFn.TimerParam(EMIT_TIMER_FAMILY)): 947 emit.set(10) 948 emit.set(20, dynamic_timer_tag='') 949 950 @on_timer(EMIT_TIMER_FAMILY) 951 def emit_callback( 952 self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam): 953 yield (tag, ts) 954 955 with TestPipeline() as p: 956 test_stream = (TestStream().advance_watermark_to(10).add_elements( 957 [1])).advance_watermark_to_infinity() 958 ( 959 p 960 | test_stream 961 | beam.Map(lambda x: ('mykey', x)) 962 | beam.ParDo(DynamicTimerDoFn()) 963 | beam.ParDo(self.record_dofn())) 964 965 self.assertEqual([('', 20)], 966 sorted(StatefulDoFnOnDirectRunnerTest.all_records)) 967 968 def test_dynamic_timer_simple_dofn(self): 969 class DynamicTimerDoFn(DoFn): 970 EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK) 971 972 def process(self, element, emit=DoFn.TimerParam(EMIT_TIMER_FAMILY)): 973 emit.set(10, dynamic_timer_tag='emit1') 974 emit.set(20, dynamic_timer_tag='emit2') 975 emit.set(30, dynamic_timer_tag='emit3') 976 977 @on_timer(EMIT_TIMER_FAMILY) 978 def emit_callback( 979 self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam): 980 yield (tag, ts) 981 982 with TestPipeline() as p: 983 test_stream = (TestStream().advance_watermark_to(10).add_elements( 984 [1])).advance_watermark_to_infinity() 985 ( 986 p 987 | test_stream 988 | beam.Map(lambda x: ('mykey', x)) 989 | beam.ParDo(DynamicTimerDoFn()) 990 | beam.ParDo(self.record_dofn())) 991 992 self.assertEqual([('emit1', 10), ('emit2', 20), ('emit3', 30)], 993 sorted(StatefulDoFnOnDirectRunnerTest.all_records)) 994 995 @pytest.mark.no_xdist 996 @pytest.mark.timeout(10) 997 def test_dynamic_timer_clear_then_set_timer(self): 998 class EmitTwoEvents(DoFn): 999 EMIT_CLEAR_SET_TIMER = TimerSpec('emitclear', TimeDomain.WATERMARK) 1000 1001 def process(self, element, emit=DoFn.TimerParam(EMIT_CLEAR_SET_TIMER)): 1002 yield ('1', 'set') 1003 emit.set(1) 1004 1005 @on_timer(EMIT_CLEAR_SET_TIMER) 1006 def emit_clear(self): 1007 yield ('1', 'clear') 1008 1009 class DynamicTimerDoFn(DoFn): 1010 EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK) 1011 1012 def process(self, element, emit=DoFn.TimerParam(EMIT_TIMER_FAMILY)): 1013 if element[1] == 'set': 1014 emit.set(10, dynamic_timer_tag='emit1') 1015 emit.set(20, dynamic_timer_tag='emit2') 1016 if element[1] == 'clear': 1017 emit.set(30, dynamic_timer_tag='emit3') 1018 emit.clear(dynamic_timer_tag='emit3') 1019 emit.set(40, dynamic_timer_tag='emit3') 1020 return [] 1021 1022 @on_timer(EMIT_TIMER_FAMILY) 1023 def emit_callback( 1024 self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam): 1025 yield (tag, ts) 1026 1027 with TestPipeline() as p: 1028 res = ( 1029 p 1030 | beam.Create([('1', 'impulse')]) 1031 | beam.ParDo(EmitTwoEvents()) 1032 | beam.ParDo(DynamicTimerDoFn())) 1033 assert_that(res, equal_to([('emit1', 10), ('emit2', 20), ('emit3', 40)])) 1034 1035 def test_dynamic_timer_clear_timer(self): 1036 class DynamicTimerDoFn(DoFn): 1037 EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK) 1038 1039 def process(self, element, emit=DoFn.TimerParam(EMIT_TIMER_FAMILY)): 1040 if element[1] == 'set': 1041 emit.set(10, dynamic_timer_tag='emit1') 1042 emit.set(20, dynamic_timer_tag='emit2') 1043 emit.set(30, dynamic_timer_tag='emit3') 1044 if element[1] == 'clear': 1045 emit.clear(dynamic_timer_tag='emit3') 1046 1047 @on_timer(EMIT_TIMER_FAMILY) 1048 def emit_callback( 1049 self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam): 1050 yield (tag, ts) 1051 1052 with TestPipeline() as p: 1053 test_stream = ( 1054 TestStream().advance_watermark_to(5).add_elements( 1055 ['set']).advance_watermark_to(10).add_elements( 1056 ['clear']).advance_watermark_to_infinity()) 1057 ( 1058 p 1059 | test_stream 1060 | beam.Map(lambda x: ('mykey', x)) 1061 | beam.ParDo(DynamicTimerDoFn()) 1062 | beam.ParDo(self.record_dofn())) 1063 1064 self.assertEqual([('emit1', 10), ('emit2', 20)], 1065 sorted(StatefulDoFnOnDirectRunnerTest.all_records)) 1066 1067 def test_dynamic_timer_multiple(self): 1068 class DynamicTimerDoFn(DoFn): 1069 EMIT_TIMER_FAMILY1 = TimerSpec('emit_family_1', TimeDomain.WATERMARK) 1070 EMIT_TIMER_FAMILY2 = TimerSpec('emit_family_2', TimeDomain.WATERMARK) 1071 1072 def process( 1073 self, 1074 element, 1075 emit1=DoFn.TimerParam(EMIT_TIMER_FAMILY1), 1076 emit2=DoFn.TimerParam(EMIT_TIMER_FAMILY2)): 1077 emit1.set(10, dynamic_timer_tag='emit11') 1078 emit1.set(20, dynamic_timer_tag='emit12') 1079 emit1.set(30, dynamic_timer_tag='emit13') 1080 emit2.set(30, dynamic_timer_tag='emit21') 1081 emit2.set(20, dynamic_timer_tag='emit22') 1082 emit2.set(10, dynamic_timer_tag='emit23') 1083 1084 @on_timer(EMIT_TIMER_FAMILY1) 1085 def emit_callback( 1086 self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam): 1087 yield (tag, ts) 1088 1089 @on_timer(EMIT_TIMER_FAMILY2) 1090 def emit_callback_2( 1091 self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam): 1092 yield (tag, ts) 1093 1094 with TestPipeline() as p: 1095 test_stream = ( 1096 TestStream().advance_watermark_to(5).add_elements( 1097 ['1']).advance_watermark_to_infinity()) 1098 ( 1099 p 1100 | test_stream 1101 | beam.Map(lambda x: ('mykey', x)) 1102 | beam.ParDo(DynamicTimerDoFn()) 1103 | beam.ParDo(self.record_dofn())) 1104 1105 self.assertEqual([('emit11', 10), ('emit12', 20), ('emit13', 30), 1106 ('emit21', 30), ('emit22', 20), ('emit23', 10)], 1107 sorted(StatefulDoFnOnDirectRunnerTest.all_records)) 1108 1109 def test_dynamic_timer_and_simple_timer(self): 1110 class DynamicTimerDoFn(DoFn): 1111 EMIT_TIMER_FAMILY = TimerSpec('emit', TimeDomain.WATERMARK) 1112 GC_TIMER = TimerSpec('gc', TimeDomain.WATERMARK) 1113 1114 def process( 1115 self, 1116 element, 1117 emit=DoFn.TimerParam(EMIT_TIMER_FAMILY), 1118 gc=DoFn.TimerParam(GC_TIMER)): 1119 emit.set(10, dynamic_timer_tag='emit1') 1120 emit.set(20, dynamic_timer_tag='emit2') 1121 emit.set(30, dynamic_timer_tag='emit3') 1122 gc.set(40) 1123 1124 @on_timer(EMIT_TIMER_FAMILY) 1125 def emit_callback( 1126 self, ts=DoFn.TimestampParam, tag=DoFn.DynamicTimerTagParam): 1127 yield (tag, ts) 1128 1129 @on_timer(GC_TIMER) 1130 def gc(self, ts=DoFn.TimestampParam): 1131 yield ('gc', ts) 1132 1133 with TestPipeline() as p: 1134 test_stream = ( 1135 TestStream().advance_watermark_to(5).add_elements( 1136 ['1']).advance_watermark_to_infinity()) 1137 ( 1138 p 1139 | test_stream 1140 | beam.Map(lambda x: ('mykey', x)) 1141 | beam.ParDo(DynamicTimerDoFn()) 1142 | beam.ParDo(self.record_dofn())) 1143 1144 self.assertEqual([('emit1', 10), ('emit2', 20), ('emit3', 30), ('gc', 40)], 1145 sorted(StatefulDoFnOnDirectRunnerTest.all_records)) 1146 1147 def test_index_assignment(self): 1148 class IndexAssigningStatefulDoFn(DoFn): 1149 INDEX_STATE = CombiningValueStateSpec('index', sum) 1150 1151 def process(self, element, state=DoFn.StateParam(INDEX_STATE)): 1152 unused_key, value = element 1153 current_index = state.read() 1154 yield (value, current_index) 1155 state.add(1) 1156 1157 with TestPipeline() as p: 1158 test_stream = ( 1159 TestStream().advance_watermark_to(10).add_elements([ 1160 'A', 'B' 1161 ]).add_elements(['C']).advance_watermark_to(25).add_elements(['D'])) 1162 ( 1163 p 1164 | test_stream 1165 | beam.Map(lambda x: ('mykey', x)) 1166 | beam.ParDo(IndexAssigningStatefulDoFn()) 1167 | beam.ParDo(self.record_dofn())) 1168 1169 self.assertEqual([('A', 0), ('B', 1), ('C', 2), ('D', 3)], 1170 StatefulDoFnOnDirectRunnerTest.all_records) 1171 1172 def test_hash_join(self): 1173 class HashJoinStatefulDoFn(DoFn): 1174 BUFFER_STATE = BagStateSpec('buffer', BytesCoder()) 1175 UNMATCHED_TIMER = TimerSpec('unmatched', TimeDomain.WATERMARK) 1176 1177 def process( 1178 self, 1179 element, 1180 state=DoFn.StateParam(BUFFER_STATE), 1181 timer=DoFn.TimerParam(UNMATCHED_TIMER)): 1182 key, value = element 1183 existing_values = list(state.read()) 1184 if not existing_values: 1185 state.add(value) 1186 timer.set(100) 1187 else: 1188 yield b'Record<%s,%s,%s>' % (key, existing_values[0], value) 1189 state.clear() 1190 timer.clear() 1191 1192 @on_timer(UNMATCHED_TIMER) 1193 def expiry_callback(self, state=DoFn.StateParam(BUFFER_STATE)): 1194 buffered = list(state.read()) 1195 assert len(buffered) == 1, buffered 1196 state.clear() 1197 yield b'Unmatched<%s>' % (buffered[0], ) 1198 1199 with TestPipeline() as p: 1200 test_stream = ( 1201 TestStream().advance_watermark_to(10).add_elements([ 1202 (b'A', b'a'), (b'B', b'b') 1203 ]).add_elements([ 1204 (b'A', b'aa'), (b'C', b'c') 1205 ]).advance_watermark_to(25).add_elements([ 1206 (b'A', b'aaa'), (b'B', b'bb') 1207 ]).add_elements([ 1208 (b'D', b'd'), (b'D', b'dd'), (b'D', b'ddd'), (b'D', b'dddd') 1209 ]).advance_watermark_to(125).add_elements([(b'C', b'cc')])) 1210 ( 1211 p 1212 | test_stream 1213 | beam.ParDo(HashJoinStatefulDoFn()) 1214 | beam.ParDo(self.record_dofn())) 1215 1216 equal_to(StatefulDoFnOnDirectRunnerTest.all_records)([ 1217 b'Record<A,a,aa>', 1218 b'Record<B,b,bb>', 1219 b'Record<D,d,dd>', 1220 b'Record<D,ddd,dddd>', 1221 b'Unmatched<aaa>', 1222 b'Unmatched<c>', 1223 b'Unmatched<cc>' 1224 ]) 1225 1226 1227 if __name__ == '__main__': 1228 unittest.main()