Our Projects
在 Amazon SageMaker 上使用 LoRA 微调 Whisper 模型 机器学习博客
Whisper 是一个自动语音识别ASR模型,使用来自互联网的 68 万小时监督数据进行训练,涵盖多种语言和任务。然而,它在像马拉地语和德拉威语这样的低资源语言上的表现较差,这可以通过微调来改善。然而,微调 Whisper 模型面临着较高的计算和存储要求。对于 Whisper 模型的完整微调,通常需要约 100 小时的 A100 GPU 计算,且每个微调后的检查点需要约 7 GB 的存储空间。这种高昂的计算和存储需求在资源有限的环境中造成了相当大的挑战,往往使得有效结果极难达到。
低秩适应LoRA以独特的方法应对微调挑战。它保持预训练模型权重不变,并在每个 Transformer 结构层中引入可训练的秩分解矩阵。此方法能够将下游任务所需的可训练参数数量减少 10000 倍,同时将 GPU 内存需求降低 3 倍。在模型质量方面,LoRA已经证明出可以匹配甚至超越传统微调方法的表现,同时可训练参数更少。此外,LoRA 还提供了提高训练吞吐量的好处。与适配器方法不同,LoRA 在推理过程中不会引入额外延迟,从而保持模型在部署阶段的高效性。
使用 LoRA 微调 Whisper 模型的结果非常可喜。例如,在一个 12 小时的通用语音数据集上运行 WhisperLargev2 仅需 6 到 8 小时,这比完整微调快 5 倍,同时性能相当。
Amazon SageMaker 是实现 Whisper 微调的理想平台。它允许用户为各种用例构建、训练和部署机器学习模型,配备完全托管的基础设施、工具和工作流程。使用 Managed Spot Training 可降低训练成本,通过分布式训练库可以将模型和训练数据集分配到 AWS GPU 实例上。此外,SageMaker 训练后模型可轻松地进行推理部署。在本文中,我们将提供在 SageMaker 中实现 LoRA 微调的逐步指南。相关的源代码可以在 GitHub找到。
我们在微调任务中使用低资源语言马拉地语。借助 Hugging Face datasets 库,您可以下载并分割 Common Voice 数据集为训练和测试数据集。请参考以下代码:
pythonfrom datasets import loaddataset DatasetDict
language = Marathilanguageabbr = mrtask = transcribedatasetname = mozillafoundation/commonvoice110
commonvoice = DatasetDict()commonvoice[train] = loaddataset(datasetname languageabbr split=trainvalidation useauthtoken=True)commonvoice[test] = loaddataset(datasetname languageabbr split=test useauthtoken=True)
Whisper 语音识别模型要求音频输入为 16kHz 单声道 16 位带符号整数的 WAV 文件。由于 Common Voice 数据集为 48K 采样率,您需要首先将音频文件下采样。之后,您需要应用 Whisper 的特征提取器提取 logmel 谱特征,并使用 Whisper 的分词器将每个句子转换为令牌 ID。以下是相关代码示例:
pythonfrom transformers import WhisperFeatureExtractorfrom transformers import WhisperTokenizer
featureextractor = WhisperFeatureExtractorfrompretrained(modelnameorpath)tokenizer = WhisperTokenizerfrompretrained(modelnameorpath language=language task=task)
def preparedataset(batch) # 加载并将音频数据从 48 kHz 重采样到 16 kHz audio = batch[audio]
飞跃加速器app最新版# 从输入音频数组计算 logMel 输入特征batch[inputfeatures] = featureextractor(audio[array] samplingrate=audio[samplingrate])inputfeatures[0]# 将目标文本编码为标签 IDbatch[labels] = tokenizer(batch[sentence])inputidsreturn batch
commonvoice = commonvoicemap(preparedataset removecolumns=commonvoicecolumnnames[train] numproc=2)commonvoicesavetodisk(marathicommonvoiceprocessed)!aws s3 cp recursive marathicommonvoiceprocessed s3//
在处理完所有训练样本后,将处理过的数据上传到 Amazon S3,从而在微调阶段使用处理过的训练数据时能够使用 FastFile 直接挂载 S3 文件,而不是复制到本地磁盘:
pythonfrom sagemakerinputs import TrainingInputtraininginputpath=s3uritraining = TrainingInput( s3datatype=S3Prefix # 可选项 S3Prefix ManifestFile AugmentedManifestFile s3data=traininginputpath distribution=FullyReplicated # 可选项 FullyReplicated ShardedByS3Key inputmode=FastFile)
为演示 purposes,我们使用 whisperlargev2 作为预训练模型目前已推出 whisper v3,可以通过 Hugging Face transformers 库导入。您还可以使用 8 位量化 进一步提升训练效率,8 位量化通过将浮点数四舍五入为 8 位整型来提供内存优化。这是一种常用的模型压缩技术,可以在不牺牲过多推理精度的前提下减少内存占用。
要以 8 位量化格式加载预训练模型,您只需在实例化模型时添加 loadin8bit=True 参数,如下所示。这将加载量化为 8 位的模型权重,从而减少内存占用。
pythonfrom transformers import WhisperForConditionalGeneration
model = WhisperForConditionalGenerationfrompretrained(modelnameorpath loadin8bit=True devicemap=auto)
我们使用 Hugging Face 的 peft 包中的 LoRA 实现。使用 LoRA 微调模型的步骤有四个:
实例化基础模型如上一步所示。创建一个配置LoraConfig,在其中定义 LoRA 特定参数。使用 getpeftmodel() 将基础模型包装为可训练的 PeftModel。按照基础模型的方式训练 PeftModel。请参考以下代码:
pythonfrom peft import LoraConfig getpeftmodel
config = LoraConfig(r=32 loraalpha=64 targetmodules=[qproj vproj] loradropout=005 bias=none)model = getpeftmodel(model config)
trainingargs = Seq2SeqTrainingArguments( outputdir=argsmodeldir perdevicetrainbatchsize=int(argstrainbatchsize) gradientaccumulationsteps=1 learningrate=float(argslearningrate) warmupsteps=argswarmupsteps numtrainepochs=argsnumtrainepochs evaluationstrategy=epoch fp16=True perdeviceevalbatchsize=argsevalbatchsize generationmaxlength=128 loggingsteps=25 removeunusedcolumns=False labelnames=[labels])trainer = Seq2SeqTrainer( args=trainingargs model=model traindataset=traindataset[train] evaldataset=traindatasetget(test traindataset[test]) datacollator=datacollator tokenizer=processorfeatureextractor)
要运行 SageMaker 训练 作业,我们需要自定义 Docker 容器。您可以从 GitHub 下载 Docker 镜像,其中 ffmpeg4 和 gitlfs 与其他 Python 依赖项一起打包。有关如何调整您的自定义 Docker 容器以与 SageMaker 配合使用的详细信息,请参阅 调整您的自定义训练容器。然后,您可以使用 Hugging Face Estimator 来启动 SageMaker 训练作业:
pythonOUTPUTPATH = fs3//{BUCKET}/{PREFIX}/{TRAININGJOBNAME}/output/
huggingfaceestimator = HuggingFace(entrypoint=trainsh sourcedir=/src outputpath=OUTPUTPATH instancetype=instancetype instancecount=1 # transformersversion=4170 # pytorchversion=1102 pyversion=py310 imageuri= role=ROLE metricdefinitions=metricdefinitions volumesize=200 distribution=distribution keepaliveperiodinseconds=1800 environment=environment)
huggingfaceestimatorfit(jobname=TRAININGJOBNAME wait=False)
LoRA 的实现使我们能够在单个 GPU 实例例如 mlg52xlarge上运行 Whisper 大型模型的微调任务。相比之下,Whisper 大型完整微调任务需要多个 GPU例如 mlp4d24xlarge和更长的训练时间。更具体地说,我们的实验表明,完整的微调任务需要的 GPU 小时是 LoRA 方法的 24 倍。
为了评估微调后 Whisper 模型的性能,我们计算在留出的测试集上的字错误率WER。WER 衡量预测的转录与真实转录之间的差异,较低的 WER 表示更好的性能。您可以运行以下脚本对比预训练模型与微调模型的 WER 差异:
pythonmetric = evaluateload(wer)
evaldataloader = DataLoader(commonvoice[test] batchsize=8 collatefn=datacollator)
modeleval()for step batch in enumerate(tqdm(evaldataloader)) with torchcudaampautocast() with torchnograd() generatedtokens = ( modelgenerate( inputfeatures=batch[inputfeatures]to(cuda) decoderinputids=batch[labels][ 4]to(cuda) maxnewtokens=255 ) cpu() numpy() ) labels = batch[labels]cpu()numpy() labels = npwhere(labels != 100 labels tokenizerpadtokenid) decodedpreds = tokenizerbatchdecode(generatedtokens skipspecialtokens=True) decodedlabels = tokenizerbatchdecode(labels skipspecialtokens=True) metricaddbatch( predictions=decodedpreds references=decodedlabels ) del generatedtokens labels batch gccollect()wer = 100 metriccompute()print(f{wer=})
在本文中,我们演示了如何微调 Whisper,这一先进的语音识别模型。特别地,我们使用了 Hugging Face 的 PEFT LoRA,并启用了 8 位量化以提高训练效率。同时,我们展示了如何在 SageMaker 上运行训练作业。
虽然这是一个重要的第一步,但仍有多种方式可以在此基础上进行进一步改善 Whisper 模型。未来,可以考虑使用 SageMaker 分布式训练,将训练扩展到更大数据集上。这将允许模型在更丰富和全面的数据上进行训练,提高准确性。此外,您还可以优化 Whisper 模型的服务延迟,以实现实时语音识别。此外,还可以考虑扩展处理更长音频转录的工作,这需要对模型架构和训练方案进行更改。
作者对 Paras Mehra、John Sol 和 Evandro Franco 表示感谢,感谢他们对本文的深入反馈和审阅。

Jun Shi 是亚马逊网络服务AWS的高级解决方案架构师。他目前专注于人工智能/机器学习基础设施和应用,拥有十多年金融科技行业作为软件工程师的经验。
Dr Changsha Ma 是 AWS 的 AI/ML 专家。她是计算机科学领域的技术专家,拥有博士学位和教育心理学硕士学位,并在数据科学和人工智能/机器学习独立咨询方面拥有多年经验。她热衷于研究机器智能和人类智能的方法论。在工作之外,她喜欢徒步旅行、烹饪、猎食和与朋友及家人共度时光。
标签:Amazon SageMaker、ASR、Fine Tuning