You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

472 lines
12 KiB
Python

2 years ago
"""Pickle related utilities. Perhaps this should be called 'can'."""
# Copyright (c) IPython Development Team.
# Distributed under the terms of the Modified BSD License.
import typing
import warnings
warnings.warn(
"ipykernel.pickleutil is deprecated. It has moved to ipyparallel.",
DeprecationWarning,
stacklevel=2,
)
import copy
import pickle
import sys
from types import FunctionType
# This registers a hook when it's imported
from ipyparallel.serialize import codeutil # noqa F401
from traitlets.log import get_logger
from traitlets.utils.importstring import import_item
buffer = memoryview
class_type = type
PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
def _get_cell_type(a=None):
"""the type of a closure cell doesn't seem to be importable,
so just create one
"""
def inner():
return a
return type(inner.__closure__[0]) # type:ignore[index]
cell_type = _get_cell_type()
# -------------------------------------------------------------------------------
# Functions
# -------------------------------------------------------------------------------
def interactive(f):
"""decorator for making functions appear as interactively defined.
This results in the function being linked to the user_ns as globals()
instead of the module globals().
"""
# build new FunctionType, so it can have the right globals
# interactive functions never have closures, that's kind of the point
if isinstance(f, FunctionType):
mainmod = __import__("__main__")
f = FunctionType(
f.__code__,
mainmod.__dict__,
f.__name__,
f.__defaults__,
)
# associate with __main__ for uncanning
f.__module__ = "__main__"
return f
def use_dill():
"""use dill to expand serialization support
adds support for object methods and closures to serialization.
"""
# import dill causes most of the magic
import dill
# dill doesn't work with cPickle,
# tell the two relevant modules to use plain pickle
global pickle
pickle = dill
try:
from ipykernel import serialize
except ImportError:
pass
else:
serialize.pickle = dill # type:ignore[attr-defined]
# disable special function handling, let dill take care of it
can_map.pop(FunctionType, None)
def use_cloudpickle():
"""use cloudpickle to expand serialization support
adds support for object methods and closures to serialization.
"""
import cloudpickle
global pickle
pickle = cloudpickle
try:
from ipykernel import serialize
except ImportError:
pass
else:
serialize.pickle = cloudpickle # type:ignore[attr-defined]
# disable special function handling, let cloudpickle take care of it
can_map.pop(FunctionType, None)
# -------------------------------------------------------------------------------
# Classes
# -------------------------------------------------------------------------------
class CannedObject:
def __init__(self, obj, keys=None, hook=None):
"""can an object for safe pickling
Parameters
----------
obj
The object to be canned
keys : list (optional)
list of attribute names that will be explicitly canned / uncanned
hook : callable (optional)
An optional extra callable,
which can do additional processing of the uncanned object.
Notes
-----
large data may be offloaded into the buffers list,
used for zero-copy transfers.
"""
self.keys = keys or []
self.obj = copy.copy(obj)
self.hook = can(hook)
for key in keys:
setattr(self.obj, key, can(getattr(obj, key)))
self.buffers = []
def get_object(self, g=None):
if g is None:
g = {}
obj = self.obj
for key in self.keys:
setattr(obj, key, uncan(getattr(obj, key), g))
if self.hook:
self.hook = uncan(self.hook, g)
self.hook(obj, g)
return self.obj
class Reference(CannedObject):
"""object for wrapping a remote reference by name."""
def __init__(self, name):
if not isinstance(name, str):
raise TypeError("illegal name: %r" % name)
self.name = name
self.buffers = []
def __repr__(self):
return "<Reference: %r>" % self.name
def get_object(self, g=None):
if g is None:
g = {}
return eval(self.name, g)
class CannedCell(CannedObject):
"""Can a closure cell"""
def __init__(self, cell):
self.cell_contents = can(cell.cell_contents)
def get_object(self, g=None):
cell_contents = uncan(self.cell_contents, g)
def inner():
return cell_contents
return inner.__closure__[0] # type:ignore[index]
class CannedFunction(CannedObject):
def __init__(self, f):
self._check_type(f)
self.code = f.__code__
self.defaults: typing.Optional[typing.List[typing.Any]]
if f.__defaults__:
self.defaults = [can(fd) for fd in f.__defaults__]
else:
self.defaults = None
self.closure: typing.Any
closure = f.__closure__
if closure:
self.closure = tuple(can(cell) for cell in closure)
else:
self.closure = None
self.module = f.__module__ or "__main__"
self.__name__ = f.__name__
self.buffers = []
def _check_type(self, obj):
assert isinstance(obj, FunctionType), "Not a function type"
def get_object(self, g=None):
# try to load function back into its module:
if not self.module.startswith("__"):
__import__(self.module)
g = sys.modules[self.module].__dict__
if g is None:
g = {}
if self.defaults:
defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
else:
defaults = None
if self.closure:
closure = tuple(uncan(cell, g) for cell in self.closure)
else:
closure = None
newFunc = FunctionType(self.code, g, self.__name__, defaults, closure)
return newFunc
class CannedClass(CannedObject):
def __init__(self, cls):
self._check_type(cls)
self.name = cls.__name__
self.old_style = not isinstance(cls, type)
self._canned_dict = {}
for k, v in cls.__dict__.items():
if k not in ("__weakref__", "__dict__"):
self._canned_dict[k] = can(v)
if self.old_style:
mro = []
else:
mro = cls.mro()
self.parents = [can(c) for c in mro[1:]]
self.buffers = []
def _check_type(self, obj):
assert isinstance(obj, class_type), "Not a class type"
def get_object(self, g=None):
parents = tuple(uncan(p, g) for p in self.parents)
return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
class CannedArray(CannedObject):
def __init__(self, obj):
from numpy import ascontiguousarray
self.shape = obj.shape
self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
self.pickled = False
if sum(obj.shape) == 0:
self.pickled = True
elif obj.dtype == "O":
# can't handle object dtype with buffer approach
self.pickled = True
elif obj.dtype.fields and any(dt == "O" for dt, sz in obj.dtype.fields.values()):
self.pickled = True
if self.pickled:
# just pickle it
self.buffers = [pickle.dumps(obj, PICKLE_PROTOCOL)]
else:
# ensure contiguous
obj = ascontiguousarray(obj, dtype=None)
self.buffers = [buffer(obj)]
def get_object(self, g=None):
from numpy import frombuffer
data = self.buffers[0]
if self.pickled:
# we just pickled it
return pickle.loads(data)
else:
return frombuffer(data, dtype=self.dtype).reshape(self.shape)
class CannedBytes(CannedObject):
@staticmethod
def wrap(buf: typing.Union[memoryview, bytes, typing.SupportsBytes]) -> bytes:
"""Cast a buffer or memoryview object to bytes"""
if isinstance(buf, memoryview):
return buf.tobytes()
if not isinstance(buf, bytes):
return bytes(buf)
return buf
def __init__(self, obj):
self.buffers = [obj]
def get_object(self, g=None):
data = self.buffers[0]
return self.wrap(data)
class CannedBuffer(CannedBytes):
wrap = buffer # type:ignore[assignment]
class CannedMemoryView(CannedBytes):
wrap = memoryview # type:ignore[assignment]
# -------------------------------------------------------------------------------
# Functions
# -------------------------------------------------------------------------------
def _import_mapping(mapping, original=None):
"""import any string-keys in a type mapping"""
log = get_logger()
log.debug("Importing canning map")
for key, _ in list(mapping.items()):
if isinstance(key, str):
try:
cls = import_item(key)
except Exception:
if original and key not in original:
# only message on user-added classes
log.error("canning class not importable: %r", key, exc_info=True)
mapping.pop(key)
else:
mapping[cls] = mapping.pop(key)
def istype(obj, check):
"""like isinstance(obj, check), but strict
This won't catch subclasses.
"""
if isinstance(check, tuple):
for cls in check:
if type(obj) is cls:
return True
return False
else:
return type(obj) is check
def can(obj):
"""prepare an object for pickling"""
import_needed = False
for cls, canner in can_map.items():
if isinstance(cls, str):
import_needed = True
break
elif istype(obj, cls):
return canner(obj)
if import_needed:
# perform can_map imports, then try again
# this will usually only happen once
_import_mapping(can_map, _original_can_map)
return can(obj)
return obj
def can_class(obj):
if isinstance(obj, class_type) and obj.__module__ == "__main__":
return CannedClass(obj)
else:
return obj
def can_dict(obj):
"""can the *values* of a dict"""
if istype(obj, dict):
newobj = {}
for k, v in obj.items():
newobj[k] = can(v)
return newobj
else:
return obj
sequence_types = (list, tuple, set)
def can_sequence(obj):
"""can the elements of a sequence"""
if istype(obj, sequence_types):
t = type(obj)
return t([can(i) for i in obj])
else:
return obj
def uncan(obj, g=None):
"""invert canning"""
import_needed = False
for cls, uncanner in uncan_map.items():
if isinstance(cls, str):
import_needed = True
break
elif isinstance(obj, cls):
return uncanner(obj, g)
if import_needed:
# perform uncan_map imports, then try again
# this will usually only happen once
_import_mapping(uncan_map, _original_uncan_map)
return uncan(obj, g)
return obj
def uncan_dict(obj, g=None):
if istype(obj, dict):
newobj = {}
for k, v in obj.items():
newobj[k] = uncan(v, g)
return newobj
else:
return obj
def uncan_sequence(obj, g=None):
if istype(obj, sequence_types):
t = type(obj)
return t([uncan(i, g) for i in obj])
else:
return obj
# -------------------------------------------------------------------------------
# API dictionaries
# -------------------------------------------------------------------------------
# These dicts can be extended for custom serialization of new objects
can_map = {
"numpy.ndarray": CannedArray,
FunctionType: CannedFunction,
bytes: CannedBytes,
memoryview: CannedMemoryView,
cell_type: CannedCell,
class_type: can_class,
}
if buffer is not memoryview:
can_map[buffer] = CannedBuffer
uncan_map: typing.Dict[type, typing.Any] = {
CannedObject: lambda obj, g: obj.get_object(g),
dict: uncan_dict,
}
# for use in _import_mapping:
_original_can_map = can_map.copy()
_original_uncan_map = uncan_map.copy()