Skip to content

Commit 85b6fef

Browse files
committed
feat: add fast ASR backend
Signed-off-by: BBC, Esquire <bbc@chintellalaw.com>
1 parent c74d378 commit 85b6fef

4 files changed

Lines changed: 496 additions & 9 deletions

File tree

docling/cli/main.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@
5252
WHISPER_TURBO,
5353
WHISPER_TURBO_MLX,
5454
WHISPER_TURBO_NATIVE,
55+
# WhisperS2T models
56+
WHISPER_TINY_S2T,
57+
WHISPER_TINY_EN_S2T,
58+
WHISPER_BASE_S2T,
59+
WHISPER_BASE_EN_S2T,
60+
WHISPER_SMALL_S2T,
61+
WHISPER_SMALL_EN_S2T,
62+
WHISPER_DISTIL_SMALL_EN_S2T,
63+
WHISPER_MEDIUM_S2T,
64+
WHISPER_MEDIUM_EN_S2T,
65+
WHISPER_DISTIL_MEDIUM_EN_S2T,
66+
WHISPER_LARGE_V3_S2T,
67+
WHISPER_DISTIL_LARGE_V3_S2T,
5568
AsrModelType,
5669
)
5770
from docling.datamodel.backend_options import PdfBackendOptions
@@ -874,7 +887,6 @@ def convert( # noqa: C901
874887
# enable_remote_services=enable_remote_services,
875888
# artifacts_path = artifacts_path
876889
)
877-
878890
# Auto-selecting models (choose best implementation for hardware)
879891
if asr_model == AsrModelType.WHISPER_TINY:
880892
asr_pipeline_options.asr_options = WHISPER_TINY
@@ -888,7 +900,6 @@ def convert( # noqa: C901
888900
asr_pipeline_options.asr_options = WHISPER_LARGE
889901
elif asr_model == AsrModelType.WHISPER_TURBO:
890902
asr_pipeline_options.asr_options = WHISPER_TURBO
891-
892903
# Explicit MLX models (force MLX implementation)
893904
elif asr_model == AsrModelType.WHISPER_TINY_MLX:
894905
asr_pipeline_options.asr_options = WHISPER_TINY_MLX
@@ -902,7 +913,6 @@ def convert( # noqa: C901
902913
asr_pipeline_options.asr_options = WHISPER_LARGE_MLX
903914
elif asr_model == AsrModelType.WHISPER_TURBO_MLX:
904915
asr_pipeline_options.asr_options = WHISPER_TURBO_MLX
905-
906916
# Explicit Native models (force native implementation)
907917
elif asr_model == AsrModelType.WHISPER_TINY_NATIVE:
908918
asr_pipeline_options.asr_options = WHISPER_TINY_NATIVE
@@ -916,13 +926,35 @@ def convert( # noqa: C901
916926
asr_pipeline_options.asr_options = WHISPER_LARGE_NATIVE
917927
elif asr_model == AsrModelType.WHISPER_TURBO_NATIVE:
918928
asr_pipeline_options.asr_options = WHISPER_TURBO_NATIVE
919-
929+
# Explicit WhisperS2T models (CTranslate2 backend - fastest)
930+
elif asr_model == AsrModelType.WHISPER_TINY_S2T:
931+
asr_pipeline_options.asr_options = WHISPER_TINY_S2T
932+
elif asr_model == AsrModelType.WHISPER_TINY_EN_S2T:
933+
asr_pipeline_options.asr_options = WHISPER_TINY_EN_S2T
934+
elif asr_model == AsrModelType.WHISPER_BASE_S2T:
935+
asr_pipeline_options.asr_options = WHISPER_BASE_S2T
936+
elif asr_model == AsrModelType.WHISPER_BASE_EN_S2T:
937+
asr_pipeline_options.asr_options = WHISPER_BASE_EN_S2T
938+
elif asr_model == AsrModelType.WHISPER_SMALL_S2T:
939+
asr_pipeline_options.asr_options = WHISPER_SMALL_S2T
940+
elif asr_model == AsrModelType.WHISPER_SMALL_EN_S2T:
941+
asr_pipeline_options.asr_options = WHISPER_SMALL_EN_S2T
942+
elif asr_model == AsrModelType.WHISPER_DISTIL_SMALL_EN_S2T:
943+
asr_pipeline_options.asr_options = WHISPER_DISTIL_SMALL_EN_S2T
944+
elif asr_model == AsrModelType.WHISPER_MEDIUM_S2T:
945+
asr_pipeline_options.asr_options = WHISPER_MEDIUM_S2T
946+
elif asr_model == AsrModelType.WHISPER_MEDIUM_EN_S2T:
947+
asr_pipeline_options.asr_options = WHISPER_MEDIUM_EN_S2T
948+
elif asr_model == AsrModelType.WHISPER_DISTIL_MEDIUM_EN_S2T:
949+
asr_pipeline_options.asr_options = WHISPER_DISTIL_MEDIUM_EN_S2T
950+
elif asr_model == AsrModelType.WHISPER_LARGE_V3_S2T:
951+
asr_pipeline_options.asr_options = WHISPER_LARGE_V3_S2T
952+
elif asr_model == AsrModelType.WHISPER_DISTIL_LARGE_V3_S2T:
953+
asr_pipeline_options.asr_options = WHISPER_DISTIL_LARGE_V3_S2T
920954
else:
921955
_log.error(f"{asr_model} is not known")
922956
raise ValueError(f"{asr_model} is not known")
923-
924957
_log.debug(f"ASR pipeline_options: {asr_pipeline_options}")
925-
926958
audio_format_option = AudioFormatOption(
927959
pipeline_cls=AsrPipeline,
928960
pipeline_options=asr_pipeline_options,

docling/datamodel/asr_model_specs.py

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
InferenceAsrFramework,
1313
InlineAsrMlxWhisperOptions,
1414
InlineAsrNativeWhisperOptions,
15+
InlineAsrWhisperS2TOptions,
1516
TransformersModelType,
1617
)
1718

@@ -463,9 +464,138 @@ def _get_whisper_turbo_model():
463464
max_time_chunk=30.0,
464465
)
465466

467+
# =============================================================================
468+
# WhisperS2T Models (CTranslate2 backend - fastest option for CPU/CUDA)
469+
# =============================================================================
470+
471+
# Tiny models
472+
WHISPER_TINY_S2T = InlineAsrWhisperS2TOptions(
473+
repo_id="tiny",
474+
inference_framework=InferenceAsrFramework.WHISPER_S2T,
475+
language="en",
476+
task="transcribe",
477+
compute_type="float16",
478+
batch_size=16,
479+
beam_size=1,
480+
)
481+
482+
WHISPER_TINY_EN_S2T = InlineAsrWhisperS2TOptions(
483+
repo_id="tiny.en",
484+
inference_framework=InferenceAsrFramework.WHISPER_S2T,
485+
language="en",
486+
task="transcribe",
487+
compute_type="float16",
488+
batch_size=16,
489+
beam_size=1,
490+
)
491+
492+
# Base models
493+
WHISPER_BASE_S2T = InlineAsrWhisperS2TOptions(
494+
repo_id="base",
495+
inference_framework=InferenceAsrFramework.WHISPER_S2T,
496+
language="en",
497+
task="transcribe",
498+
compute_type="float16",
499+
batch_size=12,
500+
beam_size=1,
501+
)
502+
503+
WHISPER_BASE_EN_S2T = InlineAsrWhisperS2TOptions(
504+
repo_id="base.en",
505+
inference_framework=InferenceAsrFramework.WHISPER_S2T,
506+
language="en",
507+
task="transcribe",
508+
compute_type="float16",
509+
batch_size=12,
510+
beam_size=1,
511+
)
512+
513+
# Small models
514+
WHISPER_SMALL_S2T = InlineAsrWhisperS2TOptions(
515+
repo_id="small",
516+
inference_framework=InferenceAsrFramework.WHISPER_S2T,
517+
language="en",
518+
task="transcribe",
519+
compute_type="float16",
520+
batch_size=8,
521+
beam_size=1,
522+
)
523+
524+
WHISPER_SMALL_EN_S2T = InlineAsrWhisperS2TOptions(
525+
repo_id="small.en",
526+
inference_framework=InferenceAsrFramework.WHISPER_S2T,
527+
language="en",
528+
task="transcribe",
529+
compute_type="float16",
530+
batch_size=8,
531+
beam_size=1,
532+
)
533+
534+
WHISPER_DISTIL_SMALL_EN_S2T = InlineAsrWhisperS2TOptions(
535+
repo_id="distil-small.en",
536+
inference_framework=InferenceAsrFramework.WHISPER_S2T,
537+
language="en",
538+
task="transcribe",
539+
compute_type="float16",
540+
batch_size=10,
541+
beam_size=1,
542+
)
543+
544+
# Medium models
545+
WHISPER_MEDIUM_S2T = InlineAsrWhisperS2TOptions(
546+
repo_id="medium",
547+
inference_framework=InferenceAsrFramework.WHISPER_S2T,
548+
language="en",
549+
task="transcribe",
550+
compute_type="float16",
551+
batch_size=6,
552+
beam_size=1,
553+
)
554+
555+
WHISPER_MEDIUM_EN_S2T = InlineAsrWhisperS2TOptions(
556+
repo_id="medium.en",
557+
inference_framework=InferenceAsrFramework.WHISPER_S2T,
558+
language="en",
559+
task="transcribe",
560+
compute_type="float16",
561+
batch_size=6,
562+
beam_size=1,
563+
)
564+
565+
WHISPER_DISTIL_MEDIUM_EN_S2T = InlineAsrWhisperS2TOptions(
566+
repo_id="distil-medium.en",
567+
inference_framework=InferenceAsrFramework.WHISPER_S2T,
568+
language="en",
569+
task="transcribe",
570+
compute_type="float16",
571+
batch_size=8,
572+
beam_size=1,
573+
)
574+
575+
# Large models
576+
WHISPER_LARGE_V3_S2T = InlineAsrWhisperS2TOptions(
577+
repo_id="large-v3",
578+
inference_framework=InferenceAsrFramework.WHISPER_S2T,
579+
language="en",
580+
task="transcribe",
581+
compute_type="float16",
582+
batch_size=4,
583+
beam_size=1,
584+
)
585+
586+
WHISPER_DISTIL_LARGE_V3_S2T = InlineAsrWhisperS2TOptions(
587+
repo_id="distil-large-v3",
588+
inference_framework=InferenceAsrFramework.WHISPER_S2T,
589+
language="en",
590+
task="transcribe",
591+
compute_type="float16",
592+
batch_size=6,
593+
beam_size=1,
594+
)
595+
466596
# Note: The main WHISPER_* models (WHISPER_TURBO, WHISPER_BASE, etc.) automatically
467597
# select the best implementation (MLX on Apple Silicon, Native elsewhere).
468-
# Use the explicit _MLX or _NATIVE variants if you need to force a specific implementation.
598+
# Use the explicit _MLX, _NATIVE, or _S2T variants if you need to force a specific implementation.
469599

470600

471601
class AsrModelType(str, Enum):
@@ -492,3 +622,17 @@ class AsrModelType(str, Enum):
492622
WHISPER_BASE_NATIVE = "whisper_base_native"
493623
WHISPER_LARGE_NATIVE = "whisper_large_native"
494624
WHISPER_TURBO_NATIVE = "whisper_turbo_native"
625+
626+
# Explicit WhisperS2T models (CTranslate2 backend - fastest)
627+
WHISPER_TINY_S2T = "whisper_tiny_s2t"
628+
WHISPER_TINY_EN_S2T = "whisper_tiny_en_s2t"
629+
WHISPER_BASE_S2T = "whisper_base_s2t"
630+
WHISPER_BASE_EN_S2T = "whisper_base_en_s2t"
631+
WHISPER_SMALL_S2T = "whisper_small_s2t"
632+
WHISPER_SMALL_EN_S2T = "whisper_small_en_s2t"
633+
WHISPER_DISTIL_SMALL_EN_S2T = "whisper_distil_small_en_s2t"
634+
WHISPER_MEDIUM_S2T = "whisper_medium_s2t"
635+
WHISPER_MEDIUM_EN_S2T = "whisper_medium_en_s2t"
636+
WHISPER_DISTIL_MEDIUM_EN_S2T = "whisper_distil_medium_en_s2t"
637+
WHISPER_LARGE_V3_S2T = "whisper_large_v3_s2t"
638+
WHISPER_DISTIL_LARGE_V3_S2T = "whisper_distil_large_v3_s2t"

docling/datamodel/pipeline_options_asr_model.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class InferenceAsrFramework(str, Enum):
2929
MLX = "mlx"
3030
# TRANSFORMERS = "transformers" # disabled for now
3131
WHISPER = "whisper"
32+
WHISPER_S2T = "whisper_s2t"
3233

3334

3435
class InlineAsrOptions(BaseAsrOptions):
@@ -262,3 +263,115 @@ class InlineAsrMlxWhisperOptions(InlineAsrOptions):
262263
)
263264
),
264265
] = 2.4
266+
267+
268+
class InlineAsrWhisperS2TOptions(InlineAsrOptions):
269+
"""Configuration for WhisperS2T (CTranslate2-based) high-speed ASR.
270+
271+
Uses whisper_s2t library with CTranslate2 backend for fast inference
272+
on CPU and CUDA devices. Requires whisper-s2t-reborn package.
273+
"""
274+
275+
inference_framework: Annotated[
276+
InferenceAsrFramework,
277+
Field(
278+
description=(
279+
"Inference framework for ASR. Uses WhisperS2T with CTranslate2 "
280+
"backend for optimized high-speed inference."
281+
)
282+
),
283+
] = InferenceAsrFramework.WHISPER_S2T
284+
language: Annotated[
285+
str,
286+
Field(
287+
description=(
288+
"Language code for transcription. Use ISO 639-1 codes "
289+
"(e.g., `en`, `es`, `fr`)."
290+
),
291+
examples=["en", "es", "fr", "de", "ja", "zh"],
292+
),
293+
] = "en"
294+
task: Annotated[
295+
str,
296+
Field(
297+
description=(
298+
"ASR task type. `transcribe` converts speech to text in the "
299+
"same language. `translate` converts speech to English text."
300+
),
301+
examples=["transcribe", "translate"],
302+
),
303+
] = "transcribe"
304+
compute_type: Annotated[
305+
str,
306+
Field(
307+
description=(
308+
"Computation precision for CTranslate2. Options: `float32`, "
309+
"`float16`, `bfloat16`. Lower precision increases speed and "
310+
"reduces memory. bfloat16 requires compute capability >= 8.6."
311+
),
312+
examples=["float32", "float16", "bfloat16"],
313+
),
314+
] = "float16"
315+
batch_size: Annotated[
316+
int,
317+
Field(
318+
description=(
319+
"Number of audio segments to process in parallel. Higher values "
320+
"increase throughput but require more VRAM."
321+
)
322+
),
323+
] = 8
324+
beam_size: Annotated[
325+
int,
326+
Field(
327+
description=(
328+
"Beam size for beam search decoding. 1 = greedy decoding (fastest), "
329+
"higher values (e.g., 5) may improve accuracy at cost of speed."
330+
)
331+
),
332+
] = 1
333+
word_timestamps: Annotated[
334+
bool,
335+
Field(
336+
description=(
337+
"Generate word-level timestamps. Requires an additional alignment "
338+
"model and increases processing time."
339+
)
340+
),
341+
] = False
342+
cpu_threads: Annotated[
343+
int,
344+
Field(
345+
description=(
346+
"Number of CPU threads for inference. Only used when device is CPU."
347+
)
348+
),
349+
] = 4
350+
num_workers: Annotated[
351+
int,
352+
Field(
353+
description=(
354+
"Number of parallel workers for CTranslate2."
355+
)
356+
),
357+
] = 1
358+
initial_prompt: Annotated[
359+
Optional[str],
360+
Field(
361+
description=(
362+
"Optional text prompt to condition the transcription style or "
363+
"provide context. Useful for domain-specific vocabulary."
364+
)
365+
),
366+
] = None
367+
supported_devices: Annotated[
368+
list[AcceleratorDevice],
369+
Field(
370+
description=(
371+
"Hardware accelerators supported by WhisperS2T."
372+
)
373+
),
374+
] = [
375+
AcceleratorDevice.CPU,
376+
AcceleratorDevice.CUDA,
377+
]

0 commit comments

Comments
 (0)