首页 >  AI教程资讯 > Whisper JAX:让你的语音转文字功能速度快70倍!

Whisper JAX:让你的语音转文字功能速度快70倍!

1374 2024-12-31 00:00:00

项目简介

该存储库包含 OpenAI 的 Whisper 模型的优化 JAX代码,主要基于 Hugging Face Transformers Whisper 实现。与 OpenAI的 PyTorch 代码相比,Whisper JAX 的运行速度快了70 倍以上,使其成为可用的最快的 Whisper 运行。

JAX 代码在 CPU、GPU 和 TPU 上兼容,并且可以独立运行或作为推理端点。

安装

Whisper JAX 使用 Python 3.9 和 JAX 版本 0.4.5 进行了测试。安装假定您的设备上已安装最新版本的 JAX 包。您可以使用官方 JAX 安装指南来执行此操作:

https://github.com/google/jax#installation

一旦安装了适当版本的 JAX,就可以通过 pip 安装 Whisper JAX

pip install git+https://github.com/sanchit-gandhi/whisper-jax.git

要将 Whisper JAX 包更新到最新版本,只需运行:

pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git

管道使用

运行 Whisper JAX 的推荐方式是通过FlaxWhisperPipline抽象类。此类处理所有必要的预处理和后处理,并包装生成方法以实现跨加速器设备的数据并行性。

Whisper JAX 利用 JAX 的pmap功能实现跨 GPU/TPU 设备的数据并行性。该函数在第一次调用时进行即时 (JIT)编译。此后,该函数将被缓存,使其能够以超快的速度运行:

from whisper_jax import FlaxWhisperPipline# instantiate pipelinepipeline = FlaxWhisperPipline("openai/whisper-large-v2")# JIT compile the forward call - slow, but we only do oncetext = pipeline("audio.mp3")# used cached function thereafter - super fast!!text = pipeline("audio.mp3")

半精度

通过在实例化管道时传递 dtype 参数,可以以半精度运行模型计算。通过以半精度存储中间张量,这将大大加快计算速度。模型权重的精度没有变化。

对于大多数 GPU,dtype 应设置为jnp.float16. 对于 A100 GPU 或 TPU,dtype 应设置为jnp.bfloat16:

from whisper_jax import FlaxWhisperPiplineimport jax.numpy as jnp# instantiate pipeline in bfloat16pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16)

批处理

Whisper JAX 还提供跨加速器设备批量处理单个音频输入的选项。音频首先被分成 30 秒的片段,然后将片段分派到模型进行并行转录。所得到的转录在边界处缝合在一起以给出单个、统一的转录。实际上,如果选择的批处理大小足够大,则与顺序转录音频样本相比,批处理提供了 10 倍的加速,并且 WER 1的损失不到 1%。

要启用批处理,请batch_size在实例化管道时传递参数:

from whisper_jax import FlaxWhisperPipline# instantiate pipeline with batchingpipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16)

任务

默认情况下,管道以所说的语言转录音频文件。对于语音翻译,请将参数设置 task为"translate":

# translatetext = pipeline("audio.mp3", task="translate")

时间戳

FlaxWhisperPipline还支持时间戳预测。请注意,启用时间戳将需要对前向调用进行第二次 JIT 编译,包括时间戳输出:

# transcribe and return timestampsoutputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)text = outputs["text"] # transcriptionchunks = outputs["chunks"] # transcription + timestamps

高级用法

更高级的用户可能希望探索不同的并行化技术。Whisper JAX 代码构建在T5x 代码库之上,这意味着它可以使用 T5x 分区约定的模型、激活和数据并行性来运行。要使用 T5x 分区,必须定义逻辑轴规则和模型分区数量。更多详细信息,用户可参考官方T5x分区指南:https://github.com/google-research/t5x/blob/main/docs/usage/partitioning.md

基准测试

我们将 Whisper JAX 与官方OpenAI 实现和

相关常用工具

查看更多

Copyright © 2025 AI图片论坛 版权所有. 站点地图