github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/utils/spirv/gen_spirv_dialect.py (about) 1 #!/usr/bin/env python3 2 # -*- coding: utf-8 -*- 3 4 # Copyright 2019 The MLIR Authors. 5 # 6 # Licensed under the Apache License, Version 2.0 (the "License"); 7 # you may not use this file except in compliance with the License. 8 # You may obtain a copy of the License at 9 # 10 # http://www.apache.org/licenses/LICENSE-2.0 11 # 12 # Unless required by applicable law or agreed to in writing, software 13 # distributed under the License is distributed on an "AS IS" BASIS, 14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 # See the License for the specific language governing permissions and 16 # limitations under the License. 17 18 # Script for updating SPIR-V dialect by scraping information from SPIR-V 19 # HTML and JSON specs from the Internet. 20 # 21 # For example, to define the enum attribute for SPIR-V memory model: 22 # 23 # ./gen_spirv_dialect.py --base_td_path /path/to/SPIRVBase.td \ 24 # --new-enum MemoryModel 25 # 26 # The 'operand_kinds' dict of spirv.core.grammar.json contains all supported 27 # SPIR-V enum classes. 28 29 import re 30 import requests 31 import textwrap 32 33 SPIRV_HTML_SPEC_URL = 'https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html' 34 SPIRV_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json' 35 36 AUTOGEN_OP_DEF_SEPARATOR = '\n// -----\n\n' 37 AUTOGEN_ENUM_SECTION_MARKER = 'enum section. Generated from SPIR-V spec; DO NOT MODIFY!' 38 AUTOGEN_OPCODE_SECTION_MARKER = ( 39 'opcode section. Generated from SPIR-V spec; DO NOT MODIFY!') 40 41 42 def get_spirv_doc_from_html_spec(): 43 """Extracts instruction documentation from SPIR-V HTML spec. 44 45 Returns: 46 - A dict mapping from instruction opcode to documentation. 47 """ 48 response = requests.get(SPIRV_HTML_SPEC_URL) 49 spec = response.content 50 51 from bs4 import BeautifulSoup 52 spirv = BeautifulSoup(spec, 'html.parser') 53 54 section_anchor = spirv.find('h3', {'id': '_a_id_instructions_a_instructions'}) 55 56 doc = {} 57 58 for section in section_anchor.parent.find_all('div', {'class': 'sect3'}): 59 for table in section.find_all('table'): 60 inst_html = table.tbody.tr.td.p 61 opname = inst_html.a['id'] 62 # Ignore the first line, which is just the opname. 63 doc[opname] = inst_html.text.split('\n', 1)[1].strip() 64 65 return doc 66 67 68 def get_spirv_grammar_from_json_spec(): 69 """Extracts operand kind and instruction grammar from SPIR-V JSON spec. 70 71 Returns: 72 - A list containing all operand kinds' grammar 73 - A list containing all instructions' grammar 74 """ 75 response = requests.get(SPIRV_JSON_SPEC_URL) 76 spec = response.content 77 78 import json 79 spirv = json.loads(spec) 80 81 return spirv['operand_kinds'], spirv['instructions'] 82 83 84 def split_list_into_sublists(items, offset): 85 """Split the list of items into multiple sublists. 86 87 This is to make sure the string composed from each sublist won't exceed 88 80 characters. 89 90 Arguments: 91 - items: a list of strings 92 - offset: the offset in calculating each sublist's length 93 """ 94 chuncks = [] 95 chunk = [] 96 chunk_len = 0 97 98 for item in items: 99 chunk_len += len(item) + 2 100 if chunk_len > 80: 101 chuncks.append(chunk) 102 chunk = [] 103 chunk_len = len(item) + 2 104 chunk.append(item) 105 106 if len(chunk) != 0: 107 chuncks.append(chunk) 108 109 return chuncks 110 111 112 def uniquify(lst, equality_fn): 113 """Returns a list after pruning duplicate elements. 114 115 Arguments: 116 - lst: List whose elements are to be uniqued. 117 - equality_fn: Function used to compare equality between elements of the 118 list. 119 120 Returns: 121 - A list with all duplicated removed. The order of elements is same as the 122 original list, with only the first occurence of duplicates retained. 123 """ 124 keys = set() 125 unique_lst = [] 126 for elem in lst: 127 key = equality_fn(elem) 128 if key not in keys: 129 unique_lst.append(elem) 130 keys.add(key) 131 return unique_lst 132 133 134 def gen_operand_kind_enum_attr(operand_kind): 135 """Generates the TableGen I32EnumAttr definition for the given operand kind. 136 137 Returns: 138 - The operand kind's name 139 - A string containing the TableGen I32EnumAttr definition 140 """ 141 if 'enumerants' not in operand_kind: 142 return '', '' 143 144 kind_name = operand_kind['kind'] 145 kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z']) 146 kind_cases = [(case['enumerant'], case['value']) 147 for case in operand_kind['enumerants']] 148 kind_cases = uniquify(kind_cases, lambda x: x[1]) 149 max_len = max([len(symbol) for (symbol, _) in kind_cases]) 150 151 # Generate the definition for each enum case 152 fmt_str = 'def SPV_{acronym}_{symbol} {colon:>{offset}} '\ 153 'I32EnumAttrCase<"{symbol}", {value}>;' 154 case_defs = [ 155 fmt_str.format( 156 acronym=kind_acronym, 157 symbol=case[0], 158 value=case[1], 159 colon=':', 160 offset=(max_len + 1 - len(case[0]))) for case in kind_cases 161 ] 162 case_defs = '\n'.join(case_defs) 163 164 # Generate the list of enum case names 165 fmt_str = 'SPV_{acronym}_{symbol}'; 166 case_names = [fmt_str.format(acronym=kind_acronym,symbol=case[0]) 167 for case in kind_cases] 168 169 # Split them into sublists and concatenate into multiple lines 170 case_names = split_list_into_sublists(case_names, 6) 171 case_names = ['{:6}'.format('') + ', '.join(sublist) 172 for sublist in case_names] 173 case_names = ',\n'.join(case_names) 174 175 # Generate the enum attribute definition 176 enum_attr = 'def SPV_{name}Attr :\n '\ 177 'I32EnumAttr<"{name}", "valid SPIR-V {name}", [\n{cases}\n ]> {{\n'\ 178 ' let returnType = "::mlir::spirv::{name}";\n'\ 179 ' let convertFromStorage = '\ 180 '"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\ 181 ' let cppNamespace = "::mlir::spirv";\n}}'.format( 182 name=kind_name, cases=case_names) 183 return kind_name, case_defs + '\n\n' + enum_attr 184 185 186 def gen_opcode(instructions): 187 """ Generates the TableGen definition to map opname to opcode 188 189 Returns: 190 - A string containing the TableGen SPV_OpCode definition 191 """ 192 193 max_len = max([len(inst['opname']) for inst in instructions]) 194 def_fmt_str = 'def SPV_OC_{name} {colon:>{offset}} '\ 195 'I32EnumAttrCase<"{name}", {value}>;' 196 opcode_defs = [ 197 def_fmt_str.format( 198 name=inst['opname'], 199 value=inst['opcode'], 200 colon=':', 201 offset=(max_len + 1 - len(inst['opname']))) for inst in instructions 202 ] 203 opcode_str = '\n'.join(opcode_defs) 204 205 decl_fmt_str = 'SPV_OC_{name}' 206 opcode_list = [ 207 decl_fmt_str.format(name=inst['opname']) for inst in instructions 208 ] 209 opcode_list = split_list_into_sublists(opcode_list, 6) 210 opcode_list = [ 211 '{:6}'.format('') + ', '.join(sublist) for sublist in opcode_list 212 ] 213 opcode_list = ',\n'.join(opcode_list) 214 enum_attr = 'def SPV_OpcodeAttr :\n'\ 215 ' I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\ 216 '{lst}\n'\ 217 ' ]> {{\n'\ 218 ' let returnType = "::mlir::spirv::{name}";\n'\ 219 ' let convertFromStorage = '\ 220 '"static_cast<::mlir::spirv::{name}>($_self.getInt())";\n'\ 221 ' let cppNamespace = "::mlir::spirv";\n}}'.format( 222 name='Opcode', lst=opcode_list) 223 return opcode_str + '\n\n' + enum_attr 224 225 226 def update_td_opcodes(path, instructions, filter_list): 227 """Updates SPIRBase.td with new generated opcode cases. 228 229 Arguments: 230 - path: the path to SPIRBase.td 231 - instructions: a list containing all SPIR-V instructions' grammar 232 - filter_list: a list containing new opnames to add 233 """ 234 235 with open(path, 'r') as f: 236 content = f.read() 237 238 content = content.split(AUTOGEN_OPCODE_SECTION_MARKER) 239 assert len(content) == 3 240 241 # Extend opcode list with existing list 242 existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])] 243 filter_list.extend(existing_opcodes) 244 filter_list = list(set(filter_list)) 245 246 # Generate the opcode for all instructions in SPIR-V 247 filter_instrs = list( 248 filter(lambda inst: (inst['opname'] in filter_list), instructions)) 249 # Sort instruction based on opcode 250 filter_instrs.sort(key=lambda inst: inst['opcode']) 251 opcode = gen_opcode(filter_instrs) 252 253 # Substitute the opcode 254 content = content[0] + AUTOGEN_OPCODE_SECTION_MARKER + '\n\n' + \ 255 opcode + '\n\n// End ' + AUTOGEN_OPCODE_SECTION_MARKER \ 256 + content[2] 257 258 with open(path, 'w') as f: 259 f.write(content) 260 261 262 def update_td_enum_attrs(path, operand_kinds, filter_list): 263 """Updates SPIRBase.td with new generated enum definitions. 264 265 Arguments: 266 - path: the path to SPIRBase.td 267 - operand_kinds: a list containing all operand kinds' grammar 268 - filter_list: a list containing new enums to add 269 """ 270 with open(path, 'r') as f: 271 content = f.read() 272 273 content = content.split(AUTOGEN_ENUM_SECTION_MARKER) 274 assert len(content) == 3 275 276 # Extend filter list with existing enum definitions 277 existing_kinds = [ 278 k[8:-4] for k in re.findall('def SPV_\w+Attr', content[1])] 279 filter_list.extend(existing_kinds) 280 281 # Generate definitions for all enums in filter list 282 defs = [gen_operand_kind_enum_attr(kind) 283 for kind in operand_kinds if kind['kind'] in filter_list] 284 # Sort alphabetically according to enum name 285 defs.sort(key=lambda enum : enum[0]) 286 # Only keep the definitions from now on 287 defs = [enum[1] for enum in defs] 288 289 # Substitute the old section 290 content = content[0] + AUTOGEN_ENUM_SECTION_MARKER + '\n\n' + \ 291 '\n\n'.join(defs) + "\n\n// End " + AUTOGEN_ENUM_SECTION_MARKER \ 292 + content[2]; 293 294 with open(path, 'w') as f: 295 f.write(content) 296 297 298 def snake_casify(name): 299 """Turns the given name to follow snake_case convension.""" 300 name = re.sub('\W+', '', name).split() 301 name = [s.lower() for s in name] 302 return '_'.join(name) 303 304 305 def map_spec_operand_to_ods_argument(operand): 306 """Maps a operand in SPIR-V JSON spec to an op argument in ODS. 307 308 Arguments: 309 - A dict containing the operand's kind, quantifier, and name 310 311 Returns: 312 - A string containing both the type and name for the argument 313 """ 314 kind = operand['kind'] 315 quantifier = operand.get('quantifier', '') 316 317 # These instruction "operands" are for encoding the results; they should 318 # not be handled here. 319 assert kind != 'IdResultType', 'unexpected to handle "IdResultType" kind' 320 assert kind != 'IdResult', 'unexpected to handle "IdResult" kind' 321 322 if kind == 'IdRef': 323 if quantifier == '': 324 arg_type = 'SPV_Type' 325 elif quantifier == '?': 326 arg_type = 'SPV_Optional<SPV_Type>' 327 else: 328 arg_type = 'Variadic<SPV_Type>' 329 elif kind == 'IdMemorySemantics' or kind == 'IdScope': 330 # TODO(antiagainst): Need to further constrain 'IdMemorySemantics' 331 # and 'IdScope' given that they should be gernated from OpConstant. 332 assert quantifier == '', ('unexpected to have optional/variadic memory ' 333 'semantics or scope <id>') 334 arg_type = 'I32' 335 elif kind == 'LiteralInteger': 336 if quantifier == '': 337 arg_type = 'I32Attr' 338 elif quantifier == '?': 339 arg_type = 'OptionalAttr<I32Attr>' 340 else: 341 arg_type = 'OptionalAttr<I32ArrayAttr>' 342 elif kind == 'LiteralString' or \ 343 kind == 'LiteralContextDependentNumber' or \ 344 kind == 'LiteralExtInstInteger' or \ 345 kind == 'LiteralSpecConstantOpInteger' or \ 346 kind == 'PairLiteralIntegerIdRef' or \ 347 kind == 'PairIdRefLiteralInteger' or \ 348 kind == 'PairIdRefIdRef': 349 assert False, '"{}" kind unimplemented'.format(kind) 350 else: 351 # The rest are all enum operands that we represent with op attributes. 352 assert quantifier != '*', 'unexpected to have variadic enum attribute' 353 arg_type = 'SPV_{}Attr'.format(kind) 354 if quantifier == '?': 355 arg_type = 'OptionalAttr<{}>'.format(arg_type) 356 357 name = operand.get('name', '') 358 name = snake_casify(name) if name else kind.lower() 359 360 return '{}:${}'.format(arg_type, name) 361 362 363 def get_op_definition(instruction, doc, existing_info, inst_category): 364 """Generates the TableGen op definition for the given SPIR-V instruction. 365 366 Arguments: 367 - instruction: the instruction's SPIR-V JSON grammar 368 - doc: the instruction's SPIR-V HTML doc 369 - existing_info: a dict containing potential manually specified sections for 370 this instruction 371 372 Returns: 373 - A string containing the TableGen op definition 374 """ 375 fmt_str = ('def SPV_{opname}Op : ' 376 'SPV_{inst_category}<"{opname}"{category_args}[{traits}]> ' 377 '{{\n let summary = {summary};\n\n let description = ' 378 '[{{\n{description}\n\n ### Custom assembly ' 379 'form\n{assembly}}}];\n') 380 if inst_category == 'Op': 381 fmt_str +='\n let arguments = (ins{args});\n\n'\ 382 ' let results = (outs{results});\n\n' 383 384 fmt_str +='{extras}'\ 385 '}}\n' 386 387 opname = instruction['opname'][2:] 388 category_args = existing_info.get('category_args', None) 389 if category_args is None: 390 category_args = ', ' 391 392 summary, description = doc.split('\n', 1) 393 wrapper = textwrap.TextWrapper( 394 width=76, initial_indent=' ', subsequent_indent=' ') 395 396 # Format summary. If the summary can fit in the same line, we print it out 397 # as a "-quoted string; otherwise, wrap the lines using "[{...}]". 398 summary = summary.strip(); 399 if len(summary) + len(' let summary = "";') <= 80: 400 summary = '"{}"'.format(summary) 401 else: 402 summary = '[{{\n{}\n }}]'.format(wrapper.fill(summary)) 403 404 # Wrap description 405 description = description.split('\n') 406 description = [wrapper.fill(line) for line in description if line] 407 description = '\n\n'.join(description) 408 409 operands = instruction.get('operands', []) 410 411 # Set op's result 412 results = '' 413 if len(operands) > 0 and operands[0]['kind'] == 'IdResultType': 414 results = '\n SPV_Type:$result\n ' 415 operands = operands[1:] 416 if 'results' in existing_info: 417 results = existing_info['results'] 418 419 # Ignore the operand standing for the result <id> 420 if len(operands) > 0 and operands[0]['kind'] == 'IdResult': 421 operands = operands[1:] 422 423 # Set op' argument 424 arguments = existing_info.get('arguments', None) 425 if arguments is None: 426 arguments = [map_spec_operand_to_ods_argument(o) for o in operands] 427 arguments = ',\n '.join(arguments) 428 if arguments: 429 # Prepend and append whitespace for formatting 430 arguments = '\n {}\n '.format(arguments) 431 432 assembly = existing_info.get('assembly', None) 433 if assembly is None: 434 assembly = '\n ``` {.ebnf}\n'\ 435 ' [TODO]\n'\ 436 ' ```\n\n'\ 437 ' For example:\n\n'\ 438 ' ```\n'\ 439 ' [TODO]\n'\ 440 ' ```\n ' 441 442 return fmt_str.format( 443 opname=opname, 444 category_args=category_args, 445 inst_category=inst_category, 446 traits=existing_info.get('traits', ''), 447 summary=summary, 448 description=description, 449 assembly=assembly, 450 args=arguments, 451 results=results, 452 extras=existing_info.get('extras', '')) 453 454 455 def get_string_between(base, start, end): 456 """Extracts a substring with a specified start and end from a string. 457 458 Arguments: 459 - base: string to extract from. 460 - start: string to use as the start of the substring. 461 - end: string to use as the end of the substring. 462 463 Returns: 464 - The substring if found 465 - The part of the base after end of the substring. Is the base string itself 466 if the substring wasnt found. 467 """ 468 split = base.split(start, 1) 469 if len(split) == 2: 470 rest = split[1].split(end, 1) 471 assert len(rest) == 2, \ 472 'cannot find end "{end}" while extracting substring '\ 473 'starting with {start}'.format(start=start, end=end) 474 return rest[0].rstrip(end), rest[1] 475 return '', split[0] 476 477 478 def extract_td_op_info(op_def): 479 """Extracts potentially manually specified sections in op's definition. 480 481 Arguments: - A string containing the op's TableGen definition 482 - doc: the instruction's SPIR-V HTML doc 483 484 Returns: 485 - A dict containing potential manually specified sections 486 """ 487 # Get opname 488 opname = [o[8:-2] for o in re.findall('def SPV_\w+Op', op_def)] 489 assert len(opname) == 1, 'more than one ops in the same section!' 490 opname = opname[0] 491 492 # Get category_args 493 op_tmpl_params = op_def.split('<', 1)[1].split('>', 1)[0] 494 opstringname, rest = get_string_between(op_tmpl_params, '"', '"') 495 category_args = rest.split('[', 1)[0] 496 497 # Get traits 498 traits, _ = get_string_between(rest, '[', ']') 499 500 # Get custom assembly form 501 assembly, rest = get_string_between(op_def, '### Custom assembly form\n', 502 '}];\n') 503 504 # Get arguments 505 args, rest = get_string_between(rest, ' let arguments = (ins', ');\n') 506 507 # Get results 508 results, rest = get_string_between(rest, ' let results = (outs', ');\n') 509 510 extras = rest.strip(' }\n') 511 if extras: 512 extras = '\n {}\n'.format(extras) 513 514 return { 515 # Prefix with 'Op' to make it consistent with SPIR-V spec 516 'opname': 'Op{}'.format(opname), 517 'category_args': category_args, 518 'traits': traits, 519 'assembly': assembly, 520 'arguments': args, 521 'results': results, 522 'extras': extras 523 } 524 525 526 def update_td_op_definitions(path, instructions, docs, filter_list, 527 inst_category): 528 """Updates SPIRVOps.td with newly generated op definition. 529 530 Arguments: 531 - path: path to SPIRVOps.td 532 - instructions: SPIR-V JSON grammar for all instructions 533 - docs: SPIR-V HTML doc for all instructions 534 - filter_list: a list containing new opnames to include 535 536 Returns: 537 - A string containing all the TableGen op definitions 538 """ 539 with open(path, 'r') as f: 540 content = f.read() 541 542 # Split the file into chuncks, each containing one op. 543 ops = content.split(AUTOGEN_OP_DEF_SEPARATOR) 544 header = ops[0] 545 footer = ops[-1] 546 ops = ops[1:-1] 547 548 # For each existing op, extract the manually-written sections out to retain 549 # them when re-generating the ops. Also append the existing ops to filter 550 # list. 551 op_info_dict = {} 552 for op in ops: 553 info_dict = extract_td_op_info(op) 554 opname = info_dict['opname'] 555 op_info_dict[opname] = info_dict 556 filter_list.append(opname) 557 filter_list = sorted(list(set(filter_list))) 558 559 op_defs = [] 560 for opname in filter_list: 561 # Find the grammar spec for this op 562 instruction = next( 563 inst for inst in instructions if inst['opname'] == opname) 564 op_defs.append( 565 get_op_definition(instruction, docs[opname], 566 op_info_dict.get(opname, {}), inst_category)) 567 568 # Substitute the old op definitions 569 op_defs = [header] + op_defs + [footer] 570 content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs) 571 572 with open(path, 'w') as f: 573 f.write(content) 574 575 576 if __name__ == '__main__': 577 import argparse 578 579 cli_parser = argparse.ArgumentParser( 580 description='Update SPIR-V dialect definitions using SPIR-V spec') 581 582 cli_parser.add_argument( 583 '--base-td-path', 584 dest='base_td_path', 585 type=str, 586 default=None, 587 help='Path to SPIRVBase.td') 588 cli_parser.add_argument( 589 '--op-td-path', 590 dest='op_td_path', 591 type=str, 592 default=None, 593 help='Path to SPIRVOps.td') 594 595 cli_parser.add_argument( 596 '--new-enum', 597 dest='new_enum', 598 type=str, 599 default=None, 600 help='SPIR-V enum to be added to SPIRVBase.td') 601 cli_parser.add_argument( 602 '--new-opcodes', 603 dest='new_opcodes', 604 type=str, 605 default=None, 606 nargs='*', 607 help='update SPIR-V opcodes in SPIRVBase.td') 608 cli_parser.add_argument( 609 '--new-inst', 610 dest='new_inst', 611 type=str, 612 default=None, 613 nargs='*', 614 help='SPIR-V instruction to be added to ops file') 615 cli_parser.add_argument( 616 '--inst-category', 617 dest='inst_category', 618 type=str, 619 default='Op', 620 help='SPIR-V instruction category used for choosing '\ 621 'a suitable .td file and TableGen common base '\ 622 'class to define this op') 623 624 args = cli_parser.parse_args() 625 626 operand_kinds, instructions = get_spirv_grammar_from_json_spec() 627 628 # Define new enum attr 629 if args.new_enum is not None: 630 assert args.base_td_path is not None 631 filter_list = [args.new_enum] if args.new_enum else [] 632 update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list) 633 634 # Define new opcode 635 if args.new_opcodes is not None: 636 assert args.base_td_path is not None 637 update_td_opcodes(args.base_td_path, instructions, args.new_opcodes) 638 639 # Define new op 640 if args.new_inst is not None: 641 assert args.op_td_path is not None 642 docs = get_spirv_doc_from_html_spec() 643 update_td_op_definitions(args.op_td_path, instructions, docs, args.new_inst, 644 args.inst_category) 645 print('Done. Note that this script just generates a template; ', end='') 646 print('please read the spec and update traits, arguments, and ', end='') 647 print('results accordingly.')