hummingbird.ml.operator_converters._gbdt_commons¶
Collections of functions shared among GBDT converters.
- hummingbird.ml.operator_converters._gbdt_commons.convert_gbdt_classifier_common(operator, tree_infos, get_tree_parameters, n_features, n_classes, classes=None, extra_config={}, decision_cond='<=')[source]¶
Common converter for GBDT classifiers.
- Args:
tree_infos: The information representaing a tree (ensemble) get_tree_parameters: A function specifying how to parse the tree_infos into parameters n_features: The number of features input to the model n_classes: How many classes are expected. 1 for regression tasks classes: The classes used for classification. None if implementing a regression model extra_config: Extra configuration used to properly implement the source tree decision_cond: The condition of the decision nodes in the x <cond> threshold order. Default ‘<=’. Values can be <=, <, >=, >
- Returns:
A tree implementation in PyTorch
- hummingbird.ml.operator_converters._gbdt_commons.convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes=None, extra_config={}, decision_cond='<=')[source]¶
Common converter for GBDT models.
- Args:
tree_infos: The information representaing a tree (ensemble) get_tree_parameters: A function specifying how to parse the tree_infos into parameters n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model extra_config: Extra configuration used to properly implement the source tree decision_cond: The condition of the decision nodes in the x <cond> threshold order. Default ‘<=’. Values can be <=, <, >=, >
- Returns:
A tree implementation in PyTorch