import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
# sample data
penguins = sns.load_dataset('penguins')
g = sns.JointGrid(data=penguins, x='bill_length_mm', y='bill_depth_mm', space=0)
g.plot_joint(sns.scatterplot)
sns.despine(top=False)
g.plot_marginals(sns.histplot, kde=True, bins=250, color='r')
g.ax_marg_x.remove()
In this code, I am trying to plot a figure and marginal distribution of an axis. However, I could not do something. How can I remove the right and left margin borders in this figure?