import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import networkx as nx
[docs]
def dict_to_list(d):
return list(d.values())
[docs]
def line_loop_index(R):
"""Determine if edges between two nodes should be lines or loops
Args:
R (Reeb Graph): Reeb Graph
Returns:
2-element tuple containing
- **line_index (list)** : list of indices for edges to be drawn as lines
- **loop_index (list)** : list of indices for edges to be drawn as loops
"""
edge_list = list(R.edges)
n = len(R.edges)
loop_index = []
line_index = []
for i in range(0, n):
if edge_list[i][2] == 1:
loop_index.append(edge_list.index(edge_list[i][0:2] + (0,)))
loop_index.append(i)
line_index.remove(edge_list.index(edge_list[i][0:2] + (0,)))
else:
line_index.append(i)
return (line_index, loop_index)
[docs]
def slope_intercept(pt0, pt1):
"""Compute slope and intercept to be used in the bezier curve function
Args:
pt0 (ordered pair): first point
pt1 (ordered pair): second point
Returns:
2-element tuple containing
- **m (float)** : slope
- **b (float)** : intercept
"""
m = (pt0[1] - pt1[1]) / (pt0[0] - pt1[0])
b = pt0[1] - m * pt0[0]
return (m, b)
[docs]
def bezier_curve(pt0, midpt, pt1):
"""Compute bezier curves for plotting two edges between a single set of nodes
Args:
pt0 (ordered pair): first point
midpt (ordered pair): midpoint for bezier curve to pass through
pt1 (ordered pair): second point
Returns:
points (np array): array of points to be used in plotting
"""
(x1, y1, x2, y2) = (pt0[0], pt0[1], midpt[0], midpt[1])
(a1, b1) = slope_intercept(pt0, midpt)
(a2, b2) = slope_intercept(midpt, pt1)
points = []
for i in range(0, 100):
if x1 == x2:
continue
else:
(a, b) = slope_intercept((x1, y1), (x2, y2))
x = i * (x2 - x1) / 100 + x1
y = a * x + b
points.append((x, y))
x1 += (midpt[0] - pt0[0]) / 100
y1 = a1 * x1 + b1
x2 += (pt1[0] - midpt[0]) / 100
y2 = a2 * x2 + b2
return points
[docs]
def reeb_plot(
R, with_labels=True, with_colorbar=False, cpx=0.1, cpy=0.1, ax=None, **kwargs
):
"""Main plotting function for the Reeb Graph Class
Parameters:
R (Reeb Graph): object of Reeb Graph class
with_labels (bool): parameter to control whether or not to plot labels
with_colorbar (bool): parameter to control whether or not to plot colorbar
cp (float): parameter to control curvature of loops in the plotting function. For vertical Reeb graph, only mess with cpx.
"""
if ax is None:
fig, ax = plt.subplots()
viridis = mpl.colormaps["viridis"].resampled(16)
n = len(R.nodes)
edge_list = list(R.edges)
line_index, loop_index = line_loop_index(R)
# Some weird plotting to make the colored and labeled nodes work.
# Taking the list of function values from the pos_f dicationary since the infinite node should already have a position set.
color_map = [R.pos_f[v][1] for v in R.nodes]
pathcollection = nx.draw_networkx_nodes(
R, R.pos_f, node_color=color_map, ax=ax, **kwargs
)
if with_labels:
nx.draw_networkx_labels(R, pos=R.pos_f, font_color="black", ax=ax)
if with_colorbar:
plt.colorbar(pathcollection)
for i in line_index:
node0 = edge_list[i][0]
node1 = edge_list[i][1]
x_pos = (R.pos_f[node0][0], R.pos_f[node1][0])
y_pos = (R.pos_f[node0][1], R.pos_f[node1][1])
ax.plot(x_pos, y_pos, color="grey", zorder=0)
for i in loop_index:
node0 = edge_list[i][0]
node1 = edge_list[i][1]
xmid = (R.pos_f[node0][0] + R.pos_f[node1][0]) / 2
xmid0 = xmid - cpx * xmid
xmid1 = xmid + cpx * xmid
ymid = (R.pos_f[node0][1] + R.pos_f[node1][1]) / 2
ymid0 = ymid - cpy * ymid
ymid1 = ymid + cpy * ymid
curve = bezier_curve(R.pos_f[node0], (xmid0, ymid0), R.pos_f[node1])
c = np.array(curve)
ax.plot(c[:, 0], c[:, 1], color="grey", zorder=0)
curve = bezier_curve(R.pos_f[node0], (xmid1, ymid1), R.pos_f[node1])
c = np.array(curve)
ax.plot(c[:, 0], c[:, 1], color="grey", zorder=0)
ax.tick_params(left=True, bottom=False, labelleft=True, labelbottom=False)