You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
627 lines
24 KiB
Python
627 lines
24 KiB
Python
5 years ago
|
# -*- coding: utf-8 -*-
|
||
|
# Natural Language Toolkit: ASCII visualization of NLTK trees
|
||
|
#
|
||
|
# Copyright (C) 2001-2019 NLTK Project
|
||
|
# Author: Andreas van Cranenburgh <A.W.vanCranenburgh@uva.nl>
|
||
|
# Peter Ljunglöf <peter.ljunglof@gu.se>
|
||
|
# URL: <http://nltk.org/>
|
||
|
# For license information, see LICENSE.TXT
|
||
|
|
||
|
"""
|
||
|
Pretty-printing of discontinuous trees.
|
||
|
Adapted from the disco-dop project, by Andreas van Cranenburgh.
|
||
|
https://github.com/andreasvc/disco-dop
|
||
|
|
||
|
Interesting reference (not used for this code):
|
||
|
T. Eschbach et al., Orth. Hypergraph Drawing, Journal of
|
||
|
Graph Algorithms and Applications, 10(2) 141--157 (2006)149.
|
||
|
http://jgaa.info/accepted/2006/EschbachGuentherBecker2006.10.2.pdf
|
||
|
"""
|
||
|
|
||
|
from __future__ import division, print_function, unicode_literals
|
||
|
|
||
|
import re
|
||
|
from cgi import escape
|
||
|
from collections import defaultdict
|
||
|
from operator import itemgetter
|
||
|
|
||
|
from nltk.util import OrderedDict
|
||
|
from nltk.compat import python_2_unicode_compatible
|
||
|
from nltk.tree import Tree
|
||
|
|
||
|
ANSICOLOR = {
|
||
|
'black': 30,
|
||
|
'red': 31,
|
||
|
'green': 32,
|
||
|
'yellow': 33,
|
||
|
'blue': 34,
|
||
|
'magenta': 35,
|
||
|
'cyan': 36,
|
||
|
'white': 37,
|
||
|
}
|
||
|
|
||
|
|
||
|
@python_2_unicode_compatible
|
||
|
class TreePrettyPrinter(object):
|
||
|
"""
|
||
|
Pretty-print a tree in text format, either as ASCII or Unicode.
|
||
|
The tree can be a normal tree, or discontinuous.
|
||
|
|
||
|
``TreePrettyPrinter(tree, sentence=None, highlight=())``
|
||
|
creates an object from which different visualizations can be created.
|
||
|
|
||
|
:param tree: a Tree object.
|
||
|
:param sentence: a list of words (strings). If `sentence` is given,
|
||
|
`tree` must contain integers as leaves, which are taken as indices
|
||
|
in `sentence`. Using this you can display a discontinuous tree.
|
||
|
:param highlight: Optionally, a sequence of Tree objects in `tree` which
|
||
|
should be highlighted. Has the effect of only applying colors to nodes
|
||
|
in this sequence (nodes should be given as Tree objects, terminals as
|
||
|
indices).
|
||
|
|
||
|
>>> from nltk.tree import Tree
|
||
|
>>> tree = Tree.fromstring('(S (NP Mary) (VP walks))')
|
||
|
>>> print(TreePrettyPrinter(tree).text())
|
||
|
... # doctest: +NORMALIZE_WHITESPACE
|
||
|
S
|
||
|
____|____
|
||
|
NP VP
|
||
|
| |
|
||
|
Mary walks
|
||
|
"""
|
||
|
|
||
|
def __init__(self, tree, sentence=None, highlight=()):
|
||
|
if sentence is None:
|
||
|
leaves = tree.leaves()
|
||
|
if (
|
||
|
leaves
|
||
|
and not any(len(a) == 0 for a in tree.subtrees())
|
||
|
and all(isinstance(a, int) for a in leaves)
|
||
|
):
|
||
|
sentence = [str(a) for a in leaves]
|
||
|
else:
|
||
|
# this deals with empty nodes (frontier non-terminals)
|
||
|
# and multiple/mixed terminals under non-terminals.
|
||
|
tree = tree.copy(True)
|
||
|
sentence = []
|
||
|
for a in tree.subtrees():
|
||
|
if len(a) == 0:
|
||
|
a.append(len(sentence))
|
||
|
sentence.append(None)
|
||
|
elif any(not isinstance(b, Tree) for b in a):
|
||
|
for n, b in enumerate(a):
|
||
|
if not isinstance(b, Tree):
|
||
|
a[n] = len(sentence)
|
||
|
if type(b) == tuple:
|
||
|
b = '/'.join(b)
|
||
|
sentence.append('%s' % b)
|
||
|
self.nodes, self.coords, self.edges, self.highlight = self.nodecoords(
|
||
|
tree, sentence, highlight
|
||
|
)
|
||
|
|
||
|
def __str__(self):
|
||
|
return self.text()
|
||
|
|
||
|
def __repr__(self):
|
||
|
return '<TreePrettyPrinter with %d nodes>' % len(self.nodes)
|
||
|
|
||
|
@staticmethod
|
||
|
def nodecoords(tree, sentence, highlight):
|
||
|
"""
|
||
|
Produce coordinates of nodes on a grid.
|
||
|
|
||
|
Objective:
|
||
|
|
||
|
- Produce coordinates for a non-overlapping placement of nodes and
|
||
|
horizontal lines.
|
||
|
- Order edges so that crossing edges cross a minimal number of previous
|
||
|
horizontal lines (never vertical lines).
|
||
|
|
||
|
Approach:
|
||
|
|
||
|
- bottom up level order traversal (start at terminals)
|
||
|
- at each level, identify nodes which cannot be on the same row
|
||
|
- identify nodes which cannot be in the same column
|
||
|
- place nodes into a grid at (row, column)
|
||
|
- order child-parent edges with crossing edges last
|
||
|
|
||
|
Coordinates are (row, column); the origin (0, 0) is at the top left;
|
||
|
the root node is on row 0. Coordinates do not consider the size of a
|
||
|
node (which depends on font, &c), so the width of a column of the grid
|
||
|
should be automatically determined by the element with the greatest
|
||
|
width in that column. Alternatively, the integer coordinates could be
|
||
|
converted to coordinates in which the distances between adjacent nodes
|
||
|
are non-uniform.
|
||
|
|
||
|
Produces tuple (nodes, coords, edges, highlighted) where:
|
||
|
|
||
|
- nodes[id]: Tree object for the node with this integer id
|
||
|
- coords[id]: (n, m) coordinate where to draw node with id in the grid
|
||
|
- edges[id]: parent id of node with this id (ordered dictionary)
|
||
|
- highlighted: set of ids that should be highlighted
|
||
|
"""
|
||
|
|
||
|
def findcell(m, matrix, startoflevel, children):
|
||
|
"""
|
||
|
Find vacant row, column index for node ``m``.
|
||
|
Iterate over current rows for this level (try lowest first)
|
||
|
and look for cell between first and last child of this node,
|
||
|
add new row to level if no free row available.
|
||
|
"""
|
||
|
candidates = [a for _, a in children[m]]
|
||
|
minidx, maxidx = min(candidates), max(candidates)
|
||
|
leaves = tree[m].leaves()
|
||
|
center = scale * sum(leaves) // len(leaves) # center of gravity
|
||
|
if minidx < maxidx and not minidx < center < maxidx:
|
||
|
center = sum(candidates) // len(candidates)
|
||
|
if max(candidates) - min(candidates) > 2 * scale:
|
||
|
center -= center % scale # round to unscaled coordinate
|
||
|
if minidx < maxidx and not minidx < center < maxidx:
|
||
|
center += scale
|
||
|
if ids[m] == 0:
|
||
|
startoflevel = len(matrix)
|
||
|
for rowidx in range(startoflevel, len(matrix) + 1):
|
||
|
if rowidx == len(matrix): # need to add a new row
|
||
|
matrix.append(
|
||
|
[
|
||
|
vertline if a not in (corner, None) else None
|
||
|
for a in matrix[-1]
|
||
|
]
|
||
|
)
|
||
|
row = matrix[rowidx]
|
||
|
i = j = center
|
||
|
if len(children[m]) == 1: # place unaries directly above child
|
||
|
return rowidx, next(iter(children[m]))[1]
|
||
|
elif all(
|
||
|
a is None or a == vertline
|
||
|
for a in row[min(candidates) : max(candidates) + 1]
|
||
|
):
|
||
|
# find free column
|
||
|
for n in range(scale):
|
||
|
i = j = center + n
|
||
|
while j > minidx or i < maxidx:
|
||
|
if i < maxidx and (
|
||
|
matrix[rowidx][i] is None or i in candidates
|
||
|
):
|
||
|
return rowidx, i
|
||
|
elif j > minidx and (
|
||
|
matrix[rowidx][j] is None or j in candidates
|
||
|
):
|
||
|
return rowidx, j
|
||
|
i += scale
|
||
|
j -= scale
|
||
|
raise ValueError(
|
||
|
'could not find a free cell for:\n%s\n%s'
|
||
|
'min=%d; max=%d' % (tree[m], minidx, maxidx, dumpmatrix())
|
||
|
)
|
||
|
|
||
|
def dumpmatrix():
|
||
|
"""Dump matrix contents for debugging purposes."""
|
||
|
return '\n'.join(
|
||
|
'%2d: %s' % (n, ' '.join(('%2r' % i)[:2] for i in row))
|
||
|
for n, row in enumerate(matrix)
|
||
|
)
|
||
|
|
||
|
leaves = tree.leaves()
|
||
|
if not all(isinstance(n, int) for n in leaves):
|
||
|
raise ValueError('All leaves must be integer indices.')
|
||
|
if len(leaves) != len(set(leaves)):
|
||
|
raise ValueError('Indices must occur at most once.')
|
||
|
if not all(0 <= n < len(sentence) for n in leaves):
|
||
|
raise ValueError(
|
||
|
'All leaves must be in the interval 0..n '
|
||
|
'with n=len(sentence)\ntokens: %d indices: '
|
||
|
'%r\nsentence: %s' % (len(sentence), tree.leaves(), sentence)
|
||
|
)
|
||
|
vertline, corner = -1, -2 # constants
|
||
|
tree = tree.copy(True)
|
||
|
for a in tree.subtrees():
|
||
|
a.sort(key=lambda n: min(n.leaves()) if isinstance(n, Tree) else n)
|
||
|
scale = 2
|
||
|
crossed = set()
|
||
|
# internal nodes and lexical nodes (no frontiers)
|
||
|
positions = tree.treepositions()
|
||
|
maxdepth = max(map(len, positions)) + 1
|
||
|
childcols = defaultdict(set)
|
||
|
matrix = [[None] * (len(sentence) * scale)]
|
||
|
nodes = {}
|
||
|
ids = dict((a, n) for n, a in enumerate(positions))
|
||
|
highlighted_nodes = set(
|
||
|
n for a, n in ids.items() if not highlight or tree[a] in highlight
|
||
|
)
|
||
|
levels = dict((n, []) for n in range(maxdepth - 1))
|
||
|
terminals = []
|
||
|
for a in positions:
|
||
|
node = tree[a]
|
||
|
if isinstance(node, Tree):
|
||
|
levels[maxdepth - node.height()].append(a)
|
||
|
else:
|
||
|
terminals.append(a)
|
||
|
|
||
|
for n in levels:
|
||
|
levels[n].sort(key=lambda n: max(tree[n].leaves()) - min(tree[n].leaves()))
|
||
|
terminals.sort()
|
||
|
positions = set(positions)
|
||
|
|
||
|
for m in terminals:
|
||
|
i = int(tree[m]) * scale
|
||
|
assert matrix[0][i] is None, (matrix[0][i], m, i)
|
||
|
matrix[0][i] = ids[m]
|
||
|
nodes[ids[m]] = sentence[tree[m]]
|
||
|
if nodes[ids[m]] is None:
|
||
|
nodes[ids[m]] = '...'
|
||
|
highlighted_nodes.discard(ids[m])
|
||
|
positions.remove(m)
|
||
|
childcols[m[:-1]].add((0, i))
|
||
|
|
||
|
# add other nodes centered on their children,
|
||
|
# if the center is already taken, back off
|
||
|
# to the left and right alternately, until an empty cell is found.
|
||
|
for n in sorted(levels, reverse=True):
|
||
|
nodesatdepth = levels[n]
|
||
|
startoflevel = len(matrix)
|
||
|
matrix.append(
|
||
|
[vertline if a not in (corner, None) else None for a in matrix[-1]]
|
||
|
)
|
||
|
for m in nodesatdepth: # [::-1]:
|
||
|
if n < maxdepth - 1 and childcols[m]:
|
||
|
_, pivot = min(childcols[m], key=itemgetter(1))
|
||
|
if set(
|
||
|
a[:-1]
|
||
|
for row in matrix[:-1]
|
||
|
for a in row[:pivot]
|
||
|
if isinstance(a, tuple)
|
||
|
) & set(
|
||
|
a[:-1]
|
||
|
for row in matrix[:-1]
|
||
|
for a in row[pivot:]
|
||
|
if isinstance(a, tuple)
|
||
|
):
|
||
|
crossed.add(m)
|
||
|
|
||
|
rowidx, i = findcell(m, matrix, startoflevel, childcols)
|
||
|
positions.remove(m)
|
||
|
|
||
|
# block positions where children of this node branch out
|
||
|
for _, x in childcols[m]:
|
||
|
matrix[rowidx][x] = corner
|
||
|
# assert m == () or matrix[rowidx][i] in (None, corner), (
|
||
|
# matrix[rowidx][i], m, str(tree), ' '.join(sentence))
|
||
|
# node itself
|
||
|
matrix[rowidx][i] = ids[m]
|
||
|
nodes[ids[m]] = tree[m]
|
||
|
# add column to the set of children for its parent
|
||
|
if m != ():
|
||
|
childcols[m[:-1]].add((rowidx, i))
|
||
|
assert len(positions) == 0
|
||
|
|
||
|
# remove unused columns, right to left
|
||
|
for m in range(scale * len(sentence) - 1, -1, -1):
|
||
|
if not any(isinstance(row[m], (Tree, int)) for row in matrix):
|
||
|
for row in matrix:
|
||
|
del row[m]
|
||
|
|
||
|
# remove unused rows, reverse
|
||
|
matrix = [
|
||
|
row
|
||
|
for row in reversed(matrix)
|
||
|
if not all(a is None or a == vertline for a in row)
|
||
|
]
|
||
|
|
||
|
# collect coordinates of nodes
|
||
|
coords = {}
|
||
|
for n, _ in enumerate(matrix):
|
||
|
for m, i in enumerate(matrix[n]):
|
||
|
if isinstance(i, int) and i >= 0:
|
||
|
coords[i] = n, m
|
||
|
|
||
|
# move crossed edges last
|
||
|
positions = sorted(
|
||
|
[a for level in levels.values() for a in level],
|
||
|
key=lambda a: a[:-1] in crossed,
|
||
|
)
|
||
|
|
||
|
# collect edges from node to node
|
||
|
edges = OrderedDict()
|
||
|
for i in reversed(positions):
|
||
|
for j, _ in enumerate(tree[i]):
|
||
|
edges[ids[i + (j,)]] = ids[i]
|
||
|
|
||
|
return nodes, coords, edges, highlighted_nodes
|
||
|
|
||
|
def text(
|
||
|
self,
|
||
|
nodedist=1,
|
||
|
unicodelines=False,
|
||
|
html=False,
|
||
|
ansi=False,
|
||
|
nodecolor='blue',
|
||
|
leafcolor='red',
|
||
|
funccolor='green',
|
||
|
abbreviate=None,
|
||
|
maxwidth=16,
|
||
|
):
|
||
|
"""
|
||
|
:return: ASCII art for a discontinuous tree.
|
||
|
|
||
|
:param unicodelines: whether to use Unicode line drawing characters
|
||
|
instead of plain (7-bit) ASCII.
|
||
|
:param html: whether to wrap output in html code (default plain text).
|
||
|
:param ansi: whether to produce colors with ANSI escape sequences
|
||
|
(only effective when html==False).
|
||
|
:param leafcolor, nodecolor: specify colors of leaves and phrasal
|
||
|
nodes; effective when either html or ansi is True.
|
||
|
:param abbreviate: if True, abbreviate labels longer than 5 characters.
|
||
|
If integer, abbreviate labels longer than `abbr` characters.
|
||
|
:param maxwidth: maximum number of characters before a label starts to
|
||
|
wrap; pass None to disable.
|
||
|
"""
|
||
|
if abbreviate == True:
|
||
|
abbreviate = 5
|
||
|
if unicodelines:
|
||
|
horzline = '\u2500'
|
||
|
leftcorner = '\u250c'
|
||
|
rightcorner = '\u2510'
|
||
|
vertline = ' \u2502 '
|
||
|
tee = horzline + '\u252C' + horzline
|
||
|
bottom = horzline + '\u2534' + horzline
|
||
|
cross = horzline + '\u253c' + horzline
|
||
|
ellipsis = '\u2026'
|
||
|
else:
|
||
|
horzline = '_'
|
||
|
leftcorner = rightcorner = ' '
|
||
|
vertline = ' | '
|
||
|
tee = 3 * horzline
|
||
|
cross = bottom = '_|_'
|
||
|
ellipsis = '.'
|
||
|
|
||
|
def crosscell(cur, x=vertline):
|
||
|
"""Overwrite center of this cell with a vertical branch."""
|
||
|
splitl = len(cur) - len(cur) // 2 - len(x) // 2 - 1
|
||
|
lst = list(cur)
|
||
|
lst[splitl : splitl + len(x)] = list(x)
|
||
|
return ''.join(lst)
|
||
|
|
||
|
result = []
|
||
|
matrix = defaultdict(dict)
|
||
|
maxnodewith = defaultdict(lambda: 3)
|
||
|
maxnodeheight = defaultdict(lambda: 1)
|
||
|
maxcol = 0
|
||
|
minchildcol = {}
|
||
|
maxchildcol = {}
|
||
|
childcols = defaultdict(set)
|
||
|
labels = {}
|
||
|
wrapre = re.compile(
|
||
|
'(.{%d,%d}\\b\\W*|.{%d})' % (maxwidth - 4, maxwidth, maxwidth)
|
||
|
)
|
||
|
# collect labels and coordinates
|
||
|
for a in self.nodes:
|
||
|
row, column = self.coords[a]
|
||
|
matrix[row][column] = a
|
||
|
maxcol = max(maxcol, column)
|
||
|
label = (
|
||
|
self.nodes[a].label()
|
||
|
if isinstance(self.nodes[a], Tree)
|
||
|
else self.nodes[a]
|
||
|
)
|
||
|
if abbreviate and len(label) > abbreviate:
|
||
|
label = label[:abbreviate] + ellipsis
|
||
|
if maxwidth and len(label) > maxwidth:
|
||
|
label = wrapre.sub(r'\1\n', label).strip()
|
||
|
label = label.split('\n')
|
||
|
maxnodeheight[row] = max(maxnodeheight[row], len(label))
|
||
|
maxnodewith[column] = max(maxnodewith[column], max(map(len, label)))
|
||
|
labels[a] = label
|
||
|
if a not in self.edges:
|
||
|
continue # e.g., root
|
||
|
parent = self.edges[a]
|
||
|
childcols[parent].add((row, column))
|
||
|
minchildcol[parent] = min(minchildcol.get(parent, column), column)
|
||
|
maxchildcol[parent] = max(maxchildcol.get(parent, column), column)
|
||
|
# bottom up level order traversal
|
||
|
for row in sorted(matrix, reverse=True):
|
||
|
noderows = [
|
||
|
[''.center(maxnodewith[col]) for col in range(maxcol + 1)]
|
||
|
for _ in range(maxnodeheight[row])
|
||
|
]
|
||
|
branchrow = [''.center(maxnodewith[col]) for col in range(maxcol + 1)]
|
||
|
for col in matrix[row]:
|
||
|
n = matrix[row][col]
|
||
|
node = self.nodes[n]
|
||
|
text = labels[n]
|
||
|
if isinstance(node, Tree):
|
||
|
# draw horizontal branch towards children for this node
|
||
|
if n in minchildcol and minchildcol[n] < maxchildcol[n]:
|
||
|
i, j = minchildcol[n], maxchildcol[n]
|
||
|
a, b = (maxnodewith[i] + 1) // 2 - 1, maxnodewith[j] // 2
|
||
|
branchrow[i] = ((' ' * a) + leftcorner).ljust(
|
||
|
maxnodewith[i], horzline
|
||
|
)
|
||
|
branchrow[j] = (rightcorner + (' ' * b)).rjust(
|
||
|
maxnodewith[j], horzline
|
||
|
)
|
||
|
for i in range(minchildcol[n] + 1, maxchildcol[n]):
|
||
|
if i == col and any(a == i for _, a in childcols[n]):
|
||
|
line = cross
|
||
|
elif i == col:
|
||
|
line = bottom
|
||
|
elif any(a == i for _, a in childcols[n]):
|
||
|
line = tee
|
||
|
else:
|
||
|
line = horzline
|
||
|
branchrow[i] = line.center(maxnodewith[i], horzline)
|
||
|
else: # if n and n in minchildcol:
|
||
|
branchrow[col] = crosscell(branchrow[col])
|
||
|
text = [a.center(maxnodewith[col]) for a in text]
|
||
|
color = nodecolor if isinstance(node, Tree) else leafcolor
|
||
|
if isinstance(node, Tree) and node.label().startswith('-'):
|
||
|
color = funccolor
|
||
|
if html:
|
||
|
text = [escape(a) for a in text]
|
||
|
if n in self.highlight:
|
||
|
text = ['<font color=%s>%s</font>' % (color, a) for a in text]
|
||
|
elif ansi and n in self.highlight:
|
||
|
text = ['\x1b[%d;1m%s\x1b[0m' % (ANSICOLOR[color], a) for a in text]
|
||
|
for x in range(maxnodeheight[row]):
|
||
|
# draw vertical lines in partially filled multiline node
|
||
|
# labels, but only if it's not a frontier node.
|
||
|
noderows[x][col] = (
|
||
|
text[x]
|
||
|
if x < len(text)
|
||
|
else (vertline if childcols[n] else ' ').center(
|
||
|
maxnodewith[col], ' '
|
||
|
)
|
||
|
)
|
||
|
# for each column, if there is a node below us which has a parent
|
||
|
# above us, draw a vertical branch in that column.
|
||
|
if row != max(matrix):
|
||
|
for n, (childrow, col) in self.coords.items():
|
||
|
if n > 0 and self.coords[self.edges[n]][0] < row < childrow:
|
||
|
branchrow[col] = crosscell(branchrow[col])
|
||
|
if col not in matrix[row]:
|
||
|
for noderow in noderows:
|
||
|
noderow[col] = crosscell(noderow[col])
|
||
|
branchrow = [
|
||
|
a + ((a[-1] if a[-1] != ' ' else b[0]) * nodedist)
|
||
|
for a, b in zip(branchrow, branchrow[1:] + [' '])
|
||
|
]
|
||
|
result.append(''.join(branchrow))
|
||
|
result.extend(
|
||
|
(' ' * nodedist).join(noderow) for noderow in reversed(noderows)
|
||
|
)
|
||
|
return '\n'.join(reversed(result)) + '\n'
|
||
|
|
||
|
def svg(self, nodecolor='blue', leafcolor='red', funccolor='green'):
|
||
|
"""
|
||
|
:return: SVG representation of a tree.
|
||
|
"""
|
||
|
fontsize = 12
|
||
|
hscale = 40
|
||
|
vscale = 25
|
||
|
hstart = vstart = 20
|
||
|
width = max(col for _, col in self.coords.values())
|
||
|
height = max(row for row, _ in self.coords.values())
|
||
|
result = [
|
||
|
'<svg version="1.1" xmlns="http://www.w3.org/2000/svg" '
|
||
|
'width="%dem" height="%dem" viewBox="%d %d %d %d">'
|
||
|
% (
|
||
|
width * 3,
|
||
|
height * 2.5,
|
||
|
-hstart,
|
||
|
-vstart,
|
||
|
width * hscale + 3 * hstart,
|
||
|
height * vscale + 3 * vstart,
|
||
|
)
|
||
|
]
|
||
|
|
||
|
children = defaultdict(set)
|
||
|
for n in self.nodes:
|
||
|
if n:
|
||
|
children[self.edges[n]].add(n)
|
||
|
|
||
|
# horizontal branches from nodes to children
|
||
|
for node in self.nodes:
|
||
|
if not children[node]:
|
||
|
continue
|
||
|
y, x = self.coords[node]
|
||
|
x *= hscale
|
||
|
y *= vscale
|
||
|
x += hstart
|
||
|
y += vstart + fontsize // 2
|
||
|
childx = [self.coords[c][1] for c in children[node]]
|
||
|
xmin = hstart + hscale * min(childx)
|
||
|
xmax = hstart + hscale * max(childx)
|
||
|
result.append(
|
||
|
'\t<polyline style="stroke:black; stroke-width:1; fill:none;" '
|
||
|
'points="%g,%g %g,%g" />' % (xmin, y, xmax, y)
|
||
|
)
|
||
|
result.append(
|
||
|
'\t<polyline style="stroke:black; stroke-width:1; fill:none;" '
|
||
|
'points="%g,%g %g,%g" />' % (x, y, x, y - fontsize // 3)
|
||
|
)
|
||
|
|
||
|
# vertical branches from children to parents
|
||
|
for child, parent in self.edges.items():
|
||
|
y, _ = self.coords[parent]
|
||
|
y *= vscale
|
||
|
y += vstart + fontsize // 2
|
||
|
childy, childx = self.coords[child]
|
||
|
childx *= hscale
|
||
|
childy *= vscale
|
||
|
childx += hstart
|
||
|
childy += vstart - fontsize
|
||
|
result += [
|
||
|
'\t<polyline style="stroke:white; stroke-width:10; fill:none;"'
|
||
|
' points="%g,%g %g,%g" />' % (childx, childy, childx, y + 5),
|
||
|
'\t<polyline style="stroke:black; stroke-width:1; fill:none;"'
|
||
|
' points="%g,%g %g,%g" />' % (childx, childy, childx, y),
|
||
|
]
|
||
|
|
||
|
# write nodes with coordinates
|
||
|
for n, (row, column) in self.coords.items():
|
||
|
node = self.nodes[n]
|
||
|
x = column * hscale + hstart
|
||
|
y = row * vscale + vstart
|
||
|
if n in self.highlight:
|
||
|
color = nodecolor if isinstance(node, Tree) else leafcolor
|
||
|
if isinstance(node, Tree) and node.label().startswith('-'):
|
||
|
color = funccolor
|
||
|
else:
|
||
|
color = 'black'
|
||
|
result += [
|
||
|
'\t<text style="text-anchor: middle; fill: %s; '
|
||
|
'font-size: %dpx;" x="%g" y="%g">%s</text>'
|
||
|
% (
|
||
|
color,
|
||
|
fontsize,
|
||
|
x,
|
||
|
y,
|
||
|
escape(node.label() if isinstance(node, Tree) else node),
|
||
|
)
|
||
|
]
|
||
|
|
||
|
result += ['</svg>']
|
||
|
return '\n'.join(result)
|
||
|
|
||
|
|
||
|
def test():
|
||
|
"""Do some tree drawing tests."""
|
||
|
|
||
|
def print_tree(n, tree, sentence=None, ansi=True, **xargs):
|
||
|
print()
|
||
|
print('{0}: "{1}"'.format(n, ' '.join(sentence or tree.leaves())))
|
||
|
print(tree)
|
||
|
print()
|
||
|
drawtree = TreePrettyPrinter(tree, sentence)
|
||
|
try:
|
||
|
print(drawtree.text(unicodelines=ansi, ansi=ansi, **xargs))
|
||
|
except (UnicodeDecodeError, UnicodeEncodeError):
|
||
|
print(drawtree.text(unicodelines=False, ansi=False, **xargs))
|
||
|
|
||
|
from nltk.corpus import treebank
|
||
|
|
||
|
for n in [0, 1440, 1591, 2771, 2170]:
|
||
|
tree = treebank.parsed_sents()[n]
|
||
|
print_tree(n, tree, nodedist=2, maxwidth=8)
|
||
|
print()
|
||
|
print('ASCII version:')
|
||
|
print(TreePrettyPrinter(tree).text(nodedist=2))
|
||
|
|
||
|
tree = Tree.fromstring(
|
||
|
'(top (punct 8) (smain (noun 0) (verb 1) (inf (verb 5) (inf (verb 6) '
|
||
|
'(conj (inf (pp (prep 2) (np (det 3) (noun 4))) (verb 7)) (inf (verb 9)) '
|
||
|
'(vg 10) (inf (verb 11)))))) (punct 12))',
|
||
|
read_leaf=int,
|
||
|
)
|
||
|
sentence = (
|
||
|
'Ze had met haar moeder kunnen gaan winkelen ,'
|
||
|
' zwemmen of terrassen .'.split()
|
||
|
)
|
||
|
print_tree('Discontinuous tree', tree, sentence, nodedist=2)
|
||
|
|
||
|
|
||
|
__all__ = ['TreePrettyPrinter']
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
test()
|