| import torch |
| from torch import nn |
| from torch.optim import Optimizer |
| from torch.utils.data import DataLoader |
| from torch.amp import GradScaler, autocast |
| import numpy as np |
| from tqdm import tqdm |
| from typing import Dict, Tuple, Union |
| from copy import deepcopy |
|
|
| from utils import barrier, reduce_mean, update_loss_info |
| from evaluate import evaluate |
|
|
|
|
| def train( |
| model: nn.Module, |
| data_loader: DataLoader, |
| loss_fn: nn.Module, |
| optimizer: Optimizer, |
| grad_scaler: Union[GradScaler, None], |
| device: torch.device = torch.device("cuda"), |
| rank: int = 0, |
| nprocs: int = 1, |
| **kwargs, |
| ) -> Tuple[nn.Module, Optimizer, GradScaler, Dict[str, float]]: |
| info = None |
| data_iter = tqdm(data_loader) if rank == 0 else data_loader |
| ddp = nprocs > 1 |
|
|
| if "eval_data_loader" in kwargs: |
| assert "eval_freq" in kwargs and 0 < kwargs["eval_freq"] < 1, f"eval_freq should be a float between 0 and 1, but got {kwargs['eval_freq']}" |
| assert "sliding_window" in kwargs, "sliding_window should be provided in kwargs" |
| assert "max_input_size" in kwargs, "max_input_size should be provided in kwargs" |
| assert "window_size" in kwargs, "window_size should be provided in kwargs" |
| assert "stride" in kwargs, "stride should be provided in kwargs" |
| assert "max_num_windows" in kwargs, "max_num_windows should be provided in kwargs" |
|
|
| eval_within_epoch = True |
| eval_data_loader = kwargs["eval_data_loader"] |
| eval_freq = int(kwargs["eval_freq"] * len(data_loader)) |
| sliding_window = kwargs["sliding_window"] |
| max_input_size = kwargs["max_input_size"] |
| window_size = kwargs["window_size"] |
| stride = kwargs["stride"] |
| max_num_windows = kwargs["max_num_windows"] |
|
|
| best_scores = {} |
| best_weights = {} |
|
|
| else: |
| eval_within_epoch = False |
| best_scores = None |
| best_weights = None |
| |
| for batch_idx, (image, gt_points, gt_den_map) in enumerate(data_iter): |
| image = image.to(device) |
| gt_points = [p.to(device) for p in gt_points] |
| gt_den_map = gt_den_map.to(device) |
| model.train() |
| with torch.set_grad_enabled(True): |
| with autocast(device_type="cuda", enabled=grad_scaler is not None and grad_scaler.is_enabled()): |
| if (model.module.zero_inflated if ddp else model.zero_inflated): |
| pred_logit_pi_map, pred_logit_map, pred_lambda_map, pred_den_map = model(image) |
| total_loss, total_loss_info = loss_fn( |
| pred_logit_pi_map=pred_logit_pi_map, |
| pred_logit_map=pred_logit_map, |
| pred_lambda_map=pred_lambda_map, |
| pred_den_map=pred_den_map, |
| gt_den_map=gt_den_map, |
| gt_points=gt_points, |
| ) |
| else: |
| pred_logit_map, pred_den_map = model(image) |
| total_loss, total_loss_info = loss_fn( |
| pred_logit_map=pred_logit_map, |
| pred_den_map=pred_den_map, |
| gt_den_map=gt_den_map, |
| gt_points=gt_points, |
| ) |
|
|
| optimizer.zero_grad() |
| if grad_scaler is not None: |
| grad_scaler.scale(total_loss).backward() |
| grad_scaler.step(optimizer) |
| grad_scaler.update() |
| else: |
| total_loss.backward() |
| optimizer.step() |
|
|
| total_loss_info = {k: reduce_mean(v.detach(), nprocs).item() if ddp else v.detach().item() for k, v in total_loss_info.items()} |
| info = update_loss_info(info, total_loss_info) |
| barrier(ddp) |
| |
| if eval_within_epoch and ((batch_idx + 1) % eval_freq == 0 or batch_idx == len(data_loader) - 1): |
| batch_scores = evaluate( |
| model=model, |
| data_loader=eval_data_loader, |
| sliding_window=sliding_window, |
| max_input_size=max_input_size, |
| window_size=window_size, |
| stride=stride, |
| max_num_windows=max_num_windows, |
| device=device, |
| amp=grad_scaler is not None and grad_scaler.is_enabled(), |
| local_rank=rank, |
| nprocs=nprocs, |
| progress_bar=False, |
| ) |
| for k, v in batch_scores.items(): |
| if k not in best_scores: |
| best_scores[k] = v |
| best_weights[k] = deepcopy(model.module.state_dict() if ddp else model.state_dict()) |
| elif v < best_scores[k]: |
| best_scores[k] = v |
| best_weights[k] = deepcopy(model.module.state_dict() if ddp else model.state_dict()) |
|
|
| barrier(ddp) |
|
|
| torch.cuda.empty_cache() |
| return model, optimizer, grad_scaler, {k: np.mean(v) for k, v in info.items()}, best_scores, best_weights |
|
|