Instructions to use yangwang825/tdnn-aam with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use yangwang825/tdnn-aam with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="yangwang825/tdnn-aam", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("yangwang825/tdnn-aam", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import math | |
| import torch | |
| import torch.nn as nn | |
| class Deltas(torch.nn.Module): | |
| """Computes delta coefficients (time derivatives). | |
| Arguments | |
| --------- | |
| win_length : int | |
| Length of the window used to compute the time derivatives. | |
| Example | |
| ------- | |
| >>> inputs = torch.randn([10, 101, 20]) | |
| >>> compute_deltas = Deltas(input_size=inputs.size(-1)) | |
| >>> features = compute_deltas(inputs) | |
| >>> features.shape | |
| torch.Size([10, 101, 20]) | |
| """ | |
| def __init__( | |
| self, input_size, window_length=5, | |
| ): | |
| super().__init__() | |
| self.n = (window_length - 1) // 2 | |
| self.denom = self.n * (self.n + 1) * (2 * self.n + 1) / 3 | |
| self.register_buffer( | |
| "kernel", | |
| torch.arange(-self.n, self.n + 1, dtype=torch.float32,).repeat( | |
| input_size, 1, 1 | |
| ), | |
| ) | |
| def forward(self, x): | |
| """Returns the delta coefficients. | |
| Arguments | |
| --------- | |
| x : tensor | |
| A batch of tensors. | |
| """ | |
| # Managing multi-channel deltas reshape tensor (batch*channel,time) | |
| x = x.transpose(1, 2).transpose(2, -1) | |
| or_shape = x.shape | |
| if len(or_shape) == 4: | |
| x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3]) | |
| # Padding for time borders | |
| x = torch.nn.functional.pad(x, (self.n, self.n), mode="replicate") | |
| # Derivative estimation (with a fixed convolutional kernel) | |
| delta_coeff = ( | |
| torch.nn.functional.conv1d( | |
| x, self.kernel.to(x.device), groups=x.shape[1] | |
| ) | |
| / self.denom | |
| ) | |
| # Retrieving the original dimensionality (for multi-channel case) | |
| if len(or_shape) == 4: | |
| delta_coeff = delta_coeff.reshape( | |
| or_shape[0], or_shape[1], or_shape[2], or_shape[3], | |
| ) | |
| delta_coeff = delta_coeff.transpose(1, -1).transpose(2, -1) | |
| return delta_coeff | |
| class Filterbank(torch.nn.Module): | |
| """computes filter bank (FBANK) features given spectral magnitudes. | |
| Arguments | |
| --------- | |
| n_mels : float | |
| Number of Mel filters used to average the spectrogram. | |
| log_mel : bool | |
| If True, it computes the log of the FBANKs. | |
| filter_shape : str | |
| Shape of the filters ('triangular', 'rectangular', 'gaussian'). | |
| f_min : int | |
| Lowest frequency for the Mel filters. | |
| f_max : int | |
| Highest frequency for the Mel filters. | |
| n_fft : int | |
| Number of fft points of the STFT. It defines the frequency resolution | |
| (n_fft should be<= than win_len). | |
| sample_rate : int | |
| Sample rate of the input audio signal (e.g, 16000) | |
| power_spectrogram : float | |
| Exponent used for spectrogram computation. | |
| amin : float | |
| Minimum amplitude (used for numerical stability). | |
| ref_value : float | |
| Reference value used for the dB scale. | |
| top_db : float | |
| Minimum negative cut-off in decibels. | |
| freeze : bool | |
| If False, it the central frequency and the band of each filter are | |
| added into nn.parameters. If True, the standard frozen features | |
| are computed. | |
| param_change_factor: bool | |
| If freeze=False, this parameter affects the speed at which the filter | |
| parameters (i.e., central_freqs and bands) can be changed. When high | |
| (e.g., param_change_factor=1) the filters change a lot during training. | |
| When low (e.g. param_change_factor=0.1) the filter parameters are more | |
| stable during training | |
| param_rand_factor: float | |
| This parameter can be used to randomly change the filter parameters | |
| (i.e, central frequencies and bands) during training. It is thus a | |
| sort of regularization. param_rand_factor=0 does not affect, while | |
| param_rand_factor=0.15 allows random variations within +-15% of the | |
| standard values of the filter parameters (e.g., if the central freq | |
| is 100 Hz, we can randomly change it from 85 Hz to 115 Hz). | |
| Example | |
| ------- | |
| >>> import torch | |
| >>> compute_fbanks = Filterbank() | |
| >>> inputs = torch.randn([10, 101, 201]) | |
| >>> features = compute_fbanks(inputs) | |
| >>> features.shape | |
| torch.Size([10, 101, 40]) | |
| """ | |
| def __init__( | |
| self, | |
| n_mels=40, | |
| log_mel=True, | |
| filter_shape="triangular", | |
| f_min=0, | |
| f_max=8000, | |
| n_fft=400, | |
| sample_rate=16000, | |
| power_spectrogram=2, | |
| amin=1e-10, | |
| ref_value=1.0, | |
| top_db=80.0, | |
| param_change_factor=1.0, | |
| param_rand_factor=0.0, | |
| freeze=True, | |
| ): | |
| super().__init__() | |
| self.n_mels = n_mels | |
| self.log_mel = log_mel | |
| self.filter_shape = filter_shape | |
| self.f_min = f_min | |
| self.f_max = f_max | |
| self.n_fft = n_fft | |
| self.sample_rate = sample_rate | |
| self.power_spectrogram = power_spectrogram | |
| self.amin = amin | |
| self.ref_value = ref_value | |
| self.top_db = top_db | |
| self.freeze = freeze | |
| self.n_stft = self.n_fft // 2 + 1 | |
| self.db_multiplier = math.log10(max(self.amin, self.ref_value)) | |
| self.device_inp = torch.device("cpu") | |
| self.param_change_factor = param_change_factor | |
| self.param_rand_factor = param_rand_factor | |
| if self.power_spectrogram == 2: | |
| self.multiplier = 10 | |
| else: | |
| self.multiplier = 20 | |
| # Make sure f_min < f_max | |
| if self.f_min >= self.f_max: | |
| err_msg = "Require f_min: %f < f_max: %f" % ( | |
| self.f_min, | |
| self.f_max, | |
| ) | |
| print(err_msg) | |
| # Filter definition | |
| mel = torch.linspace( | |
| self._to_mel(self.f_min), self._to_mel(self.f_max), self.n_mels + 2 | |
| ) | |
| hz = self._to_hz(mel) | |
| # Computation of the filter bands | |
| band = hz[1:] - hz[:-1] | |
| self.band = band[:-1] | |
| self.f_central = hz[1:-1] | |
| # Adding the central frequency and the band to the list of nn param | |
| if not self.freeze: | |
| self.f_central = torch.nn.Parameter( | |
| self.f_central / (self.sample_rate * self.param_change_factor) | |
| ) | |
| self.band = torch.nn.Parameter( | |
| self.band / (self.sample_rate * self.param_change_factor) | |
| ) | |
| # Frequency axis | |
| all_freqs = torch.linspace(0, self.sample_rate // 2, self.n_stft) | |
| # Replicating for all the filters | |
| self.all_freqs_mat = all_freqs.repeat(self.f_central.shape[0], 1) | |
| def forward(self, spectrogram): | |
| """Returns the FBANks. | |
| Arguments | |
| --------- | |
| x : tensor | |
| A batch of spectrogram tensors. | |
| """ | |
| # Computing central frequency and bandwidth of each filter | |
| f_central_mat = self.f_central.repeat( | |
| self.all_freqs_mat.shape[1], 1 | |
| ).transpose(0, 1) | |
| band_mat = self.band.repeat(self.all_freqs_mat.shape[1], 1).transpose( | |
| 0, 1 | |
| ) | |
| # Uncomment to print filter parameters | |
| # print(self.f_central*self.sample_rate * self.param_change_factor) | |
| # print(self.band*self.sample_rate* self.param_change_factor) | |
| # Creation of the multiplication matrix. It is used to create | |
| # the filters that average the computed spectrogram. | |
| if not self.freeze: | |
| f_central_mat = f_central_mat * ( | |
| self.sample_rate | |
| * self.param_change_factor | |
| * self.param_change_factor | |
| ) | |
| band_mat = band_mat * ( | |
| self.sample_rate | |
| * self.param_change_factor | |
| * self.param_change_factor | |
| ) | |
| # Regularization with random changes of filter central frequency and band | |
| elif self.param_rand_factor != 0 and self.training: | |
| rand_change = ( | |
| 1.0 | |
| + torch.rand(2) * 2 * self.param_rand_factor | |
| - self.param_rand_factor | |
| ) | |
| f_central_mat = f_central_mat * rand_change[0] | |
| band_mat = band_mat * rand_change[1] | |
| fbank_matrix = self._create_fbank_matrix(f_central_mat, band_mat).to( | |
| spectrogram.device | |
| ) | |
| sp_shape = spectrogram.shape | |
| # Managing multi-channels case (batch, time, channels) | |
| if len(sp_shape) == 4: | |
| spectrogram = spectrogram.permute(0, 3, 1, 2) | |
| spectrogram = spectrogram.reshape( | |
| sp_shape[0] * sp_shape[3], sp_shape[1], sp_shape[2] | |
| ) | |
| # FBANK computation | |
| fbanks = torch.matmul(spectrogram, fbank_matrix) | |
| if self.log_mel: | |
| fbanks = self._amplitude_to_DB(fbanks) | |
| # Reshaping in the case of multi-channel inputs | |
| if len(sp_shape) == 4: | |
| fb_shape = fbanks.shape | |
| fbanks = fbanks.reshape( | |
| sp_shape[0], sp_shape[3], fb_shape[1], fb_shape[2] | |
| ) | |
| fbanks = fbanks.permute(0, 2, 3, 1) | |
| return fbanks | |
| def _to_mel(hz): | |
| """Returns mel-frequency value corresponding to the input | |
| frequency value in Hz. | |
| Arguments | |
| --------- | |
| x : float | |
| The frequency point in Hz. | |
| """ | |
| return 2595 * math.log10(1 + hz / 700) | |
| def _to_hz(mel): | |
| """Returns hz-frequency value corresponding to the input | |
| mel-frequency value. | |
| Arguments | |
| --------- | |
| x : float | |
| The frequency point in the mel-scale. | |
| """ | |
| return 700 * (10 ** (mel / 2595) - 1) | |
| def _triangular_filters(self, all_freqs, f_central, band): | |
| """Returns fbank matrix using triangular filters. | |
| Arguments | |
| --------- | |
| all_freqs : Tensor | |
| Tensor gathering all the frequency points. | |
| f_central : Tensor | |
| Tensor gathering central frequencies of each filter. | |
| band : Tensor | |
| Tensor gathering the bands of each filter. | |
| """ | |
| # Computing the slops of the filters | |
| slope = (all_freqs - f_central) / band | |
| left_side = slope + 1.0 | |
| right_side = -slope + 1.0 | |
| # Adding zeros for negative values | |
| zero = torch.zeros(1, device=self.device_inp) | |
| fbank_matrix = torch.max( | |
| zero, torch.min(left_side, right_side) | |
| ).transpose(0, 1) | |
| return fbank_matrix | |
| def _rectangular_filters(self, all_freqs, f_central, band): | |
| """Returns fbank matrix using rectangular filters. | |
| Arguments | |
| --------- | |
| all_freqs : Tensor | |
| Tensor gathering all the frequency points. | |
| f_central : Tensor | |
| Tensor gathering central frequencies of each filter. | |
| band : Tensor | |
| Tensor gathering the bands of each filter. | |
| """ | |
| # cut-off frequencies of the filters | |
| low_hz = f_central - band | |
| high_hz = f_central + band | |
| # Left/right parts of the filter | |
| left_side = right_size = all_freqs.ge(low_hz) | |
| right_size = all_freqs.le(high_hz) | |
| fbank_matrix = (left_side * right_size).float().transpose(0, 1) | |
| return fbank_matrix | |
| def _gaussian_filters( | |
| self, all_freqs, f_central, band, smooth_factor=torch.tensor(2) | |
| ): | |
| """Returns fbank matrix using gaussian filters. | |
| Arguments | |
| --------- | |
| all_freqs : Tensor | |
| Tensor gathering all the frequency points. | |
| f_central : Tensor | |
| Tensor gathering central frequencies of each filter. | |
| band : Tensor | |
| Tensor gathering the bands of each filter. | |
| smooth_factor: Tensor | |
| Smoothing factor of the gaussian filter. It can be used to employ | |
| sharper or flatter filters. | |
| """ | |
| fbank_matrix = torch.exp( | |
| -0.5 * ((all_freqs - f_central) / (band / smooth_factor)) ** 2 | |
| ).transpose(0, 1) | |
| return fbank_matrix | |
| def _create_fbank_matrix(self, f_central_mat, band_mat): | |
| """Returns fbank matrix to use for averaging the spectrum with | |
| the set of filter-banks. | |
| Arguments | |
| --------- | |
| f_central : Tensor | |
| Tensor gathering central frequencies of each filter. | |
| band : Tensor | |
| Tensor gathering the bands of each filter. | |
| smooth_factor: Tensor | |
| Smoothing factor of the gaussian filter. It can be used to employ | |
| sharper or flatter filters. | |
| """ | |
| if self.filter_shape == "triangular": | |
| fbank_matrix = self._triangular_filters( | |
| self.all_freqs_mat, f_central_mat, band_mat | |
| ) | |
| elif self.filter_shape == "rectangular": | |
| fbank_matrix = self._rectangular_filters( | |
| self.all_freqs_mat, f_central_mat, band_mat | |
| ) | |
| else: | |
| fbank_matrix = self._gaussian_filters( | |
| self.all_freqs_mat, f_central_mat, band_mat | |
| ) | |
| return fbank_matrix | |
| def _amplitude_to_DB(self, x): | |
| """Converts linear-FBANKs to log-FBANKs. | |
| Arguments | |
| --------- | |
| x : Tensor | |
| A batch of linear FBANK tensors. | |
| """ | |
| x_db = self.multiplier * torch.log10(torch.clamp(x, min=self.amin)) | |
| x_db -= self.multiplier * self.db_multiplier | |
| # Setting up dB max. It is the max over time and frequency, | |
| # Hence, of a whole sequence (sequence-dependent) | |
| new_x_db_max = x_db.amax(dim=(-2, -1)) - self.top_db | |
| # Clipping to dB max. The view is necessary as only a scalar is obtained | |
| # per sequence. | |
| x_db = torch.max(x_db, new_x_db_max.view(x_db.shape[0], 1, 1)) | |
| return x_db | |
| class STFT(torch.nn.Module): | |
| """computes the Short-Term Fourier Transform (STFT). | |
| This class computes the Short-Term Fourier Transform of an audio signal. | |
| It supports multi-channel audio inputs (batch, time, channels). | |
| Arguments | |
| --------- | |
| sample_rate : int | |
| Sample rate of the input audio signal (e.g 16000). | |
| win_length : float | |
| Length (in ms) of the sliding window used to compute the STFT. | |
| hop_length : float | |
| Length (in ms) of the hope of the sliding window used to compute | |
| the STFT. | |
| n_fft : int | |
| Number of fft point of the STFT. It defines the frequency resolution | |
| (n_fft should be <= than win_len). | |
| window_fn : function | |
| A function that takes an integer (number of samples) and outputs a | |
| tensor to be multiplied with each window before fft. | |
| normalized_stft : bool | |
| If True, the function returns the normalized STFT results, | |
| i.e., multiplied by win_length^-0.5 (default is False). | |
| center : bool | |
| If True (default), the input will be padded on both sides so that the | |
| t-th frame is centered at time t×hop_length. Otherwise, the t-th frame | |
| begins at time t×hop_length. | |
| pad_mode : str | |
| It can be 'constant','reflect','replicate', 'circular', 'reflect' | |
| (default). 'constant' pads the input tensor boundaries with a | |
| constant value. 'reflect' pads the input tensor using the reflection | |
| of the input boundary. 'replicate' pads the input tensor using | |
| replication of the input boundary. 'circular' pads using circular | |
| replication. | |
| onesided : True | |
| If True (default) only returns nfft/2 values. Note that the other | |
| samples are redundant due to the Fourier transform conjugate symmetry. | |
| Example | |
| ------- | |
| >>> import torch | |
| >>> compute_STFT = STFT( | |
| ... sample_rate=16000, win_length=25, hop_length=10, n_fft=400 | |
| ... ) | |
| >>> inputs = torch.randn([10, 16000]) | |
| >>> features = compute_STFT(inputs) | |
| >>> features.shape | |
| torch.Size([10, 101, 201, 2]) | |
| """ | |
| def __init__( | |
| self, | |
| sample_rate, | |
| win_length=25, | |
| hop_length=10, | |
| n_fft=400, | |
| window_fn=torch.hamming_window, | |
| normalized_stft=False, | |
| center=True, | |
| pad_mode="constant", | |
| onesided=True, | |
| ): | |
| super().__init__() | |
| self.sample_rate = sample_rate | |
| self.win_length = win_length | |
| self.hop_length = hop_length | |
| self.n_fft = n_fft | |
| self.normalized_stft = normalized_stft | |
| self.center = center | |
| self.pad_mode = pad_mode | |
| self.onesided = onesided | |
| # Convert win_length and hop_length from ms to samples | |
| self.win_length = int( | |
| round((self.sample_rate / 1000.0) * self.win_length) | |
| ) | |
| self.hop_length = int( | |
| round((self.sample_rate / 1000.0) * self.hop_length) | |
| ) | |
| self.window = window_fn(self.win_length) | |
| def forward(self, x): | |
| """Returns the STFT generated from the input waveforms. | |
| Arguments | |
| --------- | |
| x : tensor | |
| A batch of audio signals to transform. | |
| """ | |
| # Managing multi-channel stft | |
| or_shape = x.shape | |
| if len(or_shape) == 3: | |
| x = x.transpose(1, 2) | |
| x = x.reshape(or_shape[0] * or_shape[2], or_shape[1]) | |
| stft = torch.stft( | |
| x, | |
| self.n_fft, | |
| self.hop_length, | |
| self.win_length, | |
| self.window.to(x.device), | |
| self.center, | |
| self.pad_mode, | |
| self.normalized_stft, | |
| self.onesided, | |
| return_complex=True, | |
| ) | |
| stft = torch.view_as_real(stft) | |
| # Retrieving the original dimensionality (batch,time, channels) | |
| if len(or_shape) == 3: | |
| stft = stft.reshape( | |
| or_shape[0], | |
| or_shape[2], | |
| stft.shape[1], | |
| stft.shape[2], | |
| stft.shape[3], | |
| ) | |
| stft = stft.permute(0, 3, 2, 4, 1) | |
| else: | |
| # (batch, time, channels) | |
| stft = stft.transpose(2, 1) | |
| return stft | |
| def spectral_magnitude( | |
| stft, power: int = 1, log: bool = False, eps: float = 1e-14 | |
| ): | |
| """Returns the magnitude of a complex spectrogram. | |
| Arguments | |
| --------- | |
| stft : torch.Tensor | |
| A tensor, output from the stft function. | |
| power : int | |
| What power to use in computing the magnitude. | |
| Use power=1 for the power spectrogram. | |
| Use power=0.5 for the magnitude spectrogram. | |
| log : bool | |
| Whether to apply log to the spectral features. | |
| Example | |
| ------- | |
| >>> a = torch.Tensor([[3, 4]]) | |
| >>> spectral_magnitude(a, power=0.5) | |
| tensor([5.]) | |
| """ | |
| spectr = stft.pow(2).sum(-1) | |
| # Add eps avoids NaN when spectr is zero | |
| if power < 1: | |
| spectr = spectr + eps | |
| spectr = spectr.pow(power) | |
| if log: | |
| return torch.log(spectr + eps) | |
| return spectr | |
| class ContextWindow(torch.nn.Module): | |
| """Computes the context window. | |
| This class applies a context window by gathering multiple time steps | |
| in a single feature vector. The operation is performed with a | |
| convolutional layer based on a fixed kernel designed for that. | |
| Arguments | |
| --------- | |
| left_frames : int | |
| Number of left frames (i.e, past frames) to collect. | |
| right_frames : int | |
| Number of right frames (i.e, future frames) to collect. | |
| Example | |
| ------- | |
| >>> import torch | |
| >>> compute_cw = ContextWindow(left_frames=5, right_frames=5) | |
| >>> inputs = torch.randn([10, 101, 20]) | |
| >>> features = compute_cw(inputs) | |
| >>> features.shape | |
| torch.Size([10, 101, 220]) | |
| """ | |
| def __init__( | |
| self, left_frames=0, right_frames=0, | |
| ): | |
| super().__init__() | |
| self.left_frames = left_frames | |
| self.right_frames = right_frames | |
| self.context_len = self.left_frames + self.right_frames + 1 | |
| self.kernel_len = 2 * max(self.left_frames, self.right_frames) + 1 | |
| # Kernel definition | |
| self.kernel = torch.eye(self.context_len, self.kernel_len) | |
| if self.right_frames > self.left_frames: | |
| lag = self.right_frames - self.left_frames | |
| self.kernel = torch.roll(self.kernel, lag, 1) | |
| self.first_call = True | |
| def forward(self, x): | |
| """Returns the tensor with the surrounding context. | |
| Arguments | |
| --------- | |
| x : tensor | |
| A batch of tensors. | |
| """ | |
| x = x.transpose(1, 2) | |
| if self.first_call is True: | |
| self.first_call = False | |
| self.kernel = ( | |
| self.kernel.repeat(x.shape[1], 1, 1) | |
| .view(x.shape[1] * self.context_len, self.kernel_len,) | |
| .unsqueeze(1) | |
| ) | |
| # Managing multi-channel case | |
| or_shape = x.shape | |
| if len(or_shape) == 4: | |
| x = x.reshape(or_shape[0] * or_shape[2], or_shape[1], or_shape[3]) | |
| # Compute context (using the estimated convolutional kernel) | |
| cw_x = torch.nn.functional.conv1d( | |
| x, | |
| self.kernel.to(x.device), | |
| groups=x.shape[1], | |
| padding=max(self.left_frames, self.right_frames), | |
| ) | |
| # Retrieving the original dimensionality (for multi-channel case) | |
| if len(or_shape) == 4: | |
| cw_x = cw_x.reshape( | |
| or_shape[0], cw_x.shape[1], or_shape[2], cw_x.shape[-1] | |
| ) | |
| cw_x = cw_x.transpose(1, 2) | |
| return cw_x | |
| class Fbank(torch.nn.Module): | |
| def __init__( | |
| self, | |
| deltas=False, | |
| context=False, | |
| requires_grad=False, | |
| sample_rate=16000, | |
| f_min=0, | |
| f_max=None, | |
| n_fft=400, | |
| n_mels=40, | |
| filter_shape="triangular", | |
| param_change_factor=1.0, | |
| param_rand_factor=0.0, | |
| left_frames=5, | |
| right_frames=5, | |
| win_length=25, | |
| hop_length=10, | |
| ): | |
| super().__init__() | |
| self.deltas = deltas | |
| self.context = context | |
| self.requires_grad = requires_grad | |
| if f_max is None: | |
| f_max = sample_rate / 2 | |
| self.compute_STFT = STFT( | |
| sample_rate=sample_rate, | |
| n_fft=n_fft, | |
| win_length=win_length, | |
| hop_length=hop_length, | |
| ) | |
| self.compute_fbanks = Filterbank( | |
| sample_rate=sample_rate, | |
| n_fft=n_fft, | |
| n_mels=n_mels, | |
| f_min=f_min, | |
| f_max=f_max, | |
| freeze=not requires_grad, | |
| filter_shape=filter_shape, | |
| param_change_factor=param_change_factor, | |
| param_rand_factor=param_rand_factor, | |
| ) | |
| self.compute_deltas = Deltas(input_size=n_mels) | |
| self.context_window = ContextWindow( | |
| left_frames=left_frames, right_frames=right_frames, | |
| ) | |
| def forward(self, wav): | |
| """Returns a set of features generated from the input waveforms. | |
| Arguments | |
| --------- | |
| wav : tensor | |
| A batch of audio signals to transform to features. | |
| """ | |
| STFT = self.compute_STFT(wav) | |
| mag = spectral_magnitude(STFT) | |
| fbanks = self.compute_fbanks(mag) | |
| if self.deltas: | |
| delta1 = self.compute_deltas(fbanks) | |
| delta2 = self.compute_deltas(delta1) | |
| fbanks = torch.cat([fbanks, delta1, delta2], dim=2) | |
| if self.context: | |
| fbanks = self.context_window(fbanks) | |
| return fbanks |