hummingbird.ml.operator_converters._tree_commons¶
Collections of classes and functions shared among all tree converters.
- class hummingbird.ml.operator_converters._tree_commons.ApplyBasePredictionPostTransform(base_prediction)[source]¶
Bases:
PostTransform
- class hummingbird.ml.operator_converters._tree_commons.ApplySigmoidBasePredictionPostTransform(base_prediction)[source]¶
Bases:
PostTransform
- class hummingbird.ml.operator_converters._tree_commons.ApplySigmoidPostTransform[source]¶
Bases:
PostTransform
- class hummingbird.ml.operator_converters._tree_commons.ApplySoftmaxBasePredictionPostTransform(base_prediction)[source]¶
Bases:
PostTransform
- class hummingbird.ml.operator_converters._tree_commons.ApplySoftmaxPostTransform[source]¶
Bases:
PostTransform
- class hummingbird.ml.operator_converters._tree_commons.ApplyTweedieBasePredictionPostTransform(base_prediction)[source]¶
Bases:
PostTransform
- class hummingbird.ml.operator_converters._tree_commons.ApplyTweediePostTransform[source]¶
Bases:
PostTransform
- class hummingbird.ml.operator_converters._tree_commons.Node(id=None)[source]¶
Bases:
object
Class defining a tree node.
- class hummingbird.ml.operator_converters._tree_commons.TreeParameters(lefts, rights, features, thresholds, values)[source]¶
Bases:
object
Class containing a convenient in-memory representation of a decision tree.
- hummingbird.ml.operator_converters._tree_commons._find_depth(node, current_depth)[source]¶
Recursive function traversing a tree and returning the maximum depth.
- hummingbird.ml.operator_converters._tree_commons._find_max_depth(tree_parameters)[source]¶
Function traversing all trees in sequence and returning the maximum depth.
- hummingbird.ml.operator_converters._tree_commons.convert_decision_ensemble_tree_common(operator, tree_infos, get_parameters, get_parameters_for_tree_trav, n_features, classes=None, extra_config={})[source]¶
- hummingbird.ml.operator_converters._tree_commons.get_parameters_for_gemm_common(lefts, rights, features, thresholds, values, n_features, extra_config={})[source]¶
Common functions used by all tree algorithms to generate the parameters according to the GEMM strategy.
- Args:
left: The left nodes right: The right nodes features: The features used in the decision nodes thresholds: The thresholds used in the decision nodes values: The values stored in the leaf nodes n_features: The number of expected input features
- Returns:
The weights and bias for the GEMM implementation
- hummingbird.ml.operator_converters._tree_commons.get_parameters_for_sklearn_common(tree_infos, extra_config)[source]¶
Parse sklearn-based trees, including SklearnRandomForestClassifier/Regressor and SklearnGradientBoostingClassifier/Regressor Args:
tree_infos: The information representing a tree (ensemble) Returns: The tree parameters wrapped into an instance of operator_converters._tree_commons_TreeParameters
- hummingbird.ml.operator_converters._tree_commons.get_parameters_for_tree_trav_common(lefts, rights, features, thresholds, values, extra_config={})[source]¶
Common functions used by all tree algorithms to generate the parameters according to the tree_trav strategies.
- Args:
left: The left nodes right: The right nodes features: The features used in the decision nodes thresholds: The thresholds used in the decision nodes values: The values stored in the leaf nodes
- Returns:
An array containing the extracted parameters
- hummingbird.ml.operator_converters._tree_commons.get_parameters_for_tree_trav_sklearn(lefts, rights, features, thresholds, values, classes=None, extra_config={})[source]¶
This function is used to generate tree parameters for sklearn trees. Includes SklearnRandomForestClassifier/Regressor, and SklearnGradientBoostingClassifier.
- Args:
left: The left nodes right: The right nodes features: The features used in the decision nodes thresholds: The thresholds used in the decision nodes values: The values stored in the leaf nodes classes: The list of class labels. None if regression model
- Returns:
An array containing the extracted parameters
- hummingbird.ml.operator_converters._tree_commons.get_tree_implementation_by_config_or_depth(extra_config, max_depth, low=3, high=10)[source]¶
Utility function used to pick the tree implementation based on input parameters and heurstics. The current heuristic is such that GEMM <= low < PerfTreeTrav <= high < TreeTrav Args:
max_depth: The maximum tree-depth found in the tree model. low: The maximum depth below which GEMM strategy is used high: The maximum depth for which PerfTreeTrav strategy is used
Returns: A tree implementation
- hummingbird.ml.operator_converters._tree_commons.get_tree_params_and_type(tree_infos, get_tree_parameters, extra_config)[source]¶
Populate the parameters from the trees and pick the tree implementation strategy.
- Args:
tree_infos: The information representaing a tree (ensemble) get_tree_parameters: A function specifying how to parse the tree_infos into a
operator_converters._tree_commons_TreeParameters object
extra_config: param extra_config: Extra configuration used also to select the best conversion strategy
- Returns:
The tree parameters, the maximum tree-depth and the tre implementation to use