
The CohortManager allows the application of different data processing pipelines over each cohort. Also allows the creation and filtering of multiple cohorts using a simple interface. Finally, allows the creation of different estimators for each cohort using the .predict() and predict_proba() interfaces. This class uses the cohort.CohortDefinition internally in order to create, filter, and manipulate multiple cohorts. There are multiple ways of using the cohort.CohortManager class when building a pipeline, and these different scenarios are summarized in following figure.

Balancing over cohorts

Figure 1 - The CohortManager class can be used in different ways to target mitigations to different cohorts. The main differences between these scenarios consist on whether the same or different type of data mitigation is applied to the cohort data, and whether a single or separate models will be trained for different cohorts. Depending on these choices, CohortManager will take care of slicing the data accordingly, applying the specified data mitigation strategy, merging the data back, and retraining the model(s).

The Cohort Manager - Scenarios and Examples notebook, located in notebooks/cohort/cohort_manager_scenarios.ipynb and listed in the Examples section below, shows how each of these scenarios can be implemented through simple code snippets.

class raimitigations.cohort.CohortManager(transform_pipe: Optional[list] = None, cohort_def: Optional[Union[dict, list, str]] = None, cohort_col: Optional[list] = None, cohort_json_files: Optional[list] = None, df: Optional[DataFrame] = None, label_col: Optional[str] = None, X: Optional[DataFrame] = None, y: Optional[DataFrame] = None, verbose: bool = True)

Concrete class that manages multiple cohort pipelines that are applied using the fit(), transform(), fit_resample(), predict(), and predict_proba() interfaces. The CohortManager uses multiple CohortDefinition objects to control the filters of each cohort, while using transformation pipelines to control which transformations should be applied to each cohort.

  • transform_pipe

    the transformation pipeline to be used for each cohort. There are different ways to present this parameter:

    1. An empty list or None: in this case, the CohortManager won’t apply any transformations over the dataset. The transform() method will simply return the dataset provided;

    2. A single transformer: in this case, this single transformer is placed in a list (a list with a single transformer), which is then replicated such that each cohort has its own list of transformations (pipeline);

    3. A list of transformers: in this case, this pipeline is replicated for each cohort;

    4. A list of pipelines: a list of pipelines is basically a list of lists of transformations. In this case, the list of pipelines should have one pipeline for each cohort created, that is, the length of the transform_pipe parameter should be the same as the number of cohorts created. The pipelines will be assigned to each cohort following the same order as the cohort_def parameter (depicted in the following example);

  • cohort_def – a list of cohort definitions or a dictionary of cohort definitions. A cohort condition is the same variable received by the cohort_definition parameter of the CohortDefinition class. When using a list of cohort definitions, the cohorts will be named automatically. For the dictionary of cohort definitions, the key used represents the cohort’s name, and the value assigned to each key is given by that cohort’s conditions. This parameter can’t be used together with the cohort_col parameter. Only one these two parameters must be used at a time. This parameter is ignored if cohort_json_files is provided;

  • cohort_col – a list of column names or indices, from which one cohort is created for each unique combination of values for these columns. This parameter can’t be used together with the cohort_def parameter. Only one these two parameters must be used at a time. This parameter is ignored if cohort_json_files is provided;

  • cohort_json_files – a list with the name of the JSON files that contains the definition of each cohort. Each cohort is saved in a single JSON file, so the length of the cohort_json_files should be equal to the number of cohorts to be used.

  • df – the data frame to be used during the fit method. This data frame must contain all the features, including the label column (specified in the label_col parameter). This parameter is mandatory if label_col is also provided. The user can also provide this dataset (along with the label_col) when calling the fit() method. If df is provided during the class instantiation, it is not necessary to provide it again when calling fit(). It is also possible to use the X and y instead of df and label_col, although it is mandatory to pass the pair of parameters (X,y) or (df, label_col) either during the class instantiation or during the fit() method;

  • label_col – the name or index of the label column. This parameter is mandatory if df is provided;

  • X – contains only the features of the original dataset, that is, does not contain the label column. This is useful if the user has already separated the features from the label column prior to calling this class. This parameter is mandatory if y is provided;

  • y – contains only the label column of the original dataset. This parameter is mandatory if X is provided;

  • verbose – indicates whether internal messages should be printed or not.

