| | import os |
| | import io |
| | import base64 |
| | import torch |
| | import numpy as np |
| | from transformers import BarkModel, BarkProcessor |
| | from typing import Dict, List, Any |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | """ |
| | Initialize the handler for Bark text-to-speech model. |
| | Args: |
| | path (str, optional): Path to the model directory. Defaults to "". |
| | """ |
| | self.path = path |
| | self.model = None |
| | self.processor = None |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.initialized = False |
| |
|
| | def setup(self, **kwargs): |
| | """ |
| | Load the model and processor. |
| | Args: |
| | **kwargs: Additional arguments. |
| | """ |
| | |
| | self.model = BarkModel.from_pretrained(self.path) |
| | self.model.to(self.device) |
| | |
| | |
| | self.processor = BarkProcessor.from_pretrained(self.path) |
| | |
| | self.initialized = True |
| | print(f"Bark model loaded on {self.device}") |
| |
|
| | def preprocess(self, request: Dict) -> Dict: |
| | """ |
| | Process the input request before inference. |
| | Args: |
| | request (Dict): The request data containing text to convert to speech. |
| | Returns: |
| | Dict: Processed inputs for the model. |
| | """ |
| | if not self.initialized: |
| | self.setup() |
| | |
| | inputs = {} |
| | |
| | |
| | if "inputs" in request: |
| | if isinstance(request["inputs"], str): |
| | |
| | inputs["text"] = request["inputs"] |
| | elif isinstance(request["inputs"], list): |
| | |
| | inputs["text"] = request["inputs"][0] |
| | |
| | |
| | params = request.get("parameters", {}) |
| | |
| | |
| | if "speaker_id" in params: |
| | inputs["speaker_id"] = params["speaker_id"] |
| | elif "voice_preset" in params: |
| | inputs["voice_preset"] = params["voice_preset"] |
| | |
| | |
| | if "temperature" in params: |
| | inputs["temperature"] = params.get("temperature", 0.7) |
| | |
| | return inputs |
| |
|
| | def inference(self, inputs: Dict) -> Dict: |
| | """ |
| | Run model inference on the processed inputs. |
| | Args: |
| | inputs (Dict): Processed inputs for the model. |
| | Returns: |
| | Dict: Model outputs. |
| | """ |
| | text = inputs.get("text", "") |
| | if not text: |
| | return {"error": "No text provided for speech generation"} |
| | |
| | |
| | speaker_id = inputs.get("speaker_id", None) |
| | voice_preset = inputs.get("voice_preset", None) |
| | temperature = inputs.get("temperature", 0.7) |
| | |
| | |
| | input_ids = self.processor(text).to(self.device) |
| | |
| | |
| | with torch.no_grad(): |
| | if speaker_id: |
| | |
| | speech_output = self.model.generate( |
| | input_ids=input_ids, |
| | speaker_id=speaker_id, |
| | temperature=temperature |
| | ) |
| | elif voice_preset: |
| | |
| | speech_output = self.model.generate( |
| | input_ids=input_ids, |
| | voice_preset=voice_preset, |
| | temperature=temperature |
| | ) |
| | else: |
| | |
| | speech_output = self.model.generate( |
| | input_ids=input_ids, |
| | temperature=temperature |
| | ) |
| | |
| | |
| | audio_array = speech_output.cpu().numpy().squeeze() |
| | |
| | return {"audio_array": audio_array, "sample_rate": self.model.generation_config.sample_rate} |
| |
|
| | def postprocess(self, inference_output: Dict) -> Dict: |
| | """ |
| | Process the model outputs after inference. |
| | Args: |
| | inference_output (Dict): Model outputs. |
| | Returns: |
| | Dict: Processed outputs ready for the response. |
| | """ |
| | if "error" in inference_output: |
| | return {"error": inference_output["error"]} |
| | |
| | audio_array = inference_output.get("audio_array") |
| | sample_rate = inference_output.get("sample_rate", 24000) |
| | |
| | |
| | try: |
| | import scipy.io.wavfile as wav |
| | audio_buffer = io.BytesIO() |
| | wav.write(audio_buffer, sample_rate, audio_array) |
| | audio_buffer.seek(0) |
| | audio_data = audio_buffer.read() |
| | |
| | |
| | audio_base64 = base64.b64encode(audio_data).decode("utf-8") |
| | |
| | return { |
| | "audio": audio_base64, |
| | "sample_rate": sample_rate, |
| | "format": "wav" |
| | } |
| | except Exception as e: |
| | return {"error": f"Error converting audio: {str(e)}"} |
| |
|
| | def __call__(self, data: Dict) -> Dict: |
| | """ |
| | Main entry point for the handler. |
| | Args: |
| | data (Dict): Request data. |
| | Returns: |
| | Dict: Response data. |
| | """ |
| | |
| | if not self.initialized: |
| | self.setup() |
| | |
| | |
| | try: |
| | inputs = self.preprocess(data) |
| | outputs = self.inference(inputs) |
| | response = self.postprocess(outputs) |
| | return response |
| | except Exception as e: |
| | return {"error": f"Error processing request: {str(e)}"} |
| |
|