github.com/matrixorigin/matrixone@v1.2.0/pkg/udf/pythonservice/pyserver/server.py (about) 1 # coding = utf-8 2 # -*- coding:utf-8 -*- 3 # Copyright 2023 Matrix Origin 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 import argparse 17 import datetime 18 import decimal 19 import enum 20 import importlib 21 import json 22 import logging 23 import os 24 import shutil 25 import subprocess 26 import threading 27 from concurrent import futures 28 from typing import Any, Callable, Optional, Iterator, Dict 29 30 import grpc 31 32 import udf_pb2 as pb2 33 import udf_pb2_grpc as pb2_grpc 34 35 DEFAULT_DECIMAL_SCALE = 16 36 37 DATE_FORMAT = '%Y-%m-%d' 38 DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S' 39 DATETIME_FORMAT_WITH_PRECISION = '%Y-%m-%d %H:%M:%S.%f' 40 41 ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) 42 INSTALLED_LABEL = 'installed' 43 44 OPTION_VECTOR = 'vector' 45 OPTION_DECIMAL_PRECISION = 'decimal_precision' 46 47 logging.basicConfig( 48 level=logging.INFO, 49 format='[%(asctime)s] - [%(name)s] - [%(levelname)s] - [%(threadName)s] : %(message)s' 50 ) 51 log = logging.getLogger('Server') 52 53 54 class Server(pb2_grpc.ServiceServicer): 55 56 def run(self, requestIterator: Iterator[pb2.Request], context) -> pb2.Response: 57 firstRequest: Optional[pb2.Request] = None 58 path: Optional[str] = None 59 filename: Optional[str] = None 60 item: Optional[InstallingItem] = None 61 62 firstBlock = True 63 lastBlock = False 64 65 try: 66 for request in requestIterator: 67 # check 68 checkUdf(request.udf) 69 70 # the first request 71 if request.type == pb2.DataRequest: 72 log.info('receive data') 73 74 if request.udf.isImport: 75 path, filename, item, status = functionStatus(request.udf) 76 77 if status == FunctionStatus.NotExist: 78 firstRequest = request 79 yield pb2.Response(type=pb2.PkgRequest) 80 81 elif status == FunctionStatus.Installing: 82 with item.condition: 83 # block and waiting 84 if not item.installed: 85 log.info('waiting') 86 item.condition.wait() 87 log.info('finish waiting') 88 yield self.calculate(request, path, filename) 89 90 else: 91 yield self.calculate(request, path, filename) 92 93 else: 94 yield self.calculate(request, "", "") 95 96 # the second request (optional) 97 # var firstRequest, path, filename and item are not null 98 elif request.type == pb2.PkgResponse: 99 log.info('receive pkg') 100 101 # install pkg, do not need write lock 102 absPath = os.path.join(ROOT_PATH, path) 103 lastBlock = request.udf.importPkg.last 104 try: 105 if firstBlock: 106 if os.path.exists(absPath): 107 shutil.rmtree(absPath) 108 os.makedirs(absPath, exist_ok=True) 109 firstBlock = False 110 111 file = os.path.join(absPath, filename) 112 113 with open(file, 'ab+') as f: 114 f.write(request.udf.importPkg.data) 115 116 if lastBlock: 117 if request.udf.body.endswith('.whl'): 118 subprocess.check_call(['pip', 'install', '--no-index', file, '-t', absPath]) 119 os.remove(file) 120 121 # mark the pkg is installed without error 122 open(os.path.join(absPath, INSTALLED_LABEL), 'w').close() 123 124 except Exception as e: 125 shutil.rmtree(absPath, ignore_errors=True) 126 raise e 127 128 finally: 129 if lastBlock: 130 with item.condition: 131 item.installed = True 132 item.condition.notifyAll() 133 134 with INSTALLING_MAP_LOCK: 135 INSTALLING_MAP[path] = None 136 137 if not lastBlock: 138 yield pb2.Response(type=pb2.PkgRequest) 139 else: 140 yield self.calculate(firstRequest, path, filename) 141 142 else: 143 raise Exception('error udf request type') 144 # notify all 145 finally: 146 if item is None: 147 with INSTALLING_MAP_LOCK: 148 item = INSTALLING_MAP.get(path) 149 150 if item is not None: 151 with item.condition: 152 item.installed = True 153 item.condition.notifyAll() 154 155 with INSTALLING_MAP_LOCK: 156 INSTALLING_MAP[path] = None 157 158 def calculate(self, request: pb2.Request, filepath: str, filename: str) -> pb2.Response: 159 log.info('calculating') 160 161 # load function 162 func = loadFunction(request.udf, filepath, filename) 163 164 # set precision 165 if hasattr(func, OPTION_DECIMAL_PRECISION): 166 prec = getattr(func, OPTION_DECIMAL_PRECISION) 167 if type(prec) is int and prec >= 0: 168 decimal.getcontext().prec = prec 169 else: 170 decimal.getcontext().prec = DEFAULT_DECIMAL_SCALE 171 else: 172 decimal.getcontext().prec = DEFAULT_DECIMAL_SCALE 173 174 # scalar or vector 175 vector = False 176 if hasattr(func, OPTION_VECTOR): 177 vector = getattr(func, OPTION_VECTOR) is True 178 179 # init result 180 result = pb2.DataVector( 181 const=False, 182 data=[], 183 length=request.length, 184 type=request.udf.retType, 185 scale=defaultScale(request.udf.retType) 186 ) 187 188 # calculate 189 if vector: 190 params = [] 191 for i in range(len(request.vectors)): 192 params.append([getValueFromDataVector(request.vectors[i], j) for j in range(request.length)]) 193 values = func(*params) 194 assert len(values) == request.length, f'request length {request.length} is not same with result length {len(values)}' 195 for value in values: 196 data = value2Data(value, result.type) 197 result.data.append(data) 198 else: 199 for i in range(request.length): 200 params = [getValueFromDataVector(request.vectors[j], i) for j in range(len(request.vectors))] 201 value = func(*params) 202 data = value2Data(value, result.type) 203 result.data.append(data) 204 log.info('finish calculating') 205 return pb2.Response(vector=result, type=pb2.DataResponse) 206 207 208 def checkUdf(udf: pb2.Udf): 209 assert udf.handler != "", "udf handler should not be null" 210 assert udf.body != "", "udf body should not be null" 211 assert udf.language == "python", "udf language should be python" 212 assert udf.modifiedTime != "", "udf modifiedTime should not be null" 213 assert udf.db != "", "udf db should not be null" 214 215 216 class FunctionStatus(enum.Enum): 217 NotExist = 0 218 Installing = 1 219 Installed = 2 220 221 222 class InstallingItem: 223 condition = threading.Condition() 224 installed = False 225 226 227 # key: db/func/modified_time, value: InstallingItem 228 INSTALLING_MAP: Dict[str, Optional[InstallingItem]] = {} 229 INSTALLING_MAP_LOCK = threading.RLock() 230 231 232 def functionStatus(udf: pb2.Udf) -> (str, str, Optional[InstallingItem], FunctionStatus): 233 with INSTALLING_MAP_LOCK: 234 filepath, filename = os.path.split(udf.body) 235 path = os.path.join('udf', udf.db, filepath[filepath.rfind('/') + 1:], udf.modifiedTime) 236 item = INSTALLING_MAP.get(path) 237 if item is None: 238 if os.path.isfile(os.path.join(ROOT_PATH, path, INSTALLED_LABEL)): 239 return path, filename, item, FunctionStatus.Installed 240 else: 241 item = InstallingItem() 242 item.installed = False 243 INSTALLING_MAP[path] = item 244 return path, filename, item, FunctionStatus.NotExist 245 else: 246 if item.installed: 247 return path, filename, item, FunctionStatus.Installed 248 else: 249 return path, filename, item, FunctionStatus.Installing 250 251 252 def loadFunction(udf: pb2.Udf, filepath: str, filename: str) -> Callable: 253 # load function 254 if not udf.isImport: 255 exec(udf.body, locals()) 256 # get function object 257 return locals()[udf.handler] 258 else: 259 if udf.body.endswith('.py'): 260 file = importlib.import_module(f'.{filename[:-3]}', package=filepath.replace("/", ".")) 261 return getattr(file, udf.handler) 262 elif udf.body.endswith('.whl'): 263 i = udf.handler.rfind('.') 264 if i < 1: 265 raise Exception( 266 "when you import a *.whl, the handler should be in the format of '<file or module name>.<function name>'") 267 file = importlib.import_module(f'.{udf.handler[:i]}', package=filepath.replace("/", ".")) 268 return getattr(file, udf.handler[i + 1:]) 269 270 271 def getDataFromDataVector(v: pb2.DataVector, i: int) -> pb2.Data: 272 if v is None: 273 return pb2.Data() 274 if v.const: 275 return v.data[0] 276 return v.data[i] 277 278 279 def getValueFromDataVector(v: pb2.DataVector, i: int) -> Any: 280 data = getDataFromDataVector(v, i) 281 282 if data.WhichOneof("val") is None: 283 return None 284 285 if v.type == pb2.BOOL: 286 return data.boolVal 287 if v.type in [pb2.INT8, pb2.INT16, pb2.INT32]: 288 return data.intVal 289 if v.type == pb2.INT64: 290 return data.int64Val 291 if v.type in [pb2.UINT8, pb2.INT16, pb2.INT32]: 292 return data.uintVal 293 if v.type == pb2.UINT64: 294 return data.uint64Val 295 if v.type == pb2.FLOAT32: 296 return data.floatVal 297 if v.type == pb2.FLOAT64: 298 return data.doubleVal 299 if v.type in [pb2.CHAR, pb2.VARCHAR, pb2.TEXT, pb2.UUID]: 300 return data.stringVal 301 if v.type == pb2.JSON: 302 return json.loads(data.stringVal) 303 if v.type == pb2.TIME: 304 sign = 1 if data.stringVal[0] == '-' else 0 305 h, m, s = data.stringVal[sign:].split(':') 306 if sign == 0: 307 return datetime.timedelta(hours=int(h), minutes=int(m), seconds=float(s)) 308 return datetime.timedelta(hours=-int(h), minutes=-int(m), seconds=-float(s)) 309 if v.type == pb2.DATE: 310 return datetime.datetime.strptime(data.stringVal, DATE_FORMAT).date() 311 if v.type in [pb2.DATETIME, pb2.TIMESTAMP]: 312 formatStr = DATETIME_FORMAT if v.scale == 0 else DATETIME_FORMAT_WITH_PRECISION 313 return datetime.datetime.strptime(data.stringVal, formatStr) 314 if v.type in [pb2.DECIMAL64, pb2.DECIMAL128]: 315 return decimal.Decimal(data.stringVal) 316 if v.type in [pb2.BINARY, pb2.VARBINARY, pb2.BLOB]: 317 return data.bytesVal 318 else: 319 raise Exception("vector type error") 320 321 322 def defaultScale(typ: pb2.DataType, scale: Optional[int] = None) -> int: 323 if scale is not None: 324 return scale 325 if typ in [pb2.FLOAT32, pb2.FLOAT64]: 326 return -1 327 if typ in [pb2.DECIMAL64, pb2.DECIMAL128]: 328 return decimal.getcontext().prec 329 if typ in [pb2.TIME, pb2.DATETIME, pb2.TIMESTAMP]: 330 return 6 331 return 0 332 333 334 def value2Data(value: Any, typ: pb2.DataType) -> pb2.Data: 335 if value is None: 336 return pb2.Data() 337 338 if typ == pb2.BOOL: 339 assert type(value) is bool, f'return type error, required {bool}, received {type(value)}' 340 return pb2.Data(boolVal=value) 341 if typ in [pb2.INT8, pb2.INT16, pb2.INT32]: 342 assert type(value) is int, f'return type error, required {int}, received {type(value)}' 343 return pb2.Data(intVal=value) 344 if typ == pb2.INT64: 345 assert type(value) is int, f'return type error, required {int}, received {type(value)}' 346 return pb2.Data(int64Val=value) 347 if typ in [pb2.UINT8, pb2.INT16, pb2.INT32]: 348 assert type(value) is int, f'return type error, required {int}, received {type(value)}' 349 return pb2.Data(uintVal=value) 350 if typ == pb2.UINT64: 351 assert type(value) is int, f'return type error, required {int}, received {type(value)}' 352 return pb2.Data(uint64Val=value) 353 if typ == pb2.FLOAT32: 354 assert type(value) is float, f'return type error, required {float}, received {type(value)}' 355 return pb2.Data(floatVal=value) 356 if typ == pb2.FLOAT64: 357 assert type(value) is float, f'return type error, required {float}, received {type(value)}' 358 return pb2.Data(doubleVal=value) 359 if typ in [pb2.CHAR, pb2.VARCHAR, pb2.TEXT, pb2.UUID]: 360 assert type(value) is str, f'return type error, required {str}, received {type(value)}' 361 return pb2.Data(stringVal=value) 362 if typ == pb2.JSON: 363 return pb2.Data(stringVal=json.dumps(value, ensure_ascii=False)) 364 if typ == pb2.TIME: 365 assert type( 366 value) is datetime.timedelta, f'return type error, required {datetime.timedelta}, received {type(value)}' 367 r = '' 368 t: datetime.timedelta = value 369 if t.days < 0: 370 t = t * -1 371 r += '-' 372 h = t.days * 24 + t.seconds // 3600 373 m = t.seconds % 3600 // 60 374 s = t.seconds % 3600 % 60 375 r += f'{h:02d}:{m:02d}:{s:02d}.{t.microseconds:06d}' 376 return pb2.Data(stringVal=str(r)) 377 if typ == pb2.DATE: 378 assert type(value) is datetime.date, f'return type error, required {datetime.date}, received {type(value)}' 379 return pb2.Data(stringVal=str(value)) 380 if typ in [pb2.DATETIME, pb2.TIMESTAMP]: 381 assert type( 382 value) is datetime.datetime, f'return type error, required {datetime.datetime}, received {type(value)}' 383 return pb2.Data(stringVal=str(value)) 384 if typ in [pb2.DECIMAL64, pb2.DECIMAL128]: 385 assert type(value) is decimal.Decimal, f'return type error, required {decimal.Decimal}, received {type(value)}' 386 return pb2.Data(stringVal=str(value)) 387 if typ in [pb2.BINARY, pb2.VARBINARY, pb2.BLOB]: 388 assert type(value) is bytes, f'return type error, required {bytes}, received {type(value)}' 389 return pb2.Data(bytesVal=value) 390 else: 391 raise Exception(f'unsupported return type: {type(value)}') 392 393 394 def run(): 395 server = grpc.server(futures.ThreadPoolExecutor(), options=[ 396 ('grpc.max_send_message_length', 0x7fffffff), 397 ('grpc.max_receive_message_length', 0x7fffffff) 398 ]) 399 pb2_grpc.add_ServiceServicer_to_server(Server(), server) 400 server.add_insecure_port(ARGS.address) 401 server.start() 402 server.wait_for_termination() 403 404 405 ARGS = None 406 407 if __name__ == '__main__': 408 parser = argparse.ArgumentParser( 409 prog='python udf server', 410 epilog='Copyright(r), 2023' 411 ) 412 parser.add_argument('--address', default='[::]:50051', help='address') 413 ARGS = parser.parse_args() 414 415 log.info('python server start') 416 417 run()