github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/data_sampler.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Functionaliry for sampling elements during bundle execution."""
    19  
    20  # pytype: skip-file
    21  
    22  import collections
    23  import threading
    24  import time
    25  from typing import Any
    26  from typing import DefaultDict
    27  from typing import Deque
    28  from typing import Dict
    29  from typing import Iterable
    30  from typing import List
    31  from typing import Optional
    32  from typing import Union
    33  
    34  from apache_beam.coders.coder_impl import CoderImpl
    35  from apache_beam.coders.coder_impl import WindowedValueCoderImpl
    36  from apache_beam.coders.coders import Coder
    37  from apache_beam.utils.windowed_value import WindowedValue
    38  
    39  
    40  class OutputSampler:
    41    """Represents a way to sample an output of a PTransform.
    42  
    43    This is configurable to only keep max_samples (see constructor) sampled
    44    elements in memory. The first 10 elements are always sampled, then after each
    45    sample_every_sec (see constructor).
    46    """
    47    def __init__(
    48        self,
    49        coder: Coder,
    50        max_samples: int = 10,
    51        sample_every_sec: float = 30,
    52        clock=None) -> None:
    53      self._samples: Deque[Any] = collections.deque(maxlen=max_samples)
    54      self._coder_impl: CoderImpl = coder.get_impl()
    55      self._sample_count: int = 0
    56      self._sample_every_sec: float = sample_every_sec
    57      self._clock = clock
    58      self._last_sample_sec: float = self.time()
    59  
    60    def remove_windowed_value(self, el: Union[WindowedValue, Any]) -> Any:
    61      """Retrieves the value from the WindowedValue.
    62  
    63      The Python SDK passes elements as WindowedValues, which may not match the
    64      coder for that particular PCollection.
    65      """
    66      if isinstance(el, WindowedValue):
    67        return self.remove_windowed_value(el.value)
    68      return el
    69  
    70    def time(self) -> float:
    71      """Returns the current time. Used for mocking out the clock for testing."""
    72      return self._clock.time() if self._clock else time.time()
    73  
    74    def flush(self) -> List[bytes]:
    75      """Returns all samples and clears buffer."""
    76      if isinstance(self._coder_impl, WindowedValueCoderImpl):
    77        samples = [s for s in self._samples]
    78      else:
    79        samples = [self.remove_windowed_value(s) for s in self._samples]
    80  
    81      # Encode in the nested context b/c this ensures that the SDK can decode the
    82      # bytes with the ToStringFn.
    83      self._samples.clear()
    84      return [self._coder_impl.encode_nested(s) for s in samples]
    85  
    86    def sample(self, element: Any) -> None:
    87      """Samples the given element to an internal buffer.
    88  
    89      Samples are only taken for the first 10 elements then every
    90      `self._sample_every_sec` second after.
    91      """
    92      self._sample_count += 1
    93      now = self.time()
    94      sample_diff = now - self._last_sample_sec
    95  
    96      if self._sample_count <= 10 or sample_diff >= self._sample_every_sec:
    97        self._samples.append(element)
    98        self._last_sample_sec = now
    99  
   100  
   101  class DataSampler:
   102    """A class for querying any samples generated during execution.
   103  
   104    This class is meant to be a singleton with regard to a particular
   105    `sdk_worker.SdkHarness`. When creating the operators, individual
   106    `OutputSampler`s are created from `DataSampler.sample_output`. This allows for
   107    multi-threaded sampling of a PCollection across the SdkHarness.
   108  
   109    Samples generated during execution can then be sampled with the `samples`
   110    method. This filters samples from the given pcollection ids.
   111    """
   112    def __init__(
   113        self, max_samples: int = 10, sample_every_sec: float = 30) -> None:
   114      # Key is PCollection id. Is guarded by the _samplers_lock.
   115      self._samplers: Dict[str, OutputSampler] = {}
   116      # Bundles are processed in parallel, so new samplers may be added when the
   117      # runner queries for samples.
   118      self._samplers_lock: threading.Lock = threading.Lock()
   119      self._max_samples = max_samples
   120      self._sample_every_sec = sample_every_sec
   121  
   122    def sample_output(self, pcoll_id: str, coder: Coder) -> OutputSampler:
   123      """Create or get an OutputSampler for a pcoll_id."""
   124      with self._samplers_lock:
   125        if pcoll_id in self._samplers:
   126          sampler = self._samplers[pcoll_id]
   127        else:
   128          sampler = OutputSampler(
   129              coder, self._max_samples, self._sample_every_sec)
   130          self._samplers[pcoll_id] = sampler
   131        return sampler
   132  
   133    def samples(
   134        self,
   135        pcollection_ids: Optional[Iterable[str]] = None
   136    ) -> Dict[str, List[bytes]]:
   137      """Returns samples filtered PCollection ids.
   138  
   139      All samples from the given PCollections are returned. Empty lists are
   140      wildcards.
   141      """
   142      ret: DefaultDict[str, List[bytes]] = collections.defaultdict(lambda: [])
   143  
   144      with self._samplers_lock:
   145        samplers = self._samplers.copy()
   146  
   147      for pcoll_id in samplers:
   148        if pcollection_ids and pcoll_id not in pcollection_ids:
   149          continue
   150  
   151        samples = samplers[pcoll_id].flush()
   152        if samples:
   153          ret[pcoll_id].extend(samples)
   154  
   155      return dict(ret)