obvs.metrics
Metrics Module
This module provides classes for computing various metrics related to language modeling tasks.
- Classes:
- PrecisionAtK: Computes the Precision@k metric for a batch of estimated probabilities vs true token indices.
The update method takes in the top-k predicted token indices and the true token indices for each example in the batch. The compute method returns the Precision@k metric result for the accumulated batches, where a correct prediction is considered if the true token is present anywhere in the top-k predictions.
- Surprisal: Computes the Surprisal metric for a batch of estimated probabilities vs true token indices.
The update method takes in the estimated probabilities and the true token indices for each example in the batch. The compute method returns the average Surprisal metric result for the accumulated batches.
These classes can be used to evaluate the performance of language models by measuring the accuracy of their predictions and the surprise factor of the true tokens given the estimated probabilities.
Module Contents
Classes
Compute Precision@k metric for a batch of estimated probabilities vs true token indices. |
|
Compute Surprisal metric for a batch of estimated probabilities vs true token indices. |
- class obvs.metrics.PrecisionAtKMetric(topk=10, dist_sync_on_step=False, batch_size=None)
Bases:
torchmetrics.MetricCompute Precision@k metric for a batch of estimated probabilities vs true token indices. The update method takes in the top-k predicted token indices and the true token indices for each example in the batch. The compute method returns the Precision@k metric result for the accumulated batches, where a correct prediction is considered if the true token is present anywhere in the top-k predictions.
- update(logits, true_token_index) None
Override this method to update the state variables of your metric class.
- static batch(logits, true_token_index, topk) torch.Tensor
- compute() torch.Tensor
Override this method to compute the final metric value.
This method will automatically synchronize state variables when running in distributed backend.
- class obvs.metrics.SurprisalMetric(dist_sync_on_step=False, batch_size=None)
Bases:
torchmetrics.MetricCompute Surprisal metric for a batch of estimated probabilities vs true token indices. The update method takes in the estimated probabilities and the true token indices for each example in the batch. The compute method returns the average Surprisal metric result for the accumulated batches.
- update(logits, true_token_index) None
Override this method to update the state variables of your metric class.
- static batch(logits, true_token_index) torch.Tensor
- compute() torch.Tensor
Override this method to compute the final metric value.
This method will automatically synchronize state variables when running in distributed backend.