#! /usr/bin/env python3 # -*- coding: utf-8 -`- """ Code generation script for class methods to be exported as public API """ import argparse import ast import astor import os from pathlib import Path import sys from textwrap import indent PREFIX = "_generated" HEADER = """# *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._instrumentation import Instrument # fmt: off """ FOOTER = """# fmt: on """ TEMPLATE = """locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return{}GLOBAL_RUN_CONTEXT.{}.{} except AttributeError: raise RuntimeError("must be called from async context") """ def is_function(node): """Check if the AST node is either a function or an async function """ if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef): return True return False def is_public(node): """Check if the AST node has a _public decorator""" if not is_function(node): return False for decorator in node.decorator_list: if isinstance(decorator, ast.Name) and decorator.id == "_public": return True return False def get_public_methods(tree): """Return a list of methods marked as public. The function walks the given tree and extracts all objects that are functions which are marked public. """ for node in ast.walk(tree): if is_public(node): yield node def create_passthrough_args(funcdef): """Given a function definition, create a string that represents taking all the arguments from the function, and passing them through to another invocation of the same function. Example input: ast.parse("def f(a, *, b): ...") Example output: "(a, b=b)" """ call_args = [] for arg in funcdef.args.args: call_args.append(arg.arg) if funcdef.args.vararg: call_args.append("*" + funcdef.args.vararg.arg) for arg in funcdef.args.kwonlyargs: call_args.append(arg.arg + "=" + arg.arg) if funcdef.args.kwarg: call_args.append("**" + funcdef.args.kwarg.arg) return "({})".format(", ".join(call_args)) def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str: """Scan the given .py file for @_public decorators, and generate wrapper functions. """ generated = [HEADER] source = astor.code_to_ast.parse_file(source_path) for method in get_public_methods(source): # Remove self from arguments assert method.args.args[0].arg == "self" del method.args.args[0] # Remove decorators method.decorator_list = [] # Create pass through arguments new_args = create_passthrough_args(method) # Remove method body without the docstring if ast.get_docstring(method) is None: del method.body[:] else: # The first entry is always the docstring del method.body[1:] # Create the function definition including the body func = astor.to_source(method, indent_with=" " * 4) # Create export function body template = TEMPLATE.format( " await " if isinstance(method, ast.AsyncFunctionDef) else " ", lookup_path, method.name + new_args, ) # Assemble function definition arguments and body snippet = func + indent(template, " " * 4) # Append the snippet to the corresponding module generated.append(snippet) generated.append(FOOTER) return "\n\n".join(generated) def matches_disk_files(new_files): for new_path, new_source in new_files.items(): if not os.path.exists(new_path): return False with open(new_path, "r", encoding="utf-8") as old_file: old_source = old_file.read() if old_source != new_source: return False return True def process(sources_and_lookups, *, do_test): new_files = {} for source_path, lookup_path in sources_and_lookups: print("Scanning:", source_path) new_source = gen_public_wrappers_source(source_path, lookup_path) dirname, basename = os.path.split(source_path) new_path = os.path.join(dirname, PREFIX + basename) new_files[new_path] = new_source if do_test: if not matches_disk_files(new_files): print("Generated sources are outdated. Please regenerate.") sys.exit(1) else: print("Generated sources are up to date.") else: for new_path, new_source in new_files.items(): with open(new_path, "w", encoding="utf-8") as f: f.write(new_source) print("Regenerated sources successfully.") # This is in fact run in CI, but only in the formatting check job, which # doesn't collect coverage. def main(): # pragma: no cover parser = argparse.ArgumentParser( description="Generate python code for public api wrappers" ) parser.add_argument( "--test", "-t", action="store_true", help="test if code is still up to date" ) parsed_args = parser.parse_args() source_root = Path.cwd() # Double-check we found the right directory assert (source_root / "LICENSE").exists() core = source_root / "trio/_core" to_wrap = [ (core / "_run.py", "runner"), (core / "_instrumentation.py", "runner.instruments"), (core / "_io_windows.py", "runner.io_manager"), (core / "_io_epoll.py", "runner.io_manager"), (core / "_io_kqueue.py", "runner.io_manager"), ] process(to_wrap, do_test=parsed_args.test) if __name__ == "__main__": # pragma: no cover main()