github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/ml/gcp/recommendations_ai.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """A connector for sending API requests to the GCP Recommendations AI 19 API (https://cloud.google.com/recommendations). 20 """ 21 22 from __future__ import absolute_import 23 24 from typing import Sequence 25 from typing import Tuple 26 27 from google.api_core.retry import Retry 28 29 from apache_beam import pvalue 30 from apache_beam.metrics import Metrics 31 from apache_beam.options.pipeline_options import GoogleCloudOptions 32 from apache_beam.transforms import DoFn 33 from apache_beam.transforms import ParDo 34 from apache_beam.transforms import PTransform 35 from apache_beam.transforms.util import GroupIntoBatches 36 from cachetools.func import ttl_cache 37 38 # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports 39 try: 40 from google.cloud import recommendationengine 41 except ImportError: 42 raise ImportError( 43 'Google Cloud Recommendation AI not supported for this execution ' 44 'environment (could not import google.cloud.recommendationengine).') 45 # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports 46 47 __all__ = [ 48 'CreateCatalogItem', 49 'WriteUserEvent', 50 'ImportCatalogItems', 51 'ImportUserEvents', 52 'PredictUserEvent' 53 ] 54 55 FAILED_CATALOG_ITEMS = "failed_catalog_items" 56 57 58 @ttl_cache(maxsize=128, ttl=3600) 59 def get_recommendation_prediction_client(): 60 """Returns a Recommendation AI - Prediction Service client.""" 61 _client = recommendationengine.PredictionServiceClient() 62 return _client 63 64 65 @ttl_cache(maxsize=128, ttl=3600) 66 def get_recommendation_catalog_client(): 67 """Returns a Recommendation AI - Catalog Service client.""" 68 _client = recommendationengine.CatalogServiceClient() 69 return _client 70 71 72 @ttl_cache(maxsize=128, ttl=3600) 73 def get_recommendation_user_event_client(): 74 """Returns a Recommendation AI - UserEvent Service client.""" 75 _client = recommendationengine.UserEventServiceClient() 76 return _client 77 78 79 class CreateCatalogItem(PTransform): 80 """Creates catalogitem information. 81 The ``PTransform`` returns a PCollectionTuple with a PCollections of 82 successfully and failed created CatalogItems. 83 84 Example usage:: 85 86 pipeline | CreateCatalogItem( 87 project='example-gcp-project', 88 catalog_name='my-catalog') 89 """ 90 def __init__( 91 self, 92 project: str = None, 93 retry: Retry = None, 94 timeout: float = 120, 95 metadata: Sequence[Tuple[str, str]] = (), 96 catalog_name: str = "default_catalog"): 97 """Initializes a :class:`CreateCatalogItem` transform. 98 99 Args: 100 project (str): Optional. GCP project name in which the catalog 101 data will be imported. 102 retry: Optional. Designation of what 103 errors, if any, should be retried. 104 timeout (float): Optional. The amount of time, in seconds, to wait 105 for the request to complete. 106 metadata: Optional. Strings which 107 should be sent along with the request as metadata. 108 catalog_name (str): Optional. Name of the catalog. 109 Default: 'default_catalog' 110 """ 111 self.project = project 112 self.retry = retry 113 self.timeout = timeout 114 self.metadata = metadata 115 self.catalog_name = catalog_name 116 117 def expand(self, pcoll): 118 if self.project is None: 119 self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project 120 if self.project is None: 121 raise ValueError( 122 """GCP project name needs to be specified in "project" pipeline 123 option""") 124 return pcoll | ParDo( 125 _CreateCatalogItemFn( 126 self.project, 127 self.retry, 128 self.timeout, 129 self.metadata, 130 self.catalog_name)) 131 132 133 class _CreateCatalogItemFn(DoFn): 134 def __init__( 135 self, 136 project: str = None, 137 retry: Retry = None, 138 timeout: float = 120, 139 metadata: Sequence[Tuple[str, str]] = (), 140 catalog_name: str = None): 141 self._client = None 142 self.retry = retry 143 self.timeout = timeout 144 self.metadata = metadata 145 self.parent = f"projects/{project}/locations/global/catalogs/{catalog_name}" 146 self.counter = Metrics.counter(self.__class__, "api_calls") 147 148 def setup(self): 149 if self._client is None: 150 self._client = get_recommendation_catalog_client() 151 152 def process(self, element): 153 catalog_item = recommendationengine.CatalogItem(element) 154 request = recommendationengine.CreateCatalogItemRequest( 155 parent=self.parent, catalog_item=catalog_item) 156 157 try: 158 created_catalog_item = self._client.create_catalog_item( 159 request=request, 160 retry=self.retry, 161 timeout=self.timeout, 162 metadata=self.metadata) 163 164 self.counter.inc() 165 yield recommendationengine.CatalogItem.to_dict(created_catalog_item) 166 except Exception: 167 yield pvalue.TaggedOutput( 168 FAILED_CATALOG_ITEMS, 169 recommendationengine.CatalogItem.to_dict(catalog_item)) 170 171 172 class ImportCatalogItems(PTransform): 173 """Imports catalogitems in bulk. 174 The `PTransform` returns a PCollectionTuple with PCollections of 175 successfully and failed imported CatalogItems. 176 177 Example usage:: 178 179 pipeline 180 | ImportCatalogItems( 181 project='example-gcp-project', 182 catalog_name='my-catalog') 183 """ 184 def __init__( 185 self, 186 max_batch_size: int = 5000, 187 project: str = None, 188 retry: Retry = None, 189 timeout: float = 120, 190 metadata: Sequence[Tuple[str, str]] = (), 191 catalog_name: str = "default_catalog"): 192 """Initializes a :class:`ImportCatalogItems` transform 193 194 Args: 195 batch_size (int): Required. Maximum number of catalogitems per 196 request. 197 project (str): Optional. GCP project name in which the catalog 198 data will be imported. 199 retry: Optional. Designation of what 200 errors, if any, should be retried. 201 timeout (float): Optional. The amount of time, in seconds, to wait 202 for the request to complete. 203 metadata: Optional. Strings which 204 should be sent along with the request as metadata. 205 catalog_name (str): Optional. Name of the catalog. 206 Default: 'default_catalog' 207 """ 208 self.max_batch_size = max_batch_size 209 self.project = project 210 self.retry = retry 211 self.timeout = timeout 212 self.metadata = metadata 213 self.catalog_name = catalog_name 214 215 def expand(self, pcoll): 216 if self.project is None: 217 self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project 218 if self.project is None: 219 raise ValueError( 220 'GCP project name needs to be specified in "project" pipeline option') 221 return ( 222 pcoll | GroupIntoBatches.WithShardedKey(self.max_batch_size) | ParDo( 223 _ImportCatalogItemsFn( 224 self.project, 225 self.retry, 226 self.timeout, 227 self.metadata, 228 self.catalog_name))) 229 230 231 class _ImportCatalogItemsFn(DoFn): 232 def __init__( 233 self, 234 project=None, 235 retry=None, 236 timeout=120, 237 metadata=None, 238 catalog_name=None): 239 self._client = None 240 self.retry = retry 241 self.timeout = timeout 242 self.metadata = metadata 243 self.parent = f"projects/{project}/locations/global/catalogs/{catalog_name}" 244 self.counter = Metrics.counter(self.__class__, "api_calls") 245 246 def setup(self): 247 if self._client is None: 248 self.client = get_recommendation_catalog_client() 249 250 def process(self, element): 251 catalog_items = [recommendationengine.CatalogItem(e) for e in element[1]] 252 catalog_inline_source = recommendationengine.CatalogInlineSource( 253 {"catalog_items": catalog_items}) 254 input_config = recommendationengine.InputConfig( 255 catalog_inline_source=catalog_inline_source) 256 257 request = recommendationengine.ImportCatalogItemsRequest( 258 parent=self.parent, input_config=input_config) 259 260 try: 261 operation = self._client.import_catalog_items( 262 request=request, 263 retry=self.retry, 264 timeout=self.timeout, 265 metadata=self.metadata) 266 self.counter.inc(len(catalog_items)) 267 yield operation.result() 268 except Exception: 269 yield pvalue.TaggedOutput(FAILED_CATALOG_ITEMS, catalog_items) 270 271 272 class WriteUserEvent(PTransform): 273 """Write user event information. 274 The `PTransform` returns a PCollectionTuple with PCollections of 275 successfully and failed written UserEvents. 276 277 Example usage:: 278 279 pipeline 280 | WriteUserEvent( 281 project='example-gcp-project', 282 catalog_name='my-catalog', 283 event_store='my_event_store') 284 """ 285 def __init__( 286 self, 287 project: str = None, 288 retry: Retry = None, 289 timeout: float = 120, 290 metadata: Sequence[Tuple[str, str]] = (), 291 catalog_name: str = "default_catalog", 292 event_store: str = "default_event_store"): 293 """Initializes a :class:`WriteUserEvent` transform. 294 295 Args: 296 project (str): Optional. GCP project name in which the catalog 297 data will be imported. 298 retry: Optional. Designation of what 299 errors, if any, should be retried. 300 timeout (float): Optional. The amount of time, in seconds, to wait 301 for the request to complete. 302 metadata: Optional. Strings which 303 should be sent along with the request as metadata. 304 catalog_name (str): Optional. Name of the catalog. 305 Default: 'default_catalog' 306 event_store (str): Optional. Name of the event store. 307 Default: 'default_event_store' 308 """ 309 self.project = project 310 self.retry = retry 311 self.timeout = timeout 312 self.metadata = metadata 313 self.catalog_name = catalog_name 314 self.event_store = event_store 315 316 def expand(self, pcoll): 317 if self.project is None: 318 self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project 319 if self.project is None: 320 raise ValueError( 321 'GCP project name needs to be specified in "project" pipeline option') 322 return pcoll | ParDo( 323 _WriteUserEventFn( 324 self.project, 325 self.retry, 326 self.timeout, 327 self.metadata, 328 self.catalog_name, 329 self.event_store)) 330 331 332 class _WriteUserEventFn(DoFn): 333 FAILED_USER_EVENTS = "failed_user_events" 334 335 def __init__( 336 self, 337 project=None, 338 retry=None, 339 timeout=120, 340 metadata=None, 341 catalog_name=None, 342 event_store=None): 343 self._client = None 344 self.retry = retry 345 self.timeout = timeout 346 self.metadata = metadata 347 self.parent = f"projects/{project}/locations/global/catalogs/"\ 348 f"{catalog_name}/eventStores/{event_store}" 349 self.counter = Metrics.counter(self.__class__, "api_calls") 350 351 def setup(self): 352 if self._client is None: 353 self._client = get_recommendation_user_event_client() 354 355 def process(self, element): 356 user_event = recommendationengine.UserEvent(element) 357 request = recommendationengine.WriteUserEventRequest( 358 parent=self.parent, user_event=user_event) 359 360 try: 361 created_user_event = self._client.write_user_event(request) 362 self.counter.inc() 363 yield recommendationengine.UserEvent.to_dict(created_user_event) 364 except Exception: 365 yield pvalue.TaggedOutput( 366 self.FAILED_USER_EVENTS, 367 recommendationengine.UserEvent.to_dict(user_event)) 368 369 370 class ImportUserEvents(PTransform): 371 """Imports userevents in bulk. 372 The `PTransform` returns a PCollectionTuple with PCollections of 373 successfully and failed imported UserEvents. 374 375 Example usage:: 376 377 pipeline 378 | ImportUserEvents( 379 project='example-gcp-project', 380 catalog_name='my-catalog', 381 event_store='my_event_store') 382 """ 383 def __init__( 384 self, 385 max_batch_size: int = 5000, 386 project: str = None, 387 retry: Retry = None, 388 timeout: float = 120, 389 metadata: Sequence[Tuple[str, str]] = (), 390 catalog_name: str = "default_catalog", 391 event_store: str = "default_event_store"): 392 """Initializes a :class:`WriteUserEvent` transform. 393 394 Args: 395 batch_size (int): Required. Maximum number of catalogitems 396 per request. 397 project (str): Optional. GCP project name in which the catalog 398 data will be imported. 399 retry: Optional. Designation of what 400 errors, if any, should be retried. 401 timeout (float): Optional. The amount of time, in seconds, to wait 402 for the request to complete. 403 metadata: Optional. Strings which 404 should be sent along with the request as metadata. 405 catalog_name (str): Optional. Name of the catalog. 406 Default: 'default_catalog' 407 event_store (str): Optional. Name of the event store. 408 Default: 'default_event_store' 409 """ 410 self.max_batch_size = max_batch_size 411 self.project = project 412 self.retry = retry 413 self.timeout = timeout 414 self.metadata = metadata 415 self.catalog_name = catalog_name 416 self.event_store = event_store 417 418 def expand(self, pcoll): 419 if self.project is None: 420 self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project 421 if self.project is None: 422 raise ValueError( 423 'GCP project name needs to be specified in "project" pipeline option') 424 return ( 425 pcoll | GroupIntoBatches.WithShardedKey(self.max_batch_size) | ParDo( 426 _ImportUserEventsFn( 427 self.project, 428 self.retry, 429 self.timeout, 430 self.metadata, 431 self.catalog_name, 432 self.event_store))) 433 434 435 class _ImportUserEventsFn(DoFn): 436 FAILED_USER_EVENTS = "failed_user_events" 437 438 def __init__( 439 self, 440 project=None, 441 retry=None, 442 timeout=120, 443 metadata=None, 444 catalog_name=None, 445 event_store=None): 446 self._client = None 447 self.retry = retry 448 self.timeout = timeout 449 self.metadata = metadata 450 self.parent = f"projects/{project}/locations/global/catalogs/"\ 451 f"{catalog_name}/eventStores/{event_store}" 452 self.counter = Metrics.counter(self.__class__, "api_calls") 453 454 def setup(self): 455 if self._client is None: 456 self.client = get_recommendation_user_event_client() 457 458 def process(self, element): 459 460 user_events = [recommendationengine.UserEvent(e) for e in element[1]] 461 user_event_inline_source = recommendationengine.UserEventInlineSource( 462 {"user_events": user_events}) 463 input_config = recommendationengine.InputConfig( 464 user_event_inline_source=user_event_inline_source) 465 466 request = recommendationengine.ImportUserEventsRequest( 467 parent=self.parent, input_config=input_config) 468 469 try: 470 operation = self._client.write_user_event(request) 471 self.counter.inc(len(user_events)) 472 yield recommendationengine.PredictResponse.to_dict(operation.result()) 473 except Exception: 474 yield pvalue.TaggedOutput(self.FAILED_USER_EVENTS, user_events) 475 476 477 class PredictUserEvent(PTransform): 478 """Make a recommendation prediction. 479 The `PTransform` returns a PCollection 480 481 Example usage:: 482 483 pipeline 484 | PredictUserEvent( 485 project='example-gcp-project', 486 catalog_name='my-catalog', 487 event_store='my_event_store', 488 placement_id='recently_viewed_default') 489 """ 490 def __init__( 491 self, 492 project: str = None, 493 retry: Retry = None, 494 timeout: float = 120, 495 metadata: Sequence[Tuple[str, str]] = (), 496 catalog_name: str = "default_catalog", 497 event_store: str = "default_event_store", 498 placement_id: str = None): 499 """Initializes a :class:`PredictUserEvent` transform. 500 501 Args: 502 project (str): Optional. GCP project name in which the catalog 503 data will be imported. 504 retry: Optional. Designation of what 505 errors, if any, should be retried. 506 timeout (float): Optional. The amount of time, in seconds, to wait 507 for the request to complete. 508 metadata: Optional. Strings which 509 should be sent along with the request as metadata. 510 catalog_name (str): Optional. Name of the catalog. 511 Default: 'default_catalog' 512 event_store (str): Optional. Name of the event store. 513 Default: 'default_event_store' 514 placement_id (str): Required. ID of the recommendation engine 515 placement. This id is used to identify the set of models that 516 will be used to make the prediction. 517 """ 518 self.project = project 519 self.retry = retry 520 self.timeout = timeout 521 self.metadata = metadata 522 self.placement_id = placement_id 523 self.catalog_name = catalog_name 524 self.event_store = event_store 525 if placement_id is None: 526 raise ValueError('placement_id must be specified') 527 else: 528 self.placement_id = placement_id 529 530 def expand(self, pcoll): 531 if self.project is None: 532 self.project = pcoll.pipeline.options.view_as(GoogleCloudOptions).project 533 if self.project is None: 534 raise ValueError( 535 'GCP project name needs to be specified in "project" pipeline option') 536 return pcoll | ParDo( 537 _PredictUserEventFn( 538 self.project, 539 self.retry, 540 self.timeout, 541 self.metadata, 542 self.catalog_name, 543 self.event_store, 544 self.placement_id)) 545 546 547 class _PredictUserEventFn(DoFn): 548 FAILED_PREDICTIONS = "failed_predictions" 549 550 def __init__( 551 self, 552 project=None, 553 retry=None, 554 timeout=120, 555 metadata=None, 556 catalog_name=None, 557 event_store=None, 558 placement_id=None): 559 self._client = None 560 self.retry = retry 561 self.timeout = timeout 562 self.metadata = metadata 563 self.name = f"projects/{project}/locations/global/catalogs/"\ 564 f"{catalog_name}/eventStores/{event_store}/placements/"\ 565 f"{placement_id}" 566 self.counter = Metrics.counter(self.__class__, "api_calls") 567 568 def setup(self): 569 if self._client is None: 570 self._client = get_recommendation_prediction_client() 571 572 def process(self, element): 573 user_event = recommendationengine.UserEvent(element) 574 request = recommendationengine.PredictRequest( 575 name=self.name, user_event=user_event) 576 577 try: 578 prediction = self._client.predict(request) 579 self.counter.inc() 580 yield [ 581 recommendationengine.PredictResponse.to_dict(p) 582 for p in prediction.pages 583 ] 584 except Exception: 585 yield pvalue.TaggedOutput(self.FAILED_PREDICTIONS, user_event)