| |
| |
|
|
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| import os |
| import zipfile |
| import urllib.request |
| import numpy as np |
|
|
| class Enwik8Dataset(Dataset): |
| """ |
| Dataset for enwik8 (Hutter Prize). |
| Downloads and loads the first 100MB of Wikipedia XML dump. |
| """ |
| URL = "http://mattmahoney.net/dc/enwik8.zip" |
| FILE_NAME = "enwik8" |
| |
| def __init__(self, data_dir: str, seq_len: int = 1024, split: str = 'train'): |
| self.seq_len = seq_len |
| self.data_dir = data_dir |
| self.file_path = os.path.join(data_dir, self.FILE_NAME) |
| |
| if not os.path.exists(self.file_path): |
| self._download_and_extract() |
| |
| |
| with open(self.file_path, 'rb') as f: |
| data = np.frombuffer(f.read(), dtype=np.uint8) |
| |
| |
| n = len(data) |
| tr_split = int(n * 0.9) |
| val_split = int(n * 0.95) |
| |
| if split == 'train': |
| self.data = data[:tr_split] |
| elif split == 'val': |
| self.data = data[tr_split:val_split] |
| else: |
| self.data = data[val_split:] |
| |
| self.data = torch.from_numpy(self.data.copy()).long() |
| |
| def _download_and_extract(self): |
| print(f"Downloading {self.URL}...") |
| zip_path = os.path.join(self.data_dir, "enwik8.zip") |
| urllib.request.urlretrieve(self.URL, zip_path) |
| |
| print("Extracting...") |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| zip_ref.extractall(self.data_dir) |
| |
| def __len__(self): |
| |
| return len(self.data) - self.seq_len - 1 |
|
|
| def __getitem__(self, idx): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| chunk = self.data[idx : idx + self.seq_len] |
| return chunk, chunk |
|
|
| def get_enwik8_dataloader(data_dir: str, batch_size: int = 32, seq_len: int = 1024, split: str = 'train'): |
| dataset = Enwik8Dataset(data_dir, seq_len, split) |
| return DataLoader(dataset, batch_size=batch_size, shuffle=(split=='train'), num_workers=0) |
|
|