I've been struggling to save my graphs to the specific directory with some certaion look.
Here is the example data and what I've tried so far
import pandas as pd
import numpy as np
import itertools
import seaborn as sns
from matplotlib.colors import ListedColormap
print("seaborn version {}".format(sns.__version__))
# R expand.grid() function in Python
# https://stackoverflow.com/a/12131385/1135316
def expandgrid(*itrs):
product = list(itertools.product(*itrs))
return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}
ltt= ['lt1','lt2']
methods=['method 1', 'method 2', 'method 3', 'method 4']
labels = ['label1','label2']
times = range(0,100,10)
data = pd.DataFrame(expandgrid(ltt,methods,labels, times, times))
data.columns = ['ltt','method','labels','dtsi','rtsi']
#data['nw_score'] = np.random.sample(data.shape[0])
data['nw_score'] = np.random.choice([0,1],data.shape[0])
data
Out[25]:
ltt method labels dtsi rtsi nw_score
0 lt1 method 1 label1 0 0 0
1 lt1 method 1 label1 0 10 1
2 lt1 method 1 label1 0 20 1
3 lt1 method 1 label1 0 30 1
4 lt1 method 1 label1 0 40 1
... ... ... ... ... ...
1595 lt2 method 4 label2 90 50 0
1596 lt2 method 4 label2 90 60 0
1597 lt2 method 4 label2 90 70 0
1598 lt2 method 4 label2 90 80 0
1599 lt2 method 4 label2 90 90 0
labels_fill = {0:'red',1:'blue'}
def facet(data,color):
data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
g = sns.heatmap(data, cmap=ListedColormap(['red', 'blue']), cbar=False,annot=True)
for l in data.ltt.unique():
# print(l)
with sns.plotting_context(font_scale=5.5):
g = sns.FacetGrid(data,row="labels", col="method+l", size=2, aspect=1,margin_titles=False)
g = g.map_dataframe(facet)
g.add_legend()
# g.set(xlabel='common xlabel', ylabel='common ylabel')
#g.set_titles(col_template="{col_name}", fontweight='bold', fontsize=18)
g.set_titles(template="")
for ax,m in zip(g.axes[0,:],methods):
ax.set_title(m, fontweight='bold', fontsize=12)
for ax,l in zip(g.axes[:,0],labels):
ax.set_ylabel(l, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')
# g.fig.tight_layout()
save_results_to = 'D:/plots'
if not os.path.exists(save_results_to):
os.makedirs(save_results_to)
g.savefig(save_results_to + l+ '.png', dpi = 300)
When I ran the code above I'm getting an error which says
ValueError: Index contains duplicate entries, cannot reshape
the expected graph format