torchmetrics
QuantileMetric
¶
Bases: CatMetric
Computes quantiles over a set of observations.
__init__
¶
__init__(q: Tensor | list[Tensor] | tuple[Tensor, ...], interpolation: Literal['linear', 'lower', 'higher', 'nearest', 'midpoint'] | str = 'linear', nan_strategy: str | float = 'warn', **kwargs: Any) -> None
Construct a QuantileMetric
.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
q |
Tensor | list[Tensor] | tuple[Tensor, ...]
|
The quantiles to compute. Values should be in the range [0, 1]. Should be a tensor, or a list or tuple of |
required |
interpolation |
Literal['linear', 'lower', 'higher', 'nearest', 'midpoint'] | str
|
One of:
* |
'linear'
|
nan_strategy |
str | float
|
One of:
* |
'warn'
|
**kwargs |
Any
|
Additional arguments to pass to the base metric class. |
required |
compute
¶
Compute the configured quantiles over the observations.
Returns:
Type | Description |
---|---|
dict[float, torch.Tensor]
|
A mapping of quantiles to their corresponding value. |
TableMetric
¶
Bases: Metric
Synchronizes the state of a dict[str, list[Any]]
across multiple worker processes.
compute
¶
Return the aggregated table.
Raises:
Type | Description |
---|---|
ValueError
|
If no updates have been performed yet. |
Returns:
Type | Description |
---|---|
dict[str, list[Any]]
|
The aggregated table. |
update
¶
Aggregate the incoming update from all workers processes and share it with all other worker processes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
table |
dict[str, list[Any]]
|
Additional rows to add to the aggregated table. |
required |
Raises:
Type | Description |
---|---|
ValueError
|
If the table update has no columns. |
ValueError
|
If the table update's keys does not match the aggregated table's keys. |
ValueError
|
If the table update's values are not lists. |
ValueError
|
If the table update's lists are not all of equal length. |
ValueError
|
If the table update has no rows. |