obvs.patchscope

🩺 Patchscope Module.

Implementation of the patchscopes framework: https://arxiv.org/abs/2401.06102

Patchscopes are a powerful way to understand and explain how large language models (LLMs) work on the inside. They use the model’s own ability to generate text that humans can understand to interpret and explain what’s happening in the model’s hidden layers. By using one language model (the target model) to analyze and explain the inner workings of another language model (the source model), patchscopes provide a way to bring together different existing techniques for understanding models while also allowing for new possibilities and uses.

This is the main patchscopes module. It uses some of the functionality from the abstract base class called patchscopes base.

Patchscopes base takes a representation like this:

(S, i, M, β„“) corresponds to the source from which the original hidden representation is drawn. S is the source input sequence. i is the position within that sequence. NB: We extend the method to allow a range of positions M is the original model that processes the sequence. β„“ is the layer in model M from which the hidden representation is taken.

and patches it to a target context like this:

(T, i*, f, M*, β„“*) defines the target context for the intervention (patching operation). T is the target prompt, which can be different from the source prompt S or the same. i* is the position in the target prompt that will receive the patched representation. NB: We extend the method to allow a range of positions f is the mapping function that operates on the hidden representation to possibly transform it before it is patched into the target context. It can be a simple identity function or a more complex transformation. M* is the model (which could be the same as M or different) in which the patching operation is performed. β„“* is the layer in the target model M* where the hidden representation hΜ…α΅’Λ‘* will be replaced with the patched version.

The simplest patchscopes base is defined by the following parameters:

S = T i = i* M = M* β„“ = β„“* f = identity function

this be indistinguishable from a forward pass. The most simple one that does something interesting is the logit lens, where:

β„“ = range(L*) β„“* = L*

Meaning, we take the hidden representation from each layer of the source model and patch it into the final layer of the target model. This is useful for output prediction tasks, as it allows us to understand how the hidden representations from different layers of the source model contribute to the final output of the target model.

Module Contents

Classes

SourceContext

Source context for the patchscope

TargetContext

Target context for the patchscope

ModelLoader

Patchscope

class obvs.patchscope.SourceContext

Source context for the patchscope

property prompt: str | torch.Tensor

The prompt

property text_prompt: str

The text prompt input or generated from soft prompt

property soft_prompt: torch.Tensor | None

The soft prompt input or None

_prompt: str | torch.Tensor
_text_prompt: str
_soft_prompt: torch.Tensor | None
prompt: str | torch.Tensor
position: collections.abc.Sequence[int] | None
layer: int
head: collections.abc.Sequence[int] | None
model_name: str = 'gpt2'
device: str
class obvs.patchscope.TargetContext

Bases: SourceContext

Target context for the patchscope Parameters identical to the source context, with the addition of a mapping function and max_new_tokens to control generation length

mapping_function: collections.abc.Callable[[torch.Tensor], torch.Tensor]
max_new_tokens: int = 10
static from_source(source: SourceContext, mapping_function: collections.abc.Callable[[torch.Tensor], torch.Tensor] | None = None, max_new_tokens: int = 10) TargetContext

Construct a target context from the source context

class obvs.patchscope.ModelLoader
static load(model_name: str, device: str) nnsight.LanguageModel
class obvs.patchscope.Patchscope(source: SourceContext, target: TargetContext)

Bases: obvs.patchscope_base.PatchscopeBase

REMOTE: bool = False
source_forward_pass() None

Get the source representation.

We use the β€˜trace’ context so we can add the REMOTE option.

For each architecture, you need to know the name of the layers.

manipulate_source() torch.Tensor

Get the hidden state from the source representation.

NB: This is seperated out from the source_forward_pass method to allow for batching.

map() None

Apply the mapping function to the source representation

target_forward_pass() None

Patch the target representation. In order to support multi-token generation, we save the output for max_new_tokens iterations.

We use a the β€˜generate’ context which support remote operation and multi-token generation

For each architecture, you need to know the name of the layers.

manipulate_target() None
check_patchscope_setup() bool

Check if patchscope is correctly set-up before running

run() None

Run the patchscope

clear() None

Clear the outputs and the cache

over(source_layers: collections.abc.Sequence[int], target_layers: collections.abc.Sequence[int]) list[torch.Tensor]

Run the patchscope over the specified set of layers.

Parameters:
  • source_layers – A list of layer indices or a range of layer indices.

  • target_layers – A list of layer indices or a range of layer indices.

Returns:

A source_layers x target_layers x max_new_tokens list of outputs.

over_pairs(source_layers: collections.abc.Sequence[int], target_layers: collections.abc.Sequence[int]) list[torch.Tensor]

Run the patchscope over the specified set of layers in pairs :param source_layers: A list of layer indices or a range of layer indices. :param target_layers: A list of layer indices or a range of layer indices. :return: A source_layers x target_layers x max_new_tokens list of outputs.