TrainerFlow

A trainerflow encapsulates a predefined pipeline for training and evaluating models on datasets for particular applications. It consists of a distinctive training procedure using loss computation and samplers that draw samples involved in loss calculation.

A trainerflow can be divided into Initialization, Train, Evaluate and Test, and Predict and Load Label Information parts.

  • Initialization

    During initialization, it sets up crucial components such as the device, model, logger, loss function, tensorboard writer, optimizer, and scheduler.

  • Train, Evaluate and Test

    During training, the class iterates through epochs, logging key metrics on Tensorboard and updating the learning rate scheduler based on validation results.

    For validation, the evaluate method assesses the model’s performance on the validation loader, providing detailed feedback on loss and specified evaluation metrics.

    Testing functionality is seamlessly integrated into the training process, allowing for post-training evaluation on a provided test loader.

  • Predict and Load Label Information

    Predict and load label information

An example can be found in the following code:

class AbstractTrainer(object):
    def __init__(self, config, model):
        self.config = config
        self.device = config.device
        self.model = model.to(config.device)
        self.logger = Logger(config, name='trainer')
        self.loss = get_loss(config=config)
        self.tensorboard_writer = SummaryWriter(config["tensorboard_dir"])
        self._set_optimizer()
        self._set_scheduler()

    def pred_fn(self, batched_data):
        raise NotImplementedError

    def label_fn(self, batched_data):
        raise NotImplementedError

    def _set_optimizer(self):
        self.optimizer = get_optimizer(config=self.config)(
            params=self.model.parameters(),
            lr=float(self.config.lr),
            weight_decay=float(self.config.weight_decay)
        )

    def _set_scheduler(self):
        if self.config["scheduler_settings"] == False:
            self.scheduler = None
        else:
            if self.config["scheduler_settings"]["scheduler"] == "ReduceLROnPlateau":
                self.scheduler = ReduceLROnPlateau(
                    self.optimizer,
                    mode=self.config["scheduler_settings"]["mode"],
                    factor=self.config["scheduler_settings"]["factor"],
                    patience=self.config["scheduler_settings"]["patience"],
                )
            elif self.config["scheduler_settings"]["scheduler"] == "StepLR":
                self.scheduler = StepLR(
                    self.optimizer,
                    step_size=self.config["scheduler_settings"]["step_size"],
                    gamma=self.config["scheduler_settings"]["gamma"]
                )
            else:
                raise NotImplementedError("scheduler not supportet")

    @torch.no_grad()
    def evaluate(self,
                 eval_loader,
                 **kwargs):
        """
            evaluate model on eval_loader
            Args:
                eval_loader: dataloader for evaluation
                kwargs:
                    pred_fn: function to get prediction from batched_data
                    label_fn: function to get label from batched_data
                    eval_fn: function to evaluate prediction and label
        """
        self.model.eval()
        result_dict = None
        iter_data = (
            tqdm(
                eval_loader,
                total=len(eval_loader),
                ncols=100,
                desc=f"eval "
            )
        )
        if "loss_fn" in kwargs:
            loss_fn = kwargs["loss_fn"]
        else:
            loss_fn = get_loss(config=self.config)
        if "pred_fn" in kwargs:
            pred_fn = kwargs["pred_fn"]
        else:
            pred_fn = self.pred_fn
        if "label_fn" in kwargs:
            label_fn = kwargs["label_fn"]
        else:
            label_fn = self.label_fn

        sum_loss = 0
        for batch_idx, batched_data in enumerate(iter_data):
            batched_pred = pred_fn(batched_data)
            batched_label = label_fn(batched_data)
            loss = loss_fn(batched_pred, batched_label)
            sum_loss += loss.item() * batched_pred.shape[0]

            if "eval_fn" in kwargs:
                cur_result_dict = kwargs["eval_fn"](pred=batched_pred, label=batched_label)
            else:
                cur_result_dict = eval_all_metric(config=self.config, pred=batched_pred, label=batched_label)

            result_dict = eval_reduce([result_dict, cur_result_dict])
            iter_data.set_postfix(**result_dict)

        result_dict["loss"] = sum_loss / result_dict["data_size"]

        return result_dict

    def scheduler_step(self, valid_result):
        if self.scheduler is None:
            return
        if self.config["scheduler_settings"]["scheduler"] == "ReduceLROnPlateau":
            self.scheduler.step(valid_result["loss"])
        else:
            raise NotImplementedError("todo other scheduler")

    def train(self,
              train_loader,
              valid_loader=None,
              test_loader=None,
              load_best_model=True,
              **kwargs):
        """
            train model on train_loader
            Args:
                train_loader: dataloader for training
                valid_loader: dataloader for validation
                test_loader: dataloader for testing
                kwargs:
                    loss_fn: function to calculate loss
                    pred_fn: function to get prediction from batched_data
                    label_fn: function to get label from batched_data
                    eval_fn: function to evaluate prediction and label
                    load_best_model: whether to load best model

        """
        valid_metric = self.config["valid_metric"]
        best_result = None
        best_epoch = 0
        best_state_dict = None
        early_stop = self.config["early_stop"]

        for epoch_idx in range(self.config.epochs):
            train_result = self._train_epoch(train_loader=train_loader, **kwargs)

            valid_result = None
            # evaluate model every eval_step epochs
            if (epoch_idx + 1) % self.config.eval_step == 0 and valid_loader is not None:
                valid_result = self.evaluate(valid_loader, **kwargs)

            # tensorboard log
            self.tensorboard_writer.add_scalar("train/loss", train_result["loss"], epoch_idx)
            if valid_metric != "loss":
                self.tensorboard_writer.add_scalar(f"train/{valid_metric}", train_result[valid_metric], epoch_idx)

            # output result
            self.logger.info(f"epoch [{epoch_idx}/{self.config.epochs}]")
            self.logger.train_epoch_format(epoch_idx, train_result)
            if valid_result is not None:
                self.logger.valid_epoch_format(epoch_idx, valid_result)
                if valid_metric != "loss":
                    self.tensorboard_writer.add_scalar(f"eval/{valid_metric}", valid_result[valid_metric], epoch_idx)

            # scheduler step
            self.scheduler_step(valid_result)

            if load_best_model == True:
                if valid_metric == "loss":
                    current_valid_result = train_result["loss"]
                else:
                    current_valid_result = valid_result[valid_metric]

                if eval_compare(valid_metric, current_valid_result, best_result):
                    best_result = current_valid_result
                    best_epoch = epoch_idx
                    best_state_dict = deepcopy(self.model.state_dict())
                elif early_stop is not False and epoch_idx - best_epoch >= early_stop:
                    self.logger.info(f"early stop at epoch {epoch_idx}")
                    break

        if load_best_model == True:
            self.model.load_state_dict(best_state_dict)
            self.logger.info(f"load best model at epoch {best_epoch}")
        else:
            self.logger.info(f"load last model at epoch {epoch_idx}")

        # eval test
        if test_loader is not None:
            test_result = self.evaluate(test_loader, **kwargs)
            self.logger.info(f"test_result : {test_result}")

        # save model
        if self.config.save_model != False:
            self.logger.info(f"save model at epoch {best_epoch} to {self.config.save_model}")
            if not os.path.exists(os.path.dirname(self.config.save_model)):
                os.makedirs(os.path.dirname(self.config.save_model))
            torch.save(self.model.state_dict(), self.config.save_model)

        # close writer
        self.tensorboard_writer.close()

    def _train_epoch(self,
                     train_loader,
                     **kwargs):
        # get loss function
        if "loss_fn" in kwargs:
            loss_fn = kwargs["loss_fn"]
        else:
            loss_fn = get_loss(config=self.config)
        if "pred_fn" in kwargs:
            pred_fn = kwargs["pred_fn"]
        else:
            pred_fn = self.pred_fn
        if "label_fn" in kwargs:
            label_fn = kwargs["label_fn"]
        else:
            label_fn = self.label_fn

        sum_loss = 0
        iter_data = (
            tqdm(
                train_loader,
                total=len(train_loader),
                ncols=100,
                desc = f"train "
            )
        )
        result_dict = None
        for batch_idx, batched_data in enumerate(iter_data):
            self.model.train()
            self.optimizer.zero_grad()
            if pred_fn is not None:
                batched_pred = pred_fn(batched_data)
            else:
                batched_pred = self.pred_fn(batched_data)
            if label_fn is not None:
                batched_label = label_fn(batched_data)
            else:
                batched_label = self.label_fn(batched_data)
            loss = loss_fn(batched_pred, batched_label)

            loss.backward()
            self.optimizer.step()
            sum_loss += loss.item() * batched_pred.shape[0]

            # self.model.eval()
            if "eval_fn" in kwargs:
                cur_result_dict = kwargs["eval_fn"](pred=batched_pred, label=batched_label)
            else:
                cur_result_dict = eval_all_metric(config=self.config, pred=batched_pred, label=batched_label)

            result_dict = eval_reduce([result_dict, cur_result_dict])
            iter_data.set_postfix(loss=loss.item(), **result_dict)
        result_dict["loss"] = sum_loss / result_dict["data_size"]

        return result_dict


class TaskTrainer(AbstractTrainer):
    def __init__(self, config, model):
        super(TaskTrainer, self).__init__(config=config, model=model)

    def pred_fn(self, batched_data):
        batched_pred = self.model(batched_data)
        return batched_pred

    def label_fn(self, batched_data):
        batched_label = batched_data["label"].to(self.device).float().squeeze(0)
        if len(batched_label.shape) == 0:
            batched_label = batched_label.unsqueeze(0)
        return batched_label

More details can refer to Customize Trainer.