namer / tests /test_models.py
Edwin Jose Palathinkal
Initial commit
2730fd2
Raw
History Blame Contribute Delete
2.89 kB
"""Tests for model classes."""
import pytest
import torch
from namer.models import NamerTransformer, PositionalEncoding
from namer.utils import VOCABULARY
class TestPositionalEncoding:
"""Tests for PositionalEncoding module."""
def test_shape(self) -> None:
pe = PositionalEncoding(d_model=128)
x = torch.randn(2, 10, 128) # batch=2, seq=10, dim=128
out = pe(x)
assert out.shape == (2, 10, 128)
def test_adds_position(self) -> None:
pe = PositionalEncoding(d_model=64)
x = torch.zeros(1, 5, 64)
out = pe(x)
# Output should be non-zero due to positional encoding
assert not torch.allclose(out, x)
class TestNamerTransformer:
"""Tests for NamerTransformer model."""
@pytest.fixture
def model(self) -> NamerTransformer:
return NamerTransformer(
vocab_size=len(VOCABULARY),
max_output_len=20,
d_model=64,
nhead=4,
num_encoder_layers=2,
dim_feedforward=128,
dropout=0.0,
)
def test_forward_shape(self, model: NamerTransformer) -> None:
batch_size = 4
seq_len = 10
digits = torch.randint(0, 10, (batch_size, seq_len))
logits = model(digits)
assert logits.shape == (batch_size, model.max_output_len, model.vocab_size)
def test_forward_with_padding(self, model: NamerTransformer) -> None:
batch_size = 2
seq_len = 10
digits = torch.full((batch_size, seq_len), 10) # All padding
digits[:, :5] = torch.randint(0, 10, (batch_size, 5))
logits = model(digits)
assert logits.shape == (batch_size, model.max_output_len, model.vocab_size)
def test_forward_with_negative_padding(self, model: NamerTransformer) -> None:
batch_size = 2
seq_len = 10
digits = torch.full((batch_size, seq_len), -1) # -1 padding
digits[:, :5] = torch.randint(0, 10, (batch_size, 5))
logits = model(digits)
assert logits.shape == (batch_size, model.max_output_len, model.vocab_size)
def test_output_is_logits(self, model: NamerTransformer) -> None:
digits = torch.randint(0, 10, (1, 5))
logits = model(digits)
# Logits should not be probabilities (no softmax applied)
assert not torch.all((logits >= 0) & (logits <= 1))
def test_gradient_flow(self, model: NamerTransformer) -> None:
digits = torch.randint(0, 10, (2, 5))
target = torch.randint(0, len(VOCABULARY), (2, model.max_output_len))
logits = model(digits)
loss = torch.nn.functional.cross_entropy(
logits.view(-1, model.vocab_size),
target.view(-1)
)
loss.backward()
# Check that gradients exist
for param in model.parameters():
assert param.grad is not None