| import numpy as np |
|
|
| def solver(u0_batch, t_coordinate, beta): |
| """Solves the Advection equation for all times in t_coordinate. |
| Args: |
| u0_batch (np.ndarray): Initial condition [batch_size, N], |
| where batch_size is the number of different initial conditions, |
| and N is the number of spatial grid points. |
| t_coordinate (np.ndarray): Time coordinates of shape [T+1]. |
| It begins with t_0=0 and follows the time steps t_1, ..., t_T. |
| beta (float): Constant advection speed. |
| Returns: |
| solutions (np.ndarray): Shape [batch_size, T+1, N]. |
| solutions[:, 0, :] contains the initial conditions (u0_batch), |
| solutions[:, i, :] contains the solutions at time t_coordinate[i]. |
| """ |
| |
| try: |
| import torch |
| use_torch = True |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using PyTorch backend on {device}") |
| except ImportError: |
| use_torch = False |
| try: |
| import jax |
| import jax.numpy as jnp |
| use_jax = True |
| print(f"Using JAX backend") |
| except ImportError: |
| use_jax = False |
| print(f"Using NumPy backend") |
| |
| batch_size, N = u0_batch.shape |
| T = len(t_coordinate) - 1 |
| |
| |
| dx = 1.0 / N |
| |
| |
| |
| |
| cfl_factor = 0.8 |
| dt_cfl = cfl_factor * dx / beta |
| |
| |
| total_time = t_coordinate[-1] |
| n_internal_steps = int(np.ceil(total_time / dt_cfl)) |
| dt_internal = total_time / n_internal_steps |
| |
| print(f"Spatial step size (dx): {dx:.6f}") |
| print(f"Internal time step (dt): {dt_internal:.6f}") |
| print(f"Number of internal time steps: {n_internal_steps}") |
| print(f"CFL number: {beta * dt_internal / dx:.6f}") |
| |
| |
| solutions = np.zeros((batch_size, T+1, N)) |
| solutions[:, 0, :] = u0_batch |
| |
| |
| coeff = beta * dt_internal / (2 * dx) |
| |
| |
| if use_torch: |
| u_current = torch.tensor(u0_batch, dtype=torch.float32, device=device) |
| elif use_jax: |
| u_current = jnp.array(u0_batch, dtype=jnp.float32) |
| |
| |
| @jax.jit |
| def time_step(u): |
| |
| u_prev = jnp.roll(u, 1, axis=1) |
| u_next = jnp.roll(u, -1, axis=1) |
| return u - coeff * (u_next - u_prev) |
| else: |
| u_current = u0_batch.copy() |
| |
| |
| current_time = 0.0 |
| next_output_idx = 1 |
| |
| for step in range(n_internal_steps): |
| |
| current_time += dt_internal |
| |
| |
| if use_torch: |
| |
| u_prev = torch.roll(u_current, 1, dims=1) |
| u_next = torch.roll(u_current, -1, dims=1) |
| u_current = u_current - coeff * (u_next - u_prev) |
| elif use_jax: |
| u_current = time_step(u_current) |
| else: |
| |
| u_prev = np.roll(u_current, 1, axis=1) |
| u_next = np.roll(u_current, -1, axis=1) |
| u_current = u_current - coeff * (u_next - u_prev) |
| |
| |
| while next_output_idx <= T and current_time >= t_coordinate[next_output_idx]: |
| |
| if use_torch: |
| solutions[:, next_output_idx, :] = u_current.cpu().numpy() |
| elif use_jax: |
| solutions[:, next_output_idx, :] = np.array(u_current) |
| else: |
| solutions[:, next_output_idx, :] = u_current |
| |
| |
| if next_output_idx % 10 == 0 or next_output_idx == T: |
| print(f"Stored solution at time {t_coordinate[next_output_idx]:.4f} (Step {step+1}/{n_internal_steps})") |
| |
| next_output_idx += 1 |
| |
| return solutions |