diff --git a/allennlp/common/cached_transformers.py b/allennlp/common/cached_transformers.py index 3177faa2ab8..4408e7a118f 100644 --- a/allennlp/common/cached_transformers.py +++ b/allennlp/common/cached_transformers.py @@ -14,7 +14,6 @@ class TransformerSpec(NamedTuple): model_name: str override_weights_file: Optional[str] = None override_weights_strip_prefix: Optional[str] = None - reinit_modules: Optional[Union[int, Tuple[int, ...], Tuple[str, ...]]] = None _model_cache: Dict[TransformerSpec, transformers.PreTrainedModel] = {} @@ -66,9 +65,8 @@ def get( model_name, override_weights_file, override_weights_strip_prefix, - reinit_modules, ) - transformer = _model_cache.get(spec, None) + transformer = None if reinit_modules is not None else _model_cache.get(spec, None) if transformer is None: if not load_weights: if override_weights_file is not None: @@ -181,7 +179,9 @@ def strip_prefix(s): model_name, **kwargs, ) - _model_cache[spec] = transformer + # Don't cache transformers with reinitialized weights. + if reinit_modules is None: + _model_cache[spec] = transformer if make_copy: import copy