Saturday, November 13, 2010

New plotter for phylogenetic trees: plotting


Here is a first pass at a plotter, following up on the last post.

It's ugly code yet, but I thought I would show you what I have. What I need to do is really think harder about how this module should be organized. But you can compare the result with what we got here.

[UPDATE: Modified the code to be easily customizable. Next up: more extensive testing.]
UPDATE2: Modified the code yet again.]

output:

         alternate name dict   None
default e node dot color k
dot color dict None
e node bar color r
e node bar color list None
e node dot size 75
e node dots visible True
e node label default color k
e node label size 14
e node labels visible True
figure type png
horizontal axis visible True
i node bar color k
i node dot color magenta
i node dots visible False
i node vertical bar color k
label width factor 1.1
line width 2
node label color dict None
r node dot color orange
using alternate names False
using e node specific colors False
vertical axis visible False


code listing:

import matplotlib.pyplot as plt
import tree_utils as tu

def plot(D,attr=None,ofn=None):
if not attr:
attr = get_default_attr()
SZ = attr['e_node_dot_size']
lw = attr['line_width']
maxx, maxy = D['meta']['max_xy']

all_node_names = D['meta']['all_node_names']
e_node_names = D['meta']['e_node_names']
i_node_names = D['meta']['i_node_names']

# actual node dictionaries
e_node_dicts = [D[n] for n in e_node_names]
i_node_dicts = [D[n] for n in i_node_names]

# extract the values we need
e_node_positions = [(nD['x'],nD['y']) for nD in e_node_dicts]
e_node_lengths = [(nD['dist_to_parent']) for nD in e_node_dicts]
i_node_positions = [(nD['x'],nD['y']) for nD in i_node_dicts]
i_node_lengths = [(nD['dist_to_parent']) for nD in i_node_dicts]
i_node_verticals = [(nD['y_bott'], nD['y_top']) for nD in i_node_dicts]

# external nodes
for i in range(len(e_node_names)):
x,y = e_node_positions[i]
d = e_node_lengths[i]
xleft = x - d
# bars
L = attr['e_node_bar_color_list']
if L:
c = L[i]
else:
c = attr['e_node_bar_color']
plt.plot([xleft,x],[y,y],color=c,lw=lw,zorder=1)
# dots
name = e_node_names[i]
if attr['e_node_dots_visible']:
if attr['using_e_node_specific_colors']:
c = attr['dot_color_dict'][name]
else:
c = attr['default_e_node_dot_color']
plt.scatter(x,y,color=c,s=SZ,zorder=2)

# internal nodes, root is i = 0
for i in range(len(i_node_names))[1:]:
x,y = i_node_positions[i]
d = i_node_lengths[i]
y_bott, y_top = i_node_verticals[i]
x0 = x - d
# bars
c = attr['i_node_bar_color']
plt.plot([x0,x],[y,y],color=c,lw=lw,zorder=1)
# verticals
c = attr['i_node_vertical_bar_color']
plt.plot([x,x],[y_bott, y_top],color=c,lw=lw,zorder=1)
# dots
if attr['i_node_dots_visible']:
c = attr['i_node_dot_color']
plt.scatter(x,y,color=c,s=SZ,zorder=2)

# root
i = 0
x,y = i_node_positions[i]
y_bott, y_top = i_node_verticals[i]
# verticals
c = attr['i_node_vertical_bar_color']
plt.plot([x,x],[y_bott, y_top],color=c,lw=lw,zorder=1)
# dots
if attr['i_node_dots_visible']:
c = attr['r_node_dot_color']
plt.scatter(x,y,color=c,s=SZ,zorder=2)

L = [len(name) for name in e_node_names]
max_label_width = maxx/100.0 * max(L)
max_label_width *= attr['label_width_factor']

# external node labels
if attr['e_node_labels_visible']:
for i in range(len(e_node_names)):
x,y = e_node_positions[i]
dx = maxx/100.0 * 4
name = e_node_names[i]
if attr['using_alternate_names']:
s = attr['alternate_name_dict'][name]
else:
s = name
if attr['using_e_node_specific_colors']:
c = attr['node_label_color_dict'][name]
else:
c = attr['e_node_label_default_color']
plt.text(x + dx, y, s,
fontname = 'Helvetica',
fontsize = attr['e_node_label_size'],
color = c,
ha = 'left',va = 'center')

ax = plt.axes()
if not attr['vertical_axis_visible']:
ax.yaxis.set_visible(False)
if not attr['horizontal_axis_visible']:
ax.xaxis.set_visible(False)
ax.set_xlim(-maxx/10.0,maxx*1.1 + max_label_width)
ax.set_ylim(-maxy/10.0,maxy*1.1)

if ofn:
plt.savefig(ofn + '.' + attr['figure_type'])
else:
plt.savefig('example.' + attr['figure_type'])

def get_default_attr():
attr = dict()
attr['e_node_dot_size'] = 75
attr['line_width'] = 2
attr['figure_type'] = 'png'
attr['horizontal_axis_visible'] = True
attr['vertical_axis_visible'] = False

attr['e_node_bar_color'] = 'r'
attr['e_node_bar_color_list'] = None
attr['e_node_dots_visible'] = True
attr['e_node_labels_visible'] = True
attr['default_e_node_dot_color'] = 'k'
attr['e_node_label_size'] = 14
attr['e_node_label_default_color'] = 'k'
attr['label_width_factor'] = 1.1

attr['i_node_bar_color'] = 'k'
attr['i_node_vertical_bar_color'] = 'k'
attr['i_node_dots_visible'] = False
attr['i_node_dot_color'] = 'magenta'
attr['r_node_dot_color'] = 'orange'

attr['using_alternate_names'] = False
attr['alternate_name_dict'] = None
attr['using_e_node_specific_colors'] = False
attr['node_label_color_dict'] = None
attr['dot_color_dict'] = None
return attr

def print_defaults():
attr = get_default_attr()
N = max([len(k) for k in attr.keys()])
for k in sorted(attr.keys()):
v = attr[k]
k = k.replace('_',' ')
print k.rjust(N), ' ', v

if __name__ == '__main__':
fn = 'tree.txt'
file_data = tu.load_data(fn)
D = tu.make_tree_dict(ts=file_data)
plot(D)
print_defaults()