github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/pipeline_context_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the windowing classes."""
    19  
    20  # pytype: skip-file
    21  
    22  import unittest
    23  
    24  from apache_beam import coders
    25  from apache_beam.runners import pipeline_context
    26  from apache_beam.transforms import environments
    27  
    28  
    29  class PipelineContextTest(unittest.TestCase):
    30    def test_deduplication(self):
    31      context = pipeline_context.PipelineContext()
    32      bytes_coder_ref = context.coders.get_id(coders.BytesCoder())
    33      bytes_coder_ref2 = context.coders.get_id(coders.BytesCoder())
    34      self.assertEqual(bytes_coder_ref, bytes_coder_ref2)
    35  
    36    def test_deduplication_by_proto(self):
    37      context = pipeline_context.PipelineContext()
    38      env_proto = environments.SubprocessSDKEnvironment(
    39          command_string="foo").to_runner_api(None)
    40      env_ref_1 = context.environments.get_by_proto(env_proto)
    41      env_ref_2 = context.environments.get_by_proto(env_proto, deduplicate=True)
    42      self.assertEqual(env_ref_1, env_ref_2)
    43  
    44    def test_equal_environments_are_deduplicated_when_fetched_by_obj_or_proto(
    45        self):
    46      context = pipeline_context.PipelineContext()
    47  
    48      env = environments.SubprocessSDKEnvironment(command_string="foo")
    49      env_proto = env.to_runner_api(None)
    50      id_from_proto = context.environments.get_by_proto(env_proto)
    51      id_from_obj = context.environments.get_id(env)
    52      self.assertEqual(id_from_obj, id_from_proto)
    53      self.assertEqual(
    54          context.environments.get_by_id(id_from_obj).command_string, "foo")
    55  
    56      env = environments.SubprocessSDKEnvironment(command_string="bar")
    57      env_proto = env.to_runner_api(None)
    58      id_from_obj = context.environments.get_id(env)
    59      id_from_proto = context.environments.get_by_proto(
    60          env_proto, deduplicate=True)
    61      self.assertEqual(id_from_obj, id_from_proto)
    62      self.assertEqual(
    63          context.environments.get_by_id(id_from_obj).command_string, "bar")
    64  
    65    def test_serialization(self):
    66      context = pipeline_context.PipelineContext()
    67      float_coder_ref = context.coders.get_id(coders.FloatCoder())
    68      bytes_coder_ref = context.coders.get_id(coders.BytesCoder())
    69      proto = context.to_runner_api()
    70      context2 = pipeline_context.PipelineContext.from_runner_api(proto)
    71      self.assertEqual(
    72          coders.FloatCoder(), context2.coders.get_by_id(float_coder_ref))
    73      self.assertEqual(
    74          coders.BytesCoder(), context2.coders.get_by_id(bytes_coder_ref))
    75  
    76    def test_common_id_assignment(self):
    77      context = pipeline_context.PipelineContext()
    78      float_coder_ref = context.coders.get_id(coders.FloatCoder())
    79      bytes_coder_ref = context.coders.get_id(coders.BytesCoder())
    80      context2 = pipeline_context.PipelineContext(
    81          component_id_map=context.component_id_map)
    82  
    83      bytes_coder_ref2 = context2.coders.get_id(coders.BytesCoder())
    84      float_coder_ref2 = context2.coders.get_id(coders.FloatCoder())
    85  
    86      self.assertEqual(bytes_coder_ref, bytes_coder_ref2)
    87      self.assertEqual(float_coder_ref, float_coder_ref2)
    88  
    89  
    90  if __name__ == '__main__':
    91    unittest.main()