| | from typing import Dict, List, Any |
| | import torch |
| | from transformers import PegasusForConditionalGeneration, PegasusTokenizer |
| | import re |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | """ |
| | Initialize the endpoint handler with the model and tokenizer. |
| | |
| | :param path: Path to the model weights |
| | """ |
| | |
| | self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| | |
| | |
| | self.tokenizer = PegasusTokenizer.from_pretrained(path) |
| | self.model = PegasusForConditionalGeneration.from_pretrained(path).to(self.torch_device) |
| |
|
| | def split_into_paragraphs(self, text: str) -> List[str]: |
| | """ |
| | Split text into paragraphs while preserving empty lines. |
| | |
| | :param text: Input text |
| | :return: List of paragraphs |
| | """ |
| | paragraphs = text.split('\n\n') |
| | return [p.strip() for p in paragraphs if p.strip()] |
| |
|
| | def split_into_sentences(self, paragraph: str) -> List[str]: |
| | """ |
| | Split paragraph into sentences using regex. |
| | |
| | :param paragraph: Input paragraph |
| | :return: List of sentences |
| | """ |
| | sentences = re.split(r'(?<=[.!?])\s+', paragraph) |
| | return [s.strip() for s in sentences if s.strip()] |
| |
|
| | def get_response(self, input_text: str, num_return_sequences: int = 1) -> str: |
| | """ |
| | Generate paraphrased text for a single input. |
| | |
| | :param input_text: Input sentence to paraphrase |
| | :param num_return_sequences: Number of alternative paraphrases to generate |
| | :return: Paraphrased text |
| | """ |
| | batch = self.tokenizer.prepare_seq2seq_batch( |
| | [input_text], |
| | truncation=True, |
| | padding='longest', |
| | max_length=80, |
| | return_tensors="pt" |
| | ).to(self.torch_device) |
| |
|
| | translated = self.model.generate( |
| | **batch, |
| | num_beams=10, |
| | num_return_sequences=num_return_sequences, |
| | temperature=1.0, |
| | repetition_penalty=2.8, |
| | length_penalty=1.2, |
| | max_length=80, |
| | min_length=5, |
| | no_repeat_ngram_size=3 |
| | ) |
| |
|
| | tgt_text = self.tokenizer.batch_decode(translated, skip_special_tokens=True) |
| | return tgt_text[0] |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Process the incoming request and generate paraphrased text. |
| | |
| | :param data: Request payload containing input text |
| | :return: Paraphrased text |
| | """ |
| | |
| | inputs = data.pop("inputs", data) |
| | |
| | |
| | if not isinstance(inputs, str): |
| | raise ValueError("Input must be a string") |
| |
|
| | |
| | paragraphs = self.split_into_paragraphs(inputs) |
| | paraphrased_paragraphs = [] |
| |
|
| | |
| | for paragraph in paragraphs: |
| | sentences = self.split_into_sentences(paragraph) |
| | paraphrased_sentences = [] |
| |
|
| | for sentence in sentences: |
| | |
| | if len(sentence.split()) < 3: |
| | paraphrased_sentences.append(sentence) |
| | continue |
| |
|
| | try: |
| | |
| | paraphrased = self.get_response(sentence) |
| | |
| | |
| | if not any(phrase in paraphrased.lower() for phrase in ['it\'s like', 'in other words']): |
| | paraphrased_sentences.append(paraphrased) |
| | else: |
| | paraphrased_sentences.append(sentence) |
| | except Exception as e: |
| | print(f"Error processing sentence: {e}") |
| | paraphrased_sentences.append(sentence) |
| |
|
| | |
| | paraphrased_paragraphs.append(' '.join(paraphrased_sentences)) |
| |
|
| | |
| | return {"outputs": '\n\n'.join(paraphrased_paragraphs)} |