Skip to content

Commit 9b9b8a3

Browse files
authored
Merge pull request #28 from chcwww/fix_wf
2 parents 3a0e32c + 2e9719f commit 9b9b8a3

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

libmultilabel/nn/attentionxml.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def fit(self, datasets):
287287
logger.info(f"Finish training level 0")
288288

289289
logger.info(f"Best model loaded from {best_model_path}")
290-
model_0 = Model.load_from_checkpoint(best_model_path)
290+
model_0 = Model.load_from_checkpoint(best_model_path, weights_only=False)
291291

292292
logger.info(
293293
f"Predicting clusters by level-0 model. We then select {self.beam_width} clusters for each instance and "
@@ -422,11 +422,13 @@ def test(self, dataset):
422422
model_0 = Model.load_from_checkpoint(
423423
self.get_best_model_path(level=0),
424424
save_k_predictions=self.beam_width,
425+
weights_only=False,
425426
)
426427
model_1 = PLTModel.load_from_checkpoint(
427428
self.get_best_model_path(level=1),
428429
save_k_predictions=self.save_k_predictions,
429430
metrics=self.metrics,
431+
weights_only=False,
430432
)
431433

432434
word_dict_path = os.path.join(os.path.dirname(self.get_best_model_path(level=1)), self.WORD_DICT_NAME)

libmultilabel/nn/networks/bert.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def __init__(
3434
hidden_dropout_prob=encoder_hidden_dropout,
3535
attention_probs_dropout_prob=encoder_attention_dropout,
3636
classifier_dropout=post_encoder_dropout,
37-
torchscript=True,
3837
)
3938

4039
def forward(self, input):

libmultilabel/nn/networks/bert_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def __init__(
4040

4141
self.lm = AutoModel.from_pretrained(
4242
lm_weight,
43-
torchscript=True,
4443
hidden_dropout_prob=encoder_hidden_dropout,
4544
attention_probs_dropout_prob=encoder_attention_dropout,
4645
)

torch_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def _setup_model(
150150

151151
if checkpoint_path is not None:
152152
logging.info(f"Loading model from `{checkpoint_path}` with the previously saved hyper-parameter...")
153-
self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path)
153+
self.model = Model.load_from_checkpoint(checkpoint_path, log_path=log_path, weights_only=False)
154154
word_dict_path = os.path.join(os.path.dirname(checkpoint_path), self.WORD_DICT_NAME)
155155
if os.path.exists(word_dict_path):
156156
with open(word_dict_path, "rb") as f:

0 commit comments

Comments
 (0)