github.com/matrixorigin/matrixone@v1.2.0/pkg/udf/pythonservice/pyserver/server.py (about)

     1  # coding = utf-8
     2  # -*- coding:utf-8 -*-
     3  # Copyright 2023 Matrix Origin
     4  #
     5  # Licensed under the Apache License, Version 2.0 (the "License");
     6  # you may not use this file except in compliance with the License.
     7  # You may obtain a copy of the License at
     8  #
     9  #      http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  import argparse
    17  import datetime
    18  import decimal
    19  import enum
    20  import importlib
    21  import json
    22  import logging
    23  import os
    24  import shutil
    25  import subprocess
    26  import threading
    27  from concurrent import futures
    28  from typing import Any, Callable, Optional, Iterator, Dict
    29  
    30  import grpc
    31  
    32  import udf_pb2 as pb2
    33  import udf_pb2_grpc as pb2_grpc
    34  
    35  DEFAULT_DECIMAL_SCALE = 16
    36  
    37  DATE_FORMAT = '%Y-%m-%d'
    38  DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S'
    39  DATETIME_FORMAT_WITH_PRECISION = '%Y-%m-%d %H:%M:%S.%f'
    40  
    41  ROOT_PATH = os.path.dirname(os.path.abspath(__file__))
    42  INSTALLED_LABEL = 'installed'
    43  
    44  OPTION_VECTOR = 'vector'
    45  OPTION_DECIMAL_PRECISION = 'decimal_precision'
    46  
    47  logging.basicConfig(
    48      level=logging.INFO,
    49      format='[%(asctime)s] - [%(name)s] - [%(levelname)s] - [%(threadName)s] : %(message)s'
    50  )
    51  log = logging.getLogger('Server')
    52  
    53  
    54  class Server(pb2_grpc.ServiceServicer):
    55  
    56      def run(self, requestIterator: Iterator[pb2.Request], context) -> pb2.Response:
    57          firstRequest: Optional[pb2.Request] = None
    58          path: Optional[str] = None
    59          filename: Optional[str] = None
    60          item: Optional[InstallingItem] = None
    61  
    62          firstBlock = True
    63          lastBlock = False
    64  
    65          try:
    66              for request in requestIterator:
    67                  # check
    68                  checkUdf(request.udf)
    69  
    70                  # the first request
    71                  if request.type == pb2.DataRequest:
    72                      log.info('receive data')
    73  
    74                      if request.udf.isImport:
    75                          path, filename, item, status = functionStatus(request.udf)
    76  
    77                          if status == FunctionStatus.NotExist:
    78                              firstRequest = request
    79                              yield pb2.Response(type=pb2.PkgRequest)
    80  
    81                          elif status == FunctionStatus.Installing:
    82                              with item.condition:
    83                                  # block and waiting
    84                                  if not item.installed:
    85                                      log.info('waiting')
    86                                      item.condition.wait()
    87                                      log.info('finish waiting')
    88                              yield self.calculate(request, path, filename)
    89  
    90                          else:
    91                              yield self.calculate(request, path, filename)
    92  
    93                      else:
    94                          yield self.calculate(request, "", "")
    95  
    96                  # the second request (optional)
    97                  # var firstRequest, path, filename and item are not null
    98                  elif request.type == pb2.PkgResponse:
    99                      log.info('receive pkg')
   100  
   101                      # install pkg, do not need write lock
   102                      absPath = os.path.join(ROOT_PATH, path)
   103                      lastBlock = request.udf.importPkg.last
   104                      try:
   105                          if firstBlock:
   106                              if os.path.exists(absPath):
   107                                  shutil.rmtree(absPath)
   108                              os.makedirs(absPath, exist_ok=True)
   109                              firstBlock = False
   110  
   111                          file = os.path.join(absPath, filename)
   112  
   113                          with open(file, 'ab+') as f:
   114                              f.write(request.udf.importPkg.data)
   115  
   116                          if lastBlock:
   117                              if request.udf.body.endswith('.whl'):
   118                                  subprocess.check_call(['pip', 'install', '--no-index', file, '-t', absPath])
   119                                  os.remove(file)
   120  
   121                              # mark the pkg is installed without error
   122                              open(os.path.join(absPath, INSTALLED_LABEL), 'w').close()
   123  
   124                      except Exception as e:
   125                          shutil.rmtree(absPath, ignore_errors=True)
   126                          raise e
   127  
   128                      finally:
   129                          if lastBlock:
   130                              with item.condition:
   131                                  item.installed = True
   132                                  item.condition.notifyAll()
   133  
   134                              with INSTALLING_MAP_LOCK:
   135                                  INSTALLING_MAP[path] = None
   136  
   137                      if not lastBlock:
   138                          yield pb2.Response(type=pb2.PkgRequest)
   139                      else:
   140                          yield self.calculate(firstRequest, path, filename)
   141  
   142                  else:
   143                      raise Exception('error udf request type')
   144          # notify all
   145          finally:
   146              if item is None:
   147                  with INSTALLING_MAP_LOCK:
   148                      item = INSTALLING_MAP.get(path)
   149  
   150              if item is not None:
   151                  with item.condition:
   152                      item.installed = True
   153                      item.condition.notifyAll()
   154  
   155                  with INSTALLING_MAP_LOCK:
   156                      INSTALLING_MAP[path] = None
   157  
   158      def calculate(self, request: pb2.Request, filepath: str, filename: str) -> pb2.Response:
   159          log.info('calculating')
   160  
   161          # load function
   162          func = loadFunction(request.udf, filepath, filename)
   163  
   164          # set precision
   165          if hasattr(func, OPTION_DECIMAL_PRECISION):
   166              prec = getattr(func, OPTION_DECIMAL_PRECISION)
   167              if type(prec) is int and prec >= 0:
   168                  decimal.getcontext().prec = prec
   169              else:
   170                  decimal.getcontext().prec = DEFAULT_DECIMAL_SCALE
   171          else:
   172              decimal.getcontext().prec = DEFAULT_DECIMAL_SCALE
   173  
   174          # scalar or vector
   175          vector = False
   176          if hasattr(func, OPTION_VECTOR):
   177              vector = getattr(func, OPTION_VECTOR) is True
   178  
   179          # init result
   180          result = pb2.DataVector(
   181              const=False,
   182              data=[],
   183              length=request.length,
   184              type=request.udf.retType,
   185              scale=defaultScale(request.udf.retType)
   186          )
   187  
   188          # calculate
   189          if vector:
   190              params = []
   191              for i in range(len(request.vectors)):
   192                  params.append([getValueFromDataVector(request.vectors[i], j) for j in range(request.length)])
   193              values = func(*params)
   194              assert len(values) == request.length, f'request length {request.length} is not same with result length {len(values)}'
   195              for value in values:
   196                  data = value2Data(value, result.type)
   197                  result.data.append(data)
   198          else:
   199              for i in range(request.length):
   200                  params = [getValueFromDataVector(request.vectors[j], i) for j in range(len(request.vectors))]
   201                  value = func(*params)
   202                  data = value2Data(value, result.type)
   203                  result.data.append(data)
   204          log.info('finish calculating')
   205          return pb2.Response(vector=result, type=pb2.DataResponse)
   206  
   207  
   208  def checkUdf(udf: pb2.Udf):
   209      assert udf.handler != "", "udf handler should not be null"
   210      assert udf.body != "", "udf body should not be null"
   211      assert udf.language == "python", "udf language should be python"
   212      assert udf.modifiedTime != "", "udf modifiedTime should not be null"
   213      assert udf.db != "", "udf db should not be null"
   214  
   215  
   216  class FunctionStatus(enum.Enum):
   217      NotExist = 0
   218      Installing = 1
   219      Installed = 2
   220  
   221  
   222  class InstallingItem:
   223      condition = threading.Condition()
   224      installed = False
   225  
   226  
   227  # key: db/func/modified_time, value: InstallingItem
   228  INSTALLING_MAP: Dict[str, Optional[InstallingItem]] = {}
   229  INSTALLING_MAP_LOCK = threading.RLock()
   230  
   231  
   232  def functionStatus(udf: pb2.Udf) -> (str, str, Optional[InstallingItem], FunctionStatus):
   233      with INSTALLING_MAP_LOCK:
   234          filepath, filename = os.path.split(udf.body)
   235          path = os.path.join('udf', udf.db, filepath[filepath.rfind('/') + 1:], udf.modifiedTime)
   236          item = INSTALLING_MAP.get(path)
   237          if item is None:
   238              if os.path.isfile(os.path.join(ROOT_PATH, path, INSTALLED_LABEL)):
   239                  return path, filename, item, FunctionStatus.Installed
   240              else:
   241                  item = InstallingItem()
   242                  item.installed = False
   243                  INSTALLING_MAP[path] = item
   244                  return path, filename, item, FunctionStatus.NotExist
   245          else:
   246              if item.installed:
   247                  return path, filename, item, FunctionStatus.Installed
   248              else:
   249                  return path, filename, item, FunctionStatus.Installing
   250  
   251  
   252  def loadFunction(udf: pb2.Udf, filepath: str, filename: str) -> Callable:
   253      # load function
   254      if not udf.isImport:
   255          exec(udf.body, locals())
   256          # get function object
   257          return locals()[udf.handler]
   258      else:
   259          if udf.body.endswith('.py'):
   260              file = importlib.import_module(f'.{filename[:-3]}', package=filepath.replace("/", "."))
   261              return getattr(file, udf.handler)
   262          elif udf.body.endswith('.whl'):
   263              i = udf.handler.rfind('.')
   264              if i < 1:
   265                  raise Exception(
   266                      "when you import a *.whl, the handler should be in the format of '<file or module name>.<function name>'")
   267              file = importlib.import_module(f'.{udf.handler[:i]}', package=filepath.replace("/", "."))
   268              return getattr(file, udf.handler[i + 1:])
   269  
   270  
   271  def getDataFromDataVector(v: pb2.DataVector, i: int) -> pb2.Data:
   272      if v is None:
   273          return pb2.Data()
   274      if v.const:
   275          return v.data[0]
   276      return v.data[i]
   277  
   278  
   279  def getValueFromDataVector(v: pb2.DataVector, i: int) -> Any:
   280      data = getDataFromDataVector(v, i)
   281  
   282      if data.WhichOneof("val") is None:
   283          return None
   284  
   285      if v.type == pb2.BOOL:
   286          return data.boolVal
   287      if v.type in [pb2.INT8, pb2.INT16, pb2.INT32]:
   288          return data.intVal
   289      if v.type == pb2.INT64:
   290          return data.int64Val
   291      if v.type in [pb2.UINT8, pb2.INT16, pb2.INT32]:
   292          return data.uintVal
   293      if v.type == pb2.UINT64:
   294          return data.uint64Val
   295      if v.type == pb2.FLOAT32:
   296          return data.floatVal
   297      if v.type == pb2.FLOAT64:
   298          return data.doubleVal
   299      if v.type in [pb2.CHAR, pb2.VARCHAR, pb2.TEXT, pb2.UUID]:
   300          return data.stringVal
   301      if v.type == pb2.JSON:
   302          return json.loads(data.stringVal)
   303      if v.type == pb2.TIME:
   304          sign = 1 if data.stringVal[0] == '-' else 0
   305          h, m, s = data.stringVal[sign:].split(':')
   306          if sign == 0:
   307              return datetime.timedelta(hours=int(h), minutes=int(m), seconds=float(s))
   308          return datetime.timedelta(hours=-int(h), minutes=-int(m), seconds=-float(s))
   309      if v.type == pb2.DATE:
   310          return datetime.datetime.strptime(data.stringVal, DATE_FORMAT).date()
   311      if v.type in [pb2.DATETIME, pb2.TIMESTAMP]:
   312          formatStr = DATETIME_FORMAT if v.scale == 0 else DATETIME_FORMAT_WITH_PRECISION
   313          return datetime.datetime.strptime(data.stringVal, formatStr)
   314      if v.type in [pb2.DECIMAL64, pb2.DECIMAL128]:
   315          return decimal.Decimal(data.stringVal)
   316      if v.type in [pb2.BINARY, pb2.VARBINARY, pb2.BLOB]:
   317          return data.bytesVal
   318      else:
   319          raise Exception("vector type error")
   320  
   321  
   322  def defaultScale(typ: pb2.DataType, scale: Optional[int] = None) -> int:
   323      if scale is not None:
   324          return scale
   325      if typ in [pb2.FLOAT32, pb2.FLOAT64]:
   326          return -1
   327      if typ in [pb2.DECIMAL64, pb2.DECIMAL128]:
   328          return decimal.getcontext().prec
   329      if typ in [pb2.TIME, pb2.DATETIME, pb2.TIMESTAMP]:
   330          return 6
   331      return 0
   332  
   333  
   334  def value2Data(value: Any, typ: pb2.DataType) -> pb2.Data:
   335      if value is None:
   336          return pb2.Data()
   337  
   338      if typ == pb2.BOOL:
   339          assert type(value) is bool, f'return type error, required {bool}, received {type(value)}'
   340          return pb2.Data(boolVal=value)
   341      if typ in [pb2.INT8, pb2.INT16, pb2.INT32]:
   342          assert type(value) is int, f'return type error, required {int}, received {type(value)}'
   343          return pb2.Data(intVal=value)
   344      if typ == pb2.INT64:
   345          assert type(value) is int, f'return type error, required {int}, received {type(value)}'
   346          return pb2.Data(int64Val=value)
   347      if typ in [pb2.UINT8, pb2.INT16, pb2.INT32]:
   348          assert type(value) is int, f'return type error, required {int}, received {type(value)}'
   349          return pb2.Data(uintVal=value)
   350      if typ == pb2.UINT64:
   351          assert type(value) is int, f'return type error, required {int}, received {type(value)}'
   352          return pb2.Data(uint64Val=value)
   353      if typ == pb2.FLOAT32:
   354          assert type(value) is float, f'return type error, required {float}, received {type(value)}'
   355          return pb2.Data(floatVal=value)
   356      if typ == pb2.FLOAT64:
   357          assert type(value) is float, f'return type error, required {float}, received {type(value)}'
   358          return pb2.Data(doubleVal=value)
   359      if typ in [pb2.CHAR, pb2.VARCHAR, pb2.TEXT, pb2.UUID]:
   360          assert type(value) is str, f'return type error, required {str}, received {type(value)}'
   361          return pb2.Data(stringVal=value)
   362      if typ == pb2.JSON:
   363          return pb2.Data(stringVal=json.dumps(value, ensure_ascii=False))
   364      if typ == pb2.TIME:
   365          assert type(
   366              value) is datetime.timedelta, f'return type error, required {datetime.timedelta}, received {type(value)}'
   367          r = ''
   368          t: datetime.timedelta = value
   369          if t.days < 0:
   370              t = t * -1
   371              r += '-'
   372          h = t.days * 24 + t.seconds // 3600
   373          m = t.seconds % 3600 // 60
   374          s = t.seconds % 3600 % 60
   375          r += f'{h:02d}:{m:02d}:{s:02d}.{t.microseconds:06d}'
   376          return pb2.Data(stringVal=str(r))
   377      if typ == pb2.DATE:
   378          assert type(value) is datetime.date, f'return type error, required {datetime.date}, received {type(value)}'
   379          return pb2.Data(stringVal=str(value))
   380      if typ in [pb2.DATETIME, pb2.TIMESTAMP]:
   381          assert type(
   382              value) is datetime.datetime, f'return type error, required {datetime.datetime}, received {type(value)}'
   383          return pb2.Data(stringVal=str(value))
   384      if typ in [pb2.DECIMAL64, pb2.DECIMAL128]:
   385          assert type(value) is decimal.Decimal, f'return type error, required {decimal.Decimal}, received {type(value)}'
   386          return pb2.Data(stringVal=str(value))
   387      if typ in [pb2.BINARY, pb2.VARBINARY, pb2.BLOB]:
   388          assert type(value) is bytes, f'return type error, required {bytes}, received {type(value)}'
   389          return pb2.Data(bytesVal=value)
   390      else:
   391          raise Exception(f'unsupported return type: {type(value)}')
   392  
   393  
   394  def run():
   395      server = grpc.server(futures.ThreadPoolExecutor(), options=[
   396          ('grpc.max_send_message_length', 0x7fffffff),
   397          ('grpc.max_receive_message_length', 0x7fffffff)
   398      ])
   399      pb2_grpc.add_ServiceServicer_to_server(Server(), server)
   400      server.add_insecure_port(ARGS.address)
   401      server.start()
   402      server.wait_for_termination()
   403  
   404  
   405  ARGS = None
   406  
   407  if __name__ == '__main__':
   408      parser = argparse.ArgumentParser(
   409          prog='python udf server',
   410          epilog='Copyright(r), 2023'
   411      )
   412      parser.add_argument('--address', default='[::]:50051', help='address')
   413      ARGS = parser.parse_args()
   414  
   415      log.info('python server start')
   416  
   417      run()