| import torch |
| from torch import nn |
|
|
| from src.utils import get_batch |
|
|
|
|
| @torch.no_grad() |
| def estimate_loss(model: nn.Module, eval_iters, block_size, batch_size, device): |
| out = {} |
| model.eval() |
| for split in ["train", "val"]: |
| losses = torch.zeros(eval_iters) |
| for k in range(eval_iters): |
| X, Y = get_batch(split, block_size, batch_size) |
| X, Y = X.to(device), Y.to(device) |
| logits, loss = model(X, Y) |
| losses[k] = loss.item() |
| out[split] = losses.mean() |
| model.train() |
| return out |
|
|
|
|
| def train( |
| model, |
| optimizer, |
| max_iters, |
| eval_interval, |
| eval_iters, |
| block_size, |
| batch_size, |
| device, |
| ): |
| val_loss = None |
| for iter in range(max_iters): |
| if iter % eval_interval == 0: |
| losses = estimate_loss(model, eval_iters, block_size, batch_size, device) |
| print( |
| f"Step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}" |
| ) |
| if val_loss is not None: |
| if losses["val"] < val_loss: |
| torch.save(model, "checkpoints/model.pth") |
| else: |
| val_loss = losses["val"] |
|
|
| xb, yb = get_batch("train", block_size, batch_size) |
| xb, yb = xb.to(device), yb.to(device) |
|
|
| logits, loss = model(xb, yb) |
|
|
| optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
| optimizer.step() |
|
|