# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Any, Dict, List, Optional, Union
from olive.common.config_utils import ConfigBase, validate_config
from olive.common.pydantic_v1 import validator
from olive.common.utils import StrEnumBase
from olive.data.config import DataConfig
from olive.evaluator.accuracy import AccuracyBase
from olive.evaluator.metric_config import LatencyMetricConfig, MetricGoal, ThroughputMetricConfig, get_user_config_class
logger = logging.getLogger(__name__)
[docs]class MetricType(StrEnumBase):
# TODO(trajep): support throughput
ACCURACY = "accuracy"
LATENCY = "latency"
THROUGHPUT = "throughput"
CUSTOM = "custom"
[docs]class AccuracySubType(StrEnumBase):
ACCURACY_SCORE = "accuracy_score"
F1_SCORE = "f1_score"
PRECISION = "precision"
RECALL = "recall"
AUROC = "auroc"
PERPLEXITY = "perplexity"
[docs]class LatencySubType(StrEnumBase):
# unit: millisecond
AVG = "avg"
MAX = "max"
MIN = "min"
P50 = "p50"
P75 = "p75"
P90 = "p90"
P95 = "p95"
P99 = "p99"
P999 = "p999"
[docs]class ThroughputSubType(StrEnumBase):
# unit: token per second, tps
AVG = "avg"
MAX = "max"
MIN = "min"
P50 = "p50"
P75 = "p75"
P90 = "p90"
P95 = "p95"
P99 = "p99"
P999 = "p999"
class SubMetric(ConfigBase):
name: Union[AccuracySubType, LatencyMetricConfig, str]
metric_config: ConfigBase = None
# -1 means no priority which will be evaluated only
priority: int = -1
higher_is_better: bool = False
goal: MetricGoal = None
@validator("goal")
def validate_goal(cls, v, values):
if v is None:
return v
if v.type not in ["percent-min-improvement", "percent-max-degradation"]:
return v
if "higher_is_better" not in values:
raise ValueError("Invalid higher_is_better")
higher_is_better = values["higher_is_better"]
ranges = {
("percent-min-improvement", True): (0, float("inf")),
("percent-min-improvement", False): (0, 100),
("percent-max-degradation", True): (0, 100),
("percent-max-degradation", False): (0, float("inf")),
}
valid_range = ranges[(v.type, higher_is_better)]
if not valid_range[0] < v.value < valid_range[1]:
raise ValueError(
f"Invalid goal value {v.value} for {v.type} and higher_is_better={higher_is_better}. Valid range is"
f" {valid_range}"
)
return v
[docs]class Metric(ConfigBase):
name: str
type: MetricType
backend: Optional[str] = "torch_metrics"
sub_types: List[SubMetric]
user_config: ConfigBase = None
data_config: Optional[DataConfig] = None
[docs] def get_inference_settings(self, framework):
if self.user_config is None:
return None
if self.user_config.inference_settings:
return self.user_config.inference_settings.get(framework)
else:
return None
[docs] def get_run_kwargs(self) -> Dict[str, Any]:
return self.user_config.run_kwargs if (self.user_config and self.user_config.run_kwargs) else {}
[docs] def get_sub_type_info(self, info_name, no_priority_filter=True, callback=lambda x: x):
sub_type_info = {}
for sub_type in self.sub_types:
if no_priority_filter and sub_type.priority <= 0:
continue
sub_type_info[sub_type.name] = callback(getattr(sub_type, info_name))
return sub_type_info
@validator("backend", always=True, pre=True)
def validate_backend(cls, v, values):
if values["type"] == MetricType.CUSTOM:
return None
from olive.evaluator.metric_backend import MetricBackend
assert v in MetricBackend.registry, f"Backend {v} is not in {list(MetricBackend.registry.keys())}"
assert MetricBackend.registry[v]() is not None, f"Backend {v} is not available"
return v
@validator("sub_types", always=True, pre=True, each_item=True)
def validate_sub_types(cls, v, values):
if "type" not in values:
raise ValueError("Invalid type")
if values["type"] == MetricType.CUSTOM:
if v.get("priority", -1) != -1 and v.get("higher_is_better", None) is None:
raise ValueError(f"higher_is_better must be specified for ranked custom metric: {v['name']}")
return v
# backend joint checking
if values["backend"] == "huggingface_metrics":
import evaluate
try:
evaluate.load(v["name"])
except FileNotFoundError as e:
raise ValueError(f"could not load metric {v['name']} from huggingface/evaluate") from e
elif values["backend"] == "torch_metrics":
try:
sub_metric_type_cls = None
if values["type"] == MetricType.ACCURACY:
sub_metric_type_cls = AccuracySubType
elif values["type"] == MetricType.LATENCY:
sub_metric_type_cls = LatencySubType
elif values["type"] == MetricType.THROUGHPUT:
sub_metric_type_cls = ThroughputSubType
# if not exist, will raise ValueError
v["name"] = sub_metric_type_cls(v["name"])
except ValueError:
raise ValueError(
f"sub_type {v['name']} is not in {list(sub_metric_type_cls.__members__.keys())}"
f" for {values['type']} metric"
) from None
# metric_config
metric_config_cls = None
if values["type"] == MetricType.ACCURACY:
v["higher_is_better"] = v.get("higher_is_better", True)
if values["backend"] == "torch_metrics":
metric_config_cls = AccuracyBase.registry[v["name"]].get_config_class()
elif values["backend"] == "huggingface_metrics":
from olive.evaluator.metric_backend import HuggingfaceMetrics
metric_config_cls = HuggingfaceMetrics.get_config_class()
elif values["type"] == MetricType.LATENCY:
v["higher_is_better"] = v.get("higher_is_better", False)
metric_config_cls = LatencyMetricConfig
elif values["type"] == MetricType.THROUGHPUT:
v["higher_is_better"] = v.get("higher_is_better", True)
metric_config_cls = ThroughputMetricConfig
v["metric_config"] = validate_config(v.get("metric_config", {}), metric_config_cls)
return v
@validator("user_config", pre=True, always=True)
def validate_user_config(cls, v, values):
if "type" not in values:
raise ValueError("Invalid type")
user_config_class = get_user_config_class(values["type"])
return validate_config(v, user_config_class)
def get_latency_config_from_metric(metric: Metric):
warmup_num, repeat_test_num, sleep_num = None, None, None
for sub_type in metric.sub_types:
if sub_type.metric_config:
warmup_num = sub_type.metric_config.warmup_num
repeat_test_num = sub_type.metric_config.repeat_test_num
sleep_num = sub_type.metric_config.sleep_num
break
return warmup_num, repeat_test_num, sleep_num