github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/external/generate_sequence_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for cross-language generate sequence."""
    19  
    20  # pytype: skip-file
    21  
    22  import logging
    23  import os
    24  import re
    25  import unittest
    26  
    27  import pytest
    28  
    29  from apache_beam.io.external.generate_sequence import GenerateSequence
    30  from apache_beam.testing.test_pipeline import TestPipeline
    31  from apache_beam.testing.util import assert_that
    32  from apache_beam.testing.util import equal_to
    33  from apache_beam.transforms.external import ExternalTransform
    34  from apache_beam.transforms.external import JavaClassLookupPayloadBuilder
    35  from apache_beam.transforms.external import JavaExternalTransform
    36  
    37  
    38  @pytest.mark.uses_java_expansion_service
    39  @unittest.skipUnless(
    40      os.environ.get('EXPANSION_PORT'),
    41      "EXPANSION_PORT environment var is not provided.")
    42  class XlangGenerateSequenceTest(unittest.TestCase):
    43    def test_generate_sequence(self):
    44      port = os.environ.get('EXPANSION_PORT')
    45      address = 'localhost:%s' % port
    46  
    47      try:
    48        with TestPipeline() as p:
    49          res = (
    50              p
    51              | GenerateSequence(start=1, stop=10, expansion_service=address))
    52  
    53          assert_that(res, equal_to(list(range(1, 10))))
    54      except RuntimeError as e:
    55        if re.search(GenerateSequence.URN, str(e)):
    56          print("looks like URN not implemented in expansion service, skipping.")
    57        else:
    58          raise e
    59  
    60    def test_generate_sequence_java_class_lookup_payload_builder(self):
    61      port = os.environ.get('EXPANSION_PORT')
    62      address = 'localhost:%s' % port
    63  
    64      with TestPipeline() as p:
    65        payload_builder = JavaClassLookupPayloadBuilder(
    66            'org.apache.beam.sdk.io.GenerateSequence')
    67        payload_builder.with_constructor_method('from', 1)
    68        payload_builder.add_builder_method('to', 10)
    69  
    70        res = (
    71            p
    72            | ExternalTransform(None, payload_builder, expansion_service=address))
    73        assert_that(res, equal_to(list(range(1, 10))))
    74  
    75    def test_generate_sequence_java_external_transform(self):
    76      port = os.environ.get('EXPANSION_PORT')
    77      address = 'localhost:%s' % port
    78  
    79      with TestPipeline() as p:
    80        java_transform = JavaExternalTransform(
    81            'org.apache.beam.sdk.io.GenerateSequence', expansion_service=address)
    82        # We have to use 'getattr' below for builder method 'from' of Java
    83        # 'GenerateSequence' class since 'from' is a reserved keyword for Python.
    84        java_transform = getattr(java_transform, 'from')(1).to(10)
    85        res = (p | java_transform)
    86  
    87        assert_that(res, equal_to(list(range(1, 10))))
    88  
    89  
    90  if __name__ == '__main__':
    91    logging.getLogger().setLevel(logging.INFO)
    92    unittest.main()