diff --git a/server/conversion.py b/server/conversion.py index 0998922..5b1d4f0 100644 --- a/server/conversion.py +++ b/server/conversion.py @@ -1,7 +1,9 @@ import pydetex.pipelines +import re import PyPDF2 import docx -from transformers import TrOCRProcessor, VisionEncoderDecoderModel +import torch +from transformers import DonutProcessor, VisionEncoderDecoderModel import requests from PIL import Image @@ -65,20 +67,40 @@ def jpg_to_txt(jpg_filename, output_filename): # file.write(generated_text) -def written_jpg_to_txt(written_jpg_filename, output_filename): - image = Image.open(written_jpg_filename).convert("RGB") +def written_jpg_to_txt(written_jpg_filename): + # Load Donut Processor and Model + processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2") + model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2") - processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten') - model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten') - pixel_values = processor(images=image, return_tensors="pt").pixel_values + device = "cuda" if torch.cuda.is_available() else "cpu" + model.to(device) - generated_ids = model.generate(pixel_values) - generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + # Load the image using PIL + image = Image.open(written_jpg_filename).convert("RGB") - return generated_text - # Print and save the generated text - # print("Generated Text: ", generated_text) + # Prepare decoder inputs + task_prompt = "" + decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids - # # Store the text in a file named output.txt - # with open(output_filename, "w", encoding="utf-8") as file: - # file.write(generated_text) \ No newline at end of file + # Prepare the image + pixel_values = processor(image, return_tensors="pt").pixel_values + + # Generate output from the model + outputs = model.generate( + pixel_values.to(device), + decoder_input_ids=decoder_input_ids.to(device), + max_length=model.decoder.config.max_position_embeddings, + pad_token_id=processor.tokenizer.pad_token_id, + eos_token_id=processor.tokenizer.eos_token_id, + use_cache=True, + bad_words_ids=[[processor.tokenizer.unk_token_id]], + return_dict_in_generate=True, + ) + + sequence = processor.batch_decode(outputs.sequences)[0] + sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") + sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() + sequence = re.sub(r"", "", sequence).strip() + sequence = re.sub(r"", "", sequence).strip() + + return sequence \ No newline at end of file diff --git a/server/requirements.txt b/server/requirements.txt index b5ae12b..939b367 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -7,6 +7,7 @@ certifi==2024.8.30 charset-normalizer==3.3.2 click==8.1.7 colour==0.1.5 +datasets==3.0.1 distro==1.9.0 docx==0.2.4 exceptiongroup==1.2.2