0

Below is the sample code

import pandas as pd
from causalimpact import CausalImpact
data = pd.read_csv('https://raw.githubusercontent.com/WillianFuks/tfcausalimpact/master/tests/fixtures/arma_data.csv')[['y', 'X']]
data.iloc[70:, 0] += 5

pre_period = [0, 69]
post_period = [70, 99]

ci = CausalImpact(data, pre_period, post_period)
ci.plot()

I wanted to write above generated plot to html or atleast save as image . Is there any solution as type of ci.plot() is nonetype . https://github.com/WillianFuks/tfcausalimpact

desertnaut
  • 57,590
  • 26
  • 140
  • 166

1 Answers1

0

A very dirty (but working) solution would be to:

  1. Get the code of ci.plot() using inspect
import inspect
print(inspect.getsource(ci.plot))
  1. Create a new function based on ci.plot() which actually saves the plot (or probably it's possible to rewrite the method of the class). In my case it's function plot2 with a new parameter path. Its only difference from the original function is the last line.
def plot2(self, path, panels=['original', 'pointwise', 'cumulative'], figsize=(15, 12)):
    """Plots inferences results related to causal impact analysis.

    Args
    ----
      panels: list.
        Indicates which plot should be considered in the graphics.
      figsize: tuple.
        Changes the size of the graphics plotted.

    Raises
    ------
      RuntimeError: if inferences were not computed yet.
    """
    plt = self._get_plotter()
    fig = plt.figure(figsize=figsize)
    if self.summary_data is None:
        raise RuntimeError('Please first run inferences before plotting results')

    valid_panels = ['original', 'pointwise', 'cumulative']
    for panel in panels:
        if panel not in valid_panels:
            raise ValueError(
                '"{}" is not a valid panel. Valid panels are: {}.'.format(
                    panel, ', '.join(['"{}"'.format(e) for e in valid_panels])
                )
            )

    # First points can be noisy due approximation techniques used in the likelihood
    # optimizaion process. We remove those points from the plots.
    llb = self.trained_model.filter_results.loglikelihood_burn
    inferences = self.inferences.iloc[llb:]

    intervention_idx = inferences.index.get_loc(self.post_period[0])
    n_panels = len(panels)
    ax = plt.subplot(n_panels, 1, 1)
    idx = 1

    if 'original' in panels:
        ax.plot(pd.concat([self.pre_data.iloc[llb:, 0], self.post_data.iloc[:, 0]]),
                'k', label='y')
        ax.plot(inferences['preds'], 'b--', label='Predicted')
        ax.axvline(inferences.index[intervention_idx - 1], c='k', linestyle='--')
        ax.fill_between(
            self.pre_data.index[llb:].union(self.post_data.index),
            inferences['preds_lower'],
            inferences['preds_upper'],
            facecolor='blue',
            interpolate=True,
            alpha=0.25
        )
        ax.grid(True, linestyle='--')
        ax.legend()
        if idx != n_panels:
            plt.setp(ax.get_xticklabels(), visible=False)
        idx += 1

    if 'pointwise' in panels:
        ax = plt.subplot(n_panels, 1, idx, sharex=ax)
        ax.plot(inferences['point_effects'], 'b--', label='Point Effects')
        ax.axvline(inferences.index[intervention_idx - 1], c='k', linestyle='--')
        ax.fill_between(
            inferences['point_effects'].index,
            inferences['point_effects_lower'],
            inferences['point_effects_upper'],
            facecolor='blue',
            interpolate=True,
            alpha=0.25
        )
        ax.axhline(y=0, color='k', linestyle='--')
        ax.grid(True, linestyle='--')
        ax.legend()
        if idx != n_panels:
            plt.setp(ax.get_xticklabels(), visible=False)
        idx += 1

    if 'cumulative' in panels:
        ax = plt.subplot(n_panels, 1, idx, sharex=ax)
        ax.plot(inferences['post_cum_effects'], 'b--',
                label='Cumulative Effect')
        ax.axvline(inferences.index[intervention_idx - 1], c='k', linestyle='--')
        ax.fill_between(
            inferences['post_cum_effects'].index,
            inferences['post_cum_effects_lower'],
            inferences['post_cum_effects_upper'],
            facecolor='blue',
            interpolate=True,
            alpha=0.25
        )
        ax.grid(True, linestyle='--')
        ax.axhline(y=0, color='k', linestyle='--')
        ax.legend()

    # Alert if points were removed due to loglikelihood burning data
    if llb > 0:
        text = ('Note: The first {} observations were removed due to approximate '
                'diffuse initialization.'.format(llb))
        fig.text(0.1, 0.01, text, fontsize='large')

    plt.savefig(path)
  1. Call the new function instead:
plot2(ci, 'myplot.png')