Skip to content

Commit 3a09f79

Browse files
committed
Use the string "off" instead of None to disable mixed precision
1 parent 023d111 commit 3a09f79

File tree

5 files changed

+22
-11
lines changed

5 files changed

+22
-11
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
"numpy~=1.23",
5959
"packaging~=24.1",
6060
"psutil~=5.9",
61-
"pyyaml~=6.0",
61+
"ruamel.yaml~=0.18",
6262
"rich~=13.7",
6363
"sacrebleu~=2.4",
6464
"tiktoken~=0.7",

src/fairseq2/recipes/common/_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def load(self, recipe_config: object, gangs: Gangs) -> Model:
230230

231231
# Load the model.
232232
trainer_section = get_config_section(recipe_config, "trainer", TrainerSection)
233-
if trainer_section.mixed_precision is None:
233+
if trainer_section.mixed_precision == "off":
234234
dtype = trainer_section.dtype
235235
else:
236236
dtype = torch.float32
@@ -359,7 +359,7 @@ def load(self, recipe_config: object, gangs: Gangs) -> Model:
359359

360360
# Load the model.
361361
trainer_section = get_config_section(recipe_config, "trainer", TrainerSection)
362-
if trainer_section.mixed_precision is None:
362+
if trainer_section.mixed_precision == "off":
363363
dtype = trainer_section.dtype
364364
else:
365365
dtype = torch.float32
@@ -498,7 +498,7 @@ def load(self, recipe_config: object, gangs: Gangs) -> Model:
498498

499499
# Create the model.
500500
trainer_section = get_config_section(recipe_config, "trainer", TrainerSection)
501-
if trainer_section.mixed_precision is None:
501+
if trainer_section.mixed_precision == "off":
502502
dtype = trainer_section.dtype
503503
else:
504504
dtype = torch.float32

src/fairseq2/recipes/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class TrainerSection:
118118

119119
fsdp: FsdpSection = field(default_factory=lambda: FsdpSection())
120120

121-
mixed_precision: Literal["static", "dynamic"] | None = "static"
121+
mixed_precision: Literal["static", "dynamic", "off"] = "static"
122122
"""
123123
If 'none', the whole training will be run in `dtype`. If 'static', forward
124124
and backward passes will be run in `dtype`, but the optimizer step will be

src/fairseq2/recipes/wav2vec2/asr/_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def load_wav2vec2_asr_trainer(
244244
config.pretrained_model.name,
245245
gangs,
246246
config.trainer.dtype,
247-
mp=config.trainer.mixed_precision is not None,
247+
mp=config.trainer.mixed_precision != "off",
248248
)
249249

250250
pt_module = cast(Wav2Vec2Model, pt_model.module)

src/fairseq2/utils/yaml.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from pathlib import Path
1111
from typing import IO, TypeAlias, final
1212

13-
import yaml
13+
from ruamel.yaml import YAML
14+
from ruamel.yaml.error import YAMLError
1415
from typing_extensions import override
15-
from yaml import YAMLError
1616

1717
from fairseq2.utils.file import FileMode, FileSystem
1818

@@ -32,9 +32,12 @@ def dump(self, obj: object, output: Path | IO[str]) -> None: ...
3232

3333
@final
3434
class StandardYamlLoader(YamlLoader):
35+
_yaml: YAML
3536
_file_system: FileSystem
3637

3738
def __init__(self, file_system: FileSystem) -> None:
39+
self._yaml = YAML(typ="safe", pure=True)
40+
3841
self._file_system = file_system
3942

4043
@override
@@ -47,16 +50,22 @@ def load(self, input_: Path | IO[str]) -> list[object]:
4750
finally:
4851
fp.close()
4952

50-
itr = yaml.safe_load_all(input_)
53+
itr = self._yaml.load_all(input_)
5154

5255
return list(itr)
5356

5457

5558
@final
5659
class StandardYamlDumper(YamlDumper):
60+
_yaml: YAML
5761
_file_system: FileSystem
5862

5963
def __init__(self, file_system: FileSystem) -> None:
64+
self._yaml = YAML(typ="safe", pure=True)
65+
66+
self._yaml.default_flow_style = False
67+
self._yaml.sort_base_mapping_type_on_output = False # type: ignore[assignment]
68+
6069
self._file_system = file_system
6170

6271
@override
@@ -69,8 +78,10 @@ def dump(self, obj: object, output: Path | IO[str]) -> None:
6978
finally:
7079
fp.close()
7180
else:
72-
yaml.safe_dump(obj, output, sort_keys=False)
81+
self._yaml.dump(obj, output)
7382

7483

7584
def read_yaml(s: str) -> object:
76-
return yaml.safe_load(s)
85+
yaml = YAML(typ="safe", pure=True)
86+
87+
return yaml.load(s)

0 commit comments

Comments
 (0)