github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/gcp/pubsub_test.py (about)

     1  # coding=utf-8
     2  #
     3  # Licensed to the Apache Software Foundation (ASF) under one or more
     4  # contributor license agreements.  See the NOTICE file distributed with
     5  # this work for additional information regarding copyright ownership.
     6  # The ASF licenses this file to You under the Apache License, Version 2.0
     7  # (the "License"); you may not use this file except in compliance with
     8  # the License.  You may obtain a copy of the License at
     9  #
    10  #    http://www.apache.org/licenses/LICENSE-2.0
    11  #
    12  # Unless required by applicable law or agreed to in writing, software
    13  # distributed under the License is distributed on an "AS IS" BASIS,
    14  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  # See the License for the specific language governing permissions and
    16  # limitations under the License.
    17  #
    18  
    19  """Unit tests for PubSub sources and sinks."""
    20  
    21  # pytype: skip-file
    22  
    23  import logging
    24  import unittest
    25  
    26  import hamcrest as hc
    27  import mock
    28  
    29  import apache_beam as beam
    30  from apache_beam.io import Read
    31  from apache_beam.io import Write
    32  from apache_beam.io.gcp.pubsub import MultipleReadFromPubSub
    33  from apache_beam.io.gcp.pubsub import PubsubMessage
    34  from apache_beam.io.gcp.pubsub import PubSubSourceDescriptor
    35  from apache_beam.io.gcp.pubsub import ReadFromPubSub
    36  from apache_beam.io.gcp.pubsub import ReadStringsFromPubSub
    37  from apache_beam.io.gcp.pubsub import WriteStringsToPubSub
    38  from apache_beam.io.gcp.pubsub import WriteToPubSub
    39  from apache_beam.io.gcp.pubsub import _PubSubSink
    40  from apache_beam.io.gcp.pubsub import _PubSubSource
    41  from apache_beam.options.pipeline_options import PipelineOptions
    42  from apache_beam.options.pipeline_options import StandardOptions
    43  from apache_beam.portability import common_urns
    44  from apache_beam.portability.api import beam_runner_api_pb2
    45  from apache_beam.runners import pipeline_context
    46  from apache_beam.runners.direct import transform_evaluator
    47  from apache_beam.runners.direct.direct_runner import _DirectReadFromPubSub
    48  from apache_beam.runners.direct.direct_runner import _get_transform_overrides
    49  from apache_beam.runners.direct.transform_evaluator import _PubSubReadEvaluator
    50  from apache_beam.testing import test_utils
    51  from apache_beam.testing.test_pipeline import TestPipeline
    52  from apache_beam.testing.util import TestWindowedValue
    53  from apache_beam.testing.util import assert_that
    54  from apache_beam.testing.util import equal_to
    55  from apache_beam.transforms import window
    56  from apache_beam.transforms.core import Create
    57  from apache_beam.transforms.display import DisplayData
    58  from apache_beam.transforms.display_test import DisplayDataItemMatcher
    59  from apache_beam.utils import proto_utils
    60  from apache_beam.utils import timestamp
    61  
    62  # Protect against environments where the PubSub library is not available.
    63  try:
    64    from google.cloud import pubsub
    65  except ImportError:
    66    pubsub = None
    67  
    68  
    69  class TestPubsubMessage(unittest.TestCase):
    70    def test_payload_valid(self):
    71      _ = PubsubMessage('', None)
    72      _ = PubsubMessage('data', None)
    73      _ = PubsubMessage(None, {'k': 'v'})
    74  
    75    def test_payload_invalid(self):
    76      with self.assertRaisesRegex(ValueError, r'data.*attributes.*must be set'):
    77        _ = PubsubMessage(None, None)
    78      with self.assertRaisesRegex(ValueError, r'data.*attributes.*must be set'):
    79        _ = PubsubMessage(None, {})
    80  
    81    @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
    82    def test_proto_conversion(self):
    83      data = b'data'
    84      attributes = {'k1': 'v1', 'k2': 'v2'}
    85      m = PubsubMessage(data, attributes)
    86      m_converted = PubsubMessage._from_proto_str(m._to_proto_str())
    87      self.assertEqual(m_converted.data, data)
    88      self.assertEqual(m_converted.attributes, attributes)
    89  
    90    @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
    91    def test_payload_publish_invalid(self):
    92      with self.assertRaisesRegex(ValueError, r'data field.*10MB'):
    93        msg = PubsubMessage(b'0' * 1024 * 1024 * 11, None)
    94        msg._to_proto_str(for_publish=True)
    95      with self.assertRaisesRegex(ValueError, 'attribute key'):
    96        msg = PubsubMessage(b'0', {'0' * 257: '0'})
    97        msg._to_proto_str(for_publish=True)
    98      with self.assertRaisesRegex(ValueError, 'attribute value'):
    99        msg = PubsubMessage(b'0', {'0' * 100: '0' * 1025})
   100        msg._to_proto_str(for_publish=True)
   101      with self.assertRaisesRegex(ValueError, '100 attributes'):
   102        attributes = {}
   103        for i in range(0, 101):
   104          attributes[str(i)] = str(i)
   105        msg = PubsubMessage(b'0', attributes)
   106        msg._to_proto_str(for_publish=True)
   107      with self.assertRaisesRegex(ValueError, 'ordering key'):
   108        msg = PubsubMessage(b'0', None, ordering_key='0' * 1301)
   109        msg._to_proto_str(for_publish=True)
   110  
   111    def test_eq(self):
   112      a = PubsubMessage(b'abc', {1: 2, 3: 4})
   113      b = PubsubMessage(b'abc', {1: 2, 3: 4})
   114      c = PubsubMessage(b'abc', {1: 2})
   115      self.assertTrue(a == b)
   116      self.assertTrue(a != c)
   117      self.assertTrue(b != c)
   118  
   119    def test_hash(self):
   120      a = PubsubMessage(b'abc', {1: 2, 3: 4})
   121      b = PubsubMessage(b'abc', {1: 2, 3: 4})
   122      c = PubsubMessage(b'abc', {1: 2})
   123      self.assertTrue(hash(a) == hash(b))
   124      self.assertTrue(hash(a) != hash(c))
   125      self.assertTrue(hash(b) != hash(c))
   126  
   127    def test_repr(self):
   128      a = PubsubMessage(b'abc', {1: 2, 3: 4})
   129      b = PubsubMessage(b'abc', {1: 2, 3: 4})
   130      c = PubsubMessage(b'abc', {1: 2})
   131      self.assertTrue(repr(a) == repr(b))
   132      self.assertTrue(repr(a) != repr(c))
   133      self.assertTrue(repr(b) != repr(c))
   134  
   135  
   136  @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
   137  class TestReadFromPubSubOverride(unittest.TestCase):
   138    def test_expand_with_topic(self):
   139      options = PipelineOptions([])
   140      options.view_as(StandardOptions).streaming = True
   141      p = TestPipeline(options=options)
   142      pcoll = (
   143          p
   144          | ReadFromPubSub(
   145              'projects/fakeprj/topics/a_topic',
   146              None,
   147              'a_label',
   148              with_attributes=False,
   149              timestamp_attribute=None)
   150          | beam.Map(lambda x: x))
   151      self.assertEqual(bytes, pcoll.element_type)
   152  
   153      # Apply the necessary PTransformOverrides.
   154      overrides = _get_transform_overrides(options)
   155      p.replace_all(overrides)
   156  
   157      # Note that the direct output of ReadFromPubSub will be replaced
   158      # by a PTransformOverride, so we use a no-op Map.
   159      read_transform = pcoll.producer.inputs[0].producer.transform
   160  
   161      # Ensure that the properties passed through correctly
   162      source = read_transform._source
   163      self.assertEqual('a_topic', source.topic_name)
   164      self.assertEqual('a_label', source.id_label)
   165  
   166    def test_expand_with_subscription(self):
   167      options = PipelineOptions([])
   168      options.view_as(StandardOptions).streaming = True
   169      p = TestPipeline(options=options)
   170      pcoll = (
   171          p
   172          | ReadFromPubSub(
   173              None,
   174              'projects/fakeprj/subscriptions/a_subscription',
   175              'a_label',
   176              with_attributes=False,
   177              timestamp_attribute=None)
   178          | beam.Map(lambda x: x))
   179      self.assertEqual(bytes, pcoll.element_type)
   180  
   181      # Apply the necessary PTransformOverrides.
   182      overrides = _get_transform_overrides(options)
   183      p.replace_all(overrides)
   184  
   185      # Note that the direct output of ReadFromPubSub will be replaced
   186      # by a PTransformOverride, so we use a no-op Map.
   187      read_transform = pcoll.producer.inputs[0].producer.transform
   188  
   189      # Ensure that the properties passed through correctly
   190      source = read_transform._source
   191      self.assertEqual('a_subscription', source.subscription_name)
   192      self.assertEqual('a_label', source.id_label)
   193  
   194    def test_expand_with_no_topic_or_subscription(self):
   195      with self.assertRaisesRegex(
   196          ValueError, "Either a topic or subscription must be provided."):
   197        ReadFromPubSub(
   198            None,
   199            None,
   200            'a_label',
   201            with_attributes=False,
   202            timestamp_attribute=None)
   203  
   204    def test_expand_with_both_topic_and_subscription(self):
   205      with self.assertRaisesRegex(
   206          ValueError, "Only one of topic or subscription should be provided."):
   207        ReadFromPubSub(
   208            'a_topic',
   209            'a_subscription',
   210            'a_label',
   211            with_attributes=False,
   212            timestamp_attribute=None)
   213  
   214    def test_expand_with_other_options(self):
   215      options = PipelineOptions([])
   216      options.view_as(StandardOptions).streaming = True
   217      p = TestPipeline(options=options)
   218      pcoll = (
   219          p
   220          | ReadFromPubSub(
   221              'projects/fakeprj/topics/a_topic',
   222              None,
   223              'a_label',
   224              with_attributes=True,
   225              timestamp_attribute='time')
   226          | beam.Map(lambda x: x))
   227      self.assertEqual(PubsubMessage, pcoll.element_type)
   228  
   229      # Apply the necessary PTransformOverrides.
   230      overrides = _get_transform_overrides(options)
   231      p.replace_all(overrides)
   232  
   233      # Note that the direct output of ReadFromPubSub will be replaced
   234      # by a PTransformOverride, so we use a no-op Map.
   235      read_transform = pcoll.producer.inputs[0].producer.transform
   236  
   237      # Ensure that the properties passed through correctly
   238      source = read_transform._source
   239      self.assertTrue(source.with_attributes)
   240      self.assertEqual('time', source.timestamp_attribute)
   241  
   242  
   243  @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
   244  class TestMultiReadFromPubSubOverride(unittest.TestCase):
   245    def test_expand_with_multiple_sources(self):
   246      options = PipelineOptions([])
   247      options.view_as(StandardOptions).streaming = True
   248      p = TestPipeline(options=options)
   249      topics = [
   250          'projects/fakeprj/topics/a_topic', 'projects/fakeprj2/topics/b_topic'
   251      ]
   252      subscriptions = ['projects/fakeprj/subscriptions/a_subscription']
   253  
   254      pubsub_sources = [
   255          PubSubSourceDescriptor(descriptor)
   256          for descriptor in topics + subscriptions
   257      ]
   258      pcoll = (p | MultipleReadFromPubSub(pubsub_sources) | beam.Map(lambda x: x))
   259  
   260      # Apply the necessary PTransformOverrides.
   261      overrides = _get_transform_overrides(options)
   262      p.replace_all(overrides)
   263  
   264      self.assertEqual(bytes, pcoll.element_type)
   265  
   266      # Ensure that the sources are passed through correctly
   267      read_transforms = pcoll.producer.inputs[0].producer.inputs
   268      topics_list = []
   269      subscription_list = []
   270      for read_transform in read_transforms:
   271        source = read_transform.producer.transform._source
   272        if source.full_topic:
   273          topics_list.append(source.full_topic)
   274        else:
   275          subscription_list.append(source.full_subscription)
   276      self.assertEqual(topics_list, topics)
   277      self.assertEqual(subscription_list, subscriptions)
   278  
   279    def test_expand_with_multiple_sources_and_attributes(self):
   280      options = PipelineOptions([])
   281      options.view_as(StandardOptions).streaming = True
   282      p = TestPipeline(options=options)
   283      topics = [
   284          'projects/fakeprj/topics/a_topic', 'projects/fakeprj2/topics/b_topic'
   285      ]
   286      subscriptions = ['projects/fakeprj/subscriptions/a_subscription']
   287  
   288      pubsub_sources = [
   289          PubSubSourceDescriptor(descriptor)
   290          for descriptor in topics + subscriptions
   291      ]
   292      pcoll = (
   293          p | MultipleReadFromPubSub(pubsub_sources, with_attributes=True)
   294          | beam.Map(lambda x: x))
   295  
   296      # Apply the necessary PTransformOverrides.
   297      overrides = _get_transform_overrides(options)
   298      p.replace_all(overrides)
   299  
   300      self.assertEqual(PubsubMessage, pcoll.element_type)
   301  
   302      # Ensure that the sources are passed through correctly
   303      read_transforms = pcoll.producer.inputs[0].producer.inputs
   304      topics_list = []
   305      subscription_list = []
   306      for read_transform in read_transforms:
   307        source = read_transform.producer.transform._source
   308        if source.full_topic:
   309          topics_list.append(source.full_topic)
   310        else:
   311          subscription_list.append(source.full_subscription)
   312      self.assertEqual(topics_list, topics)
   313      self.assertEqual(subscription_list, subscriptions)
   314  
   315    def test_expand_with_multiple_sources_and_other_options(self):
   316      options = PipelineOptions([])
   317      options.view_as(StandardOptions).streaming = True
   318      p = TestPipeline(options=options)
   319      sources = [
   320          'projects/fakeprj/topics/a_topic',
   321          'projects/fakeprj2/topics/b_topic',
   322          'projects/fakeprj/subscriptions/a_subscription'
   323      ]
   324      id_labels = ['a_label_topic', 'b_label_topic', 'a_label_subscription']
   325      timestamp_attributes = ['a_ta_topic', 'b_ta_topic', 'a_ta_subscription']
   326  
   327      pubsub_sources = [
   328          PubSubSourceDescriptor(
   329              source=source,
   330              id_label=id_label,
   331              timestamp_attribute=timestamp_attribute) for source,
   332          id_label,
   333          timestamp_attribute in zip(sources, id_labels, timestamp_attributes)
   334      ]
   335  
   336      pcoll = (p | MultipleReadFromPubSub(pubsub_sources) | beam.Map(lambda x: x))
   337  
   338      # Apply the necessary PTransformOverrides.
   339      overrides = _get_transform_overrides(options)
   340      p.replace_all(overrides)
   341  
   342      self.assertEqual(bytes, pcoll.element_type)
   343  
   344      # Ensure that the sources are passed through correctly
   345      read_transforms = pcoll.producer.inputs[0].producer.inputs
   346      for i, read_transform in enumerate(read_transforms):
   347        id_label = id_labels[i]
   348        timestamp_attribute = timestamp_attributes[i]
   349  
   350        source = read_transform.producer.transform._source
   351        self.assertEqual(source.id_label, id_label)
   352        self.assertEqual(source.with_attributes, False)
   353        self.assertEqual(source.timestamp_attribute, timestamp_attribute)
   354  
   355    def test_expand_with_wrong_source(self):
   356      with self.assertRaisesRegex(
   357          ValueError,
   358          r'PubSub source descriptor must be in the form '
   359          r'"projects/<project>/topics/<topic>"'
   360          ' or "projects/<project>/subscription/<subscription>".*'):
   361        MultipleReadFromPubSub([PubSubSourceDescriptor('not_a_proper_source')])
   362  
   363  
   364  @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
   365  class TestWriteStringsToPubSubOverride(unittest.TestCase):
   366    def test_expand_deprecated(self):
   367      options = PipelineOptions([])
   368      options.view_as(StandardOptions).streaming = True
   369      p = TestPipeline(options=options)
   370      pcoll = (
   371          p
   372          | ReadFromPubSub('projects/fakeprj/topics/baz')
   373          | WriteStringsToPubSub('projects/fakeprj/topics/a_topic')
   374          | beam.Map(lambda x: x))
   375  
   376      # Apply the necessary PTransformOverrides.
   377      overrides = _get_transform_overrides(options)
   378      p.replace_all(overrides)
   379  
   380      # Note that the direct output of ReadFromPubSub will be replaced
   381      # by a PTransformOverride, so we use a no-op Map.
   382      write_transform = pcoll.producer.inputs[0].producer.transform
   383  
   384      # Ensure that the properties passed through correctly
   385      self.assertEqual('a_topic', write_transform.dofn.short_topic_name)
   386  
   387    def test_expand(self):
   388      options = PipelineOptions([])
   389      options.view_as(StandardOptions).streaming = True
   390      p = TestPipeline(options=options)
   391      pcoll = (
   392          p
   393          | ReadFromPubSub('projects/fakeprj/topics/baz')
   394          | WriteToPubSub(
   395              'projects/fakeprj/topics/a_topic', with_attributes=True)
   396          | beam.Map(lambda x: x))
   397  
   398      # Apply the necessary PTransformOverrides.
   399      overrides = _get_transform_overrides(options)
   400      p.replace_all(overrides)
   401  
   402      # Note that the direct output of ReadFromPubSub will be replaced
   403      # by a PTransformOverride, so we use a no-op Map.
   404      write_transform = pcoll.producer.inputs[0].producer.transform
   405  
   406      # Ensure that the properties passed through correctly
   407      self.assertEqual('a_topic', write_transform.dofn.short_topic_name)
   408      self.assertEqual(True, write_transform.dofn.with_attributes)
   409      # TODO(https://github.com/apache/beam/issues/18939): These properties
   410      # aren't supported yet in direct runner.
   411      self.assertEqual(None, write_transform.dofn.id_label)
   412      self.assertEqual(None, write_transform.dofn.timestamp_attribute)
   413  
   414  
   415  @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
   416  class TestPubSubSource(unittest.TestCase):
   417    def test_display_data_topic(self):
   418      source = _PubSubSource('projects/fakeprj/topics/a_topic', None, 'a_label')
   419      dd = DisplayData.create_from(source)
   420      expected_items = [
   421          DisplayDataItemMatcher('topic', 'projects/fakeprj/topics/a_topic'),
   422          DisplayDataItemMatcher('id_label', 'a_label'),
   423          DisplayDataItemMatcher('with_attributes', False),
   424      ]
   425  
   426      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   427  
   428    def test_display_data_subscription(self):
   429      source = _PubSubSource(
   430          None, 'projects/fakeprj/subscriptions/a_subscription', 'a_label')
   431      dd = DisplayData.create_from(source)
   432      expected_items = [
   433          DisplayDataItemMatcher(
   434              'subscription', 'projects/fakeprj/subscriptions/a_subscription'),
   435          DisplayDataItemMatcher('id_label', 'a_label'),
   436          DisplayDataItemMatcher('with_attributes', False),
   437      ]
   438  
   439      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   440  
   441    def test_display_data_no_subscription(self):
   442      source = _PubSubSource('projects/fakeprj/topics/a_topic')
   443      dd = DisplayData.create_from(source)
   444      expected_items = [
   445          DisplayDataItemMatcher('topic', 'projects/fakeprj/topics/a_topic'),
   446          DisplayDataItemMatcher('with_attributes', False),
   447      ]
   448  
   449      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   450  
   451  
   452  @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
   453  class TestPubSubSink(unittest.TestCase):
   454    def test_display_data(self):
   455      sink = WriteToPubSub(
   456          'projects/fakeprj/topics/a_topic',
   457          id_label='id',
   458          timestamp_attribute='time')
   459      dd = DisplayData.create_from(sink)
   460      expected_items = [
   461          DisplayDataItemMatcher('topic', 'projects/fakeprj/topics/a_topic'),
   462          DisplayDataItemMatcher('id_label', 'id'),
   463          DisplayDataItemMatcher('with_attributes', True),
   464          DisplayDataItemMatcher('timestamp_attribute', 'time'),
   465      ]
   466  
   467      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   468  
   469  
   470  class TestPubSubReadEvaluator(object):
   471    """Wrapper of _PubSubReadEvaluator that makes it bounded."""
   472  
   473    _pubsub_read_evaluator = _PubSubReadEvaluator
   474  
   475    def __init__(self, *args, **kwargs):
   476      self._evaluator = self._pubsub_read_evaluator(*args, **kwargs)
   477  
   478    def start_bundle(self):
   479      return self._evaluator.start_bundle()
   480  
   481    def process_element(self, element):
   482      return self._evaluator.process_element(element)
   483  
   484    def finish_bundle(self):
   485      result = self._evaluator.finish_bundle()
   486      result.unprocessed_bundles = []
   487      result.keyed_watermark_holds = {None: None}
   488      return result
   489  
   490  
   491  transform_evaluator.TransformEvaluatorRegistry._test_evaluators_overrides = {
   492      _DirectReadFromPubSub: TestPubSubReadEvaluator,  # type: ignore[dict-item]
   493  }
   494  
   495  
   496  @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
   497  @mock.patch('google.cloud.pubsub.SubscriberClient')
   498  class TestReadFromPubSub(unittest.TestCase):
   499    def test_read_messages_success(self, mock_pubsub):
   500      data = b'data'
   501      publish_time_secs = 1520861821
   502      publish_time_nanos = 234567000
   503      attributes = {'key': 'value'}
   504      ack_id = 'ack_id'
   505      pull_response = test_utils.create_pull_response([
   506          test_utils.PullResponseMessage(
   507              data, attributes, publish_time_secs, publish_time_nanos, ack_id)
   508      ])
   509      expected_elements = [
   510          TestWindowedValue(
   511              PubsubMessage(data, attributes),
   512              timestamp.Timestamp(1520861821.234567), [window.GlobalWindow()])
   513      ]
   514      mock_pubsub.return_value.pull.return_value = pull_response
   515  
   516      options = PipelineOptions([])
   517      options.view_as(StandardOptions).streaming = True
   518      with TestPipeline(options=options) as p:
   519        pcoll = (
   520            p
   521            | ReadFromPubSub(
   522                'projects/fakeprj/topics/a_topic',
   523                None,
   524                None,
   525                with_attributes=True))
   526        assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
   527      mock_pubsub.return_value.acknowledge.assert_has_calls(
   528          [mock.call(subscription=mock.ANY, ack_ids=[ack_id])])
   529  
   530      mock_pubsub.return_value.close.assert_has_calls([mock.call()])
   531  
   532    def test_read_strings_success(self, mock_pubsub):
   533      data = u'🤷 ¯\\_(ツ)_/¯'
   534      data_encoded = data.encode('utf-8')
   535      ack_id = 'ack_id'
   536      pull_response = test_utils.create_pull_response(
   537          [test_utils.PullResponseMessage(data_encoded, ack_id=ack_id)])
   538      expected_elements = [data]
   539      mock_pubsub.return_value.pull.return_value = pull_response
   540  
   541      options = PipelineOptions([])
   542      options.view_as(StandardOptions).streaming = True
   543      with TestPipeline(options=options) as p:
   544        pcoll = (
   545            p
   546            | ReadStringsFromPubSub(
   547                'projects/fakeprj/topics/a_topic', None, None))
   548        assert_that(pcoll, equal_to(expected_elements))
   549      mock_pubsub.return_value.acknowledge.assert_has_calls(
   550          [mock.call(subscription=mock.ANY, ack_ids=[ack_id])])
   551  
   552      mock_pubsub.return_value.close.assert_has_calls([mock.call()])
   553  
   554    def test_read_data_success(self, mock_pubsub):
   555      data_encoded = u'🤷 ¯\\_(ツ)_/¯'.encode('utf-8')
   556      ack_id = 'ack_id'
   557      pull_response = test_utils.create_pull_response(
   558          [test_utils.PullResponseMessage(data_encoded, ack_id=ack_id)])
   559      expected_elements = [data_encoded]
   560      mock_pubsub.return_value.pull.return_value = pull_response
   561  
   562      options = PipelineOptions([])
   563      options.view_as(StandardOptions).streaming = True
   564      with TestPipeline(options=options) as p:
   565        pcoll = (
   566            p
   567            | ReadFromPubSub('projects/fakeprj/topics/a_topic', None, None))
   568        assert_that(pcoll, equal_to(expected_elements))
   569      mock_pubsub.return_value.acknowledge.assert_has_calls(
   570          [mock.call(subscription=mock.ANY, ack_ids=[ack_id])])
   571  
   572      mock_pubsub.return_value.close.assert_has_calls([mock.call()])
   573  
   574    def test_read_messages_timestamp_attribute_milli_success(self, mock_pubsub):
   575      data = b'data'
   576      attributes = {'time': '1337'}
   577      publish_time_secs = 1520861821
   578      publish_time_nanos = 234567000
   579      ack_id = 'ack_id'
   580      pull_response = test_utils.create_pull_response([
   581          test_utils.PullResponseMessage(
   582              data, attributes, publish_time_secs, publish_time_nanos, ack_id)
   583      ])
   584      expected_elements = [
   585          TestWindowedValue(
   586              PubsubMessage(data, attributes),
   587              timestamp.Timestamp(micros=int(attributes['time']) * 1000),
   588              [window.GlobalWindow()]),
   589      ]
   590      mock_pubsub.return_value.pull.return_value = pull_response
   591  
   592      options = PipelineOptions([])
   593      options.view_as(StandardOptions).streaming = True
   594      with TestPipeline(options=options) as p:
   595        pcoll = (
   596            p
   597            | ReadFromPubSub(
   598                'projects/fakeprj/topics/a_topic',
   599                None,
   600                None,
   601                with_attributes=True,
   602                timestamp_attribute='time'))
   603        assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
   604      mock_pubsub.return_value.acknowledge.assert_has_calls(
   605          [mock.call(subscription=mock.ANY, ack_ids=[ack_id])])
   606  
   607      mock_pubsub.return_value.close.assert_has_calls([mock.call()])
   608  
   609    def test_read_messages_timestamp_attribute_rfc3339_success(self, mock_pubsub):
   610      data = b'data'
   611      attributes = {'time': '2018-03-12T13:37:01.234567Z'}
   612      publish_time_secs = 1337000000
   613      publish_time_nanos = 133700000
   614      ack_id = 'ack_id'
   615      pull_response = test_utils.create_pull_response([
   616          test_utils.PullResponseMessage(
   617              data, attributes, publish_time_secs, publish_time_nanos, ack_id)
   618      ])
   619      expected_elements = [
   620          TestWindowedValue(
   621              PubsubMessage(data, attributes),
   622              timestamp.Timestamp.from_rfc3339(attributes['time']),
   623              [window.GlobalWindow()]),
   624      ]
   625      mock_pubsub.return_value.pull.return_value = pull_response
   626  
   627      options = PipelineOptions([])
   628      options.view_as(StandardOptions).streaming = True
   629      with TestPipeline(options=options) as p:
   630        pcoll = (
   631            p
   632            | ReadFromPubSub(
   633                'projects/fakeprj/topics/a_topic',
   634                None,
   635                None,
   636                with_attributes=True,
   637                timestamp_attribute='time'))
   638        assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
   639      mock_pubsub.return_value.acknowledge.assert_has_calls(
   640          [mock.call(subscription=mock.ANY, ack_ids=[ack_id])])
   641  
   642      mock_pubsub.return_value.close.assert_has_calls([mock.call()])
   643  
   644    def test_read_messages_timestamp_attribute_missing(self, mock_pubsub):
   645      data = b'data'
   646      attributes = {}
   647      publish_time_secs = 1520861821
   648      publish_time_nanos = 234567000
   649      publish_time = '2018-03-12T13:37:01.234567Z'
   650      ack_id = 'ack_id'
   651      pull_response = test_utils.create_pull_response([
   652          test_utils.PullResponseMessage(
   653              data, attributes, publish_time_secs, publish_time_nanos, ack_id)
   654      ])
   655      expected_elements = [
   656          TestWindowedValue(
   657              PubsubMessage(data, attributes),
   658              timestamp.Timestamp.from_rfc3339(publish_time),
   659              [window.GlobalWindow()]),
   660      ]
   661      mock_pubsub.return_value.pull.return_value = pull_response
   662  
   663      options = PipelineOptions([])
   664      options.view_as(StandardOptions).streaming = True
   665      with TestPipeline(options=options) as p:
   666        pcoll = (
   667            p
   668            | ReadFromPubSub(
   669                'projects/fakeprj/topics/a_topic',
   670                None,
   671                None,
   672                with_attributes=True,
   673                timestamp_attribute='nonexistent'))
   674        assert_that(pcoll, equal_to(expected_elements), reify_windows=True)
   675      mock_pubsub.return_value.acknowledge.assert_has_calls(
   676          [mock.call(subscription=mock.ANY, ack_ids=[ack_id])])
   677  
   678      mock_pubsub.return_value.close.assert_has_calls([mock.call()])
   679  
   680    def test_read_messages_timestamp_attribute_fail_parse(self, mock_pubsub):
   681      data = b'data'
   682      attributes = {'time': '1337 unparseable'}
   683      publish_time_secs = 1520861821
   684      publish_time_nanos = 234567000
   685      ack_id = 'ack_id'
   686      pull_response = test_utils.create_pull_response([
   687          test_utils.PullResponseMessage(
   688              data, attributes, publish_time_secs, publish_time_nanos, ack_id)
   689      ])
   690      mock_pubsub.return_value.pull.return_value = pull_response
   691  
   692      options = PipelineOptions([])
   693      options.view_as(StandardOptions).streaming = True
   694      p = TestPipeline(options=options)
   695      _ = (
   696          p
   697          | ReadFromPubSub(
   698              'projects/fakeprj/topics/a_topic',
   699              None,
   700              None,
   701              with_attributes=True,
   702              timestamp_attribute='time'))
   703      with self.assertRaisesRegex(ValueError, r'parse'):
   704        p.run()
   705      mock_pubsub.return_value.acknowledge.assert_not_called()
   706  
   707      mock_pubsub.return_value.close.assert_has_calls([mock.call()])
   708  
   709    def test_read_message_id_label_unsupported(self, unused_mock_pubsub):
   710      # id_label is unsupported in DirectRunner.
   711      options = PipelineOptions([])
   712      options.view_as(StandardOptions).streaming = True
   713      with self.assertRaisesRegex(NotImplementedError,
   714                                  r'id_label is not supported'):
   715        with TestPipeline(options=options) as p:
   716          _ = (
   717              p | ReadFromPubSub(
   718                  'projects/fakeprj/topics/a_topic', None, 'a_label'))
   719  
   720    def test_runner_api_transformation_with_topic(self, unused_mock_pubsub):
   721      source = _PubSubSource(
   722          topic='projects/fakeprj/topics/a_topic',
   723          subscription=None,
   724          id_label='a_label',
   725          timestamp_attribute='b_label',
   726          with_attributes=True)
   727      transform = Read(source)
   728  
   729      context = pipeline_context.PipelineContext()
   730      proto_transform_spec = transform.to_runner_api(context)
   731      self.assertEqual(
   732          common_urns.composites.PUBSUB_READ.urn, proto_transform_spec.urn)
   733  
   734      pubsub_read_payload = (
   735          proto_utils.parse_Bytes(
   736              proto_transform_spec.payload,
   737              beam_runner_api_pb2.PubSubReadPayload))
   738      self.assertEqual(
   739          'projects/fakeprj/topics/a_topic', pubsub_read_payload.topic)
   740      self.assertEqual('a_label', pubsub_read_payload.id_attribute)
   741      self.assertEqual('b_label', pubsub_read_payload.timestamp_attribute)
   742      self.assertEqual('', pubsub_read_payload.subscription)
   743      self.assertTrue(pubsub_read_payload.with_attributes)
   744  
   745      proto_transform = beam_runner_api_pb2.PTransform(
   746          unique_name="dummy_label", spec=proto_transform_spec)
   747  
   748      transform_from_proto = Read.from_runner_api_parameter(
   749          proto_transform, pubsub_read_payload, None)
   750      self.assertTrue(isinstance(transform_from_proto, Read))
   751      self.assertTrue(isinstance(transform_from_proto.source, _PubSubSource))
   752      self.assertEqual(
   753          'projects/fakeprj/topics/a_topic',
   754          transform_from_proto.source.full_topic)
   755      self.assertTrue(transform_from_proto.source.with_attributes)
   756  
   757    def test_runner_api_transformation_properties_none(self, unused_mock_pubsub):
   758      # Confirming that properties stay None after a runner API transformation.
   759      source = _PubSubSource(
   760          topic='projects/fakeprj/topics/a_topic', with_attributes=True)
   761      transform = Read(source)
   762  
   763      context = pipeline_context.PipelineContext()
   764      proto_transform_spec = transform.to_runner_api(context)
   765      self.assertEqual(
   766          common_urns.composites.PUBSUB_READ.urn, proto_transform_spec.urn)
   767  
   768      pubsub_read_payload = (
   769          proto_utils.parse_Bytes(
   770              proto_transform_spec.payload,
   771              beam_runner_api_pb2.PubSubReadPayload))
   772  
   773      proto_transform = beam_runner_api_pb2.PTransform(
   774          unique_name="dummy_label", spec=proto_transform_spec)
   775  
   776      transform_from_proto = Read.from_runner_api_parameter(
   777          proto_transform, pubsub_read_payload, None)
   778      self.assertIsNone(transform_from_proto.source.full_subscription)
   779      self.assertIsNone(transform_from_proto.source.id_label)
   780      self.assertIsNone(transform_from_proto.source.timestamp_attribute)
   781  
   782    def test_runner_api_transformation_with_subscription(
   783        self, unused_mock_pubsub):
   784      source = _PubSubSource(
   785          topic=None,
   786          subscription='projects/fakeprj/subscriptions/a_subscription',
   787          id_label='a_label',
   788          timestamp_attribute='b_label',
   789          with_attributes=True)
   790      transform = Read(source)
   791  
   792      context = pipeline_context.PipelineContext()
   793      proto_transform_spec = transform.to_runner_api(context)
   794      self.assertEqual(
   795          common_urns.composites.PUBSUB_READ.urn, proto_transform_spec.urn)
   796  
   797      pubsub_read_payload = (
   798          proto_utils.parse_Bytes(
   799              proto_transform_spec.payload,
   800              beam_runner_api_pb2.PubSubReadPayload))
   801      self.assertEqual(
   802          'projects/fakeprj/subscriptions/a_subscription',
   803          pubsub_read_payload.subscription)
   804      self.assertEqual('a_label', pubsub_read_payload.id_attribute)
   805      self.assertEqual('b_label', pubsub_read_payload.timestamp_attribute)
   806      self.assertEqual('', pubsub_read_payload.topic)
   807      self.assertTrue(pubsub_read_payload.with_attributes)
   808  
   809      proto_transform = beam_runner_api_pb2.PTransform(
   810          unique_name="dummy_label", spec=proto_transform_spec)
   811  
   812      transform_from_proto = Read.from_runner_api_parameter(
   813          proto_transform, pubsub_read_payload, None)
   814      self.assertTrue(isinstance(transform_from_proto, Read))
   815      self.assertTrue(isinstance(transform_from_proto.source, _PubSubSource))
   816      self.assertTrue(transform_from_proto.source.with_attributes)
   817      self.assertEqual(
   818          'projects/fakeprj/subscriptions/a_subscription',
   819          transform_from_proto.source.full_subscription)
   820  
   821  
   822  @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
   823  @mock.patch('google.cloud.pubsub.PublisherClient')
   824  class TestWriteToPubSub(unittest.TestCase):
   825    def test_write_messages_success(self, mock_pubsub):
   826      data = 'data'
   827      payloads = [data]
   828  
   829      options = PipelineOptions([])
   830      options.view_as(StandardOptions).streaming = True
   831      with TestPipeline(options=options) as p:
   832        _ = (
   833            p
   834            | Create(payloads)
   835            | WriteToPubSub(
   836                'projects/fakeprj/topics/a_topic', with_attributes=False))
   837      mock_pubsub.return_value.publish.assert_has_calls(
   838          [mock.call(mock.ANY, data)])
   839  
   840    def test_write_messages_deprecated(self, mock_pubsub):
   841      data = 'data'
   842      data_bytes = b'data'
   843      payloads = [data]
   844  
   845      options = PipelineOptions([])
   846      options.view_as(StandardOptions).streaming = True
   847      with TestPipeline(options=options) as p:
   848        _ = (
   849            p
   850            | Create(payloads)
   851            | WriteStringsToPubSub('projects/fakeprj/topics/a_topic'))
   852      mock_pubsub.return_value.publish.assert_has_calls(
   853          [mock.call(mock.ANY, data_bytes)])
   854  
   855    def test_write_messages_with_attributes_success(self, mock_pubsub):
   856      data = b'data'
   857      attributes = {'key': 'value'}
   858      payloads = [PubsubMessage(data, attributes)]
   859  
   860      options = PipelineOptions([])
   861      options.view_as(StandardOptions).streaming = True
   862      with TestPipeline(options=options) as p:
   863        _ = (
   864            p
   865            | Create(payloads)
   866            | WriteToPubSub(
   867                'projects/fakeprj/topics/a_topic', with_attributes=True))
   868      mock_pubsub.return_value.publish.assert_has_calls(
   869          [mock.call(mock.ANY, data, **attributes)])
   870  
   871    def test_write_messages_with_attributes_error(self, mock_pubsub):
   872      data = 'data'
   873      # Sending raw data when WriteToPubSub expects a PubsubMessage object.
   874      payloads = [data]
   875  
   876      options = PipelineOptions([])
   877      options.view_as(StandardOptions).streaming = True
   878      with self.assertRaisesRegex(AttributeError, r'str.*has no attribute.*data'):
   879        with TestPipeline(options=options) as p:
   880          _ = (
   881              p
   882              | Create(payloads)
   883              | WriteToPubSub(
   884                  'projects/fakeprj/topics/a_topic', with_attributes=True))
   885  
   886    def test_write_messages_unsupported_features(self, mock_pubsub):
   887      data = b'data'
   888      attributes = {'key': 'value'}
   889      payloads = [PubsubMessage(data, attributes)]
   890  
   891      options = PipelineOptions([])
   892      options.view_as(StandardOptions).streaming = True
   893      with self.assertRaisesRegex(NotImplementedError,
   894                                  r'id_label is not supported'):
   895        with TestPipeline(options=options) as p:
   896          _ = (
   897              p
   898              | Create(payloads)
   899              | WriteToPubSub(
   900                  'projects/fakeprj/topics/a_topic', id_label='a_label'))
   901  
   902      options = PipelineOptions([])
   903      options.view_as(StandardOptions).streaming = True
   904      with self.assertRaisesRegex(NotImplementedError,
   905                                  r'timestamp_attribute is not supported'):
   906        with TestPipeline(options=options) as p:
   907          _ = (
   908              p
   909              | Create(payloads)
   910              | WriteToPubSub(
   911                  'projects/fakeprj/topics/a_topic',
   912                  timestamp_attribute='timestamp'))
   913  
   914    def test_runner_api_transformation(self, unused_mock_pubsub):
   915      sink = _PubSubSink(
   916          topic='projects/fakeprj/topics/a_topic',
   917          id_label=None,
   918          # We expect encoded PubSub write transform to always return attributes.
   919          timestamp_attribute=None)
   920      transform = Write(sink)
   921  
   922      context = pipeline_context.PipelineContext()
   923      proto_transform_spec = transform.to_runner_api(context)
   924      self.assertEqual(
   925          common_urns.composites.PUBSUB_WRITE.urn, proto_transform_spec.urn)
   926  
   927      pubsub_write_payload = (
   928          proto_utils.parse_Bytes(
   929              proto_transform_spec.payload,
   930              beam_runner_api_pb2.PubSubWritePayload))
   931  
   932      self.assertEqual(
   933          'projects/fakeprj/topics/a_topic', pubsub_write_payload.topic)
   934  
   935      proto_transform = beam_runner_api_pb2.PTransform(
   936          unique_name="dummy_label", spec=proto_transform_spec)
   937  
   938      transform_from_proto = Write.from_runner_api_parameter(
   939          proto_transform, pubsub_write_payload, None)
   940      self.assertTrue(isinstance(transform_from_proto, Write))
   941      self.assertTrue(isinstance(transform_from_proto.sink, _PubSubSink))
   942      self.assertEqual(
   943          'projects/fakeprj/topics/a_topic', transform_from_proto.sink.full_topic)
   944  
   945    def test_runner_api_transformation_properties_none(self, unused_mock_pubsub):
   946      # Confirming that properties stay None after a runner API transformation.
   947      sink = _PubSubSink(
   948          topic='projects/fakeprj/topics/a_topic',
   949          id_label=None,
   950          # We expect encoded PubSub write transform to always return attributes.
   951          timestamp_attribute=None)
   952      transform = Write(sink)
   953  
   954      context = pipeline_context.PipelineContext()
   955      proto_transform_spec = transform.to_runner_api(context)
   956      self.assertEqual(
   957          common_urns.composites.PUBSUB_WRITE.urn, proto_transform_spec.urn)
   958  
   959      pubsub_write_payload = (
   960          proto_utils.parse_Bytes(
   961              proto_transform_spec.payload,
   962              beam_runner_api_pb2.PubSubWritePayload))
   963      proto_transform = beam_runner_api_pb2.PTransform(
   964          unique_name="dummy_label", spec=proto_transform_spec)
   965      transform_from_proto = Write.from_runner_api_parameter(
   966          proto_transform, pubsub_write_payload, None)
   967  
   968      self.assertTrue(isinstance(transform_from_proto, Write))
   969      self.assertTrue(isinstance(transform_from_proto.sink, _PubSubSink))
   970      self.assertIsNone(transform_from_proto.sink.id_label)
   971      self.assertIsNone(transform_from_proto.sink.timestamp_attribute)
   972  
   973  
   974  if __name__ == '__main__':
   975    logging.getLogger().setLevel(logging.INFO)
   976    unittest.main()