Skip to content

Commit 343be99

Browse files
committed
Nit improvements in model loading
1 parent 6ac4241 commit 343be99

File tree

4 files changed

+71
-77
lines changed

4 files changed

+71
-77
lines changed

src/fairseq2/assets/_download_manager.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def download_checkpoint(
3737
uri: str,
3838
model_name: str,
3939
*,
40-
shard_idx: int | None = None,
40+
shard_idx: int,
4141
force: bool = False,
4242
progress: bool = True,
4343
) -> Path:
@@ -62,18 +62,15 @@ def download_checkpoint(
6262
def download_tokenizer(
6363
self,
6464
uri: str,
65-
model_name: str,
65+
tokenizer_name: str,
6666
*,
67-
tokenizer_name: str | None = None,
6867
force: bool = False,
6968
progress: bool = True,
7069
) -> Path:
7170
"""Download the tokenizer at ``uri`` to the asset cache directory.
7271
7372
:param uri:
7473
The URI to download from.
75-
:param model_name:
76-
The name of the associated model.
7774
:param tokenizer_name:
7875
The name of the tokenizer.
7976
:param force:
@@ -129,15 +126,12 @@ def download_checkpoint(
129126
uri: str,
130127
model_name: str,
131128
*,
132-
shard_idx: int | None = None,
129+
shard_idx: int = 0,
133130
force: bool = False,
134131
progress: bool = True,
135132
) -> Path:
136133
display_name = f"checkpoint of {model_name}"
137134

138-
if shard_idx is not None:
139-
display_name = f"{display_name} (shard {shard_idx})"
140-
141135
op = _AssetDownloadOp(
142136
self._cache_dir, uri, display_name, force, progress, shard_idx
143137
)
@@ -148,16 +142,12 @@ def download_checkpoint(
148142
def download_tokenizer(
149143
self,
150144
uri: str,
151-
model_name: str,
145+
tokenizer_name: str,
152146
*,
153-
tokenizer_name: str | None = None,
154147
force: bool = False,
155148
progress: bool = True,
156149
) -> Path:
157-
if not tokenizer_name:
158-
display_name = f"tokenizer of {model_name}"
159-
else:
160-
display_name = f"{tokenizer_name} tokenizer of {model_name}"
150+
display_name = f"{tokenizer_name} tokenizer"
161151

162152
op = _AssetDownloadOp(self._cache_dir, uri, display_name, force, progress)
163153

@@ -187,7 +177,7 @@ class _AssetDownloadOp:
187177
_display_name: str
188178
_force: bool
189179
_progress: bool
190-
_shard_idx: int | None
180+
_shard_idx: int
191181

192182
def __init__(
193183
self,
@@ -196,7 +186,7 @@ def __init__(
196186
display_name: str,
197187
force: bool,
198188
progress: bool,
199-
shard_idx: int | None = None,
189+
shard_idx: int = 0,
200190
) -> None:
201191
self._cache_dir = cache_dir
202192
self._uri = uri
@@ -266,14 +256,12 @@ def _process_uri(self) -> None:
266256
self._uri = parsed_uri._replace(params="").geturl()
267257

268258
def _format_uri_with_shard_index(self) -> None:
269-
if self._shard_idx is None:
270-
return
271-
272259
sharded_uri = self._uri.replace("%7Bshard_idx%7D", str(self._shard_idx))
273-
if sharded_uri == self._uri:
274-
raise AssetDownloadError(
275-
f"`shard_idx` is specified, but the {self._display_name} is not sharded."
276-
)
260+
if self._shard_idx > 1:
261+
if sharded_uri == self._uri:
262+
raise AssetDownloadError(
263+
f"`shard_idx` is specified, but the {self._display_name} is not sharded."
264+
)
277265

278266
self._uri = sharded_uri
279267

src/fairseq2/models/_handler.py

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
import torch
1414
from torch.nn import Module
15-
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
1615
from typing_extensions import override
1716

1817
from fairseq2.assets import (
@@ -264,51 +263,11 @@ def load_config(self, card: AssetCard) -> object:
264263
def create(
265264
self, config: object, gangs: Gangs, dtype: DataType, meta: bool
266265
) -> Module:
267-
if meta:
268-
if not self.supports_meta:
269-
raise NotSupportedError(
270-
f"The '{self._family}' model family does not support meta device initialization."
271-
)
272-
273-
device = META
274-
elif gangs.root.size != gangs.dp.size:
275-
device = CPU # Avoid OOM for sharded models.
276-
else:
277-
device = gangs.root.device
278-
279266
config = structure(config, self._configs.config_kls)
280267

281268
validate(config)
282269

283-
original_dtype = torch.get_default_dtype()
284-
285-
try:
286-
torch.set_default_dtype(dtype)
287-
288-
with device:
289-
model = self._factory(config)
290-
except NotImplementedError as ex:
291-
if "'Meta' backend" not in str(ex):
292-
raise
293-
294-
raise ContractError(
295-
"One or more operators in the model constructor have failed to initialize on the meta device. See the nested exception for details."
296-
) from ex
297-
finally:
298-
torch.set_default_dtype(original_dtype)
299-
300-
if gangs.root.size != gangs.dp.size:
301-
if self._sharder is None:
302-
raise NotSupportedError(
303-
f"The '{self._family}' model family does not support non-data parallelism."
304-
)
305-
306-
self._sharder(model, config, gangs)
307-
308-
if not meta and device != gangs.root.device:
309-
to_device(model, gangs.root.device)
310-
311-
return model
270+
return self._do_create(config, gangs, dtype, meta)
312271

313272
@override
314273
def load(
@@ -336,10 +295,8 @@ def load(
336295
except AssetCardError as ex:
337296
raise model_asset_card_error(model_name) from ex
338297

339-
shard_idx = gangs.tp.rank if num_shards > 1 else None
340-
341298
path = self._asset_download_manager.download_checkpoint(
342-
checkpoint_uri, model_name, shard_idx=shard_idx
299+
checkpoint_uri, model_name, shard_idx=gangs.tp.rank
343300
)
344301

345302
# Load the configuration.
@@ -394,6 +351,10 @@ def load_from_path(
394351
"`gangs` must be on a real device, but is on the meta device instead."
395352
)
396353

354+
config = structure(config, self._configs.config_kls)
355+
356+
validate(config)
357+
397358
if restrict is None:
398359
restrict = self._restrict
399360

@@ -421,7 +382,7 @@ def load_from_path(
421382
) from ex
422383

423384
# Create the model.
424-
model = self.create(config, gangs, dtype, meta=self.supports_meta)
385+
model = self._do_create(config, gangs, dtype, meta=self.supports_meta)
425386

426387
if self.supports_meta:
427388
# Move the model to the actual device without initializing. Its
@@ -448,9 +409,6 @@ def load_from_path(
448409
model_name, f"The model state dictionary in the '{model_name}' checkpoint is expected to be of type `dict`, but is of type `{type(state_dict)}` instead." # fmt: skip
449410
)
450411

451-
# Remove DDP 'module' prefix.
452-
consume_prefix_in_state_dict_if_present(state_dict, prefix="module.")
453-
454412
try:
455413
load_state_dict(model, state_dict)
456414
except (KeyError, ValueError) as ex:
@@ -465,6 +423,51 @@ def load_from_path(
465423

466424
return model
467425

426+
def _do_create(
427+
self, config: object, gangs: Gangs, dtype: DataType, meta: bool
428+
) -> Module:
429+
if meta:
430+
if not self.supports_meta:
431+
raise NotSupportedError(
432+
f"The '{self._family}' model family does not support meta device initialization."
433+
)
434+
435+
device = META
436+
elif gangs.root.size != gangs.dp.size:
437+
device = CPU # Avoid OOM for sharded models.
438+
else:
439+
device = gangs.root.device
440+
441+
original_dtype = torch.get_default_dtype()
442+
443+
try:
444+
torch.set_default_dtype(dtype)
445+
446+
with device:
447+
model = self._factory(config)
448+
except NotImplementedError as ex:
449+
if "'Meta' backend" not in str(ex):
450+
raise
451+
452+
raise ContractError(
453+
"One or more operators in the model constructor have failed to initialize on the meta device. See the nested exception for details."
454+
) from ex
455+
finally:
456+
torch.set_default_dtype(original_dtype)
457+
458+
if gangs.root.size != gangs.dp.size:
459+
if self._sharder is None:
460+
raise NotSupportedError(
461+
f"The '{self._family}' model family does not support non-data parallelism."
462+
)
463+
464+
self._sharder(model, config, gangs)
465+
466+
if not meta and device != gangs.root.device:
467+
to_device(model, gangs.root.device)
468+
469+
return model
470+
468471
def compile(self, model: Module, config: object) -> Module:
469472
if self._torch_compiler is None:
470473
raise NotSupportedError(

src/fairseq2/recipes/common/_distributed.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch import Tensor
1414
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1515
from torch.nn import Module
16+
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
1617
from torch.nn.parallel import DistributedDataParallel as DDP
1718
from torch.optim import Optimizer
1819
from typing_extensions import override
@@ -113,7 +114,11 @@ def clip_gradient_norm(self, max_norm: float | None) -> Tensor:
113114

114115
@override
115116
def state_dict(self) -> dict[str, object]:
116-
return self._ddp.state_dict()
117+
state_dict = self._ddp.state_dict()
118+
119+
consume_prefix_in_state_dict_if_present(state_dict, prefix="module.")
120+
121+
return state_dict
117122

118123
@override
119124
def optim_state_dict(self, optim: Optimizer) -> dict[str, object]:

src/fairseq2/utils/structured.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,7 @@ def _structure_dataclass(
169169
)
170170

171171
if isinstance(obj, kls):
172-
values = {f.name: getattr(obj, f.name) for f in fields(kls)}
173-
174-
return self._create_dataclass(kls, values, set_empty)
172+
return obj
175173

176174
if isinstance(obj, Mapping):
177175
values = self.structure(obj, dict[str, object])

0 commit comments

Comments
 (0)