You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
4.3 KiB
Python

from collections import OrderedDict, deque
from datetime import date, time, datetime
from decimal import Decimal
from fractions import Fraction
import ast
import enum
import typing
class CannotEval(Exception):
def __repr__(self):
return self.__class__.__name__
__str__ = __repr__
def is_any(x, *args):
return any(
x is arg
for arg in args
)
def of_type(x, *types):
if is_any(type(x), *types):
return x
else:
raise CannotEval
def of_standard_types(x, *, check_dict_values: bool, deep: bool):
if is_standard_types(x, check_dict_values=check_dict_values, deep=deep):
return x
else:
raise CannotEval
def is_standard_types(x, *, check_dict_values: bool, deep: bool):
try:
return _is_standard_types_deep(x, check_dict_values, deep)[0]
except RecursionError:
return False
def _is_standard_types_deep(x, check_dict_values: bool, deep: bool):
typ = type(x)
if is_any(
typ,
str,
int,
bool,
float,
bytes,
complex,
date,
time,
datetime,
Fraction,
Decimal,
type(None),
object,
):
return True, 0
if is_any(typ, tuple, frozenset, list, set, dict, OrderedDict, deque, slice):
if typ in [slice]:
length = 0
else:
length = len(x)
assert isinstance(deep, bool)
if not deep:
return True, length
if check_dict_values and typ in (dict, OrderedDict):
items = (v for pair in x.items() for v in pair)
elif typ is slice:
items = [x.start, x.stop, x.step]
else:
items = x
for item in items:
if length > 100000:
return False, length
is_standard, item_length = _is_standard_types_deep(
item, check_dict_values, deep
)
if not is_standard:
return False, length
length += item_length
return True, length
return False, 0
class _E(enum.Enum):
pass
class _C:
def foo(self): pass # pragma: nocover
def bar(self): pass # pragma: nocover
@classmethod
def cm(cls): pass # pragma: nocover
@staticmethod
def sm(): pass # pragma: nocover
safe_name_samples = {
"len": len,
"append": list.append,
"__add__": list.__add__,
"insert": [].insert,
"__mul__": [].__mul__,
"fromkeys": dict.__dict__['fromkeys'],
"is_any": is_any,
"__repr__": CannotEval.__repr__,
"foo": _C().foo,
"bar": _C.bar,
"cm": _C.cm,
"sm": _C.sm,
"ast": ast,
"CannotEval": CannotEval,
"_E": _E,
}
typing_annotation_samples = {
name: getattr(typing, name)
for name in "List Dict Tuple Set Callable Mapping".split()
}
safe_name_types = tuple({
type(f)
for f in safe_name_samples.values()
})
typing_annotation_types = tuple({
type(f)
for f in typing_annotation_samples.values()
})
def eq_checking_types(a, b):
return type(a) is type(b) and a == b
def ast_name(node):
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return node.attr
else:
return None
def safe_name(value):
typ = type(value)
if is_any(typ, *safe_name_types):
return value.__name__
elif value is typing.Optional:
return "Optional"
elif value is typing.Union:
return "Union"
elif is_any(typ, *typing_annotation_types):
return getattr(value, "__name__", None) or getattr(value, "_name", None)
else:
return None
def has_ast_name(value, node):
value_name = safe_name(value)
if type(value_name) is not str:
return False
return eq_checking_types(ast_name(node), value_name)
def copy_ast_without_context(x):
if isinstance(x, ast.AST):
kwargs = {
field: copy_ast_without_context(getattr(x, field))
for field in x._fields
if field != 'ctx'
if hasattr(x, field)
}
return type(x)(**kwargs)
elif isinstance(x, list):
return list(map(copy_ast_without_context, x))
else:
return x
def ensure_dict(x):
"""
Handles invalid non-dict inputs
"""
try:
return dict(x)
except Exception:
return {}