obvs.metrics

Metrics Module

This module provides classes for computing various metrics related to language modeling tasks.

Classes:
PrecisionAtK: Computes the Precision@k metric for a batch of estimated probabilities vs true token indices.

The update method takes in the top-k predicted token indices and the true token indices for each example in the batch. The compute method returns the Precision@k metric result for the accumulated batches, where a correct prediction is considered if the true token is present anywhere in the top-k predictions.

Surprisal: Computes the Surprisal metric for a batch of estimated probabilities vs true token indices.

The update method takes in the estimated probabilities and the true token indices for each example in the batch. The compute method returns the average Surprisal metric result for the accumulated batches.

These classes can be used to evaluate the performance of language models by measuring the accuracy of their predictions and the surprise factor of the true tokens given the estimated probabilities.

Module Contents

Classes

PrecisionAtKMetric

Compute Precision@k metric for a batch of estimated probabilities vs true token indices.

SurprisalMetric

Compute Surprisal metric for a batch of estimated probabilities vs true token indices.

class obvs.metrics.PrecisionAtKMetric(topk=10, dist_sync_on_step=False, batch_size=None)

Bases: torchmetrics.Metric

Compute Precision@k metric for a batch of estimated probabilities vs true token indices. The update method takes in the top-k predicted token indices and the true token indices for each example in the batch. The compute method returns the Precision@k metric result for the accumulated batches, where a correct prediction is considered if the true token is present anywhere in the top-k predictions.

update(logits, true_token_index) None

Override this method to update the state variables of your metric class.

static batch(logits, true_token_index, topk) torch.Tensor
compute() torch.Tensor

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

class obvs.metrics.SurprisalMetric(dist_sync_on_step=False, batch_size=None)

Bases: torchmetrics.Metric

Compute Surprisal metric for a batch of estimated probabilities vs true token indices. The update method takes in the estimated probabilities and the true token indices for each example in the batch. The compute method returns the average Surprisal metric result for the accumulated batches.

update(logits, true_token_index) None

Override this method to update the state variables of your metric class.

static batch(logits, true_token_index) torch.Tensor
compute() torch.Tensor

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.