Skip to content

torchmetrics

QuantileMetric

Bases: CatMetric

Computes quantiles over a set of observations.

dtype property

dtype: dtype

The torch.dtype that each update value will be cast to and thereby the torch.dtype of the compute result also.

torch.quantile requires the q tensor to have same dtype as the input tensor (the compute result) and both must be either float32 or float64.

__init__

__init__(
    q: float | Iterable[float] | Tensor,
    interpolation: Literal[
        "linear", "lower", "higher", "nearest", "midpoint"
    ]
    | str = "linear",
    nan_strategy: Literal[
        "error", "warn", "ignore", "disable"
    ]
    | float = "warn",
    **kwargs: Any,
) -> None

Construct a QuantileMetric.

Parameters:

Name Type Description Default

q

float | Iterable[float] | Tensor

The quantiles to compute. Values should be in the range [0, 1]. Should be a tensor, or a list or tuple of hstack-able tensors. The result is flattened and quantile values are computed for each element of q.

required

interpolation

Literal['linear', 'lower', 'higher', 'nearest', 'midpoint'] | str

One of: * 'linear': a + (b - a) * fraction, where fraction is the fractional part of the computed quantile index. * 'lower': a. * 'higher': b. * 'nearest': a or b, whichever's index is closer to the computed quantile index (rounding down for .5 fractions). * 'midpoint': (a + b) / 2.

'linear'

nan_strategy

Literal['error', 'warn', 'ignore', 'disable'] | float

One of: * 'error': if any nan values are encountered will give a RuntimeError * 'warn': if any nan values are encountered will give a warning and continue * 'ignore': all nan values are silently removed * a float: if a float is provided will impute any nan values with this value

'warn'

**kwargs

Any

Additional arguments to pass to the base metric class.

required

compute

compute() -> tuple[torch.Tensor, dict[float, torch.Tensor]]

Aggregate the observations and compute the configured quantiles.

Returns:

Type Description
tuple[torch.Tensor, dict[float, torch.Tensor]]

The observations and a mapping of quantiles to their corresponding value.

TableMetric

Bases: Metric

Synchronizes the state of a dict[str, list[Any]] across multiple worker processes.

__init__

__init__() -> None

Construct a TableMetric.

__len__

__len__() -> int

Return the number of rows in the aggregated table.

compute

compute() -> dict[str, list[Any]]

Return the aggregated table.

Raises:

Type Description
ValueError

If no updates have been performed yet.

Returns:

Type Description
dict[str, list[Any]]

The aggregated table.

reset

reset() -> None

Reset the aggregated table.

update

update(table: dict[str, list[Any]]) -> None

Aggregate the incoming update from all workers processes and share it with all other worker processes.

Parameters:

Name Type Description Default

table

dict[str, list[Any]]

Additional rows to add to the aggregated table.

required

Raises:

Type Description
ValueError

If the table update has no columns.

ValueError

If the table update's keys does not match the aggregated table's keys.

ValueError

If the table update's values are not lists.

ValueError

If the table update's lists are not all of equal length.

ValueError

If the table update has no rows.