obvs.patchscopeο
π©Ί Patchscope Module.
Implementation of the patchscopes framework: https://arxiv.org/abs/2401.06102
Patchscopes are a powerful way to understand and explain how large language models (LLMs) work on the inside. They use the modelβs own ability to generate text that humans can understand to interpret and explain whatβs happening in the modelβs hidden layers. By using one language model (the target model) to analyze and explain the inner workings of another language model (the source model), patchscopes provide a way to bring together different existing techniques for understanding models while also allowing for new possibilities and uses.
This is the main patchscopes module. It uses some of the functionality from the abstract base class called patchscopes base.
Patchscopes base takes a representation like this:
(S, i, M, β) corresponds to the source from which the original hidden representation is drawn. S is the source input sequence. i is the position within that sequence. NB: We extend the method to allow a range of positions M is the original model that processes the sequence. β is the layer in model M from which the hidden representation is taken.
and patches it to a target context like this:
(T, i*, f, M*, β*) defines the target context for the intervention (patching operation). T is the target prompt, which can be different from the source prompt S or the same. i* is the position in the target prompt that will receive the patched representation. NB: We extend the method to allow a range of positions f is the mapping function that operates on the hidden representation to possibly transform it before it is patched into the target context. It can be a simple identity function or a more complex transformation. M* is the model (which could be the same as M or different) in which the patching operation is performed. β* is the layer in the target model M* where the hidden representation hΜ α΅’Λ‘* will be replaced with the patched version.
The simplest patchscopes base is defined by the following parameters:
S = T i = i* M = M* β = β* f = identity function
this be indistinguishable from a forward pass. The most simple one that does something interesting is the logit lens, where:
β = range(L*) β* = L*
Meaning, we take the hidden representation from each layer of the source model and patch it into the final layer of the target model. This is useful for output prediction tasks, as it allows us to understand how the hidden representations from different layers of the source model contribute to the final output of the target model.
Module Contentsο
Classesο
Source context for the patchscope |
|
Target context for the patchscope |
|
- class obvs.patchscope.SourceContextο
Source context for the patchscope
- property prompt: str | torch.Tensorο
The prompt
- property text_prompt: strο
The text prompt input or generated from soft prompt
- property soft_prompt: torch.Tensor | Noneο
The soft prompt input or None
- _prompt: str | torch.Tensorο
- _text_prompt: strο
- _soft_prompt: torch.Tensor | Noneο
- prompt: str | torch.Tensorο
- position: collections.abc.Sequence[int] | Noneο
- layer: intο
- head: collections.abc.Sequence[int] | Noneο
- model_name: str = 'gpt2'ο
- device: strο
- class obvs.patchscope.TargetContextο
Bases:
SourceContextTarget context for the patchscope Parameters identical to the source context, with the addition of a mapping function and max_new_tokens to control generation length
- mapping_function: collections.abc.Callable[[torch.Tensor], torch.Tensor]ο
- max_new_tokens: int = 10ο
- static from_source(source: SourceContext, mapping_function: collections.abc.Callable[[torch.Tensor], torch.Tensor] | None = None, max_new_tokens: int = 10) TargetContextο
Construct a target context from the source context
- class obvs.patchscope.ModelLoaderο
- static load(model_name: str, device: str) nnsight.LanguageModelο
- class obvs.patchscope.Patchscope(source: SourceContext, target: TargetContext)ο
Bases:
obvs.patchscope_base.PatchscopeBase- REMOTE: bool = Falseο
- source_forward_pass() Noneο
Get the source representation.
We use the βtraceβ context so we can add the REMOTE option.
For each architecture, you need to know the name of the layers.
- manipulate_source() torch.Tensorο
Get the hidden state from the source representation.
NB: This is seperated out from the source_forward_pass method to allow for batching.
- map() Noneο
Apply the mapping function to the source representation
- target_forward_pass() Noneο
Patch the target representation. In order to support multi-token generation, we save the output for max_new_tokens iterations.
We use a the βgenerateβ context which support remote operation and multi-token generation
For each architecture, you need to know the name of the layers.
- manipulate_target() Noneο
- check_patchscope_setup() boolο
Check if patchscope is correctly set-up before running
- run() Noneο
Run the patchscope
- clear() Noneο
Clear the outputs and the cache
- over(source_layers: collections.abc.Sequence[int], target_layers: collections.abc.Sequence[int]) list[torch.Tensor]ο
Run the patchscope over the specified set of layers.
- Parameters:
source_layers β A list of layer indices or a range of layer indices.
target_layers β A list of layer indices or a range of layer indices.
- Returns:
A source_layers x target_layers x max_new_tokens list of outputs.
- over_pairs(source_layers: collections.abc.Sequence[int], target_layers: collections.abc.Sequence[int]) list[torch.Tensor]ο
Run the patchscope over the specified set of layers in pairs :param source_layers: A list of layer indices or a range of layer indices. :param target_layers: A list of layer indices or a range of layer indices. :return: A source_layers x target_layers x max_new_tokens list of outputs.