🎏

Synthetic Control Method

Category
Statistics 📊
Published on
July 14, 2024
Updated on

Introduction

The Synthetic Control Method is a causal inference technique, developed by Abadie and Gardeazabal in 2003, used to estimate the effects of interventions at an aggregate level.

It is particularly valuable when traditional experimental methods, such as randomized trials, are not feasible and can be considered an advanced form of the Difference-in-Difference method.

The core principle involves constructing a counterfactual for the treatment group using similar units that have not been affected by the intervention. The difference between this "synthetic" control group and the actual data serves as an estimate of the true effect. This synthetic group is created through a weighted combination of untreated units from the "donor pool."

The robustness of the results is then assessed through various checks and sensitivity analyses to ensure the evaluation's reliability.

Implementation in Python

  1. First, let’s load the necessary packages (only the usual libraries, no specialized package in this example) and some sample data.
  2. # Import libraries
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn import linear_model
    from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
    
    # Import source data
    df = pd.read_csv('scm_data.csv').set_index('year')
    df.head()
            control_1  control_2  control_3  control_4  treatment
    year
    1990    224        122        76         682        1031
    1991    251        151        82         660        1045
    1992    260        152        84         655        1028
    1993    254        136        85         790        1119
    1994    286        152        86         794        1074
  3. Then we split the dataframe between Control groups (the “donor pool” ) and Treatment, and make a visual check.
  4. # Dataframe for 'donor pool' (control groups)
    X = df.drop(columns={'treatment'})
    
    # Dataframe for treatment group
    Y = df[['treatment']]
    
    # Plot all groups
    sns.lineplot(Y, x='year', y='treatment')
    for col in X.columns[1:]:
        sns.lineplot(X, x='year', y=col, color='grey')
    
    treatment_year = 2004
    plt.axvline(x=treatment_year, color='firebrick', linestyle='--', lw=1)
    image

    The change in trend starting 2004 is very visible for the treatment group.

  5. We can now fit a linear regression on the pre-treatment period, using scikit-learn.
  6. # Fit a linear regression
    linear_reg = linear_model.LinearRegression(fit_intercept=False).fit(
        X.loc[X.index <= treatment_year],
        Y.loc[Y.index <= treatment_year]
    )
    
    # Get weights of the fitted model
    weights = linear_reg.coef_[0]
    weights
    array([ 1.0404247 ,  2.05844824, -1.09116727,  0.88922022])
  7. Finally, we use the fitted weights to create the synthetic group.
  8. # Compute the synthetic control group
    X['synthetic'] = np.dot(X, weights)
    
    # Plot synthetic group and treatment group
    sns.lineplot(Y, x='year', y='treatment')
    sns.lineplot(X, x='year', y='synthetic', linestyle='--')
    image
  9. We can calculate some metrics to assess the performance of the linear regression.
  10. # Performance metrics of the linear regression
    Y_pred = X['synthetic'].loc[X.index <= treatment_year]
    
    rmse = np.sqrt(mean_squared_error(Y.loc[Y.index <= treatment_year], Y_pred))
    mae = mean_absolute_error(Y.loc[Y.index <= treatment_year], Y_pred)
    r2 = r2_score(Y.loc[Y.index <= treatment_year], Y_pred)
    
    print(f'RMSE: {rmse}')
    print(f'MAE: {mae}')
    print(f'R²: {r2}')
    RMSE: 64.42106137354853
    MAE: 50.00772638278276
    R²: 0.9259623777267626
  11. Finally, we can look at the estimated effect of the treatment, by substracting the counterfactual to the treatment group.
  12. # Calculate effect of treatment
    sns.lineplot(data=Y['treatment'] - X['synthetic'])
    plt.axvline(x=treatment_year, color='firebrick', linestyle='--', lw=1)
    image

Resources