hummingbird.ml.operator_converters._tree_commons

Collections of classes and functions shared among all tree converters.

class hummingbird.ml.operator_converters._tree_commons.ApplyBasePredictionPostTransform(base_prediction)[source]

Bases: PostTransform

class hummingbird.ml.operator_converters._tree_commons.ApplySigmoidBasePredictionPostTransform(base_prediction)[source]

Bases: PostTransform

class hummingbird.ml.operator_converters._tree_commons.ApplySigmoidPostTransform[source]

Bases: PostTransform

class hummingbird.ml.operator_converters._tree_commons.ApplySoftmaxBasePredictionPostTransform(base_prediction)[source]

Bases: PostTransform

class hummingbird.ml.operator_converters._tree_commons.ApplySoftmaxPostTransform[source]

Bases: PostTransform

class hummingbird.ml.operator_converters._tree_commons.ApplyTweedieBasePredictionPostTransform(base_prediction)[source]

Bases: PostTransform

class hummingbird.ml.operator_converters._tree_commons.ApplyTweediePostTransform[source]

Bases: PostTransform

class hummingbird.ml.operator_converters._tree_commons.Node(id=None)[source]

Bases: object

Class defining a tree node.

class hummingbird.ml.operator_converters._tree_commons.PostTransform[source]

Bases: object

class hummingbird.ml.operator_converters._tree_commons.TreeParameters(lefts, rights, features, thresholds, values)[source]

Bases: object

Class containing a convenient in-memory representation of a decision tree.

hummingbird.ml.operator_converters._tree_commons._find_depth(node, current_depth)[source]

Recursive function traversing a tree and returning the maximum depth.

hummingbird.ml.operator_converters._tree_commons._find_max_depth(tree_parameters)[source]

Function traversing all trees in sequence and returning the maximum depth.

hummingbird.ml.operator_converters._tree_commons.convert_decision_ensemble_tree_common(operator, tree_infos, get_parameters, get_parameters_for_tree_trav, n_features, classes=None, extra_config={})[source]
hummingbird.ml.operator_converters._tree_commons.get_parameters_for_gemm_common(lefts, rights, features, thresholds, values, n_features, extra_config={})[source]

Common functions used by all tree algorithms to generate the parameters according to the GEMM strategy.

Args:

left: The left nodes right: The right nodes features: The features used in the decision nodes thresholds: The thresholds used in the decision nodes values: The values stored in the leaf nodes n_features: The number of expected input features

Returns:

The weights and bias for the GEMM implementation

hummingbird.ml.operator_converters._tree_commons.get_parameters_for_sklearn_common(tree_infos, extra_config)[source]

Parse sklearn-based trees, including SklearnRandomForestClassifier/Regressor and SklearnGradientBoostingClassifier/Regressor Args:

tree_infos: The information representing a tree (ensemble) Returns: The tree parameters wrapped into an instance of operator_converters._tree_commons_TreeParameters

hummingbird.ml.operator_converters._tree_commons.get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values, extra_config={})[source]

Common functions used by all tree algorithms to generate the parameters according to the tree_trav strategies.

Args:

left: The left nodes right: The right nodes features: The features used in the decision nodes thresholds: The thresholds used in the decision nodes values: The values stored in the leaf nodes

Returns:

An array containing the extracted parameters

hummingbird.ml.operator_converters._tree_commons.get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, classes=None, extra_config={})[source]

This function is used to generate tree parameters for sklearn trees. Includes SklearnRandomForestClassifier/Regressor, and SklearnGradientBoostingClassifier.

Args:

left: The left nodes right: The right nodes features: The features used in the decision nodes thresholds: The thresholds used in the decision nodes values: The values stored in the leaf nodes classes: The list of class labels. None if regression model

Returns:

An array containing the extracted parameters

hummingbird.ml.operator_converters._tree_commons.get_tree_implementation_by_config_or_depth(extra_config, max_depth, low=3, high=10)[source]

Utility function used to pick the tree implementation based on input parameters and heurstics. The current heuristic is such that GEMM <= low < PerfTreeTrav <= high < TreeTrav Args:

max_depth: The maximum tree-depth found in the tree model. low: The maximum depth below which GEMM strategy is used high: The maximum depth for which PerfTreeTrav strategy is used

Returns: A tree implementation

hummingbird.ml.operator_converters._tree_commons.get_tree_params_and_type(tree_infos, get_tree_parameters, extra_config)[source]

Populate the parameters from the trees and pick the tree implementation strategy.

Args:

tree_infos: The information representaing a tree (ensemble) get_tree_parameters: A function specifying how to parse the tree_infos into a

operator_converters._tree_commons_TreeParameters object

extra_config: param extra_config: Extra configuration used also to select the best conversion strategy

Returns:

The tree parameters, the maximum tree-depth and the tre implementation to use