github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/snippets/transforms/elementwise/partition.py (about)

     1  # coding=utf-8
     2  #
     3  # Licensed to the Apache Software Foundation (ASF) under one or more
     4  # contributor license agreements.  See the NOTICE file distributed with
     5  # this work for additional information regarding copyright ownership.
     6  # The ASF licenses this file to You under the Apache License, Version 2.0
     7  # (the "License"); you may not use this file except in compliance with
     8  # the License.  You may obtain a copy of the License at
     9  #
    10  #    http://www.apache.org/licenses/LICENSE-2.0
    11  #
    12  # Unless required by applicable law or agreed to in writing, software
    13  # distributed under the License is distributed on an "AS IS" BASIS,
    14  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  # See the License for the specific language governing permissions and
    16  # limitations under the License.
    17  #
    18  
    19  # pytype: skip-file
    20  
    21  
    22  def partition_function(test=None):
    23    # pylint: disable=line-too-long, expression-not-assigned
    24    # [START partition_function]
    25    import apache_beam as beam
    26  
    27    durations = ['annual', 'biennial', 'perennial']
    28  
    29    def by_duration(plant, num_partitions):
    30      return durations.index(plant['duration'])
    31  
    32    with beam.Pipeline() as pipeline:
    33      annuals, biennials, perennials = (
    34          pipeline
    35          | 'Gardening plants' >> beam.Create([
    36              {'icon': '🍓', 'name': 'Strawberry', 'duration': 'perennial'},
    37              {'icon': '🥕', 'name': 'Carrot', 'duration': 'biennial'},
    38              {'icon': '🍆', 'name': 'Eggplant', 'duration': 'perennial'},
    39              {'icon': '🍅', 'name': 'Tomato', 'duration': 'annual'},
    40              {'icon': '🥔', 'name': 'Potato', 'duration': 'perennial'},
    41          ])
    42          | 'Partition' >> beam.Partition(by_duration, len(durations))
    43      )
    44  
    45      annuals | 'Annuals' >> beam.Map(lambda x: print('annual: {}'.format(x)))
    46      biennials | 'Biennials' >> beam.Map(
    47          lambda x: print('biennial: {}'.format(x)))
    48      perennials | 'Perennials' >> beam.Map(
    49          lambda x: print('perennial: {}'.format(x)))
    50      # [END partition_function]
    51      # pylint: enable=line-too-long, expression-not-assigned
    52      if test:
    53        test(annuals, biennials, perennials)
    54  
    55  
    56  def partition_lambda(test=None):
    57    # pylint: disable=line-too-long, expression-not-assigned
    58    # [START partition_lambda]
    59    import apache_beam as beam
    60  
    61    durations = ['annual', 'biennial', 'perennial']
    62  
    63    with beam.Pipeline() as pipeline:
    64      annuals, biennials, perennials = (
    65          pipeline
    66          | 'Gardening plants' >> beam.Create([
    67              {'icon': '🍓', 'name': 'Strawberry', 'duration': 'perennial'},
    68              {'icon': '🥕', 'name': 'Carrot', 'duration': 'biennial'},
    69              {'icon': '🍆', 'name': 'Eggplant', 'duration': 'perennial'},
    70              {'icon': '🍅', 'name': 'Tomato', 'duration': 'annual'},
    71              {'icon': '🥔', 'name': 'Potato', 'duration': 'perennial'},
    72          ])
    73          | 'Partition' >> beam.Partition(
    74              lambda plant, num_partitions: durations.index(plant['duration']),
    75              len(durations),
    76          )
    77      )
    78  
    79      annuals | 'Annuals' >> beam.Map(lambda x: print('annual: {}'.format(x)))
    80      biennials | 'Biennials' >> beam.Map(
    81          lambda x: print('biennial: {}'.format(x)))
    82      perennials | 'Perennials' >> beam.Map(
    83          lambda x: print('perennial: {}'.format(x)))
    84      # [END partition_lambda]
    85      # pylint: enable=line-too-long, expression-not-assigned
    86      if test:
    87        test(annuals, biennials, perennials)
    88  
    89  
    90  def partition_multiple_arguments(test=None):
    91    # pylint: disable=expression-not-assigned
    92    # [START partition_multiple_arguments]
    93    import apache_beam as beam
    94    import json
    95  
    96    def split_dataset(plant, num_partitions, ratio):
    97      assert num_partitions == len(ratio)
    98      bucket = sum(map(ord, json.dumps(plant))) % sum(ratio)
    99      total = 0
   100      for i, part in enumerate(ratio):
   101        total += part
   102        if bucket < total:
   103          return i
   104      return len(ratio) - 1
   105  
   106    with beam.Pipeline() as pipeline:
   107      train_dataset, test_dataset = (
   108          pipeline
   109          | 'Gardening plants' >> beam.Create([
   110              {'icon': '🍓', 'name': 'Strawberry', 'duration': 'perennial'},
   111              {'icon': '🥕', 'name': 'Carrot', 'duration': 'biennial'},
   112              {'icon': '🍆', 'name': 'Eggplant', 'duration': 'perennial'},
   113              {'icon': '🍅', 'name': 'Tomato', 'duration': 'annual'},
   114              {'icon': '🥔', 'name': 'Potato', 'duration': 'perennial'},
   115          ])
   116          | 'Partition' >> beam.Partition(split_dataset, 2, ratio=[8, 2])
   117      )
   118  
   119      train_dataset | 'Train' >> beam.Map(lambda x: print('train: {}'.format(x)))
   120      test_dataset | 'Test' >> beam.Map(lambda x: print('test: {}'.format(x)))
   121      # [END partition_multiple_arguments]
   122      # pylint: enable=expression-not-assigned
   123      if test:
   124        test(train_dataset, test_dataset)