You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

193 lines
5.8 KiB
Python

#! /usr/bin/env python3
# -*- coding: utf-8 -`-
"""
Code generation script for class methods
to be exported as public API
"""
import argparse
import ast
import astor
import os
from pathlib import Path
import sys
from textwrap import indent
PREFIX = "_generated"
HEADER = """# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._instrumentation import Instrument
# fmt: off
"""
FOOTER = """# fmt: on
"""
TEMPLATE = """locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return{}GLOBAL_RUN_CONTEXT.{}.{}
except AttributeError:
raise RuntimeError("must be called from async context")
"""
def is_function(node):
"""Check if the AST node is either a function
or an async function
"""
if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
return True
return False
def is_public(node):
"""Check if the AST node has a _public decorator"""
if not is_function(node):
return False
for decorator in node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id == "_public":
return True
return False
def get_public_methods(tree):
"""Return a list of methods marked as public.
The function walks the given tree and extracts
all objects that are functions which are marked
public.
"""
for node in ast.walk(tree):
if is_public(node):
yield node
def create_passthrough_args(funcdef):
"""Given a function definition, create a string that represents taking all
the arguments from the function, and passing them through to another
invocation of the same function.
Example input: ast.parse("def f(a, *, b): ...")
Example output: "(a, b=b)"
"""
call_args = []
for arg in funcdef.args.args:
call_args.append(arg.arg)
if funcdef.args.vararg:
call_args.append("*" + funcdef.args.vararg.arg)
for arg in funcdef.args.kwonlyargs:
call_args.append(arg.arg + "=" + arg.arg)
if funcdef.args.kwarg:
call_args.append("**" + funcdef.args.kwarg.arg)
return "({})".format(", ".join(call_args))
def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str:
"""Scan the given .py file for @_public decorators, and generate wrapper
functions.
"""
generated = [HEADER]
source = astor.code_to_ast.parse_file(source_path)
for method in get_public_methods(source):
# Remove self from arguments
assert method.args.args[0].arg == "self"
del method.args.args[0]
# Remove decorators
method.decorator_list = []
# Create pass through arguments
new_args = create_passthrough_args(method)
# Remove method body without the docstring
if ast.get_docstring(method) is None:
del method.body[:]
else:
# The first entry is always the docstring
del method.body[1:]
# Create the function definition including the body
func = astor.to_source(method, indent_with=" " * 4)
# Create export function body
template = TEMPLATE.format(
" await " if isinstance(method, ast.AsyncFunctionDef) else " ",
lookup_path,
method.name + new_args,
)
# Assemble function definition arguments and body
snippet = func + indent(template, " " * 4)
# Append the snippet to the corresponding module
generated.append(snippet)
generated.append(FOOTER)
return "\n\n".join(generated)
def matches_disk_files(new_files):
for new_path, new_source in new_files.items():
if not os.path.exists(new_path):
return False
with open(new_path, "r", encoding="utf-8") as old_file:
old_source = old_file.read()
if old_source != new_source:
return False
return True
def process(sources_and_lookups, *, do_test):
new_files = {}
for source_path, lookup_path in sources_and_lookups:
print("Scanning:", source_path)
new_source = gen_public_wrappers_source(source_path, lookup_path)
dirname, basename = os.path.split(source_path)
new_path = os.path.join(dirname, PREFIX + basename)
new_files[new_path] = new_source
if do_test:
if not matches_disk_files(new_files):
print("Generated sources are outdated. Please regenerate.")
sys.exit(1)
else:
print("Generated sources are up to date.")
else:
for new_path, new_source in new_files.items():
with open(new_path, "w", encoding="utf-8") as f:
f.write(new_source)
print("Regenerated sources successfully.")
# This is in fact run in CI, but only in the formatting check job, which
# doesn't collect coverage.
def main(): # pragma: no cover
parser = argparse.ArgumentParser(
description="Generate python code for public api wrappers"
)
parser.add_argument(
"--test", "-t", action="store_true", help="test if code is still up to date"
)
parsed_args = parser.parse_args()
source_root = Path.cwd()
# Double-check we found the right directory
assert (source_root / "LICENSE").exists()
core = source_root / "trio/_core"
to_wrap = [
(core / "_run.py", "runner"),
(core / "_instrumentation.py", "runner.instruments"),
(core / "_io_windows.py", "runner.io_manager"),
(core / "_io_epoll.py", "runner.io_manager"),
(core / "_io_kqueue.py", "runner.io_manager"),
]
process(to_wrap, do_test=parsed_args.test)
if __name__ == "__main__": # pragma: no cover
main()