You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
430 lines
12 KiB
Python
430 lines
12 KiB
Python
6 years ago
|
"""
|
||
|
python generate_sparsetools.py
|
||
|
|
||
|
Generate manual wrappers for C++ sparsetools code.
|
||
|
|
||
|
Type codes used:
|
||
|
|
||
|
'i': integer scalar
|
||
|
'I': integer array
|
||
|
'T': data array
|
||
|
'B': boolean array
|
||
|
'V': std::vector<integer>*
|
||
|
'W': std::vector<data>*
|
||
|
'*': indicates that the next argument is an output argument
|
||
|
'v': void
|
||
|
'l': 64-bit integer scalar
|
||
|
|
||
|
See sparsetools.cxx for more details.
|
||
|
|
||
|
"""
|
||
|
import optparse
|
||
|
import os
|
||
|
from distutils.dep_util import newer
|
||
|
|
||
|
#
|
||
|
# List of all routines and their argument types.
|
||
|
#
|
||
|
# The first code indicates the return value, the rest the arguments.
|
||
|
#
|
||
|
|
||
|
# bsr.h
|
||
|
BSR_ROUTINES = """
|
||
|
bsr_diagonal v iiiiiIIT*T
|
||
|
bsr_tocsr v iiiiIIT*I*I*T
|
||
|
bsr_scale_rows v iiiiII*TT
|
||
|
bsr_scale_columns v iiiiII*TT
|
||
|
bsr_sort_indices v iiii*I*I*T
|
||
|
bsr_transpose v iiiiIIT*I*I*T
|
||
|
bsr_matmat_pass2 v iiiiiIITIIT*I*I*T
|
||
|
bsr_matvec v iiiiIITT*T
|
||
|
bsr_matvecs v iiiiiIITT*T
|
||
|
bsr_elmul_bsr v iiiiIITIIT*I*I*T
|
||
|
bsr_eldiv_bsr v iiiiIITIIT*I*I*T
|
||
|
bsr_plus_bsr v iiiiIITIIT*I*I*T
|
||
|
bsr_minus_bsr v iiiiIITIIT*I*I*T
|
||
|
bsr_maximum_bsr v iiiiIITIIT*I*I*T
|
||
|
bsr_minimum_bsr v iiiiIITIIT*I*I*T
|
||
|
bsr_ne_bsr v iiiiIITIIT*I*I*B
|
||
|
bsr_lt_bsr v iiiiIITIIT*I*I*B
|
||
|
bsr_gt_bsr v iiiiIITIIT*I*I*B
|
||
|
bsr_le_bsr v iiiiIITIIT*I*I*B
|
||
|
bsr_ge_bsr v iiiiIITIIT*I*I*B
|
||
|
"""
|
||
|
|
||
|
# csc.h
|
||
|
CSC_ROUTINES = """
|
||
|
csc_diagonal v iiiIIT*T
|
||
|
csc_tocsr v iiIIT*I*I*T
|
||
|
csc_matmat_pass1 v iiIIII*I
|
||
|
csc_matmat_pass2 v iiIITIIT*I*I*T
|
||
|
csc_matvec v iiIITT*T
|
||
|
csc_matvecs v iiiIITT*T
|
||
|
csc_elmul_csc v iiIITIIT*I*I*T
|
||
|
csc_eldiv_csc v iiIITIIT*I*I*T
|
||
|
csc_plus_csc v iiIITIIT*I*I*T
|
||
|
csc_minus_csc v iiIITIIT*I*I*T
|
||
|
csc_maximum_csc v iiIITIIT*I*I*T
|
||
|
csc_minimum_csc v iiIITIIT*I*I*T
|
||
|
csc_ne_csc v iiIITIIT*I*I*B
|
||
|
csc_lt_csc v iiIITIIT*I*I*B
|
||
|
csc_gt_csc v iiIITIIT*I*I*B
|
||
|
csc_le_csc v iiIITIIT*I*I*B
|
||
|
csc_ge_csc v iiIITIIT*I*I*B
|
||
|
"""
|
||
|
|
||
|
# csr.h
|
||
|
CSR_ROUTINES = """
|
||
|
csr_matmat_pass1 v iiIIII*I
|
||
|
csr_matmat_pass2 v iiIITIIT*I*I*T
|
||
|
csr_diagonal v iiiIIT*T
|
||
|
csr_tocsc v iiIIT*I*I*T
|
||
|
csr_tobsr v iiiiIIT*I*I*T
|
||
|
csr_todense v iiIIT*T
|
||
|
csr_matvec v iiIITT*T
|
||
|
csr_matvecs v iiiIITT*T
|
||
|
csr_elmul_csr v iiIITIIT*I*I*T
|
||
|
csr_eldiv_csr v iiIITIIT*I*I*T
|
||
|
csr_plus_csr v iiIITIIT*I*I*T
|
||
|
csr_minus_csr v iiIITIIT*I*I*T
|
||
|
csr_maximum_csr v iiIITIIT*I*I*T
|
||
|
csr_minimum_csr v iiIITIIT*I*I*T
|
||
|
csr_ne_csr v iiIITIIT*I*I*B
|
||
|
csr_lt_csr v iiIITIIT*I*I*B
|
||
|
csr_gt_csr v iiIITIIT*I*I*B
|
||
|
csr_le_csr v iiIITIIT*I*I*B
|
||
|
csr_ge_csr v iiIITIIT*I*I*B
|
||
|
csr_scale_rows v iiII*TT
|
||
|
csr_scale_columns v iiII*TT
|
||
|
csr_sort_indices v iI*I*T
|
||
|
csr_eliminate_zeros v ii*I*I*T
|
||
|
csr_sum_duplicates v ii*I*I*T
|
||
|
get_csr_submatrix v iiIITiiii*V*V*W
|
||
|
csr_sample_values v iiIITiII*T
|
||
|
csr_count_blocks i iiiiII
|
||
|
csr_sample_offsets i iiIIiII*I
|
||
|
expandptr v iI*I
|
||
|
test_throw_error i
|
||
|
csr_has_sorted_indices i iII
|
||
|
csr_has_canonical_format i iII
|
||
|
"""
|
||
|
|
||
|
# coo.h, dia.h, csgraph.h
|
||
|
OTHER_ROUTINES = """
|
||
|
coo_tocsr v iiiIIT*I*I*T
|
||
|
coo_todense v iilIIT*Ti
|
||
|
coo_matvec v lIITT*T
|
||
|
dia_matvec v iiiiITT*T
|
||
|
cs_graph_components i iII*I
|
||
|
"""
|
||
|
|
||
|
# List of compilation units
|
||
|
COMPILATION_UNITS = [
|
||
|
('bsr', BSR_ROUTINES),
|
||
|
('csr', CSR_ROUTINES),
|
||
|
('csc', CSC_ROUTINES),
|
||
|
('other', OTHER_ROUTINES),
|
||
|
]
|
||
|
|
||
|
#
|
||
|
# List of the supported index typenums and the corresponding C++ types
|
||
|
#
|
||
|
I_TYPES = [
|
||
|
('NPY_INT32', 'npy_int32'),
|
||
|
('NPY_INT64', 'npy_int64'),
|
||
|
]
|
||
|
|
||
|
#
|
||
|
# List of the supported data typenums and the corresponding C++ types
|
||
|
#
|
||
|
T_TYPES = [
|
||
|
('NPY_BOOL', 'npy_bool_wrapper'),
|
||
|
('NPY_BYTE', 'npy_byte'),
|
||
|
('NPY_UBYTE', 'npy_ubyte'),
|
||
|
('NPY_SHORT', 'npy_short'),
|
||
|
('NPY_USHORT', 'npy_ushort'),
|
||
|
('NPY_INT', 'npy_int'),
|
||
|
('NPY_UINT', 'npy_uint'),
|
||
|
('NPY_LONG', 'npy_long'),
|
||
|
('NPY_ULONG', 'npy_ulong'),
|
||
|
('NPY_LONGLONG', 'npy_longlong'),
|
||
|
('NPY_ULONGLONG', 'npy_ulonglong'),
|
||
|
('NPY_FLOAT', 'npy_float'),
|
||
|
('NPY_DOUBLE', 'npy_double'),
|
||
|
('NPY_LONGDOUBLE', 'npy_longdouble'),
|
||
|
('NPY_CFLOAT', 'npy_cfloat_wrapper'),
|
||
|
('NPY_CDOUBLE', 'npy_cdouble_wrapper'),
|
||
|
('NPY_CLONGDOUBLE', 'npy_clongdouble_wrapper'),
|
||
|
]
|
||
|
|
||
|
#
|
||
|
# Code templates
|
||
|
#
|
||
|
|
||
|
THUNK_TEMPLATE = """
|
||
|
static PY_LONG_LONG %(name)s_thunk(int I_typenum, int T_typenum, void **a)
|
||
|
{
|
||
|
%(thunk_content)s
|
||
|
}
|
||
|
"""
|
||
|
|
||
|
METHOD_TEMPLATE = """
|
||
|
NPY_VISIBILITY_HIDDEN PyObject *
|
||
|
%(name)s_method(PyObject *self, PyObject *args)
|
||
|
{
|
||
|
return call_thunk('%(ret_spec)s', "%(arg_spec)s", %(name)s_thunk, args);
|
||
|
}
|
||
|
"""
|
||
|
|
||
|
GET_THUNK_CASE_TEMPLATE = """
|
||
|
static int get_thunk_case(int I_typenum, int T_typenum)
|
||
|
{
|
||
|
%(content)s;
|
||
|
return -1;
|
||
|
}
|
||
|
"""
|
||
|
|
||
|
|
||
|
#
|
||
|
# Code generation
|
||
|
#
|
||
|
|
||
|
def get_thunk_type_set():
|
||
|
"""
|
||
|
Get a list containing cartesian product of data types, plus a getter routine.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
i_types : list [(j, I_typenum, None, I_type, None), ...]
|
||
|
Pairing of index type numbers and the corresponding C++ types,
|
||
|
and an unique index `j`. This is for routines that are parameterized
|
||
|
only by I but not by T.
|
||
|
it_types : list [(j, I_typenum, T_typenum, I_type, T_type), ...]
|
||
|
Same as `i_types`, but for routines parameterized both by T and I.
|
||
|
getter_code : str
|
||
|
C++ code for a function that takes I_typenum, T_typenum and returns
|
||
|
the unique index corresponding to the lists, or -1 if no match was
|
||
|
found.
|
||
|
|
||
|
"""
|
||
|
it_types = []
|
||
|
i_types = []
|
||
|
|
||
|
j = 0
|
||
|
|
||
|
getter_code = " if (0) {}"
|
||
|
|
||
|
for I_typenum, I_type in I_TYPES:
|
||
|
piece = """
|
||
|
else if (I_typenum == %(I_typenum)s) {
|
||
|
if (T_typenum == -1) { return %(j)s; }"""
|
||
|
getter_code += piece % dict(I_typenum=I_typenum, j=j)
|
||
|
|
||
|
i_types.append((j, I_typenum, None, I_type, None))
|
||
|
j += 1
|
||
|
|
||
|
for T_typenum, T_type in T_TYPES:
|
||
|
piece = """
|
||
|
else if (T_typenum == %(T_typenum)s) { return %(j)s; }"""
|
||
|
getter_code += piece % dict(T_typenum=T_typenum, j=j)
|
||
|
|
||
|
it_types.append((j, I_typenum, T_typenum, I_type, T_type))
|
||
|
j += 1
|
||
|
|
||
|
getter_code += """
|
||
|
}"""
|
||
|
|
||
|
return i_types, it_types, GET_THUNK_CASE_TEMPLATE % dict(content=getter_code)
|
||
|
|
||
|
|
||
|
def parse_routine(name, args, types):
|
||
|
"""
|
||
|
Generate thunk and method code for a given routine.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
name : str
|
||
|
Name of the C++ routine
|
||
|
args : str
|
||
|
Argument list specification (in format explained above)
|
||
|
types : list
|
||
|
List of types to instantiate, as returned `get_thunk_type_set`
|
||
|
|
||
|
"""
|
||
|
|
||
|
ret_spec = args[0]
|
||
|
arg_spec = args[1:]
|
||
|
|
||
|
def get_arglist(I_type, T_type):
|
||
|
"""
|
||
|
Generate argument list for calling the C++ function
|
||
|
"""
|
||
|
args = []
|
||
|
next_is_writeable = False
|
||
|
j = 0
|
||
|
for t in arg_spec:
|
||
|
const = '' if next_is_writeable else 'const '
|
||
|
next_is_writeable = False
|
||
|
if t == '*':
|
||
|
next_is_writeable = True
|
||
|
continue
|
||
|
elif t == 'i':
|
||
|
args.append("*(%s*)a[%d]" % (const + I_type, j))
|
||
|
elif t == 'I':
|
||
|
args.append("(%s*)a[%d]" % (const + I_type, j))
|
||
|
elif t == 'T':
|
||
|
args.append("(%s*)a[%d]" % (const + T_type, j))
|
||
|
elif t == 'B':
|
||
|
args.append("(npy_bool_wrapper*)a[%d]" % (j,))
|
||
|
elif t == 'V':
|
||
|
if const:
|
||
|
raise ValueError("'V' argument must be an output arg")
|
||
|
args.append("(std::vector<%s>*)a[%d]" % (I_type, j,))
|
||
|
elif t == 'W':
|
||
|
if const:
|
||
|
raise ValueError("'W' argument must be an output arg")
|
||
|
args.append("(std::vector<%s>*)a[%d]" % (T_type, j,))
|
||
|
elif t == 'l':
|
||
|
args.append("*(%snpy_int64*)a[%d]" % (const, j))
|
||
|
else:
|
||
|
raise ValueError("Invalid spec character %r" % (t,))
|
||
|
j += 1
|
||
|
return ", ".join(args)
|
||
|
|
||
|
# Generate thunk code: a giant switch statement with different
|
||
|
# type combinations inside.
|
||
|
thunk_content = """int j = get_thunk_case(I_typenum, T_typenum);
|
||
|
switch (j) {"""
|
||
|
for j, I_typenum, T_typenum, I_type, T_type in types:
|
||
|
arglist = get_arglist(I_type, T_type)
|
||
|
if T_type is None:
|
||
|
dispatch = "%s" % (I_type,)
|
||
|
else:
|
||
|
dispatch = "%s,%s" % (I_type, T_type)
|
||
|
if 'B' in arg_spec:
|
||
|
dispatch += ",npy_bool_wrapper"
|
||
|
|
||
|
piece = """
|
||
|
case %(j)s:"""
|
||
|
if ret_spec == 'v':
|
||
|
piece += """
|
||
|
(void)%(name)s<%(dispatch)s>(%(arglist)s);
|
||
|
return 0;"""
|
||
|
else:
|
||
|
piece += """
|
||
|
return %(name)s<%(dispatch)s>(%(arglist)s);"""
|
||
|
thunk_content += piece % dict(j=j, I_type=I_type, T_type=T_type,
|
||
|
I_typenum=I_typenum, T_typenum=T_typenum,
|
||
|
arglist=arglist, name=name,
|
||
|
dispatch=dispatch)
|
||
|
|
||
|
thunk_content += """
|
||
|
default:
|
||
|
throw std::runtime_error("internal error: invalid argument typenums");
|
||
|
}"""
|
||
|
|
||
|
thunk_code = THUNK_TEMPLATE % dict(name=name,
|
||
|
thunk_content=thunk_content)
|
||
|
|
||
|
# Generate method code
|
||
|
method_code = METHOD_TEMPLATE % dict(name=name,
|
||
|
ret_spec=ret_spec,
|
||
|
arg_spec=arg_spec)
|
||
|
|
||
|
return thunk_code, method_code
|
||
|
|
||
|
|
||
|
def main():
|
||
|
p = optparse.OptionParser(usage=(__doc__ or '').strip())
|
||
|
p.add_option("--no-force", action="store_false",
|
||
|
dest="force", default=True)
|
||
|
options, args = p.parse_args()
|
||
|
|
||
|
names = []
|
||
|
|
||
|
i_types, it_types, getter_code = get_thunk_type_set()
|
||
|
|
||
|
# Generate *_impl.h for each compilation unit
|
||
|
for unit_name, routines in COMPILATION_UNITS:
|
||
|
thunks = []
|
||
|
methods = []
|
||
|
|
||
|
# Generate thunks and methods for all routines
|
||
|
for line in routines.splitlines():
|
||
|
line = line.strip()
|
||
|
if not line or line.startswith('#'):
|
||
|
continue
|
||
|
|
||
|
try:
|
||
|
name, args = line.split(None, 1)
|
||
|
except ValueError:
|
||
|
raise ValueError("Malformed line: %r" % (line,))
|
||
|
|
||
|
args = "".join(args.split())
|
||
|
if 't' in args or 'T' in args:
|
||
|
thunk, method = parse_routine(name, args, it_types)
|
||
|
else:
|
||
|
thunk, method = parse_routine(name, args, i_types)
|
||
|
|
||
|
if name in names:
|
||
|
raise ValueError("Duplicate routine %r" % (name,))
|
||
|
|
||
|
names.append(name)
|
||
|
thunks.append(thunk)
|
||
|
methods.append(method)
|
||
|
|
||
|
# Produce output
|
||
|
dst = os.path.join(os.path.dirname(__file__),
|
||
|
'sparsetools',
|
||
|
unit_name + '_impl.h')
|
||
|
if newer(__file__, dst) or options.force:
|
||
|
print("[generate_sparsetools] generating %r" % (dst,))
|
||
|
with open(dst, 'w') as f:
|
||
|
write_autogen_blurb(f)
|
||
|
f.write(getter_code)
|
||
|
for thunk in thunks:
|
||
|
f.write(thunk)
|
||
|
for method in methods:
|
||
|
f.write(method)
|
||
|
else:
|
||
|
print("[generate_sparsetools] %r already up-to-date" % (dst,))
|
||
|
|
||
|
# Generate code for method struct
|
||
|
method_defs = ""
|
||
|
for name in names:
|
||
|
method_defs += "NPY_VISIBILITY_HIDDEN PyObject *%s_method(PyObject *, PyObject *);\n" % (name,)
|
||
|
|
||
|
method_struct = """\nstatic struct PyMethodDef sparsetools_methods[] = {"""
|
||
|
for name in names:
|
||
|
method_struct += """
|
||
|
{"%(name)s", (PyCFunction)%(name)s_method, METH_VARARGS, NULL},""" % dict(name=name)
|
||
|
method_struct += """
|
||
|
{NULL, NULL, 0, NULL}
|
||
|
};"""
|
||
|
|
||
|
# Produce sparsetools_impl.h
|
||
|
dst = os.path.join(os.path.dirname(__file__),
|
||
|
'sparsetools',
|
||
|
'sparsetools_impl.h')
|
||
|
|
||
|
if newer(__file__, dst) or options.force:
|
||
|
print("[generate_sparsetools] generating %r" % (dst,))
|
||
|
with open(dst, 'w') as f:
|
||
|
write_autogen_blurb(f)
|
||
|
f.write(method_defs)
|
||
|
f.write(method_struct)
|
||
|
else:
|
||
|
print("[generate_sparsetools] %r already up-to-date" % (dst,))
|
||
|
|
||
|
|
||
|
def write_autogen_blurb(stream):
|
||
|
stream.write("""\
|
||
|
/* This file is autogenerated by generate_sparsetools.py
|
||
|
* Do not edit manually or check into VCS.
|
||
|
*/
|
||
|
""")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|