1212
1313import torch
1414from torch .nn import Module
15- from torch .nn .modules .utils import consume_prefix_in_state_dict_if_present
1615from typing_extensions import override
1716
1817from fairseq2 .assets import (
@@ -264,51 +263,11 @@ def load_config(self, card: AssetCard) -> object:
264263 def create (
265264 self , config : object , gangs : Gangs , dtype : DataType , meta : bool
266265 ) -> Module :
267- if meta :
268- if not self .supports_meta :
269- raise NotSupportedError (
270- f"The '{ self ._family } ' model family does not support meta device initialization."
271- )
272-
273- device = META
274- elif gangs .root .size != gangs .dp .size :
275- device = CPU # Avoid OOM for sharded models.
276- else :
277- device = gangs .root .device
278-
279266 config = structure (config , self ._configs .config_kls )
280267
281268 validate (config )
282269
283- original_dtype = torch .get_default_dtype ()
284-
285- try :
286- torch .set_default_dtype (dtype )
287-
288- with device :
289- model = self ._factory (config )
290- except NotImplementedError as ex :
291- if "'Meta' backend" not in str (ex ):
292- raise
293-
294- raise ContractError (
295- "One or more operators in the model constructor have failed to initialize on the meta device. See the nested exception for details."
296- ) from ex
297- finally :
298- torch .set_default_dtype (original_dtype )
299-
300- if gangs .root .size != gangs .dp .size :
301- if self ._sharder is None :
302- raise NotSupportedError (
303- f"The '{ self ._family } ' model family does not support non-data parallelism."
304- )
305-
306- self ._sharder (model , config , gangs )
307-
308- if not meta and device != gangs .root .device :
309- to_device (model , gangs .root .device )
310-
311- return model
270+ return self ._do_create (config , gangs , dtype , meta )
312271
313272 @override
314273 def load (
@@ -336,10 +295,8 @@ def load(
336295 except AssetCardError as ex :
337296 raise model_asset_card_error (model_name ) from ex
338297
339- shard_idx = gangs .tp .rank if num_shards > 1 else None
340-
341298 path = self ._asset_download_manager .download_checkpoint (
342- checkpoint_uri , model_name , shard_idx = shard_idx
299+ checkpoint_uri , model_name , shard_idx = gangs . tp . rank
343300 )
344301
345302 # Load the configuration.
@@ -394,6 +351,10 @@ def load_from_path(
394351 "`gangs` must be on a real device, but is on the meta device instead."
395352 )
396353
354+ config = structure (config , self ._configs .config_kls )
355+
356+ validate (config )
357+
397358 if restrict is None :
398359 restrict = self ._restrict
399360
@@ -421,7 +382,7 @@ def load_from_path(
421382 ) from ex
422383
423384 # Create the model.
424- model = self .create (config , gangs , dtype , meta = self .supports_meta )
385+ model = self ._do_create (config , gangs , dtype , meta = self .supports_meta )
425386
426387 if self .supports_meta :
427388 # Move the model to the actual device without initializing. Its
@@ -448,9 +409,6 @@ def load_from_path(
448409 model_name , f"The model state dictionary in the '{ model_name } ' checkpoint is expected to be of type `dict`, but is of type `{ type (state_dict )} ` instead." # fmt: skip
449410 )
450411
451- # Remove DDP 'module' prefix.
452- consume_prefix_in_state_dict_if_present (state_dict , prefix = "module." )
453-
454412 try :
455413 load_state_dict (model , state_dict )
456414 except (KeyError , ValueError ) as ex :
@@ -465,6 +423,51 @@ def load_from_path(
465423
466424 return model
467425
426+ def _do_create (
427+ self , config : object , gangs : Gangs , dtype : DataType , meta : bool
428+ ) -> Module :
429+ if meta :
430+ if not self .supports_meta :
431+ raise NotSupportedError (
432+ f"The '{ self ._family } ' model family does not support meta device initialization."
433+ )
434+
435+ device = META
436+ elif gangs .root .size != gangs .dp .size :
437+ device = CPU # Avoid OOM for sharded models.
438+ else :
439+ device = gangs .root .device
440+
441+ original_dtype = torch .get_default_dtype ()
442+
443+ try :
444+ torch .set_default_dtype (dtype )
445+
446+ with device :
447+ model = self ._factory (config )
448+ except NotImplementedError as ex :
449+ if "'Meta' backend" not in str (ex ):
450+ raise
451+
452+ raise ContractError (
453+ "One or more operators in the model constructor have failed to initialize on the meta device. See the nested exception for details."
454+ ) from ex
455+ finally :
456+ torch .set_default_dtype (original_dtype )
457+
458+ if gangs .root .size != gangs .dp .size :
459+ if self ._sharder is None :
460+ raise NotSupportedError (
461+ f"The '{ self ._family } ' model family does not support non-data parallelism."
462+ )
463+
464+ self ._sharder (model , config , gangs )
465+
466+ if not meta and device != gangs .root .device :
467+ to_device (model , gangs .root .device )
468+
469+ return model
470+
468471 def compile (self , model : Module , config : object ) -> Module :
469472 if self ._torch_compiler is None :
470473 raise NotSupportedError (
0 commit comments