base
Base class for packaged Stained Glass Transforms.
Classes:
| Name | Description |
|---|---|
StainedGlassTransform |
Base class for all packaged Stained Glass Transforms. |
StainedGlassTransform
¶
Bases: Module, ABC, ModelHubMixin, Generic[NoisyModelT, NoiseLayerT]
Base class for all packaged Stained Glass Transforms.
This class provides a common, simple interface for performing inference with pretrained Stained Glass Transforms, as well as managing/loading/saving to the Hugging Face Hub and to disk. It also provides utilities for inferring the minimal parameters of the client and for optimizing the loading and saving of the Stained Glass Transform by only loading and saving the necessary submodules.
Methods:
| Name | Description |
|---|---|
__getstate__ |
Return a json-serializable dictionary representing the state of the Stained Glass Transform. |
__init__ |
Initialize the Stained Glass Transform base client. |
__repr__ |
Safe string representation of the Stained Glass Transform module. |
__setstate__ |
Set the state of the Stained Glass Transform from a dictionary. |
forward |
Apply the Stained Glass Transform to an input. |
from_pretrained |
Load the client from the given path. |
generate_model_card |
Generate model card from instance model card metadata and class templates. |
infer_minimal_parameters |
Infer the minimal parameters of the client, excluding parameters not needed for the client. |
infer_minimal_submodules |
Infer the minimal set of submodules that are needed to apply the Stained Glass Transform. |
manual_seed |
Set seed to enable/disable reproducible behavior. |
override_runtime_config_during_load |
Override any attributes of the runtime config during loading of the Stained Glass Transform. |
push_to_hub |
Upload model checkpoint to the Hub. |
save_pretrained |
Save the client to the given path. |
state_dict |
Get the state dictionary of the client, excluding parameters not needed for the client. |
Attributes:
| Name | Type | Description |
|---|---|---|
noise_layer |
NoiseLayerT
|
Alias for the contained noise layer. |
noisy_model |
NoisyModelT
|
Alias for the contained |
parameter_names_relative_to_client |
list[str]
|
Get the minimal parameters of the client, excluding parameters not needed for the client. |
parameter_names_to_remove_relative_to_client |
list[str]
|
Get the parameters to ignore when saving the client, excluding parameters not needed for the client. |
stainedglass_core_version |
str | None
|
Get the version of Stained Glass Core used to save the Stained Glass Transform. |
noisy_model
property
¶
Alias for the contained NoisyModel.
Warning
A deserialized Stained Glass Transform may not include the complete base-model parameters. Calling the underlying noisy model referenced in this property can fail.
parameter_names_relative_to_client
property
¶
Get the minimal parameters of the client, excluding parameters not needed for the client.
This property will first check if self.parameter_names_relative_to_base_model is set (this is usually set via the
parameter_names argument in the __init__ method). If it is, then it will return the parameters defined there, but with the
submodule names changed to be relative to the client.
If self.parameter_names_relative_to_base_model is not set, then it will return the parameters inferred by the
infer_minimal_parameters method's most recent call. This requires that the infer_minimal_parameters method has been called at
least once before accessing this property.
Note
self.parameter_names_relative_to_base_model, if specified, will override the inferred parameters in calculating this property.
Returns:
| Type | Description |
|---|---|
list[str]
|
The minimal parameters of the client, excluding parameters not needed for the client. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the minimal parameters of the base model have not been specified manually or inferred automatically. |
parameter_names_to_remove_relative_to_client
property
¶
Get the parameters to ignore when saving the client, excluding parameters not needed for the client.
This is effectively the set of all parameters in the client that are not in parameter_names_relative_to_client, considering
duplicate parameters shared by multiple modules (and thus can be accessed by multiple names).
Returns:
| Type | Description |
|---|---|
list[str]
|
The parameters to ignore when saving the client, excluding parameters not needed for the client. |
stainedglass_core_version
property
¶
stainedglass_core_version: str | None
Get the version of Stained Glass Core used to save the Stained Glass Transform.
Returns:
| Type | Description |
|---|---|
str | None
|
The version of Stained Glass Core used to save the Stained Glass Transform. |
__getstate__
abstractmethod
¶
Return a json-serializable dictionary representing the state of the Stained Glass Transform.
The output of this method should be able to be passed into __setstate__ to reconstruct the Stained Glass Transform.
__init__
¶
__init__(
model: ~NoisyModelT,
noise_layer_type: type[~NoiseLayerT],
parameter_names: list[str] | None = None,
include_all_base_model_params: bool = False,
name: str | None = None,
model_card_data: ModelCardData | None = None,
) -> None
Initialize the Stained Glass Transform base client.
Warning
The constructor will automatically infer the minimal base model parameters required to calculate the base model's input
embeddings. This requires a forward pass and assumes the model has a static computational graph. If you want to manually specify
the minimal parameters, you can pass in the parameter_names argument. Note, however, that you must specify all of the
parameters necessary to calculate the base model's input embeddings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
~NoisyModelT
|
The NoisyModel used to train Stained Glass Transform. |
required |
|
type[~NoiseLayerT]
|
The type of the noise layer used in the Stained Glass Transform. This is primarily used for type safety. |
required |
|
list[str] | None
|
Parameters of the base model to be saved and loaded during serialization and deserialization. This should be
the minimal list of parameters necessary to get the base model's input embeddings. If |
None
|
|
bool
|
Whether to include all base model parameters in the client. If |
False
|
|
str | None
|
The name of the StainedGlassTransform. This is used to identify the transform when saving and loading. |
None
|
|
ModelCardData | None
|
Optional model card data to associate with the Stained Glass Transform. Useful for providing metadata when
sharing the transform on the Hugging Face Hub. Follow the documentation on
Model Cards and
|
None
|
__repr__
¶
Safe string representation of the Stained Glass Transform module.
The Stained Glass Transform operates with None weight valued submodules
which are not needed for the SGT computation. Standard nn.Module.__repr__
calls extra_repr on each submodule, which may attempt to access weight
attributes and raise an AttributeError. This override temporarily patches
nn.Module.__repr__ to gracefully handle those cases for the duration of
the call.
Returns:
| Type | Description |
|---|---|
<class 'str'>
|
A string representation of the Stained Glass Transform module. |
__setstate__
abstractmethod
¶
Set the state of the Stained Glass Transform from a dictionary.
The input dictionary should be in the format produced by __getstate__.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
dict[str, Any]
|
A dictionary representing the state of the Stained Glass Transform. |
required |
|
Any
|
Additional keyword arguments that may be needed to set the state. |
required |
forward
abstractmethod
¶
Apply the Stained Glass Transform to an input.
This method should be implemented by subclasses to define the forward pass of the Stained Glass Transform.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Any
|
Positional arguments for the forward pass. |
required |
|
Any
|
Keyword arguments for the forward pass. |
required |
Returns:
| Type | Description |
|---|---|
typing.Any
|
The output of the forward pass. |
from_pretrained
classmethod
¶
from_pretrained(
pretrained_model_name_or_path: str | Path,
map_location: device | str | None = None,
index_file_name: str | None = None,
dtype: str | dtype | None = None,
noise_layer_attention: (
Literal[
"sdpa",
"flash_attention_2",
"flex_attention",
"transformers_default",
]
| None
) = None,
third_party_model_path: (
str | PathLike[str] | None
) = None,
*,
force_download: bool = False,
resume_download: bool | None = None,
proxies: bool | dict[Any, Any] | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
trust_remote_code: bool = False,
**model_kwargs: Any
) -> typing.Self
Load the client from the given path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str | Path
|
The path to load the client from. This can be a path to a |
required |
|
device | str | None
|
The location to map the client to. See torch.device for more information. |
None
|
|
str | None
|
The name of the index file to use within the zipfile. If None, the default index file name will be used. |
None
|
|
str | dtype | None
|
The dtype, either as a string or a |
None
|
|
Literal['sdpa', 'flash_attention_2', 'flex_attention', 'transformers_default'] | None
|
The attention type to use for the noise layer. If None, the default attention type will be used. |
None
|
|
str | PathLike[str] | None
|
The path or huggingface reference to a third-party model to load. This is useful when loading SGTs whose internal structure depends on transformers which are not importable directly through transformers, but are present on the Hugging Face Hub. |
None
|
|
bool
|
Whether to force the download of the client. If False, the client will be downloaded if it is not already present in the cache. |
False
|
|
bool | None
|
Unused. Required for compatibility with the Hugging Face Hub API. |
None
|
|
bool | dict[Any, Any] | None
|
Unused. Required for compatibility with the Hugging Face Hub API. |
None
|
|
str | bool | None
|
The token to use for authentication with the Hugging Face Hub API. |
None
|
|
str | Path | None
|
The directory to use for caching the client. If None, the default cache directory will be used. |
None
|
|
bool
|
Whether to only use local files and not attempt to download the client. If True, an error will be raised if the client is not present in the cache. |
False
|
|
str | None
|
The revision of the client to use. This can be a branch name, tag name, or commit hash. If None, the default revision will be used. |
None
|
|
bool
|
Whether to trust remote code when loading from HuggingFace Hub. |
False
|
|
Any
|
Unused. Required for compatibility with the Hugging Face Hub API. |
required |
Returns:
| Type | Description |
|---|---|
typing.Self
|
The loaded client. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If any |
IsADirectoryError
|
If the specified path is a directory, but a .sgt file path is required. |
generate_model_card
¶
Generate model card from instance model card metadata and class templates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
Any
|
Positional arguments to huggingface_hub.ModelCard.from_template. Unused (because all arguments are passed by keyword). |
required |
|
Any
|
Keyword arguments to the template_str passed to huggingface_hub.ModelCard.from_template. |
required |
Returns:
| Type | Description |
|---|---|
<class 'huggingface_hub.repocard.ModelCard'>
|
Generated ModelCard object. |
Changed in version v2.8.0: Automatically generated model card files now respect instance model card metadata.
infer_minimal_parameters
¶
Infer the minimal parameters of the client, excluding parameters not needed for the client.
This method will infer the minimal parameters of the client by tracing a forward pass through the model. This is useful when the minimal parameters are not known ahead of time.
Raises:
| Type | Description |
|---|---|
ValueError
|
If the minimal parameters of the client have been specified |
infer_minimal_submodules
abstractmethod
¶
Infer the minimal set of submodules that are needed to apply the Stained Glass Transform.
This method should return a list of submodule names that are needed to apply the Stained Glass Transform. This is used to optimize the loading and saving of the Stained Glass Transform by only loading and saving the necessary submodules.
Returns:
| Type | Description |
|---|---|
list[str]
|
A list of submodule names that are needed to apply the Stained Glass Transform. |
manual_seed
¶
manual_seed(
seed: int | None, rank_dependent: bool = True
) -> None
Set seed to enable/disable reproducible behavior.
Setting seed to None will disable reproducible behavior.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
int | None
|
Value to seed into the random number generator. |
required |
|
bool
|
Whether to add the distributed rank to the seed to ensure that each process samples different noise. |
True
|
override_runtime_config_during_load
abstractmethod
staticmethod
¶
override_runtime_config_during_load(
runtime_config: dict[str, Any], **kwargs: Any
) -> dict[str, typing.Any]
Override any attributes of the runtime config during loading of the Stained Glass Transform.
This is normally used to override add any backward compatibility fixes to the output of __getstate__ or for
embedding any additional kwargs into the config passed to __setstate__ during loading.
This is used during from_pretrained after loading the SGT config from file, but before passing it to
__setstate__ to construct the SGT.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
dict[str, Any]
|
The runtime config loaded from the SGT config file. |
required |
|
Any
|
Additional keyword arguments that may be needed to override the runtime config. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, typing.Any]
|
The overridden runtime config to be passed to |
push_to_hub
¶
push_to_hub(repo_id: str, *, config: dict | DataclassInstance | None = None, commit_message: str = 'Upload using stainedglass_core.', private: bool | None = None, token: str | None = None, branch: str | None = None, create_pr: bool | None = None, allow_patterns: list[str] | str | None = None, ignore_patterns: list[str] | str | None = None, delete_patterns: list[str] | str | None = None, model_card_kwargs: dict[str, Any] | None = None) -> <class 'str'>
Upload model checkpoint to the Hub.
Warning
This method is currently not supported on StainedGlassTransform. Instead use save_pretrained with push_to_hub=True.
Use allow_patterns and ignore_patterns to precisely filter which files should be pushed to the hub. Use
delete_patterns to delete existing remote files in the same commit. See [upload_folder] reference for more
details.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
ID of the repository to push to (example: |
required |
|
dict | DataclassInstance | None
|
Model configuration specified as a key/value dictionary or a dataclass instance. |
None
|
|
str
|
Message to commit while pushing. |
'Upload using stainedglass_core.'
|
|
bool | None
|
Whether the repository created should be private.
If |
None
|
|
str | None
|
The token to use as HTTP bearer authorization for remote files. By default, it will use the token
cached when running |
None
|
|
str | None
|
The git branch on which to push the model. This defaults to |
None
|
|
bool | None
|
Whether or not to create a Pull Request from |
None
|
|
list[str] | str | None
|
If provided, only files matching at least one pattern are pushed. |
None
|
|
list[str] | str | None
|
If provided, files matching any of the patterns are not pushed. |
None
|
|
list[str] | str | None
|
If provided, remote files matching any of the patterns will be deleted from the repo. |
None
|
|
dict[str, Any] | None
|
Additional arguments passed to the model card template to customize the model card. |
None
|
Returns:
| Type | Description |
|---|---|
<class 'str'>
|
The url of the commit of your model in the given repository. |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
This method is not implemented. |
save_pretrained
¶
save_pretrained(
save_directory: str | Path,
*,
compression: int = 8,
push_to_hub: bool = False,
repo_id: str | None = None,
private: bool = True,
config: dict | DataclassInstance | None = None,
model_card_kwargs: dict[str, Any] | None = None,
**push_to_hub_kwargs: Any
) -> None
Save the client to the given path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str | Path
|
The path to save the client to. Although this is called |
required |
|
int
|
The compression method to use for the ZIP file. Defaults to zipfile.ZIP_DEFLATED, but this can cause very slow serialization times. If serialization times are a problem, use zipfile.ZIP_STORED instead. |
8
|
|
bool
|
Whether to push the client to the Hugging Face Hub. |
False
|
|
str | None
|
The repository ID to push the client to. This is required if |
None
|
|
bool
|
Whether to make the repository private. This is only used if |
True
|
|
dict | DataclassInstance | None
|
Unused. Required for compatibility with the Hugging Face Hub API. |
None
|
|
dict[str, Any] | None
|
The kwargs to pass to the model card generator. This is only used if |
None
|
|
Any
|
The kwargs to pass to the |
required |
Raises:
| Type | Description |
|---|---|
IsADirectoryError
|
If a directory is passed in. |
ValueError
|
If |
UserWarning
|
If |
compression
|
The compression method to use for the ZIP file. Defaults to zipfile.ZIP_DEFLATED, but this can cause very slow serialization times. If serialization times are a problem, use zipfile.ZIP_STORED instead. |
Examples:
Uploading a Stained Glass Transform zipfile to the Hugging Face Hub (note that this will also create a local copy of the SGT zipfile):
>>> from stainedglass_core.transform import text
>>> sgt = text.StainedGlassTransformForText.from_pretrained(
... "path/to/sgt_file.sgt"
... )
>>> sgt.save_pretrained(
... "new-sgt-zipfile.sgt",
... push_to_hub=True,
... repo_id="username/new-sgt-repo",
... )
Optionally, you can override any model card metadata before uploading to the Hub. This can be useful for specifying the base model and datasets used for training Stained Glass Transform. You can also specify additional metadata such as eval_results. See huggingface_hub.ModelCardData for more details on the available fields.
>>> sgt.model_card_data.base_model = (
... "meta-llama/Llama-3.1-8B-Instruct"
... )
>>> sgt.model_card_data.__dict__["base_model_relation"] = (
... "adapter"
... )
>>> sgt.model_card_data.datasets = ["Open-Orca/OpenOrca"]
>>> sgt.save_pretrained(
... "new-sgt-zipfile.sgt",
... push_to_hub=True,
... repo_id="username/new-sgt-repo",
... )
Changed in version v2.8.0: Added ability to push Stained Glass Transform to the Hugging Face Hub. BREAKING CHANGE: Argument `path` was renamed `save_directory` for compatibility with ModelHubMixin.save_pretrained
Changed in version v2.20.3: The model safetensors filename was changed for better compatibility with the Hugging Face Hub. This has no practical effect on saving or loading.
state_dict
¶
Get the state dictionary of the client, excluding parameters not needed for the client.
The parameters considered necessary for the client are those passed into the constructor as parameter_names.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
|
str
|
A prefix added to parameter and buffer names to compose the keys in state_dict. |
''
|
|
bool
|
By default the |
False
|
Returns:
| Type | Description |
|---|---|
dict[str, typing.Any]
|
The state dictionary of the client, excluding parameters not needed for the client. |