hummingbird.ml.operator_converters._gbdt_commons

Collections of functions shared among GBDT converters.

hummingbird.ml.operator_converters._gbdt_commons.convert_gbdt_classifier_common(operator, tree_infos, get_tree_parameters, n_features, n_classes, classes=None, extra_config={}, decision_cond='<=')[source]

Common converter for GBDT classifiers.

Args:

tree_infos: The information representaing a tree (ensemble) get_tree_parameters: A function specifying how to parse the tree_infos into parameters n_features: The number of features input to the model n_classes: How many classes are expected. 1 for regression tasks classes: The classes used for classification. None if implementing a regression model extra_config: Extra configuration used to properly implement the source tree decision_cond: The condition of the decision nodes in the x <cond> threshold order. Default ‘<=’. Values can be <=, <, >=, >

Returns:

A tree implementation in PyTorch

hummingbird.ml.operator_converters._gbdt_commons.convert_gbdt_common(operator, tree_infos, get_tree_parameters, n_features, classes=None, extra_config={}, decision_cond='<=')[source]

Common converter for GBDT models.

Args:

tree_infos: The information representaing a tree (ensemble) get_tree_parameters: A function specifying how to parse the tree_infos into parameters n_features: The number of features input to the model classes: The classes used for classification. None if implementing a regression model extra_config: Extra configuration used to properly implement the source tree decision_cond: The condition of the decision nodes in the x <cond> threshold order. Default ‘<=’. Values can be <=, <, >=, >

Returns:

A tree implementation in PyTorch