fit(X: Optional[Union[DataFrame, ndarray]] = None, y: Optional[Union[Series, ndarray]] = None, df: Optional[DataFrame] = None, label_col: Optional[str] = None)

Calls the fit() method of all transformers in all pipelines. Each cohort has its own pipeline. This way, the following steps are executed: (i) iterate over each cohort, (ii) filter the dataset (X or df) using each cohort’s filter, (iii) cycle through each of the transformers in the cohort’s pipeline and call the transformer’s fit() method, (iv) after fitting the transformer, call its transform() method to get the updated subset, which is then used in the fit() call of the following transformer. Finally, check if all instances belong to only a single cohort.

  • X – contains only the features of the original dataset, that is, does not contain the label column;

  • y – contains only the label column of the original dataset;

  • df – the full dataset;

  • label_col – the name or index of the label column;

Check the documentation of the _set_df_mult method (DataProcessing class) for more information on how these parameters work.

fit_resample(X: Optional[Union[DataFrame, ndarray]] = None, y: Optional[Union[DataFrame, ndarray]] = None, df: Optional[Union[DataFrame, ndarray]] = None, rebalance_col: Optional[str] = None)

Calls the fit_resample() method of all transformers in all pipelines. Each cohort has its own pipeline. This way, the following steps are executed: (i) iterate over each cohort, (ii) filter the dataset (X or df) using each cohort’s filter, (iii) cycle through each of the transformers in the cohort’s pipeline and call the transformer’s fit_resample() method, (iv) after resampling using the current transformer, save the new subset and use it when calling the fit_resample() of the following transformer. Finally, check if all instances belong to only a single cohort.

  • X – contains only the features of the original dataset, that is, does not contain the column used for rebalancing. This is useful if the user has already separated the features from the label column prior to calling this class. This parameter is mandatory if y is provided;

  • y – contains only the rebalance column of the original dataset. The rebalance operation is executed based on the data distribution of this column. This parameter is mandatory if X is provided;

  • df – the dataset to be rebalanced, which is used during the fit() method. This data frame must contain all the features, including the rebalance column (specified in the rebalance_col parameter). This parameter is mandatory if rebalance_col is also provided. The user can also provide this dataset (along with the rebalance_col) when calling the fit() method. If df is provided during the class instantiation, it is not necessary to provide it again when calling fit(). It is also possible to use the X and y instead of df and rebalance_col, although it is mandatory to pass the pair of parameters (X,y) or (df, rebalance_col) either during the class instantiation or during the fit() method;

  • rebalance_col – the name or index of the column used to do the rebalance operation. This parameter is mandatory if df is provided.


the resampled dataset.

Return type


predict(X: Union[DataFrame, ndarray], split_pred: bool = False)

Calls the transform() method of all transformers in all pipelines, followed by the predict() method for the estimator (which is always the last object in the pipeline).

  • X – contains only the features of the dataset to be transformed;

  • split_pred – if True, return a dictionary with the predictions for each cohort. If False, return a single predictions array;


an array with the predictions of all instances of the dataset, built from the predictions of each cohort, or a dictionary with the predictions for each cohort;

Return type

np.ndarray or dict

predict_proba(X: Union[DataFrame, ndarray], split_pred: bool = False)

Calls the transform() method of all transformers in all pipelines, followed by the predict_proba() method for the estimator (which is always the last object in the pipeline).

  • X – contains only the features of the dataset to be transformed;

  • split_pred – if True, return a dictionary with the predictions for each cohort. If False, return a single predictions array;


an array with the predictions of all instances of the dataset, built from the predictions of each cohort, or a dictionary with the predictions for each cohort;

Return type

np.ndarray or dict

transform(X: Union[DataFrame, ndarray])

Calls the transform() method of all transformers in all pipelines. Each cohort has its own pipeline. This way, the following steps are executed: (i) iterate over each cohort, (ii) filter the dataset X using each cohort’s filter, (iii) cycle through each of the transformers in the cohort’s pipeline and call the transformer’s transform() method, which returns a new transformed subset, that is then used in the transform() call of the following transformer. Finally, check if all instances belong to only a single cohort, and merge all cohort subsets into a single dataset.


X – contains only the features of the dataset to be transformed;


a dataset containing the transformed instances of all cohorts.

Return type


Class Diagram

Inheritance diagram of raimitigations.cohort.CohortManager
