Source code for dowhy.do_sampler

import logging
import numpy as np
import pandas as pd

from dowhy.utils.api import parse_state

[docs]class DoSampler: """Base class for a sampler from the interventional distribution. """ def __init__(self, data, params=None, variable_types=None, num_cores=1, causal_model=None, keep_original_treatment=False): """ Initializes a do sampler with data and names of relevant variables. Do sampling implements the do() operation from Pearl (2000). This is an operation is defined on a causal bayesian network, an explicit implementation of which is the basis for the MCMC sampling method. We abstract the idea behind the three-step process to allow other methods, as well. The `disrupt_causes` method is the means to make treatment assignment ignorable. In the Pearlian framework, this is where we cut the edges pointing into the causal state. With other methods, this will typically be by using some approach which assumes conditional ignorability (e.g. weighting, or explicit conditioning with Robins G-formula.) Next, the `make_treatment_effective` method reflects the assumption that the intervention we impose is "effective". Most simply, we fix the causal state to some specific value. We skip this step there is no value specified for the causal state, and the original values are used instead. Finally, we sample from the resulting distribution. This can be either from a `point_sample` method, in the case that the inference method doesn't support batch sampling, or the `sample` method in the case that it does. For convenience, the `point_sample` method parallelizes with `multiprocessing` using the `num_cores` kwargs to set the number of cores to use for parallelization. While different methods will have their own class attributes, the `_df` method should be common to all methods. This is them temporary dataset which starts as a copy of the original data, and is modified to reflect the steps of the do operation. Read through the existing methods (weighting is likely the most minimal) to get an idea of how this works to implement one yourself. :param data: pandas.DataFrame containing the data :param identified_estimand: dowhy.causal_identifier.IdentifiedEstimand: and estimand using a backdoor method for effect identification. :param treatments: list or str: names of the treatment variables :param outcomes: list or str: names of the outcome variables :param variable_types: dict: A dictionary containing the variable's names and types. 'c' for continuous, 'o' for ordered, 'd' for discrete, and 'u' for unordered discrete. :param keep_original_treatment: bool: Whether to use `make_treatment_effective`, or to keep the original treatment assignments. :param params: (optional) additional method parameters """ self._data = data.copy() self._causal_model = causal_model self._target_estimand = self._causal_model.identify_effect() self._treatment_names = parse_state(self._causal_model._treatment) self._outcome_names = parse_state(self._causal_model._outcome) self._estimate = None self._variable_types = variable_types self.num_cores = num_cores self.point_sampler = True self.sampler = None self.keep_original_treatment = keep_original_treatment if params is not None: for key, value in params.items(): setattr(self, key, value) self._df = self._data.copy() if not self._variable_types: self._infer_variable_types() self.dep_type = [self._variable_types[var] for var in self._outcome_names] self.indep_type = [self._variable_types[var] for var in self._treatment_names + self._target_estimand.backdoor_variables] self.density_types = [self._variable_types[var] for var in self._target_estimand.backdoor_variables] self.outcome_lower_support = self._data[self._outcome_names].min().values self.outcome_upper_support = self._data[self._outcome_names].max().values self.logger = logging.getLogger(__name__) def _sample_point(self, x_z): """ OVerride this if your sampling method only allows sampling a point at a time. :param : numpy.array: x_z is a numpy array containing the values of x and z in the order of the list given by self._treatment_names + self._target_estimand.backdoor_variables :return: numpy.array: a sampled outcome point """ raise NotImplementedError
[docs] def reset(self): """ If your `DoSampler` has more attributes that the `_df` attribute, you should reset them all to their initialization values by overriding this method. :return: """ self._df = self._data.copy()
[docs] def make_treatment_effective(self, x): """ This is more likely the implementation you'd like to use, but some methods may require overriding this method to make the treatment effective. :param x: :return: """ if not self.keep_original_treatment: self._df[self._treatment_names] = x
[docs] def disrupt_causes(self): """ Override this method to render treatment assignment conditionally ignorable :return: """ raise NotImplementedError
[docs] def point_sample(self): if self.num_cores == 1: sampled_outcomes = self._df[self._treatment_names + self._target_estimand.backdoor_variables].apply(self._sample_point, axis=1) else: from multiprocessing import Pool p = Pool(self.num_cores) sampled_outcomes = np.array(, self._df[self._treatment_names + self._target_estimand.backdoor_variables].values)) sampled_outcomes = pd.DataFrame(sampled_outcomes, columns=self._outcome_names) self._df[self._outcome_names] = sampled_outcomes
[docs] def sample(self): """ By default, this expects a sampler to be built on class initialization which contains a `sample` method. Override this method if you want to use a different approach to sampling. :return: """ sampled_outcomes = self.sampler.sample(self._df[self._treatment_names + self._target_estimand.backdoor_variables].values) sampled_outcomes = pd.DataFrame(sampled_outcomes, columns=self._outcome_names) self._df[self._outcome_names] = sampled_outcomes
[docs] def do_sample(self, x): self.reset() self.disrupt_causes() self.make_treatment_effective(x) if self.point_sampler: self.point_sample() else: self.sample() return self._df
def _infer_variable_types(self): raise NotImplementedError('Variable type inference not implemented. Use the variable_types kwarg.')