github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/utils/shared.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Shared class.
    19  
    20  Shared is a helper class for managing a single instance of an object
    21  shared by multiple threads within the same process. Instances of Shared
    22  are serializable objects that can be shared by all threads of each worker
    23  process. A Shared object encapsulates a weak reference to a singleton
    24  instance of the shared resource. The singleton is lazily initialized by
    25  calls to Shared.acquire().
    26  
    27  Example usage:
    28  
    29  To share a very large list across all threads of each worker in a DoFn::
    30  
    31    # Several built-in types such as list and dict do not directly support weak
    32    # references but can add support through subclassing:
    33    # https://docs.python.org/3/library/weakref.html
    34    class WeakRefList(list):
    35      pass
    36  
    37    class GetNthStringFn(beam.DoFn):
    38      def __init__(self):
    39        self._shared_handle = shared.Shared()
    40  
    41      def setup(self):
    42        # setup is a good place to initialize transient in-memory resources.
    43        def initialize_list():
    44          # Build the giant initial list.
    45          return WeakRefList([str(i) for i in range(1000000)])
    46  
    47        self._giant_list = self._shared_handle.acquire(initialize_list)
    48  
    49      def process(self, element):
    50        yield self._giant_list[element]
    51  
    52    p = beam.Pipeline()
    53    (p | beam.Create([2, 4, 6, 8])
    54       | beam.ParDo(GetNthStringFn()))
    55  
    56  
    57  Real-world uses will typically involve using a side-input to a DoFn to
    58  initialize the shared resource in a way that can't be done with just its
    59  constructor::
    60  
    61    class RainbowTableLookupFn(beam.DoFn):
    62      def __init__(self):
    63        self._shared_handle = shared.Shared()
    64  
    65      def process(self, element, table_elements):
    66        def construct_table():
    67          # Construct the rainbow table from the table elements.
    68          # The table contains lines in the form "string::hash"
    69          result = {}
    70          for key, value in table_elements:
    71            result[value] = key
    72          return result
    73  
    74        rainbow_table = self._shared_handle.acquire(construct_table)
    75        unhashed_str = rainbow_table.get(element)
    76        if unhashed_str is not None:
    77          yield unhashed_str
    78  
    79    p = beam.Pipeline()
    80    reverse_hash_table = p | "ReverseHashTable" >> beam.Create([
    81                    ('a', '0cc175b9c0f1b6a831c399e269772661'),
    82                    ('b', '92eb5ffee6ae2fec3ad71c777531578f'),
    83                    ('c', '4a8a08f09d37b73795649038408b5f33'),
    84                    ('d', '8277e0910d750195b448797616e091ad')])
    85    unhashed = (p
    86                | 'Hashes' >> beam.Create([
    87                    '0cc175b9c0f1b6a831c399e269772661',
    88                    '8277e0910d750195b448797616e091ad'])
    89                | 'Unhash' >> beam.ParDo(
    90                     RainbowTableLookupFn(), reverse_hash_table))
    91  
    92  """
    93  import threading
    94  import uuid
    95  import weakref
    96  from typing import Any
    97  from typing import Callable
    98  from typing import Text
    99  
   100  
   101  class _SharedControlBlock(object):
   102    """Wrapper class for holding objects in the SharedMap.
   103  
   104    We need this so we can call constructors for distinct Shared elements in the
   105    SharedMap concurrently.
   106    """
   107    def __init__(self):
   108      self._lock = threading.Lock()
   109      self._ref = None
   110      self._tag = None
   111  
   112    def acquire(
   113        self,
   114        constructor_fn,  # type: Callable[[], Any]
   115        tag=None  # type: Any
   116    ):
   117      # type: (...) -> Any
   118  
   119      """Acquire a reference to the object this shared control block manages.
   120  
   121      Args:
   122        constructor_fn: function that initialises / constructs the object if not
   123          present in the cache. This function should take no arguments. It should
   124          return an initialised object, or None if the object could not be
   125          initialised / constructed.
   126        tag: an optional indentifier to store with the cached object. If
   127          subsequent calls to acquire use different tags, the object will be
   128          reloaded rather than returned from cache.
   129  
   130      Returns:
   131        An initialised object, either from a previous initialisation, or
   132        newly-constructed.
   133      """
   134      with self._lock:
   135        # self._ref is None if this is a new control block.
   136        # self._ref() is None if the weak reference was GCed.
   137        # self._tag != tag if user specifies a new identifier
   138        if self._ref is None or self._ref() is None or self._tag != tag:
   139          result = constructor_fn()
   140          if result is None:
   141            return None
   142          self._ref = weakref.ref(result)
   143          self._tag = tag
   144        else:
   145          result = self._ref()
   146      return result
   147  
   148  
   149  class _SharedMap(object):
   150    """Map for storing objects pointed to by Shared.
   151  
   152    The behaviour of SharedMap is as follows: when acquire is called, if the
   153    Shared object has already been initialised, we return the already-initialised
   154    copy. If not, we call the constructor_fn to construct it, and store it in
   155    the cache.
   156  
   157    One big caveat is this: we want to support cases where there is some delay
   158    between reacquistion of Shared objects, i.e. there may be a short period of
   159    time in which there are no references to the object before it is reacquired.
   160  
   161    This happens in various Beam runners (e.g. Dataflow runner): if we use a
   162    single thread for doing predictions with a large model, when the thread
   163    finishes its workitem, it will release the reference to the model. Since
   164    there's only a single thread, the model will have zero references to it
   165    and will be garbage collected. Shortly after this, the process receives a new
   166    workitem, creates a new thread, and attempts to reacquire the model. If we
   167    don't keep the model alive in between, the new thread will have to
   168    reinitialise the model from scratch.
   169  
   170    As such, we need to do some extra work to manage cached objects' lifetime.
   171    Ideally we would want to release the shared objects once the stage is
   172    complete, but we don't have information about that. As such, we work around
   173    this limitation as follows: when an object is first initialised, we create and
   174    maintain an explicit reference to it. This means that it will always have one
   175    reference to it from within _SharedMap.
   176  
   177    When acquire is called for a *different* object, we delete explicit references
   178    to *all other objects*. This means that if there are no external references to
   179    these objects, they will be garbage collected.
   180  
   181    This has the following implications:
   182    *  A shared object won't be GC'ed if there isn't another acquire called for
   183       a different shared object. This is okay for our use-cases. This means
   184       that the shared object will be kept alive for all stages fused with the
   185       stage that works with the shared object. However, all these stages would
   186       be allocated the same memory anyway, even if the shared object
   187       were released after the stage that uses it was done with it.
   188    *  Each stage can only use exactly one Shared token, otherwise only one
   189       Shared token, *NOT NECESSARILY THE LATEST*, will be "kept-alive" (using
   190       multiple shared tokens per-stage won't affect correctness, but will have
   191       no performance benefit either)
   192    *  If there are two different stages using separate Shared tokens, but which
   193       get fused together, only one Shared token will be "kept-alive". This
   194       effectively means that the Shared tokens do nothing: since S2 displaces S1,
   195       and after S2 executes a new thread is created starting with S1 again, which
   196       displaces S2.
   197  
   198    Related issues:
   199      BEAM-562 - DoFn reuse
   200    """
   201    def __init__(self):
   202      # Lock that protects cache_map
   203      self._lock = threading.Lock()
   204  
   205      # Dictionary of references to shared control blocks
   206      self._cache_map = {}
   207  
   208      # Tuple of (key, obj), where obj is an object we explicitly hold a reference
   209      # to keep it alive
   210      self._keepalive = (None, None)
   211  
   212    def make_key(self):
   213      # type: (...) -> Text
   214      return str(uuid.uuid1())
   215  
   216    def acquire(
   217        self,
   218        key,  # type: Text
   219        constructor_fn,  # type: Callable[[], Any]
   220        tag=None  # type: Any
   221    ):
   222      # type: (...) -> Any
   223  
   224      """Acquire a reference to a Shared object.
   225  
   226      Args:
   227        key: the key to the shared object
   228        constructor_fn: function that initialises / constructs the object if not
   229          present in the cache. This function should take no arguments. It should
   230          return an initialised object, or None if the object could not be
   231          initialised / constructed.
   232        tag: an optional indentifier to store with the cached object. If
   233          subsequent calls to acquire use different tags, the object will be
   234          reloaded rather than returned from cache.
   235  
   236      Returns:
   237        A reference to the initialised object, either from the cache, or
   238        newly-constructed.
   239      """
   240      with self._lock:
   241        control_block = self._cache_map.get(key)
   242        if control_block is None:
   243          control_block = _SharedControlBlock()
   244          self._cache_map[key] = control_block
   245  
   246      result = control_block.acquire(constructor_fn, tag)
   247  
   248      # Because we release the lock in between, if we acquire multiple Shareds
   249      # in a short time, there's no guarantee as to which one will be kept alive.
   250      with self._lock:
   251        self._keepalive = (key, result)
   252  
   253      return result
   254  
   255  
   256  # Instance of the shared map to be used with Shared objects.
   257  _shared_map = _SharedMap()
   258  
   259  
   260  class Shared(object):
   261    """Handle for managing shared per-process objects.
   262  
   263    Each instance of a Shared object represents a distinct handle to a distinct
   264    object. Example usage is described in the file comment of shared.py.
   265  
   266    This object has the following limitations:
   267    *  A shared object won't be GC'ed if there isn't another acquire called for
   268    a different shared object.
   269    *  Each stage can only use exactly one Shared token, otherwise only one
   270    Shared token, *NOT NECESSARILY THE LATEST*, will be "kept-alive".
   271    *  If there are two different stages using separate Shared tokens, but which
   272    get fused together, only one Shared token will be "kept-alive".
   273  
   274    (See documentation of _SharedMap for details.)
   275    """
   276  
   277    # TODO(altay): Consider allowing users to also pass in a key (GUID)
   278    # for more easily sharing of identifiable expensive objects. User would be
   279    # responsible for handling collisions.
   280    def __init__(self):
   281      self._key = _shared_map.make_key()
   282  
   283    def acquire(
   284        self,
   285        constructor_fn,  # type: Callable[[], Any]
   286        tag=None  # type: Any
   287    ):
   288      # type: (...) -> Any
   289  
   290      """Acquire a reference to the object associated with this Shared handle.
   291  
   292      Args:
   293        constructor_fn: function that initialises / constructs the object if not
   294          present in the cache. This function should take no arguments. It should
   295          return an initialised object, or None if the object could not be
   296          initialised / constructed.
   297        tag: an optional indentifier to store with the cached object. If
   298          subsequent calls to acquire use different tags, the object will be
   299          reloaded rather than returned from cache.
   300  
   301      Returns:
   302        A reference to an initialised object, either from the cache, or
   303        newly-constructed.
   304      """
   305      return _shared_map.acquire(self._key, constructor_fn, tag)