# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from enum import Enum
from typing import Union
from pydantic import validator
from olive.common.config_utils import ConfigBase, validate_config
from olive.data.config import DataConfig
from olive.evaluator.accuracy import AccuracyBase
from olive.evaluator.metric_config import LatencyMetricConfig, MetricGoal, get_user_config_class
[docs]class MetricType(str, Enum):
ACCURACY = "accuracy"
LATENCY = "latency"
CUSTOM = "custom"
[docs]class AccuracySubType(str, Enum):
ACCURACY_SCORE = "accuracy_score"
F1_SCORE = "f1_score"
PRECISION = "precision"
RECALL = "recall"
AUC = "auc"
[docs]class LatencySubType(str, Enum):
AVG = "avg"
MAX = "max"
MIN = "min"
P50 = "p50"
P75 = "p75"
P90 = "p90"
P95 = "p95"
P99 = "p99"
P999 = "p999"
# TODO: support multiple subtypes at the same type for the same type
# Otherwise it's a waste of compute and time if we have to evaluate a model for different subtypes
# names, subtypes: Union[str, List[str]]
# However accuracy metric poses a slight problem since AUC has a different config. Need to resolve this
# so that we get a single metric config for a single type
# This way, the user can return multiple metrics at once
[docs]class Metric(ConfigBase):
name: str
type: MetricType
sub_type: Union[AccuracySubType, LatencySubType] = None
higher_is_better: bool = True
priority_rank: int = 1
goal: MetricGoal = None
metric_config: ConfigBase = None
user_config: ConfigBase = None
data_config: DataConfig = DataConfig()
@validator("sub_type", always=True, pre=True)
def validate_sub_type(cls, v, values):
if "type" not in values:
raise ValueError("Invalid type")
if values["type"] == MetricType.CUSTOM:
return None
sub_type_enum = AccuracySubType if values["type"] == MetricType.ACCURACY else LatencySubType
try:
v = sub_type_enum(v)
except ValueError:
raise ValueError(
f"sub_type must be one of {list(sub_type_enum.__members__.keys())} for {values['type']} metric"
)
return v
@validator("higher_is_better", always=True, pre=True)
def validate_higher_is_better(cls, v, values):
if "type" not in values:
raise ValueError("Invalid type")
if values["type"] == MetricType.ACCURACY:
return True
if values["type"] == MetricType.LATENCY:
return False
if v is None:
raise ValueError("higher_is_better must be specified for custom metric")
return v
@validator("metric_config", always=True, pre=True)
def validate_metric_config(cls, v, values):
if "type" not in values:
raise ValueError("Invalid type")
if "sub_type" not in values:
raise ValueError("Invalid sub_type")
if values["type"] == MetricType.CUSTOM:
return None
# metric config class
if values["type"] == MetricType.LATENCY:
metric_config_class = LatencyMetricConfig
elif values["type"] == MetricType.ACCURACY:
metric_config_class = AccuracyBase.registry[values["sub_type"]].get_config_class()
# validate metric config
return validate_config(v, ConfigBase, metric_config_class)
@validator("user_config", pre=True)
def validate_user_config(cls, v, values):
if "type" not in values:
raise ValueError("Invalid type")
user_config_class = get_user_config_class(values["type"])
return validate_config(v, ConfigBase, user_config_class)
@validator("goal")
def validate_goal(cls, v, values):
if v is None:
return v
if v.type not in ["percent-min-improvement", "percent-max-degradation"]:
return v
if "higher_is_better" not in values:
raise ValueError("Invalid higher_is_better")
higher_is_better = values["higher_is_better"]
ranges = {
("percent-min-improvement", True): (0, float("inf")),
("percent-min-improvement", False): (0, 100),
("percent-max-degradation", True): (0, 100),
("percent-max-degradation", False): (0, float("inf")),
}
valid_range = ranges[(v.type, higher_is_better)]
if not valid_range[0] < v.value < valid_range[1]:
raise ValueError(
f"Invalid goal value {v.value} for {v.type} and higher_is_better={higher_is_better}. Valid range is"
f" {valid_range}"
)
return v