| import torch |
| from math import gcd |
| from typing import Optional, Union |
| import joblib |
|
|
| import numpy as np |
| from scipy import signal |
|
|
| def load_scaler_joblib(path: str) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Load ecg_scaler.pkl and return center and scale as torch tensors. |
| Args: |
| path: Path to the joblib file. |
| Returns: |
| center: torch.Tensor |
| scale: torch.Tensor |
| """ |
| sc = joblib.load(path) |
| center = torch.from_numpy(sc.mean_.astype(np.float32)) |
| scale = torch.from_numpy(sc.scale_.astype(np.float32)).clamp_min(1e-8) |
| return center, scale |
|
|
| class ECGTransform: |
| """ |
| Unified ECG preprocessing: downsampling and scaling. |
| Usage: |
| transform = ECGTransform(center, scale, src_fs=512, target_fs=100) |
| ecg_out = transform(ecg_in) |
| """ |
| def __init__( |
| self, |
| center: Union[np.ndarray, torch.Tensor], |
| scale: Union[np.ndarray, torch.Tensor], |
| src_fs: int = 100, |
| target_fs: int = 100, |
| band: Optional[tuple[float, float]] = (0.5, 40.0), |
| bp_order: int = 4, |
| axis: int = -1, |
| ) -> None: |
| self.center = torch.as_tensor(center, dtype=torch.float32) |
| self.scale = torch.as_tensor(scale, dtype=torch.float32).clamp_min(1e-8) |
| self.src_fs = src_fs |
| self.target_fs = target_fs |
| self.band = band |
| self.bp_order = bp_order |
| self.axis = axis |
|
|
| def downsample(self, x: np.ndarray) -> np.ndarray: |
| x = np.asarray(x) |
| if self.band is not None: |
| lowcut, highcut = self.band |
| max_high = 0.45 * self.target_fs |
| highcut = min(highcut, max_high) |
| nyq = self.src_fs / 2.0 |
| if lowcut <= 0: |
| wn = highcut / nyq |
| sos = signal.butter(self.bp_order, wn, btype="low", output="sos") |
| else: |
| wn = (lowcut / nyq, highcut / nyq) |
| sos = signal.butter(self.bp_order, wn, btype="band", output="sos") |
| x = signal.sosfiltfilt(sos, x, axis=self.axis) |
| g = gcd(self.src_fs, self.target_fs) |
| up = self.target_fs // g |
| down = self.src_fs // g |
| y = signal.resample_poly(x, up, down, axis=self.axis, window=("kaiser", 5.0), padtype="median") |
| return y |
|
|
| def scale(self, ecg: torch.Tensor) -> torch.Tensor: |
| ecg = ecg.to(torch.float32) |
| ecg = (ecg - self.center[:, None]) / self.scale[:, None] |
| return ecg |
|
|
| def __call__(self, x: np.ndarray) -> torch.Tensor: |
| """ |
| Downsample and scale ECG data. |
| Args: |
| x: np.ndarray, shape (leads, time) |
| Returns: |
| torch.Tensor, shape (leads, time) |
| """ |
| if self.src_fs != self.target_fs: |
| x = self.downsample(x) |
| if not isinstance(x, torch.Tensor): |
| x = torch.from_numpy(x) |
| x = self.scale(x) |
| return x |