| |
| """ |
| example_usage.py |
| |
| Demonstrates how to use the BYOL Mammogram model for feature extraction |
| and classification tasks. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from torchvision import models, transforms |
| from PIL import Image |
| import numpy as np |
| from pathlib import Path |
|
|
| |
| from train_byol_mammo import MammogramBYOL |
|
|
|
|
| def load_byol_model(checkpoint_path: str, device: torch.device): |
| """Load the pre-trained BYOL model for feature extraction.""" |
|
|
| print(f"π₯ Loading BYOL model from: {checkpoint_path}") |
|
|
| |
| resnet = models.resnet50(weights=None) |
| backbone = nn.Sequential(*list(resnet.children())[:-1]) |
|
|
| |
| model = MammogramBYOL( |
| backbone=backbone, |
| input_dim=2048, |
| hidden_dim=4096, |
| proj_dim=256 |
| ).to(device) |
|
|
| |
| checkpoint = torch.load(checkpoint_path, map_location=device) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model.eval() |
|
|
| print(f"β
Model loaded successfully!") |
| print(f" Epoch: {checkpoint.get('epoch', 'Unknown')}") |
| print(f" Final loss: {checkpoint.get('loss', 'Unknown'):.4f}") |
|
|
| return model |
|
|
|
|
| def create_inference_transform(tile_size: int = 512): |
| """Create the preprocessing transform for inference.""" |
| return transforms.Compose([ |
| transforms.Resize((tile_size, tile_size), antialias=True), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| ]) |
|
|
|
|
| def extract_features(model, image_tensor, device): |
| """Extract 2048-dimensional features from mammogram tiles.""" |
| with torch.no_grad(): |
| image_tensor = image_tensor.to(device) |
| features = model.get_features(image_tensor) |
| return features.cpu().numpy() |
|
|
|
|
| def main(): |
| """Demonstrate model usage.""" |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"π₯οΈ Using device: {device}") |
|
|
| |
| model = load_byol_model("mammogram_byol_best.pth", device) |
|
|
| |
| transform = create_inference_transform(tile_size=512) |
|
|
| |
| print("\nπ Example 1: Feature Extraction") |
| print("-" * 40) |
|
|
| |
| dummy_image = Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8)) |
| dummy_image = dummy_image.convert('RGB') |
|
|
| |
| image_tensor = transform(dummy_image).unsqueeze(0) |
|
|
| |
| features = extract_features(model, image_tensor, device) |
|
|
| print(f"β
Input shape: {image_tensor.shape}") |
| print(f"β
Feature shape: {features.shape}") |
| print(f"β
Feature vector (first 10 values): {features[0][:10]}") |
|
|
| |
| print("\nπ Example 2: Batch Feature Extraction") |
| print("-" * 40) |
|
|
| |
| batch_size = 4 |
| dummy_batch = torch.stack([ |
| transform(Image.fromarray(np.random.randint(0, 255, (512, 512), dtype=np.uint8)).convert('RGB')) |
| for _ in range(batch_size) |
| ]) |
|
|
| |
| batch_features = extract_features(model, dummy_batch, device) |
|
|
| print(f"β
Batch input shape: {dummy_batch.shape}") |
| print(f"β
Batch features shape: {batch_features.shape}") |
| print(f"β
Features per image: {batch_features.shape[1]} dimensions") |
|
|
| |
| print("\nπ Example 3: Feature Similarity") |
| print("-" * 40) |
|
|
| |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| similarity = cosine_similarity( |
| batch_features[0:1], |
| batch_features[1:2] |
| )[0][0] |
|
|
| print(f"β
Cosine similarity between image 1 and 2: {similarity:.4f}") |
|
|
| print("\nπ― Next Steps:") |
| print("- Use these 2048D features for downstream classification") |
| print("- Train a classifier using train_classification.py") |
| print("- Fine-tune the entire model for specific tasks") |
| print("- Use for similarity search or clustering") |
|
|
| print(f"\nπ Model Summary:") |
| print(f"- Architecture: ResNet50 + BYOL") |
| print(f"- Input: 512x512 RGB mammogram tiles") |
| print(f"- Output: 2048-dimensional feature vectors") |
| print(f"- Training: Self-supervised on breast tissue tiles") |
| print(f"- Use case: Medical image analysis and classification") |
|
|
|
|
| if __name__ == "__main__": |
| main() |