Skip to content

Commit b7bc285

Browse files
committed
fix
1 parent 56c8356 commit b7bc285

File tree

4 files changed

+14
-4
lines changed

4 files changed

+14
-4
lines changed

doctr/models/recognition/parseq/pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,9 @@ def __call__(
403403
for encoded_seq in out_idxs.cpu().numpy()
404404
]
405405
# compute probabilties for each word up to the EOS token
406-
probs = [preds_prob[i, : len(word)].clip(0, 1).mean().item() for i, word in enumerate(word_values)]
406+
probs = [
407+
preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values)
408+
]
407409

408410
return list(zip(word_values, probs))
409411

doctr/models/recognition/parseq/tensorflow.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,10 @@ def __call__(
432432
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
433433

434434
# compute probabilties for each word up to the EOS token
435-
probs = [preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() for i, word in enumerate(word_values)]
435+
probs = [
436+
preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
437+
for i, word in enumerate(word_values)
438+
]
436439

437440
return list(zip(word_values, probs))
438441

doctr/models/recognition/vitstr/pytorch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ def __call__(
167167
for encoded_seq in out_idxs.cpu().numpy()
168168
]
169169
# compute probabilties for each word up to the EOS token
170-
probs = [preds_prob[i, : len(word)].clip(0, 1).mean().item() for i, word in enumerate(word_values)]
170+
probs = [
171+
preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values)
172+
]
171173

172174
return list(zip(word_values, probs))
173175

doctr/models/recognition/vitstr/tensorflow.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ def __call__(
175175
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()]
176176

177177
# compute probabilties for each word up to the EOS token
178-
probs = [preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() for i, word in enumerate(word_values)]
178+
probs = [
179+
preds_prob[i, : len(word)].numpy().clip(0, 1).mean().item() if word else 0.0
180+
for i, word in enumerate(word_values)
181+
]
179182

180183
return list(zip(word_values, probs))
181184

0 commit comments

Comments
 (0)