Skip to content

Commit ac7b097

Browse files
cbaliogluMartinGleize
authored andcommitted
Add OPT architecture
1 parent b22af06 commit ac7b097

File tree

9 files changed

+359
-1
lines changed

9 files changed

+359
-1
lines changed

src/fairseq2/composition/models.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@
6767
create_nllb_model,
6868
register_nllb_configs,
6969
)
70+
from fairseq2.models.opt import (
71+
OPT_FAMILY,
72+
OPTConfig,
73+
convert_opt_state_dict,
74+
create_opt_model,
75+
register_opt_configs,
76+
)
7077
from fairseq2.models.qwen import (
7178
QWEN_FAMILY,
7279
QwenConfig,
@@ -296,6 +303,21 @@ def _register_model_families(container: DependencyContainer) -> None:
296303

297304
register_nllb_configs(container)
298305

306+
# OPT
307+
register_model_family(
308+
container,
309+
OPT_FAMILY,
310+
kls=TransformerLM,
311+
config_kls=OPTConfig,
312+
factory=create_opt_model,
313+
state_dict_converter=convert_opt_state_dict,
314+
compiler=compile_transformer_lm,
315+
fsdp_applier=apply_fsdp_to_transformer_lm,
316+
layerwise_ac_applier=apply_ac_to_transformer_lm,
317+
)
318+
319+
register_opt_configs(container)
320+
299321
# Qwen
300322
register_model_family(
301323
container,
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from fairseq2.models.opt.config import OPT_FAMILY as OPT_FAMILY
10+
from fairseq2.models.opt.config import OPTConfig as OPTConfig
11+
from fairseq2.models.opt.config import register_opt_configs as register_opt_configs
12+
from fairseq2.models.opt.factory import OPTFactory as OPTFactory
13+
from fairseq2.models.opt.factory import create_opt_model as create_opt_model
14+
from fairseq2.models.opt.hub import get_opt_model_hub as get_opt_model_hub
15+
from fairseq2.models.opt.interop import convert_opt_state_dict as convert_opt_state_dict

src/fairseq2/models/opt/config.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from dataclasses import dataclass
10+
from typing import Final
11+
12+
from fairseq2.runtime.config_registry import ConfigRegistrar
13+
from fairseq2.runtime.dependency import DependencyContainer
14+
15+
OPT_FAMILY: Final = "opt"
16+
17+
18+
@dataclass(kw_only=True)
19+
class OPTConfig:
20+
"""Holds the configuration of a OPT model.
21+
22+
The default values correspond to the base architecture as described in
23+
:cite:t:`https://arxiv.org/abs/2205.01068`.
24+
"""
25+
26+
model_dim: int = 768
27+
"""The dimensionality of the model."""
28+
29+
max_seq_len: int = 2048 + 1
30+
"""The maximum sequence length."""
31+
32+
vocab_size: int = 50272
33+
"""The size of the vocabulary."""
34+
35+
pad_idx: int | None = 1
36+
"""The index of the PAD symbol in the vocabulary."""
37+
38+
attn_window_len: int = 2048
39+
"""The local attention window length."""
40+
41+
num_layers: int = 12
42+
"""The number of decoder layers."""
43+
44+
num_attn_heads: int = 12
45+
"""The number of attention heads in decoder layers."""
46+
47+
num_key_value_heads: int = 12
48+
"""The number of key/value heads for Grouped Query Attention."""
49+
50+
ffn_inner_dim: int = 3072
51+
"""The dimensionality of inner projection layers in feed-forward networks."""
52+
53+
dropout_p: float = 0.1
54+
"""The dropout probability on outputs of Transformer layers."""
55+
56+
57+
def register_opt_configs(container: DependencyContainer) -> None:
58+
arch = ConfigRegistrar(container, OPTConfig)
59+
60+
@arch("125m")
61+
def _125m() -> OPTConfig:
62+
return OPTConfig()

src/fairseq2/models/opt/factory.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
import torch.nn as nn
10+
11+
from fairseq2.models.opt.config import OPTConfig
12+
from fairseq2.models.transformer import (
13+
CausalAttentionBias,
14+
FeedForwardNetwork,
15+
LocalAttentionStateFactory,
16+
MultiheadAttention,
17+
StandardFeedForwardNetwork,
18+
StandardMultiheadAttention,
19+
TransformerEmbeddingFrontend,
20+
TransformerFrontend,
21+
TransformerNormOrder,
22+
create_default_sdpa,
23+
)
24+
from fairseq2.models.transformer_lm import (
25+
StandardTransformerLMDecoder,
26+
StandardTransformerLMDecoderLayer,
27+
TransformerLM,
28+
TransformerLMDecoder,
29+
TransformerLMDecoderLayer,
30+
)
31+
from fairseq2.nn import (
32+
Embedding,
33+
LayerNorm,
34+
LearnedPositionEncoder,
35+
Linear,
36+
PositionEncoder,
37+
Projection,
38+
StandardEmbedding,
39+
StandardLayerNorm,
40+
)
41+
42+
43+
def create_opt_model(config: OPTConfig) -> TransformerLM:
44+
return OPTFactory(config).create_model()
45+
46+
47+
class OPTFactory:
48+
def __init__(self, config: OPTConfig) -> None:
49+
self._config = config
50+
51+
def create_model(self) -> TransformerLM:
52+
config = self._config
53+
54+
decoder_frontend = self.create_decoder_frontend()
55+
56+
decoder = self.create_decoder()
57+
58+
final_proj = self.create_final_projection()
59+
60+
return TransformerLM(
61+
config.model_dim,
62+
decoder_frontend,
63+
decoder,
64+
final_proj,
65+
config.pad_idx,
66+
config.max_seq_len,
67+
)
68+
69+
def create_decoder_frontend(self) -> TransformerFrontend:
70+
config = self._config
71+
72+
embed = self.create_embedding()
73+
74+
pos_encoder = self.create_position_encoder()
75+
76+
return TransformerEmbeddingFrontend(
77+
config.model_dim,
78+
embed,
79+
pos_encoder=pos_encoder,
80+
no_scale=True,
81+
# dropout_p=config.dropout_p, # TODO: check if there is dropout here
82+
)
83+
84+
def create_embedding(self) -> Embedding:
85+
config = self._config
86+
87+
return StandardEmbedding(config.vocab_size, config.model_dim, config.pad_idx)
88+
89+
def create_decoder(self) -> TransformerLMDecoder:
90+
config = self._config
91+
92+
layers = []
93+
94+
for _ in range(config.num_layers):
95+
layer = self.create_decoder_layer()
96+
97+
layers.append(layer)
98+
99+
layer_norm = self.create_layer_norm()
100+
101+
return StandardTransformerLMDecoder(layers, layer_norm)
102+
103+
def create_position_encoder(self) -> PositionEncoder:
104+
config = self._config
105+
106+
return LearnedPositionEncoder(
107+
config.model_dim, config.max_seq_len, _legacy_pad_idx=1
108+
)
109+
110+
def create_decoder_layer(self) -> TransformerLMDecoderLayer:
111+
config = self._config
112+
113+
self_attn = self.create_self_attention()
114+
115+
self_attn_layer_norm = self.create_layer_norm()
116+
117+
ffn = self.create_ffn()
118+
119+
ffn_layer_norm = self.create_layer_norm()
120+
121+
return StandardTransformerLMDecoderLayer(
122+
self_attn,
123+
self_attn_layer_norm,
124+
ffn,
125+
ffn_layer_norm,
126+
norm_order=TransformerNormOrder.PRE,
127+
dropout_p=config.dropout_p,
128+
)
129+
130+
def create_self_attention(self) -> MultiheadAttention:
131+
config = self._config
132+
133+
attn_bias = CausalAttentionBias(attn_window_len=config.attn_window_len)
134+
135+
sdpa = create_default_sdpa(attn_bias)
136+
137+
state_factory = LocalAttentionStateFactory(config.attn_window_len)
138+
139+
return StandardMultiheadAttention(
140+
config.model_dim,
141+
config.num_attn_heads,
142+
sdpa,
143+
num_key_value_heads=config.num_key_value_heads,
144+
bias=True,
145+
state_factory=state_factory,
146+
)
147+
148+
def create_ffn(self) -> FeedForwardNetwork:
149+
config = self._config
150+
151+
return StandardFeedForwardNetwork(
152+
config.model_dim, config.ffn_inner_dim, bias=True
153+
)
154+
155+
def create_layer_norm(self) -> LayerNorm:
156+
config = self._config
157+
158+
return StandardLayerNorm(config.model_dim, bias=True)
159+
160+
def create_final_projection(self) -> Projection:
161+
config = self._config
162+
163+
return Linear(
164+
config.model_dim,
165+
config.vocab_size,
166+
bias=False,
167+
init_fn=_init_final_projection,
168+
)
169+
170+
171+
def _init_final_projection(proj: Linear) -> None:
172+
nn.init.normal_(proj.weight, std=proj.input_dim**-0.5)
173+
174+
if proj.bias is not None:
175+
nn.init.zeros_(proj.bias)

src/fairseq2/models/opt/hub.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from fairseq2.models import ModelHubAccessor
10+
from fairseq2.models.opt.config import OPT_FAMILY, OPTConfig
11+
from fairseq2.models.transformer_lm import TransformerLM
12+
13+
get_opt_model_hub = ModelHubAccessor(
14+
OPT_FAMILY, kls=TransformerLM, config_kls=OPTConfig
15+
)

src/fairseq2/models/opt/interop.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from fairseq2.models.opt.config import OPTConfig
10+
from fairseq2.models.utils.checkpoint import convert_state_dict
11+
12+
_HG_KEY_MAP = {
13+
# fmt: off
14+
r"^model\.decoder\.embed_tokens\.": r"decoder_frontend.embed.",
15+
r"^model\.decoder\.embed_positions\.": r"decoder_frontend.pos_encoder.",
16+
r"^model\.decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.": r"decoder.layers.\1.self_attn_layer_norm.",
17+
r"^model\.decoder\.layers\.([0-9]+)\.self_attn\.q_proj\.": r"decoder.layers.\1.self_attn.q_proj.",
18+
r"^model\.decoder\.layers\.([0-9]+)\.self_attn\.k_proj\.": r"decoder.layers.\1.self_attn.k_proj.",
19+
r"^model\.decoder\.layers\.([0-9]+)\.self_attn\.v_proj\.": r"decoder.layers.\1.self_attn.v_proj.",
20+
r"^model\.decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"decoder.layers.\1.self_attn.output_proj.",
21+
r"^model\.decoder\.layers\.([0-9]+)\.mlp\.gate_proj\.": r"decoder.layers.\1.ffn.gate_proj.",
22+
r"^model\.decoder\.layers\.([0-9]+)\.fc1\.": r"decoder.layers.\1.ffn.inner_proj.",
23+
r"^model\.decoder\.layers\.([0-9]+)\.fc2\.": r"decoder.layers.\1.ffn.output_proj.",
24+
r"^model\.decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"decoder.layers.\1.ffn_layer_norm.",
25+
r"^model\.decoder\.final_layer_norm\.": r"decoder.layer_norm.",
26+
r"^lm_head\.": r"final_proj.",
27+
# fmt: on
28+
}
29+
30+
31+
def convert_opt_state_dict(
32+
state_dict: dict[str, object], config: OPTConfig
33+
) -> dict[str, object]:
34+
if "model.decoder.embed_tokens.weight" in state_dict: # Hugging Face
35+
state_dict = convert_state_dict(state_dict, _HG_KEY_MAP)
36+
37+
return state_dict

src/fairseq2/nn/position_encoder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def __init__(
225225
encoding_dim: int,
226226
max_seq_len: int,
227227
*,
228+
_legacy_pad_idx: int | None = None,
228229
device: Device | None = None,
229230
dtype: DataType | None = None,
230231
) -> None:
@@ -243,6 +244,10 @@ def __init__(
243244

244245
self.max_seq_len = max_seq_len
245246

247+
# This is a legacy parameter that should only be set when the encodings
248+
# must be compatible with fairseq.
249+
self._legacy_pad_idx = 0 if _legacy_pad_idx is None else _legacy_pad_idx
250+
246251
self.reset_parameters()
247252

248253
def reset_parameters(self) -> None:
@@ -271,7 +276,7 @@ def forward(
271276
f"The lengths of all sequences in `seqs` must be less than or equal to the maximum sequence length ({self.max_seq_len}), but at least one sequence is of length {max_seq_len} instead."
272277
)
273278

274-
indices = seqs_layout.position_indices + 1 # +1 for padding
279+
indices = seqs_layout.position_indices + (1 + self._legacy_pad_idx)
275280

276281
if not self.training and state_bag is not None:
277282
indices = state_bag.step_nr + indices

tests/unit/models/opt/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)