You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
4.3 KiB
Python

2 years ago
from collections import OrderedDict, deque
from datetime import date, time, datetime
from decimal import Decimal
from fractions import Fraction
import ast
import enum
import typing
class CannotEval(Exception):
def __repr__(self):
return self.__class__.__name__
__str__ = __repr__
def is_any(x, *args):
return any(
x is arg
for arg in args
)
def of_type(x, *types):
if is_any(type(x), *types):
return x
else:
raise CannotEval
def of_standard_types(x, *, check_dict_values: bool, deep: bool):
if is_standard_types(x, check_dict_values=check_dict_values, deep=deep):
return x
else:
raise CannotEval
def is_standard_types(x, *, check_dict_values: bool, deep: bool):
try:
return _is_standard_types_deep(x, check_dict_values, deep)[0]
except RecursionError:
return False
def _is_standard_types_deep(x, check_dict_values: bool, deep: bool):
typ = type(x)
if is_any(
typ,
str,
int,
bool,
float,
bytes,
complex,
date,
time,
datetime,
Fraction,
Decimal,
type(None),
object,
):
return True, 0
if is_any(typ, tuple, frozenset, list, set, dict, OrderedDict, deque, slice):
if typ in [slice]:
length = 0
else:
length = len(x)
assert isinstance(deep, bool)
if not deep:
return True, length
if check_dict_values and typ in (dict, OrderedDict):
items = (v for pair in x.items() for v in pair)
elif typ is slice:
items = [x.start, x.stop, x.step]
else:
items = x
for item in items:
if length > 100000:
return False, length
is_standard, item_length = _is_standard_types_deep(
item, check_dict_values, deep
)
if not is_standard:
return False, length
length += item_length
return True, length
return False, 0
class _E(enum.Enum):
pass
class _C:
def foo(self): pass # pragma: nocover
def bar(self): pass # pragma: nocover
@classmethod
def cm(cls): pass # pragma: nocover
@staticmethod
def sm(): pass # pragma: nocover
safe_name_samples = {
"len": len,
"append": list.append,
"__add__": list.__add__,
"insert": [].insert,
"__mul__": [].__mul__,
"fromkeys": dict.__dict__['fromkeys'],
"is_any": is_any,
"__repr__": CannotEval.__repr__,
"foo": _C().foo,
"bar": _C.bar,
"cm": _C.cm,
"sm": _C.sm,
"ast": ast,
"CannotEval": CannotEval,
"_E": _E,
}
typing_annotation_samples = {
name: getattr(typing, name)
for name in "List Dict Tuple Set Callable Mapping".split()
}
safe_name_types = tuple({
type(f)
for f in safe_name_samples.values()
})
typing_annotation_types = tuple({
type(f)
for f in typing_annotation_samples.values()
})
def eq_checking_types(a, b):
return type(a) is type(b) and a == b
def ast_name(node):
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Attribute):
return node.attr
else:
return None
def safe_name(value):
typ = type(value)
if is_any(typ, *safe_name_types):
return value.__name__
elif value is typing.Optional:
return "Optional"
elif value is typing.Union:
return "Union"
elif is_any(typ, *typing_annotation_types):
return getattr(value, "__name__", None) or getattr(value, "_name", None)
else:
return None
def has_ast_name(value, node):
value_name = safe_name(value)
if type(value_name) is not str:
return False
return eq_checking_types(ast_name(node), value_name)
def copy_ast_without_context(x):
if isinstance(x, ast.AST):
kwargs = {
field: copy_ast_without_context(getattr(x, field))
for field in x._fields
if field != 'ctx'
if hasattr(x, field)
}
return type(x)(**kwargs)
elif isinstance(x, list):
return list(map(copy_ast_without_context, x))
else:
return x
def ensure_dict(x):
"""
Handles invalid non-dict inputs
"""
try:
return dict(x)
except Exception:
return {}