github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/options/value_provider.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A ValueProvider abstracts the notion of fetching a value that may or
    19  may not be currently available.
    20  
    21  This can be used to parameterize transforms that only read values in at
    22  runtime, for example.
    23  """
    24  
    25  # pytype: skip-file
    26  
    27  from functools import wraps
    28  from typing import Set
    29  
    30  from apache_beam import error
    31  
    32  __all__ = [
    33      'ValueProvider',
    34      'StaticValueProvider',
    35      'RuntimeValueProvider',
    36      'NestedValueProvider',
    37      'check_accessible',
    38  ]
    39  
    40  
    41  class ValueProvider(object):
    42    """Base class that all other ValueProviders must implement.
    43    """
    44    def is_accessible(self):
    45      """Whether the contents of this ValueProvider is available to routines
    46      that run at graph construction time.
    47      """
    48      raise NotImplementedError(
    49          'ValueProvider.is_accessible implemented in derived classes')
    50  
    51    def get(self):
    52      """Return the value wrapped by this ValueProvider.
    53      """
    54      raise NotImplementedError(
    55          'ValueProvider.get implemented in derived classes')
    56  
    57  
    58  class StaticValueProvider(ValueProvider):
    59    """StaticValueProvider is an implementation of ValueProvider that allows
    60    for a static value to be provided.
    61    """
    62    def __init__(self, value_type, value):
    63      """
    64      Args:
    65          value_type: Type of the static value
    66          value: Static value
    67      """
    68      self.value_type = value_type
    69      self.value = value_type(value)
    70  
    71    def is_accessible(self):
    72      return True
    73  
    74    def get(self):
    75      return self.value
    76  
    77    def __str__(self):
    78      return str(self.value)
    79  
    80    def __eq__(self, other):
    81      if self.value == other:
    82        return True
    83      if isinstance(other, StaticValueProvider):
    84        if (self.value_type == other.value_type and self.value == other.value):
    85          return True
    86      return False
    87  
    88    def __hash__(self):
    89      return hash((type(self), self.value_type, self.value))
    90  
    91  
    92  class RuntimeValueProvider(ValueProvider):
    93    """RuntimeValueProvider is an implementation of ValueProvider that
    94    allows for a value to be provided at execution time rather than
    95    at graph construction time.
    96    """
    97    runtime_options = None
    98    experiments = set()  # type: Set[str]
    99  
   100    def __init__(self, option_name, value_type, default_value):
   101      self.option_name = option_name
   102      self.default_value = default_value
   103      self.value_type = value_type
   104  
   105    def is_accessible(self):
   106      return RuntimeValueProvider.runtime_options is not None
   107  
   108    @classmethod
   109    def get_value(cls, option_name, value_type, default_value):
   110      if not RuntimeValueProvider.runtime_options:
   111        return default_value
   112  
   113      candidate = RuntimeValueProvider.runtime_options.get(option_name)
   114      if candidate:
   115        return value_type(candidate)
   116      else:
   117        return default_value
   118  
   119    def get(self):
   120      if RuntimeValueProvider.runtime_options is None:
   121        raise error.RuntimeValueProviderError(
   122            '%s.get() not called from a runtime context' % self)
   123  
   124      return RuntimeValueProvider.get_value(
   125          self.option_name, self.value_type, self.default_value)
   126  
   127    @classmethod
   128    def set_runtime_options(cls, pipeline_options):
   129      RuntimeValueProvider.runtime_options = pipeline_options
   130      RuntimeValueProvider.experiments = RuntimeValueProvider.get_value(
   131          'experiments', set, set())
   132  
   133    def __str__(self):
   134      return '%s(option: %s, type: %s, default_value: %s)' % (
   135          self.__class__.__name__,
   136          self.option_name,
   137          self.value_type.__name__,
   138          repr(self.default_value))
   139  
   140  
   141  class NestedValueProvider(ValueProvider):
   142    """NestedValueProvider is an implementation of ValueProvider that allows
   143    for wrapping another ValueProvider object.
   144    """
   145    def __init__(self, value, translator):
   146      """Creates a NestedValueProvider that wraps the provided ValueProvider.
   147  
   148      Args:
   149        value: ValueProvider object to wrap
   150        translator: function that is applied to the ValueProvider
   151      Raises:
   152        ``RuntimeValueProviderError``: if any of the provided objects are not
   153          accessible.
   154      """
   155      self.value = value
   156      self.translator = translator
   157  
   158    def is_accessible(self):
   159      return self.value.is_accessible()
   160  
   161    def get(self):
   162      try:
   163        return self.cached_value
   164      except AttributeError:
   165        self.cached_value = self.translator(self.value.get())
   166        return self.cached_value
   167  
   168    def __str__(self):
   169      return "%s(value: %s, translator: %s)" % (
   170          self.__class__.__name__,
   171          self.value,
   172          self.translator.__name__,
   173      )
   174  
   175  
   176  def check_accessible(value_provider_list):
   177    """A decorator that checks accessibility of a list of ValueProvider objects.
   178  
   179    Args:
   180      value_provider_list: list of ValueProvider objects
   181    Raises:
   182      ``RuntimeValueProviderError``: if any of the provided objects are not
   183        accessible.
   184    """
   185    assert isinstance(value_provider_list, list)
   186  
   187    def _check_accessible(fnc):
   188      @wraps(fnc)
   189      def _f(self, *args, **kwargs):
   190        for obj in [getattr(self, vp) for vp in value_provider_list]:
   191          if not obj.is_accessible():
   192            raise error.RuntimeValueProviderError('%s not accessible' % obj)
   193        return fnc(self, *args, **kwargs)
   194  
   195      return _f
   196  
   197    return _check_accessible