github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/gcp/recommendations_ai.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A connector for sending API requests to the GCP Recommendations AI
    19  API (https://cloud.google.com/recommendations).
    20  """
    21  
    22  from __future__ import absolute_import
    23  
    24  from typing import Sequence
    25  from typing import Tuple
    26  
    27  from google.api_core.retry import Retry
    28  
    29  from apache_beam import pvalue
    30  from apache_beam.metrics import Metrics
    31  from apache_beam.options.pipeline_options import GoogleCloudOptions
    32  from apache_beam.transforms import DoFn
    33  from apache_beam.transforms import ParDo
    34  from apache_beam.transforms import PTransform
    35  from apache_beam.transforms.util import GroupIntoBatches
    36  from cachetools.func import ttl_cache
    37  
    38  # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
    39  try:
    40    from google.cloud import recommendationengine
    41  except ImportError:
    42    raise ImportError(
    43        'Google Cloud Recommendation AI not supported for this execution '
    44        'environment (could not import google.cloud.recommendationengine).')
    45  # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
    46  
    47  __all__ = [
    48      'CreateCatalogItem',
    49      'WriteUserEvent',
    50      'ImportCatalogItems',
    51      'ImportUserEvents',
    52      'PredictUserEvent'
    53  ]
    54  
    55  FAILED_CATALOG_ITEMS = "failed_catalog_items"
    56  
    57  
    58  @ttl_cache(maxsize=128, ttl=3600)
    59  def get_recommendation_prediction_client():
    60    """Returns a Recommendation AI - Prediction Service client."""
    61    _client = recommendationengine.PredictionServiceClient()
    62    return _client
    63  
    64  
    65  @ttl_cache(maxsize=128, ttl=3600)
    66  def get_recommendation_catalog_client():
    67    """Returns a Recommendation AI - Catalog Service client."""
    68    _client = recommendationengine.CatalogServiceClient()
    69    return _client
    70  
    71  
    72  @ttl_cache(maxsize=128, ttl=3600)
    73  def get_recommendation_user_event_client():
    74    """Returns a Recommendation AI - UserEvent Service client."""
    75    _client = recommendationengine.UserEventServiceClient()
    76    return _client
    77  
    78  
    79  class CreateCatalogItem(PTransform):
    80    """Creates catalogitem information.
    81      The ``PTransform`` returns a PCollectionTuple with a PCollections of
    82      successfully and failed created CatalogItems.
    83  
    84      Example usage::
    85  
    86        pipeline | CreateCatalogItem(
    87          project='example-gcp-project',
    88          catalog_name='my-catalog')
    89      """
    90    def __init__(
    91        self,
    92        project: str = None,
    93        retry: Retry = None,
    94        timeout: float = 120,
    95        metadata: Sequence[Tuple[str, str]] = (),
    96        catalog_name: str = "default_catalog"):
    97      """Initializes a :class:`CreateCatalogItem` transform.
    98  
    99          Args:
   100              project (str): Optional. GCP project name in which the catalog
   101                data will be imported.
   102              retry: Optional. Designation of what
   103                errors, if any, should be retried.
   104              timeout (float): Optional. The amount of time, in seconds, to wait
   105                for the request to complete.
   106              metadata: Optional. Strings which
   107                should be sent along with the request as metadata.
   108              catalog_name (str): Optional. Name of the catalog.
   109                Default: 'default_catalog'
   110          """
   111      self.project = project
   112      self.retry = retry
   113      self.timeout = timeout
   114      self.metadata = metadata
   115      self.catalog_name = catalog_name
   116  
   117    def expand(self, pcoll):
   118      if self.project is None:
   119        self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
   120      if self.project is None:
   121        raise ValueError(
   122            """GCP project name needs to be specified in "project" pipeline
   123              option""")
   124      return pcoll | ParDo(
   125          _CreateCatalogItemFn(
   126              self.project,
   127              self.retry,
   128              self.timeout,
   129              self.metadata,
   130              self.catalog_name))
   131  
   132  
   133  class _CreateCatalogItemFn(DoFn):
   134    def __init__(
   135        self,
   136        project: str = None,
   137        retry: Retry = None,
   138        timeout: float = 120,
   139        metadata: Sequence[Tuple[str, str]] = (),
   140        catalog_name: str = None):
   141      self._client = None
   142      self.retry = retry
   143      self.timeout = timeout
   144      self.metadata = metadata
   145      self.parent = f"projects/{project}/locations/global/catalogs/{catalog_name}"
   146      self.counter = Metrics.counter(self.__class__, "api_calls")
   147  
   148    def setup(self):
   149      if self._client is None:
   150        self._client = get_recommendation_catalog_client()
   151  
   152    def process(self, element):
   153      catalog_item = recommendationengine.CatalogItem(element)
   154      request = recommendationengine.CreateCatalogItemRequest(
   155          parent=self.parent, catalog_item=catalog_item)
   156  
   157      try:
   158        created_catalog_item = self._client.create_catalog_item(
   159            request=request,
   160            retry=self.retry,
   161            timeout=self.timeout,
   162            metadata=self.metadata)
   163  
   164        self.counter.inc()
   165        yield recommendationengine.CatalogItem.to_dict(created_catalog_item)
   166      except Exception:
   167        yield pvalue.TaggedOutput(
   168            FAILED_CATALOG_ITEMS,
   169            recommendationengine.CatalogItem.to_dict(catalog_item))
   170  
   171  
   172  class ImportCatalogItems(PTransform):
   173    """Imports catalogitems in bulk.
   174      The `PTransform` returns a PCollectionTuple with PCollections of
   175      successfully and failed imported CatalogItems.
   176  
   177      Example usage::
   178  
   179        pipeline
   180        | ImportCatalogItems(
   181            project='example-gcp-project',
   182            catalog_name='my-catalog')
   183      """
   184    def __init__(
   185        self,
   186        max_batch_size: int = 5000,
   187        project: str = None,
   188        retry: Retry = None,
   189        timeout: float = 120,
   190        metadata: Sequence[Tuple[str, str]] = (),
   191        catalog_name: str = "default_catalog"):
   192      """Initializes a :class:`ImportCatalogItems` transform
   193  
   194          Args:
   195              batch_size (int): Required. Maximum number of catalogitems per
   196                request.
   197              project (str): Optional. GCP project name in which the catalog
   198                data will be imported.
   199              retry: Optional. Designation of what
   200                errors, if any, should be retried.
   201              timeout (float): Optional. The amount of time, in seconds, to wait
   202                for the request to complete.
   203              metadata: Optional. Strings which
   204                should be sent along with the request as metadata.
   205              catalog_name (str): Optional. Name of the catalog.
   206                Default: 'default_catalog'
   207          """
   208      self.max_batch_size = max_batch_size
   209      self.project = project
   210      self.retry = retry
   211      self.timeout = timeout
   212      self.metadata = metadata
   213      self.catalog_name = catalog_name
   214  
   215    def expand(self, pcoll):
   216      if self.project is None:
   217        self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
   218      if self.project is None:
   219        raise ValueError(
   220            'GCP project name needs to be specified in "project" pipeline option')
   221      return (
   222          pcoll | GroupIntoBatches.WithShardedKey(self.max_batch_size) | ParDo(
   223              _ImportCatalogItemsFn(
   224                  self.project,
   225                  self.retry,
   226                  self.timeout,
   227                  self.metadata,
   228                  self.catalog_name)))
   229  
   230  
   231  class _ImportCatalogItemsFn(DoFn):
   232    def __init__(
   233        self,
   234        project=None,
   235        retry=None,
   236        timeout=120,
   237        metadata=None,
   238        catalog_name=None):
   239      self._client = None
   240      self.retry = retry
   241      self.timeout = timeout
   242      self.metadata = metadata
   243      self.parent = f"projects/{project}/locations/global/catalogs/{catalog_name}"
   244      self.counter = Metrics.counter(self.__class__, "api_calls")
   245  
   246    def setup(self):
   247      if self._client is None:
   248        self.client = get_recommendation_catalog_client()
   249  
   250    def process(self, element):
   251      catalog_items = [recommendationengine.CatalogItem(e) for e in element[1]]
   252      catalog_inline_source = recommendationengine.CatalogInlineSource(
   253          {"catalog_items": catalog_items})
   254      input_config = recommendationengine.InputConfig(
   255          catalog_inline_source=catalog_inline_source)
   256  
   257      request = recommendationengine.ImportCatalogItemsRequest(
   258          parent=self.parent, input_config=input_config)
   259  
   260      try:
   261        operation = self._client.import_catalog_items(
   262            request=request,
   263            retry=self.retry,
   264            timeout=self.timeout,
   265            metadata=self.metadata)
   266        self.counter.inc(len(catalog_items))
   267        yield operation.result()
   268      except Exception:
   269        yield pvalue.TaggedOutput(FAILED_CATALOG_ITEMS, catalog_items)
   270  
   271  
   272  class WriteUserEvent(PTransform):
   273    """Write user event information.
   274      The `PTransform` returns a PCollectionTuple with PCollections of
   275      successfully and failed written UserEvents.
   276  
   277      Example usage::
   278  
   279        pipeline
   280        | WriteUserEvent(
   281            project='example-gcp-project',
   282            catalog_name='my-catalog',
   283            event_store='my_event_store')
   284      """
   285    def __init__(
   286        self,
   287        project: str = None,
   288        retry: Retry = None,
   289        timeout: float = 120,
   290        metadata: Sequence[Tuple[str, str]] = (),
   291        catalog_name: str = "default_catalog",
   292        event_store: str = "default_event_store"):
   293      """Initializes a :class:`WriteUserEvent` transform.
   294  
   295          Args:
   296              project (str): Optional. GCP project name in which the catalog
   297                data will be imported.
   298              retry: Optional. Designation of what
   299                errors, if any, should be retried.
   300              timeout (float): Optional. The amount of time, in seconds, to wait
   301                for the request to complete.
   302              metadata: Optional. Strings which
   303                should be sent along with the request as metadata.
   304              catalog_name (str): Optional. Name of the catalog.
   305                Default: 'default_catalog'
   306              event_store (str): Optional. Name of the event store.
   307                Default: 'default_event_store'
   308          """
   309      self.project = project
   310      self.retry = retry
   311      self.timeout = timeout
   312      self.metadata = metadata
   313      self.catalog_name = catalog_name
   314      self.event_store = event_store
   315  
   316    def expand(self, pcoll):
   317      if self.project is None:
   318        self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
   319      if self.project is None:
   320        raise ValueError(
   321            'GCP project name needs to be specified in "project" pipeline option')
   322      return pcoll | ParDo(
   323          _WriteUserEventFn(
   324              self.project,
   325              self.retry,
   326              self.timeout,
   327              self.metadata,
   328              self.catalog_name,
   329              self.event_store))
   330  
   331  
   332  class _WriteUserEventFn(DoFn):
   333    FAILED_USER_EVENTS = "failed_user_events"
   334  
   335    def __init__(
   336        self,
   337        project=None,
   338        retry=None,
   339        timeout=120,
   340        metadata=None,
   341        catalog_name=None,
   342        event_store=None):
   343      self._client = None
   344      self.retry = retry
   345      self.timeout = timeout
   346      self.metadata = metadata
   347      self.parent = f"projects/{project}/locations/global/catalogs/"\
   348                    f"{catalog_name}/eventStores/{event_store}"
   349      self.counter = Metrics.counter(self.__class__, "api_calls")
   350  
   351    def setup(self):
   352      if self._client is None:
   353        self._client = get_recommendation_user_event_client()
   354  
   355    def process(self, element):
   356      user_event = recommendationengine.UserEvent(element)
   357      request = recommendationengine.WriteUserEventRequest(
   358          parent=self.parent, user_event=user_event)
   359  
   360      try:
   361        created_user_event = self._client.write_user_event(request)
   362        self.counter.inc()
   363        yield recommendationengine.UserEvent.to_dict(created_user_event)
   364      except Exception:
   365        yield pvalue.TaggedOutput(
   366            self.FAILED_USER_EVENTS,
   367            recommendationengine.UserEvent.to_dict(user_event))
   368  
   369  
   370  class ImportUserEvents(PTransform):
   371    """Imports userevents in bulk.
   372      The `PTransform` returns a PCollectionTuple with PCollections of
   373      successfully and failed imported UserEvents.
   374  
   375      Example usage::
   376  
   377        pipeline
   378        | ImportUserEvents(
   379            project='example-gcp-project',
   380            catalog_name='my-catalog',
   381            event_store='my_event_store')
   382      """
   383    def __init__(
   384        self,
   385        max_batch_size: int = 5000,
   386        project: str = None,
   387        retry: Retry = None,
   388        timeout: float = 120,
   389        metadata: Sequence[Tuple[str, str]] = (),
   390        catalog_name: str = "default_catalog",
   391        event_store: str = "default_event_store"):
   392      """Initializes a :class:`WriteUserEvent` transform.
   393  
   394          Args:
   395              batch_size (int): Required. Maximum number of catalogitems
   396                per request.
   397              project (str): Optional. GCP project name in which the catalog
   398                data will be imported.
   399              retry: Optional. Designation of what
   400                errors, if any, should be retried.
   401              timeout (float): Optional. The amount of time, in seconds, to wait
   402                for the request to complete.
   403              metadata: Optional. Strings which
   404                should be sent along with the request as metadata.
   405              catalog_name (str): Optional. Name of the catalog.
   406                Default: 'default_catalog'
   407              event_store (str): Optional. Name of the event store.
   408                Default: 'default_event_store'
   409          """
   410      self.max_batch_size = max_batch_size
   411      self.project = project
   412      self.retry = retry
   413      self.timeout = timeout
   414      self.metadata = metadata
   415      self.catalog_name = catalog_name
   416      self.event_store = event_store
   417  
   418    def expand(self, pcoll):
   419      if self.project is None:
   420        self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
   421      if self.project is None:
   422        raise ValueError(
   423            'GCP project name needs to be specified in "project" pipeline option')
   424      return (
   425          pcoll | GroupIntoBatches.WithShardedKey(self.max_batch_size) | ParDo(
   426              _ImportUserEventsFn(
   427                  self.project,
   428                  self.retry,
   429                  self.timeout,
   430                  self.metadata,
   431                  self.catalog_name,
   432                  self.event_store)))
   433  
   434  
   435  class _ImportUserEventsFn(DoFn):
   436    FAILED_USER_EVENTS = "failed_user_events"
   437  
   438    def __init__(
   439        self,
   440        project=None,
   441        retry=None,
   442        timeout=120,
   443        metadata=None,
   444        catalog_name=None,
   445        event_store=None):
   446      self._client = None
   447      self.retry = retry
   448      self.timeout = timeout
   449      self.metadata = metadata
   450      self.parent = f"projects/{project}/locations/global/catalogs/"\
   451                    f"{catalog_name}/eventStores/{event_store}"
   452      self.counter = Metrics.counter(self.__class__, "api_calls")
   453  
   454    def setup(self):
   455      if self._client is None:
   456        self.client = get_recommendation_user_event_client()
   457  
   458    def process(self, element):
   459  
   460      user_events = [recommendationengine.UserEvent(e) for e in element[1]]
   461      user_event_inline_source = recommendationengine.UserEventInlineSource(
   462          {"user_events": user_events})
   463      input_config = recommendationengine.InputConfig(
   464          user_event_inline_source=user_event_inline_source)
   465  
   466      request = recommendationengine.ImportUserEventsRequest(
   467          parent=self.parent, input_config=input_config)
   468  
   469      try:
   470        operation = self._client.write_user_event(request)
   471        self.counter.inc(len(user_events))
   472        yield recommendationengine.PredictResponse.to_dict(operation.result())
   473      except Exception:
   474        yield pvalue.TaggedOutput(self.FAILED_USER_EVENTS, user_events)
   475  
   476  
   477  class PredictUserEvent(PTransform):
   478    """Make a recommendation prediction.
   479      The `PTransform` returns a PCollection
   480  
   481      Example usage::
   482  
   483        pipeline
   484        | PredictUserEvent(
   485            project='example-gcp-project',
   486            catalog_name='my-catalog',
   487            event_store='my_event_store',
   488            placement_id='recently_viewed_default')
   489      """
   490    def __init__(
   491        self,
   492        project: str = None,
   493        retry: Retry = None,
   494        timeout: float = 120,
   495        metadata: Sequence[Tuple[str, str]] = (),
   496        catalog_name: str = "default_catalog",
   497        event_store: str = "default_event_store",
   498        placement_id: str = None):
   499      """Initializes a :class:`PredictUserEvent` transform.
   500  
   501          Args:
   502              project (str): Optional. GCP project name in which the catalog
   503                data will be imported.
   504              retry: Optional. Designation of what
   505                errors, if any, should be retried.
   506              timeout (float): Optional. The amount of time, in seconds, to wait
   507                for the request to complete.
   508              metadata: Optional. Strings which
   509                should be sent along with the request as metadata.
   510              catalog_name (str): Optional. Name of the catalog.
   511                Default: 'default_catalog'
   512              event_store (str): Optional. Name of the event store.
   513                Default: 'default_event_store'
   514              placement_id (str): Required. ID of the recommendation engine
   515                placement. This id is used to identify the set of models that
   516                will be used to make the prediction.
   517          """
   518      self.project = project
   519      self.retry = retry
   520      self.timeout = timeout
   521      self.metadata = metadata
   522      self.placement_id = placement_id
   523      self.catalog_name = catalog_name
   524      self.event_store = event_store
   525      if placement_id is None:
   526        raise ValueError('placement_id must be specified')
   527      else:
   528        self.placement_id = placement_id
   529  
   530    def expand(self, pcoll):
   531      if self.project is None:
   532        self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project
   533      if self.project is None:
   534        raise ValueError(
   535            'GCP project name needs to be specified in "project" pipeline option')
   536      return pcoll | ParDo(
   537          _PredictUserEventFn(
   538              self.project,
   539              self.retry,
   540              self.timeout,
   541              self.metadata,
   542              self.catalog_name,
   543              self.event_store,
   544              self.placement_id))
   545  
   546  
   547  class _PredictUserEventFn(DoFn):
   548    FAILED_PREDICTIONS = "failed_predictions"
   549  
   550    def __init__(
   551        self,
   552        project=None,
   553        retry=None,
   554        timeout=120,
   555        metadata=None,
   556        catalog_name=None,
   557        event_store=None,
   558        placement_id=None):
   559      self._client = None
   560      self.retry = retry
   561      self.timeout = timeout
   562      self.metadata = metadata
   563      self.name = f"projects/{project}/locations/global/catalogs/"\
   564                  f"{catalog_name}/eventStores/{event_store}/placements/"\
   565                  f"{placement_id}"
   566      self.counter = Metrics.counter(self.__class__, "api_calls")
   567  
   568    def setup(self):
   569      if self._client is None:
   570        self._client = get_recommendation_prediction_client()
   571  
   572    def process(self, element):
   573      user_event = recommendationengine.UserEvent(element)
   574      request = recommendationengine.PredictRequest(
   575          name=self.name, user_event=user_event)
   576  
   577      try:
   578        prediction = self._client.predict(request)
   579        self.counter.inc()
   580        yield [
   581            recommendationengine.PredictResponse.to_dict(p)
   582            for p in prediction.pages
   583        ]
   584      except Exception:
   585        yield pvalue.TaggedOutput(self.FAILED_PREDICTIONS, user_event)