forked from chenxwh/cog-whisper
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict.py
More file actions
103 lines (86 loc) · 3.92 KB
/
predict.py
File metadata and controls
103 lines (86 loc) · 3.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
download the models to ./weights
wget https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt -P ./weights
wget https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt -P ./weights
wget https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt -P ./weights
wget https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt -P ./weights
wget https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large.pt -P ./weights
"""
import io
import os
from typing import Optional, Any
import torch
from cog import BasePredictor, Input, Path, BaseModel
import whisper
from whisper.model import Whisper, ModelDimensions
from whisper.tokenizer import LANGUAGES
from whisper.utils import format_timestamp
class ModelOutput(BaseModel):
detected_language: str
transcription: str
segments: Any
translation: Optional[str]
txt_file: Optional[Path]
srt_file: Optional[Path]
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.models = {}
for model in ["tiny", "base", "small", "medium", "large"]:
model_bytes = open(f"weights/{model}.pt", "rb").read()
with io.BytesIO(model_bytes) as fp:
checkpoint = torch.load(fp, map_location="cpu")
dims = ModelDimensions(**checkpoint["dims"])
state_dict = checkpoint["model_state_dict"]
self.models[model] = Whisper(dims)
self.models[model].load_state_dict(state_dict)
def predict(
self,
audio: Path = Input(description="Audio file"),
model: str = Input(
default="base",
choices=["tiny", "base", "small", "medium", "large"],
description="Choose a Whisper model.",
),
transcription: str = Input(
choices=["plain text", "srt", "vtt"],
default="plain text",
description="Choose the format for the transcription",
),
translate: bool = Input(
default=False,
description="Translate the text to English when set to True",
),
) -> ModelOutput:
"""Run a single prediction on the model"""
print(f"Transcribe with {model} model")
model = self.models[model].to("cuda")
result = model.transcribe(str(audio))
if transcription == "plain text":
transcription = result["text"]
elif transcription == "srt":
transcription = write_srt(result["segments"])
else:
transcription = write_vtt(result["segments"])
if translate:
translation = model.transcribe(str(audio), task="translate")
return ModelOutput(
detected_language=LANGUAGES[result["language"]],
transcription=transcription,
translation=translation["text"] if translate else None,
segments=result["segments"],
)
def write_vtt(transcript):
result = "WEBVTT\n"
for segment in transcript:
result += f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
result += f"{segment['text'].strip().replace('-->', '->')}\n"
return result
def write_srt(transcript):
result = ""
for i, segment in enumerate(transcript, start=1):
result += f"{i}\n"
result += f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
result += f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
result += f"{segment['text'].strip().replace('-->', '->')}\n"
return result