## What is this?
A deep learning system that uses a modified ResNet-50 architecture to detect intracranial hemorrhages (brain bleeds) in CT scans. The project tackles a critical medical challenge where rapid, accurate detection can significantly impact patient outcomes.
**Why This Matters:** Every minute counts in diagnosing brain bleeds. Delayed treatment can lead to severe neurological damage or death, yet CT scan interpretation traditionally requires significant time and expertise from radiologists. An automated detection system could:
- Accelerate diagnosis in time-critical situations
- Serve as a second opinion tool for radiologists
- Help prioritize urgent cases in hospital workflows
- Improve access to expert-level diagnosis in underserved areas
The system processes dual-window CT scans (brain and bone windows) simultaneously through a custom 6-channel input layer, attempting to learn hemorrhage patterns from a limited, imbalanced dataset.
---
```python
#%%
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from PIL import Image
import os
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import WeightedRandomSampler
import copy
```
```python
#%%
def calculate_metrics(outputs, labels):
predictions = (torch.sigmoid(outputs) > 0.5).float()
tp = ((predictions == 1) & (labels == 1)).sum().float()
fp = ((predictions == 1) & (labels == 0)).sum().float()
fn = ((predictions == 0) & (labels == 1)).sum().float()
precision = tp / (tp + fp + 1e-8) # Add small epsilon to avoid division by zero
recall = tp / (tp + fn + 1e-8)
f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
return precision.item(), recall.item(), f1.item()
#%%
def train_classifier(model, train_loader, criterion, optimizer, device):
model.train()
total_loss = 0
correct = 0
total = 0
total_precision = 0
total_recall = 0
total_f1 = 0
num_batches = 0
for images, labels in tqdm(train_loader, desc='Training'):
images = images.to(device)
labels = labels.to(device).view(-1, 1)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
#gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
optimizer.step()
total_loss += loss.item()
probabilities = torch.sigmoid(outputs)
predictions = (probabilities > 0.5).float()
# basic accuracy
correct += (predictions == labels).sum().item()
total += labels.size(0)
tp = ((predictions == 1) & (labels == 1)).sum().float()
fp = ((predictions == 1) & (labels == 0)).sum().float()
fn = ((predictions == 0) & (labels == 1)).sum().float()
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
total_precision += precision.item()
total_recall += recall.item()
total_f1 += f1.item()
num_batches += 1
accuracy = correct / total
avg_precision = total_precision / num_batches
avg_recall = total_recall / num_batches
avg_f1 = total_f1 / num_batches
return (
total_loss / len(train_loader),
accuracy,
avg_precision,
avg_recall,
avg_f1
)
def test_classifier(model, test_loader, criterion, device):
model.eval()
total_loss = 0
correct = 0
total = 0
total_precision = 0
total_recall = 0
total_f1 = 0
num_batches = 0
with torch.no_grad():
for images, labels in tqdm(test_loader, desc='Testing'):
images = images.to(device)
labels = labels.to(device).view(-1, 1)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item()
probabilities = torch.sigmoid(outputs)
predictions = (probabilities > 0.5).float()
correct += (predictions == labels).sum().item()
total += labels.size(0)
tp = ((predictions == 1) & (labels == 1)).sum().float()
fp = ((predictions == 1) & (labels == 0)).sum().float()
fn = ((predictions == 0) & (labels == 1)).sum().float()
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
total_precision += precision.item()
total_recall += recall.item()
total_f1 += f1.item()
num_batches += 1
accuracy = correct / total
avg_precision = total_precision / num_batches
avg_recall = total_recall / num_batches
avg_f1 = total_f1 / num_batches
return (
total_loss / len(test_loader),
accuracy,
avg_precision,
avg_recall,
avg_f1
)
```
```python
#%%
class BrainCTDataset(Dataset):
def __init__(
self, brain_images_dir, bone_images_dir, labels_df=None, transform=None,
phase='classification', mask_dir=None
):
self.phase = phase
if isinstance(brain_images_dir, list):
self.brain_images = brain_images_dir
self.bone_images = bone_images_dir
else:
self.brain_images = sorted([
os.path.join(brain_images_dir, f)
for f in os.listdir(brain_images_dir)
])
self.bone_images = sorted([
os.path.join(bone_images_dir, f)
for f in os.listdir(bone_images_dir)
])
if len(self.brain_images) != len(self.bone_images):
raise ValueError(
f"Number of brain images ({len(self.brain_images)}) does not match "
f"number of bone images ({len(self.bone_images)})"
)
self.transform = transform
self.labels_df = labels_df
self.mask_dir = mask_dir
# Phase 2
if phase == 'segmentation' and mask_dir:
valid_indices = []
for idx, (brain_path, bone_path) in enumerate(zip(self.brain_images, self.bone_images)):
parts = brain_path.split(os.sep)
patient_num = parts[-3]
slice_num = os.path.splitext(parts[-1])[0]
mask_path = os.path.join(mask_dir, f"patient{patient_num}_slice{slice_num}_mask.png")
if os.path.exists(mask_path):
valid_indices.append(idx)
self.brain_images = [self.brain_images[i] for i in valid_indices]
self.bone_images = [self.bone_images[i] for i in valid_indices]
def __len__(self):
return len(self.brain_images)
def __getitem__(self, idx):
try:
brain_path = self.brain_images[idx]
bone_path = self.bone_images[idx]
brain_img = Image.open(brain_path).convert('RGB')
bone_img = Image.open(bone_path).convert('RGB')
if self.transform:
brain_img = self.transform(brain_img)
bone_img = self.transform(bone_img)
image = torch.cat([brain_img, bone_img], dim=0)
parts = brain_path.split(os.sep)
patient_num = int(parts[-3])
slice_num = int(os.path.splitext(parts[-1])[0])
if self.labels_df is not None:
label_row = self.labels_df[
(self.labels_df['PatientNumber'] == patient_num) &
(self.labels_df['SliceNumber'] == slice_num)
]
if len(label_row) == 0:
label = 0
else:
hemorrhage_types = [
'Intraventricular', 'Intraparenchymal', 'Subarachnoid',
'Epidural', 'Subdural'
]
label = 1 if label_row[hemorrhage_types].values.sum() > 0 else 0
else:
label = 0
if self.phase == 'classification':
return image, torch.tensor(label, dtype=torch.float32)
if self.phase == 'segmentation' and label == 1:
mask_path = os.path.join(
os.path.dirname(brain_path),
f"{slice_num}_HGE_Seg.jpg"
)
if os.path.exists(mask_path):
mask = Image.open(mask_path).convert('L')
if self.transform:
mask = transforms.Resize((224, 224))(mask)
mask = transforms.ToTensor()(mask)
return image, mask, torch.tensor(label, dtype=torch.float32)
return image, torch.zeros((1, 224, 224)), torch.tensor(label, dtype=torch.float32)
except Exception as e:
print(f"Error loading image at index {idx}")
print(f"Brain image path: {brain_path}")
print(f"Bone image path: {bone_path}")
raise e
class EMA:
def __init__(self, model, decay=0.999):
self.model = copy.deepcopy(model)
self.decay = decay
def update(self, model):
with torch.no_grad():
for ema_p, p in zip(self.model.parameters(),
model.parameters()):
ema_p.data.mul_(self.decay).add_(
p.data, alpha=1 - self.decay
)
class HemorrhageClassifier(nn.Module):
def __init__(self, n_channels=6, dropout_rate=0.7):
super().__init__()
self.encoder = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
self.encoder.conv1 = nn.Conv2d(
n_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
)
self.classification_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.BatchNorm1d(2048),
nn.Linear(2048, 512),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(dropout_rate),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(256, 1)
)
def forward(self, x):
x = self.encoder.conv1(x)
x = self.encoder.bn1(x)
x = self.encoder.relu(x)
x = self.encoder.maxpool(x)
x = self.encoder.layer1(x)
x = self.encoder.layer2(x)
x = self.encoder.layer3(x)
x = self.encoder.layer4(x)
return self.classification_head(x)
#%%
def create_patient_splits(patient_dirs, test_size=0.2, random_state=42):
np.random.seed(random_state)
n_patients = len(patient_dirs)
n_test = int(n_patients * test_size)
test_indices = np.random.choice(
n_patients, n_test, replace=False
)
test_patients = [patient_dirs[i] for i in test_indices]
train_patients = [p for i, p in enumerate(patient_dirs) if i not in test_indices]
return train_patients, test_patients
#%%
class EarlyStopping:
def __init__(self, patience=7, min_delta=1e-4):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = None
self.early_stop = False
def __call__(self, val_loss):
if self.best_loss is None:
self.best_loss = val_loss
elif val_loss > self.best_loss - self.min_delta:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_loss = val_loss
self.counter = 0
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
bce_loss = F.binary_cross_entropy_with_logits(
inputs, targets, reduction='none'
)
pt = torch.exp(-bce_loss)
focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss
return focal_loss.mean()
def get_balanced_sampler(dataset):
labels = []
for _, label in dataset:
labels.append(label.item())
class_counts = np.bincount(labels)
weights = np.zeros_like(labels, dtype=np.float32)
for idx, label in enumerate(labels):
weights[idx] = 1.0 / class_counts[int(label)]
sampler = WeightedRandomSampler(
weights=weights,
num_samples=len(labels),
replacement=True
)
return sampler
```
```python
#%%
#==========================
# Training
#==========================
BATCH_SIZE = 16
EPOCHS = 30
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 5e-4
GRAD_CLIP = 1.0
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(
10,
fill=0
),
transforms.RandomAffine(
degrees=0,
translate=(0.05, 0.05),
scale=(0.95, 1.05),
fill=0
),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
dataset_dir = (
'/kaggle/input/computed-tomography-ct-images/computed-tomography-images-for-intracranial-hemorrhage-detection-and-segmentation-1.0.0'
)
patients_dir = os.path.join(dataset_dir, 'Patients_CT')
labels_file = os.path.join(dataset_dir, 'hemorrhage_diagnosis.csv')
labels_df = pd.read_csv(labels_file)
patient_dirs = sorted([
d for d in os.listdir(patients_dir)
if os.path.isdir(os.path.join(patients_dir, d))
])
train_patients, test_patients = create_patient_splits(patient_dirs)
print(f"Split dataset into {len(train_patients)} train and {len(test_patients)} test patients")
brain_images = []
bone_images = []
print("Loading training dataset...")
for patient in tqdm(train_patients):
patient_path = os.path.join(patients_dir, patient)
brain_dir = os.path.join(patient_path, 'brain')
bone_dir = os.path.join(patient_path, 'bone')
if os.path.exists(brain_dir) and os.path.exists(bone_dir):
brain_files = sorted(os.listdir(brain_dir))
bone_files = sorted(os.listdir(bone_dir))
for brain_file, bone_file in zip(brain_files, bone_files):
brain_path = os.path.join(brain_dir, brain_file)
bone_path = os.path.join(bone_dir, bone_file)
if (os.path.exists(brain_path) and os.path.exists(bone_path) and
os.path.splitext(brain_file)[0] == os.path.splitext(bone_file)[0]):
brain_images.append(brain_path)
bone_images.append(bone_path)
print(f"Found {len(brain_images)} paired training images")
if len(brain_images) == 0:
raise ValueError("No valid image pairs found! Check the dataset structure.")
print("\nPhase 1: Classification Model Training")
classification_dataset = BrainCTDataset(
brain_images_dir=brain_images,
bone_images_dir=bone_images,
labels_df=labels_df,
transform=transform,
phase='classification'
)
DEBUG = True
if DEBUG:
subset_size = len(classification_dataset) // 1
subset_size = len(classification_dataset) // 1
all_labels = []
for idx in range(len(classification_dataset)):
_, label = classification_dataset[idx]
all_labels.append(int(label.item()))
subset_indices = np.arange(len(classification_dataset))
train_indices, val_indices = train_test_split(
subset_indices[:subset_size],
test_size=0.2,
stratify=[all_labels[i] for i in range(subset_size)],
random_state=42
)
train_subset = torch.utils.data.Subset(classification_dataset, train_indices)
val_subset = torch.utils.data.Subset(classification_dataset, val_indices)
train_labels = [classification_dataset[i][1].item() for i in train_indices]
val_labels = [classification_dataset[i][1].item() for i in val_indices]
print("\nClass distribution:")
print(f"Training - Negative: {train_labels.count(0)}, Positive: {train_labels.count(1)}")
print(f"Validation - Negative: {val_labels.count(0)}, Positive: {val_labels.count(1)}")
train_sampler = get_balanced_sampler(train_subset)
train_loader = DataLoader(
train_subset,
batch_size=BATCH_SIZE,
sampler=train_sampler
)
val_loader = DataLoader(
val_subset,
batch_size=BATCH_SIZE,
shuffle=False
)
classifier = HemorrhageClassifier().to(DEVICE)
optimizer = optim.AdamW(
classifier.parameters(),
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY
)
ema = EMA(classifier)
criterion = FocalLoss(alpha=0.75)
early_stopping = EarlyStopping(patience=3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='max',
factor=0.5,
patience=3,
verbose=True,
min_lr=1e-6
)
best_val_loss = float('inf')
best_f1 = 0.0
for epoch in range(EPOCHS):
train_loss, train_acc, train_prec, train_rec, train_f1 = train_classifier(
classifier, train_loader, criterion, optimizer, DEVICE
)
val_loss, val_acc, val_prec, val_rec, val_f1 = test_classifier(
classifier, val_loader, criterion, DEVICE
)
print(f'Epoch {epoch+1}/{EPOCHS}:')
print(f' Training - Loss: {train_loss:.4f}, Acc: {train_acc:.2%}, '
f'Prec: {train_prec:.2%}, Rec: {train_rec:.2%}, F1: {train_f1:.2%}')
print(f' Validation - Loss: {val_loss:.4f}, Acc: {val_acc:.2%}, '
f'Prec: {val_prec:.2%}, Rec: {val_rec:.2%}, F1: {val_f1:.2%}')
print(f' Current LR: {optimizer.param_groups[0]["lr"]:.2e}')
scheduler.step(val_f1)
if early_stopping(val_loss):
print("Early stopping triggered")
break
if val_f1 > best_f1:
best_f1 = val_f1
ema.update(classifier)
torch.save(classifier.state_dict(), 'best_classifier.pth')
else:
pass
```
Split dataset into 66 train and 16 test patients
Loading training dataset...
100%|██████████| 66/66 [00:01<00:00, 35.02it/s]
Found 1393 paired training images
Phase 1: Classification Model Training
Class distribution:
Training - Negative: 1093, Positive: 21
Validation - Negative: 274, Positive: 5
/opt/conda/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:60: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.
warnings.warn(
Training: 100%|██████████| 70/70 [00:26<00:00, 2.61it/s]
Testing: 100%|██████████| 18/18 [00:05<00:00, 3.54it/s]
Epoch 1/30:
Training - Loss: 0.1506, Acc: 68.31%, Prec: 67.64%, Rec: 70.28%, F1: 67.40%
Validation - Loss: 0.2676, Acc: 63.80%, Prec: 3.83%, Rec: 22.22%, F1: 6.47%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.59it/s]
Testing: 100%|██████████| 18/18 [00:05<00:00, 3.53it/s]
Epoch 2/30:
Training - Loss: 0.1410, Acc: 71.54%, Prec: 71.15%, Rec: 71.67%, F1: 69.76%
Validation - Loss: 0.0849, Acc: 80.29%, Prec: 3.89%, Rec: 13.89%, F1: 5.93%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.51it/s]
Testing: 100%|██████████| 18/18 [00:05<00:00, 3.43it/s]
Epoch 3/30:
Training - Loss: 0.1107, Acc: 77.65%, Prec: 76.10%, Rec: 81.42%, F1: 77.46%
Validation - Loss: 0.0526, Acc: 86.02%, Prec: 4.63%, Rec: 11.11%, F1: 6.48%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.56it/s]
Testing: 100%|██████████| 18/18 [00:05<00:00, 3.50it/s]
Epoch 4/30:
Training - Loss: 0.0885, Acc: 80.97%, Prec: 80.15%, Rec: 83.37%, F1: 80.54%
Validation - Loss: 0.0286, Acc: 96.77%, Prec: 0.00%, Rec: 0.00%, F1: 0.00%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.52it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.89it/s]
Epoch 5/30:
Training - Loss: 0.0982, Acc: 83.66%, Prec: 82.02%, Rec: 88.91%, F1: 84.10%
Validation - Loss: 0.0575, Acc: 86.02%, Prec: 6.02%, Rec: 13.89%, F1: 8.15%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.55it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.91it/s]
Epoch 6/30:
Training - Loss: 0.0684, Acc: 85.37%, Prec: 81.81%, Rec: 91.34%, F1: 85.39%
Validation - Loss: 0.0806, Acc: 78.49%, Prec: 4.40%, Rec: 13.89%, F1: 6.23%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.54it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.79it/s]
Epoch 7/30:
Training - Loss: 0.0699, Acc: 87.25%, Prec: 86.36%, Rec: 90.10%, F1: 86.86%
Validation - Loss: 0.0946, Acc: 78.49%, Prec: 5.89%, Rec: 16.67%, F1: 8.06%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.56it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.97it/s]
Epoch 8/30:
Training - Loss: 0.0692, Acc: 87.16%, Prec: 84.03%, Rec: 92.05%, F1: 86.97%
Validation - Loss: 0.0634, Acc: 87.81%, Prec: 6.02%, Rec: 13.89%, F1: 8.15%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:28<00:00, 2.48it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.94it/s]
Epoch 9/30:
Training - Loss: 0.0476, Acc: 90.66%, Prec: 90.13%, Rec: 92.87%, F1: 90.89%
Validation - Loss: 0.0511, Acc: 88.89%, Prec: 7.41%, Rec: 16.67%, F1: 10.00%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.54it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.66it/s]
Epoch 10/30:
Training - Loss: 0.0557, Acc: 89.14%, Prec: 87.53%, Rec: 93.51%, F1: 89.38%
Validation - Loss: 0.0406, Acc: 91.40%, Prec: 9.26%, Rec: 13.89%, F1: 9.26%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.53it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.98it/s]
Epoch 11/30:
Training - Loss: 0.0414, Acc: 92.19%, Prec: 89.75%, Rec: 95.47%, F1: 91.88%
Validation - Loss: 0.0459, Acc: 92.11%, Prec: 10.19%, Rec: 13.89%, F1: 10.19%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.55it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.97it/s]
Epoch 12/30:
Training - Loss: 0.0426, Acc: 92.64%, Prec: 90.41%, Rec: 95.60%, F1: 92.22%
Validation - Loss: 0.0603, Acc: 89.25%, Prec: 10.19%, Rec: 13.89%, F1: 10.19%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.55it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.94it/s]
Epoch 13/30:
Training - Loss: 0.0579, Acc: 91.20%, Prec: 89.63%, Rec: 94.82%, F1: 91.20%
Validation - Loss: 0.0366, Acc: 92.83%, Prec: 9.26%, Rec: 16.67%, F1: 11.11%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.57it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.91it/s]
Epoch 14/30:
Training - Loss: 0.0420, Acc: 92.82%, Prec: 91.29%, Rec: 95.15%, F1: 92.11%
Validation - Loss: 0.0512, Acc: 91.04%, Prec: 1.85%, Rec: 5.56%, F1: 2.78%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.55it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 4.00it/s]
Epoch 15/30:
Training - Loss: 0.0375, Acc: 94.08%, Prec: 92.37%, Rec: 96.27%, F1: 93.78%
Validation - Loss: 0.0392, Acc: 97.13%, Prec: 5.56%, Rec: 5.56%, F1: 5.56%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.54it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.96it/s]
Epoch 16/30:
Training - Loss: 0.0333, Acc: 94.97%, Prec: 94.88%, Rec: 95.96%, F1: 94.96%
Validation - Loss: 0.0303, Acc: 94.62%, Prec: 11.11%, Rec: 8.33%, F1: 9.26%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.53it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.95it/s]
Epoch 17/30:
Training - Loss: 0.0310, Acc: 96.14%, Prec: 94.80%, Rec: 97.54%, F1: 95.62%
Validation - Loss: 0.0546, Acc: 94.62%, Prec: 0.00%, Rec: 0.00%, F1: 0.00%
Current LR: 3.00e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.51it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.82it/s]
Epoch 18/30:
Training - Loss: 0.0336, Acc: 94.88%, Prec: 95.94%, Rec: 95.54%, F1: 95.12%
Validation - Loss: 0.0373, Acc: 96.77%, Prec: 5.56%, Rec: 2.78%, F1: 3.70%
Current LR: 1.50e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.55it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.97it/s]
Epoch 19/30:
Training - Loss: 0.0195, Acc: 96.68%, Prec: 95.38%, Rec: 97.62%, F1: 96.09%
Validation - Loss: 0.0336, Acc: 96.06%, Prec: 9.72%, Rec: 13.89%, F1: 10.56%
Current LR: 1.50e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.53it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.94it/s]
Epoch 20/30:
Training - Loss: 0.0215, Acc: 96.68%, Prec: 95.82%, Rec: 97.97%, F1: 96.55%
Validation - Loss: 0.0337, Acc: 94.62%, Prec: 6.94%, Rec: 8.33%, F1: 5.93%
Current LR: 1.50e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.57it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 4.03it/s]
Epoch 21/30:
Training - Loss: 0.0192, Acc: 97.13%, Prec: 96.71%, Rec: 98.01%, F1: 97.09%
Validation - Loss: 0.0373, Acc: 95.34%, Prec: 2.78%, Rec: 5.56%, F1: 3.70%
Current LR: 1.50e-04
Training: 100%|██████████| 70/70 [00:27<00:00, 2.56it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.96it/s]
Epoch 22/30:
Training - Loss: 0.0070, Acc: 98.65%, Prec: 98.03%, Rec: 99.13%, F1: 98.39%
Validation - Loss: 0.0314, Acc: 97.85%, Prec: 11.11%, Rec: 8.33%, F1: 9.26%
Current LR: 7.50e-05
Training: 100%|██████████| 70/70 [00:27<00:00, 2.57it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.89it/s]
Epoch 23/30:
Training - Loss: 0.0148, Acc: 98.03%, Prec: 95.84%, Rec: 99.58%, F1: 97.43%
Validation - Loss: 0.0296, Acc: 98.21%, Prec: 5.56%, Rec: 5.56%, F1: 5.56%
Current LR: 7.50e-05
Training: 100%|██████████| 70/70 [00:28<00:00, 2.46it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.89it/s]
Epoch 24/30:
Training - Loss: 0.0118, Acc: 98.65%, Prec: 98.25%, Rec: 98.82%, F1: 98.38%
Validation - Loss: 0.0386, Acc: 97.13%, Prec: 11.11%, Rec: 8.33%, F1: 9.26%
Current LR: 7.50e-05
Training: 100%|██████████| 70/70 [00:27<00:00, 2.56it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.64it/s]
Epoch 25/30:
Training - Loss: 0.0127, Acc: 97.94%, Prec: 96.27%, Rec: 99.05%, F1: 97.27%
Validation - Loss: 0.0254, Acc: 98.92%, Prec: 11.11%, Rec: 8.33%, F1: 9.26%
Current LR: 7.50e-05
Training: 100%|██████████| 70/70 [00:27<00:00, 2.54it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.95it/s]
Epoch 26/30:
Training - Loss: 0.0136, Acc: 98.56%, Prec: 98.57%, Rec: 98.86%, F1: 98.62%
Validation - Loss: 0.0235, Acc: 97.13%, Prec: 11.11%, Rec: 8.33%, F1: 9.26%
Current LR: 3.75e-05
Training: 100%|██████████| 70/70 [00:27<00:00, 2.52it/s]
Testing: 100%|██████████| 18/18 [00:05<00:00, 3.54it/s]
Epoch 27/30:
Training - Loss: 0.0085, Acc: 98.83%, Prec: 98.48%, Rec: 99.07%, F1: 98.60%
Validation - Loss: 0.0343, Acc: 97.85%, Prec: 5.56%, Rec: 5.56%, F1: 5.56%
Current LR: 3.75e-05
Training: 100%|██████████| 70/70 [00:27<00:00, 2.57it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.79it/s]
Epoch 28/30:
Training - Loss: 0.0105, Acc: 97.94%, Prec: 98.14%, Rec: 98.18%, F1: 97.99%
Validation - Loss: 0.0421, Acc: 97.85%, Prec: 5.56%, Rec: 5.56%, F1: 5.56%
Current LR: 3.75e-05
Training: 100%|██████████| 70/70 [00:27<00:00, 2.56it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.87it/s]
Epoch 29/30:
Training - Loss: 0.0054, Acc: 99.19%, Prec: 99.15%, Rec: 99.39%, F1: 99.20%
Validation - Loss: 0.0301, Acc: 97.49%, Prec: 11.11%, Rec: 8.33%, F1: 9.26%
Current LR: 3.75e-05
Training: 100%|██████████| 70/70 [00:27<00:00, 2.57it/s]
Testing: 100%|██████████| 18/18 [00:04<00:00, 3.92it/s]
Epoch 30/30:
Training - Loss: 0.0056, Acc: 99.37%, Prec: 99.50%, Rec: 99.50%, F1: 99.47%
Validation - Loss: 0.0343, Acc: 97.85%, Prec: 5.56%, Rec: 2.78%, F1: 3.70%
Current LR: 1.87e-05
```python
def visualize_predictions(classifier_path, sample_images, device):
classifier = HemorrhageClassifier().to(device)
classifier.load_state_dict(torch.load(classifier_path, weights_only=True))
classifier.eval()
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
with torch.no_grad():
for sample in sample_images:
brain_img = Image.open(sample['brain_path']).convert('RGB')
bone_img = Image.open(sample['bone_path']).convert('RGB')
brain_tensor = transform(brain_img)
bone_tensor = transform(bone_img)
combined = torch.cat([brain_tensor, bone_tensor], dim=0).unsqueeze(0).to(device)
clf_pred = classifier(combined)
clf_prob = torch.sigmoid(clf_pred).cpu().numpy()[0, 0]
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.imshow(brain_img)
plt.title('Brain Window')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(bone_img)
plt.title('Bone Window')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(brain_img)
plt.title(f'Prediction (prob: {clf_prob:.3f})\nGround Truth: {"Positive" if sample["has_hemorrhage"] else "Negative"}')
if clf_prob > 0.5:
plt.text(10, 30, 'HEMORRHAGE', color='red',
bbox=dict(facecolor='white', alpha=0.8))
plt.axis('off')
plt.suptitle(f'Patient {sample["patient"]}, Slice {sample["slice"]}')
plt.tight_layout()
plt.show()
```
```python
print("\nVisualizing predictions on unseen test images...")
sample_images = []
hemorrhage_count = 0
non_hemorrhage_count = 0
target_samples_per_class = 5
for patient_dir in sorted(test_patients):
if hemorrhage_count >= target_samples_per_class and non_hemorrhage_count >= target_samples_per_class:
break
brain_dir = os.path.join(patients_dir, patient_dir, 'brain')
bone_dir = os.path.join(patients_dir, patient_dir, 'bone')
if os.path.exists(brain_dir) and os.path.exists(bone_dir):
brain_files = sorted(os.listdir(brain_dir))
bone_files = sorted(os.listdir(bone_dir))
for brain_file in brain_files:
if brain_file.endswith('.jpg') and not brain_file.endswith('_HGE_Seg.jpg'):
brain_path = os.path.join(brain_dir, brain_file)
bone_path = os.path.join(bone_dir, brain_file)
if os.path.exists(bone_path):
patient_num = int(patient_dir)
slice_num = int(os.path.splitext(brain_file)[0])
label_row = labels_df[
(labels_df['PatientNumber'] == patient_num) &
(labels_df['SliceNumber'] == slice_num)
]
has_hemorrhage = len(label_row) > 0 and any(
label_row[['Intraventricular', 'Intraparenchymal',
'Subarachnoid', 'Epidural', 'Subdural']].values[0]
)
if has_hemorrhage and hemorrhage_count < target_samples_per_class:
sample_images.append({
'brain_path': brain_path,
'bone_path': bone_path,
'patient': patient_num,
'slice': slice_num,
'has_hemorrhage': has_hemorrhage
})
hemorrhage_count += 1
break
elif not has_hemorrhage and non_hemorrhage_count < target_samples_per_class:
sample_images.append({
'brain_path': brain_path,
'bone_path': bone_path,
'patient': patient_num,
'slice': slice_num,
'has_hemorrhage': has_hemorrhage
})
non_hemorrhage_count += 1
break
print(f"Collected {hemorrhage_count} hemorrhage and {non_hemorrhage_count} non-hemorrhage samples")
visualize_predictions(
'best_classifier.pth',
sample_images,
DEVICE
)
```
Visualizing predictions on unseen test images...
Collected 5 hemorrhage and 5 non-hemorrhage samples










```python
```