Skip to content

torchmetrics

Modules:

Name Description
quantile
table

Classes:

Name Description
QuantileMetric

Computes quantiles over a set of observations.

TableMetric

Synchronizes the state of a dict[str, list[Any]] across multiple worker processes.

QuantileMetric

Bases: CatMetric

Computes quantiles over a set of observations.

Methods:

Name Description
__init__

Construct a QuantileMetric.

compute

Compute the configured quantiles over the observations.

__init__

__init__(
    q: Tensor | list[Tensor] | tuple[Tensor, ...],
    interpolation: Literal[
        "linear", "lower", "higher", "nearest", "midpoint"
    ]
    | str = "linear",
    nan_strategy: str | float = "warn",
    **kwargs: Any,
) -> None

Construct a QuantileMetric.

Parameters:

Name Type Description Default

q

Tensor | list[Tensor] | tuple[Tensor, ...]

The quantiles to compute. Values should be in the range [0, 1]. Should be a tensor, or a list or tuple of hstack-able tensors. The result is flattened and quantile values are computed for each element of q.

required

interpolation

Literal['linear', 'lower', 'higher', 'nearest', 'midpoint'] | str

One of: * 'linear': a + (b - a) * fraction, where fraction is the fractional part of the computed quantile index. * 'lower': a. * 'higher': b. * 'nearest': a or b, whichever's index is closer to the computed quantile index (rounding down for .5 fractions). * 'midpoint': (a + b) / 2.

'linear'

nan_strategy

str | float

One of: * 'error': if any nan values are encountered will give a RuntimeError * 'warn': if any nan values are encountered will give a warning and continue * 'ignore': all nan values are silently removed * a float: if a float is provided will impute any nan values with this value

'warn'

**kwargs

Any

Additional arguments to pass to the base metric class.

required

compute

compute() -> dict[float, torch.Tensor]

Compute the configured quantiles over the observations.

Returns:

Type Description
dict[float, torch.Tensor]

A mapping of quantiles to their corresponding value.

TableMetric

Bases: Metric

Synchronizes the state of a dict[str, list[Any]] across multiple worker processes.

Methods:

Name Description
__init__

Construct a TableMetric.

__len__

Return the number of rows in the aggregated table.

compute

Return the aggregated table.

reset

Reset the aggregated table.

update

Aggregate the incoming update from all workers processes and share it with all other worker processes.

__init__

__init__() -> None

Construct a TableMetric.

__len__

__len__() -> int

Return the number of rows in the aggregated table.

compute

compute() -> dict[str, list[Any]]

Return the aggregated table.

Raises:

Type Description
ValueError

If no updates have been performed yet.

Returns:

Type Description
dict[str, list[Any]]

The aggregated table.

reset

reset() -> None

Reset the aggregated table.

update

update(table: dict[str, list[Any]]) -> None

Aggregate the incoming update from all workers processes and share it with all other worker processes.

Parameters:

Name Type Description Default

table

dict[str, list[Any]]

Additional rows to add to the aggregated table.

required

Raises:

Type Description
ValueError

If the table update has no columns.

ValueError

If the table update's keys does not match the aggregated table's keys.

ValueError

If the table update's values are not lists.

ValueError

If the table update's lists are not all of equal length.

ValueError

If the table update has no rows.