Customize Trainer¶
First Step: Initialization¶
The TaskTrainer class inherits from AbstractTrainer and is responsible for training, validation, and testing a machine
learning model. During initialization:
The configuration and model are passed as parameters.
The device, model, logger, loss function, tensorboard writer, optimizer, and scheduler are initialized.
class AbstractTrainer(object):
def __init__(self, config, model):
self.config = config
self.device = config.device
self.model = model.to(config.device)
self.logger = Logger(config, name='trainer')
self.loss = get_loss(config=config)
self.tensorboard_writer = SummaryWriter(config["tensorboard_dir"])
self._set_optimizer()
self._set_scheduler()
def _set_optimizer(self):
self.optimizer = get_optimizer(config=self.config)(
params=self.model.parameters(),
lr=float(self.config.lr),
weight_decay=float(self.config.weight_decay)
)
def _set_scheduler(self):
if self.config["scheduler_settings"] == False:
self.scheduler = None
else:
if self.config["scheduler_settings"]["scheduler"] == "ReduceLROnPlateau":
self.scheduler = ReduceLROnPlateau(
self.optimizer,
mode=self.config["scheduler_settings"]["mode"],
factor=self.config["scheduler_settings"]["factor"],
patience=self.config["scheduler_settings"]["patience"],
)
elif self.config["scheduler_settings"]["scheduler"] == "StepLR":
self.scheduler = StepLR(
self.optimizer,
step_size=self.config["scheduler_settings"]["step_size"],
gamma=self.config["scheduler_settings"]["gamma"]
)
else:
raise NotImplementedError("scheduler not supportet")
def scheduler_step(self, valid_result):
if self.scheduler is None:
return
if self.config["scheduler_settings"]["scheduler"] == "ReduceLROnPlateau":
self.scheduler.step(valid_result["loss"])
else:
raise NotImplementedError("todo other scheduler")
Second Step: Train, Evaluate and Test¶
The training method orchestrates the model’s training over epochs, with optional periodic validation and tensorboard logging, showcasing flexibility and customization through various configurable parameters.
For validation, the evaluate method assesses the model on the validation loader, providing detailed metrics and visual feedback.
Testing is seamlessly integrated into the training process, allowing for post-training evaluation on a test loader if provided.
def evaluate(self, eval_loader, **kwargs):
"""
evaluate model on eval_loader
Args:
eval_loader: dataloader for evaluation
kwargs:
pred_fn: function to get prediction from batched_data
label_fn: function to get label from batched_data
eval_fn: function to evaluate prediction and label
"""
self.model.eval()
result_dict = None
iter_data = (
tqdm(
eval_loader,
total=len(eval_loader),
ncols=100,
desc=f"eval "
)
)
if "loss_fn" in kwargs:
loss_fn = kwargs["loss_fn"]
else:
loss_fn = get_loss(config=self.config)
if "pred_fn" in kwargs:
pred_fn = kwargs["pred_fn"]
else:
pred_fn = self.pred_fn
if "label_fn" in kwargs:
label_fn = kwargs["label_fn"]
else:
label_fn = self.label_fn
sum_loss = 0
for batch_idx, batched_data in enumerate(iter_data):
batched_pred = pred_fn(batched_data)
batched_label = label_fn(batched_data)
loss = loss_fn(batched_pred, batched_label)
sum_loss += loss.item() * batched_pred.shape[0]
if "eval_fn" in kwargs:
cur_result_dict = kwargs["eval_fn"](pred=batched_pred, label=batched_label)
else:
cur_result_dict = eval_all_metric(config=self.config, pred=batched_pred, label=batched_label)
result_dict = eval_reduce([result_dict, cur_result_dict])
iter_data.set_postfix(**result_dict)
result_dict["loss"] = sum_loss / result_dict["data_size"]
return result_dict
def train(self, train_loader, valid_loader=None, test_loader=None, load_best_model=True, **kwargs):
"""
train model on train_loader
Args:
train_loader: dataloader for training
valid_loader: dataloader for validation
test_loader: dataloader for testing
kwargs:
loss_fn: function to calculate loss
pred_fn: function to get prediction from batched_data
label_fn: function to get label from batched_data
eval_fn: function to evaluate prediction and label
load_best_model: whether to load best model
"""
valid_metric = self.config["valid_metric"]
best_result = None
best_epoch = 0
best_state_dict = None
early_stop = self.config["early_stop"]
for epoch_idx in range(self.config.epochs):
train_result = self._train_epoch(train_loader=train_loader, **kwargs)
valid_result = None
# evaluate model every eval_step epochs
if (epoch_idx + 1) % self.config.eval_step == 0 and valid_loader is not None:
valid_result = self.evaluate(valid_loader, **kwargs)
# tensorboard log
self.tensorboard_writer.add_scalar("train/loss", train_result["loss"], epoch_idx)
if valid_metric != "loss":
self.tensorboard_writer.add_scalar(f"train/{valid_metric}", train_result[valid_metric], epoch_idx)
# output result
self.logger.info(f"epoch [{epoch_idx}/{self.config.epochs}]")
self.logger.train_epoch_format(epoch_idx, train_result)
if valid_result is not None:
self.logger.valid_epoch_format(epoch_idx, valid_result)
if valid_metric != "loss":
self.tensorboard_writer.add_scalar(f"eval/{valid_metric}", valid_result[valid_metric], epoch_idx)
# scheduler step
self.scheduler_step(valid_result)
if load_best_model == True:
if valid_metric == "loss":
current_valid_result = train_result["loss"]
else:
current_valid_result = valid_result[valid_metric]
if eval_compare(valid_metric, current_valid_result, best_result):
best_result = current_valid_result
best_epoch = epoch_idx
best_state_dict = deepcopy(self.model.state_dict())
elif early_stop is not False and epoch_idx - best_epoch >= early_stop:
self.logger.info(f"early stop at epoch {epoch_idx}")
break
if load_best_model == True:
self.model.load_state_dict(best_state_dict)
self.logger.info(f"load best model at epoch {best_epoch}")
else:
self.logger.info(f"load last model at epoch {epoch_idx}")
# eval test
if test_loader is not None:
test_result = self.evaluate(test_loader, **kwargs)
self.logger.info(f"test_result : {test_result}")
# save model
if self.config.save_model != False:
self.logger.info(f"save model at epoch {best_epoch} to {self.config.save_model}")
if not os.path.exists(os.path.dirname(self.config.save_model)):
os.makedirs(os.path.dirname(self.config.save_model))
torch.save(self.model.state_dict(), self.config.save_model)
# close writer
self.tensorboard_writer.close()
def _train_epoch(self, train_loader, **kwargs):
# get loss function
if "loss_fn" in kwargs:
loss_fn = kwargs["loss_fn"]
else:
loss_fn = get_loss(config=self.config)
if "pred_fn" in kwargs:
pred_fn = kwargs["pred_fn"]
else:
pred_fn = self.pred_fn
if "label_fn" in kwargs:
label_fn = kwargs["label_fn"]
else:
label_fn = self.label_fn
sum_loss = 0
iter_data = (
tqdm(
train_loader,
total=len(train_loader),
ncols=100,
desc = f"train "
)
)
result_dict = None
for batch_idx, batched_data in enumerate(iter_data):
self.model.train()
self.optimizer.zero_grad()
if pred_fn is not None:
batched_pred = pred_fn(batched_data)
else:
batched_pred = self.pred_fn(batched_data)
if label_fn is not None:
batched_label = label_fn(batched_data)
else:
batched_label = self.label_fn(batched_data)
loss = loss_fn(batched_pred, batched_label)
loss.backward()
self.optimizer.step()
sum_loss += loss.item() * batched_pred.shape[0]
# self.model.eval()
if "eval_fn" in kwargs:
cur_result_dict = kwargs["eval_fn"](pred=batched_pred, label=batched_label)
else:
cur_result_dict = eval_all_metric(config=self.config, pred=batched_pred, label=batched_label)
result_dict = eval_reduce([result_dict, cur_result_dict])
iter_data.set_postfix(loss=loss.item(), **result_dict)
result_dict["loss"] = sum_loss / result_dict["data_size"]
return result_dict
Complete Code
class AbstractTrainer(object):
def __init__(self, config, model):
self.config = config
self.device = config.device
self.model = model.to(config.device)
self.logger = Logger(config, name='trainer')
self.loss = get_loss(config=config)
self.tensorboard_writer = SummaryWriter(config["tensorboard_dir"])
self._set_optimizer()
self._set_scheduler()
def pred_fn(self, batched_data):
raise NotImplementedError
def label_fn(self, batched_data):
raise NotImplementedError
def _set_optimizer(self):
self.optimizer = get_optimizer(config=self.config)(
params=self.model.parameters(),
lr=float(self.config.lr),
weight_decay=float(self.config.weight_decay)
)
def _set_scheduler(self):
if self.config["scheduler_settings"] == False:
self.scheduler = None
else:
if self.config["scheduler_settings"]["scheduler"] == "ReduceLROnPlateau":
self.scheduler = ReduceLROnPlateau(
self.optimizer,
mode=self.config["scheduler_settings"]["mode"],
factor=self.config["scheduler_settings"]["factor"],
patience=self.config["scheduler_settings"]["patience"],
)
elif self.config["scheduler_settings"]["scheduler"] == "StepLR":
self.scheduler = StepLR(
self.optimizer,
step_size=self.config["scheduler_settings"]["step_size"],
gamma=self.config["scheduler_settings"]["gamma"]
)
else:
raise NotImplementedError("scheduler not supportet")
@torch.no_grad()
def evaluate(self,
eval_loader,
**kwargs):
"""
evaluate model on eval_loader
Args:
eval_loader: dataloader for evaluation
kwargs:
pred_fn: function to get prediction from batched_data
label_fn: function to get label from batched_data
eval_fn: function to evaluate prediction and label
"""
self.model.eval()
result_dict = None
iter_data = (
tqdm(
eval_loader,
total=len(eval_loader),
ncols=100,
desc=f"eval "
)
)
if "loss_fn" in kwargs:
loss_fn = kwargs["loss_fn"]
else:
loss_fn = get_loss(config=self.config)
if "pred_fn" in kwargs:
pred_fn = kwargs["pred_fn"]
else:
pred_fn = self.pred_fn
if "label_fn" in kwargs:
label_fn = kwargs["label_fn"]
else:
label_fn = self.label_fn
sum_loss = 0
for batch_idx, batched_data in enumerate(iter_data):
batched_pred = pred_fn(batched_data)
batched_label = label_fn(batched_data)
loss = loss_fn(batched_pred, batched_label)
sum_loss += loss.item() * batched_pred.shape[0]
if "eval_fn" in kwargs:
cur_result_dict = kwargs["eval_fn"](pred=batched_pred, label=batched_label)
else:
cur_result_dict = eval_all_metric(config=self.config, pred=batched_pred, label=batched_label)
result_dict = eval_reduce([result_dict, cur_result_dict])
iter_data.set_postfix(**result_dict)
result_dict["loss"] = sum_loss / result_dict["data_size"]
return result_dict
def scheduler_step(self, valid_result):
if self.scheduler is None:
return
if self.config["scheduler_settings"]["scheduler"] == "ReduceLROnPlateau":
self.scheduler.step(valid_result["loss"])
else:
raise NotImplementedError("todo other scheduler")
def train(self,
train_loader,
valid_loader=None,
test_loader=None,
load_best_model=True,
**kwargs):
"""
train model on train_loader
Args:
train_loader: dataloader for training
valid_loader: dataloader for validation
test_loader: dataloader for testing
kwargs:
loss_fn: function to calculate loss
pred_fn: function to get prediction from batched_data
label_fn: function to get label from batched_data
eval_fn: function to evaluate prediction and label
load_best_model: whether to load best model
"""
valid_metric = self.config["valid_metric"]
best_result = None
best_epoch = 0
best_state_dict = None
early_stop = self.config["early_stop"]
for epoch_idx in range(self.config.epochs):
train_result = self._train_epoch(train_loader=train_loader, **kwargs)
valid_result = None
# evaluate model every eval_step epochs
if (epoch_idx + 1) % self.config.eval_step == 0 and valid_loader is not None:
valid_result = self.evaluate(valid_loader, **kwargs)
# tensorboard log
self.tensorboard_writer.add_scalar("train/loss", train_result["loss"], epoch_idx)
if valid_metric != "loss":
self.tensorboard_writer.add_scalar(f"train/{valid_metric}", train_result[valid_metric], epoch_idx)
# output result
self.logger.info(f"epoch [{epoch_idx}/{self.config.epochs}]")
self.logger.train_epoch_format(epoch_idx, train_result)
if valid_result is not None:
self.logger.valid_epoch_format(epoch_idx, valid_result)
if valid_metric != "loss":
self.tensorboard_writer.add_scalar(f"eval/{valid_metric}", valid_result[valid_metric], epoch_idx)
# scheduler step
self.scheduler_step(valid_result)
if load_best_model == True:
if valid_metric == "loss":
current_valid_result = train_result["loss"]
else:
current_valid_result = valid_result[valid_metric]
if eval_compare(valid_metric, current_valid_result, best_result):
best_result = current_valid_result
best_epoch = epoch_idx
best_state_dict = deepcopy(self.model.state_dict())
elif early_stop is not False and epoch_idx - best_epoch >= early_stop:
self.logger.info(f"early stop at epoch {epoch_idx}")
break
if load_best_model == True:
self.model.load_state_dict(best_state_dict)
self.logger.info(f"load best model at epoch {best_epoch}")
else:
self.logger.info(f"load last model at epoch {epoch_idx}")
# eval test
if test_loader is not None:
test_result = self.evaluate(test_loader, **kwargs)
self.logger.info(f"test_result : {test_result}")
# save model
if self.config.save_model != False:
self.logger.info(f"save model at epoch {best_epoch} to {self.config.save_model}")
if not os.path.exists(os.path.dirname(self.config.save_model)):
os.makedirs(os.path.dirname(self.config.save_model))
torch.save(self.model.state_dict(), self.config.save_model)
# close writer
self.tensorboard_writer.close()
def _train_epoch(self,
train_loader,
**kwargs):
# get loss function
if "loss_fn" in kwargs:
loss_fn = kwargs["loss_fn"]
else:
loss_fn = get_loss(config=self.config)
if "pred_fn" in kwargs:
pred_fn = kwargs["pred_fn"]
else:
pred_fn = self.pred_fn
if "label_fn" in kwargs:
label_fn = kwargs["label_fn"]
else:
label_fn = self.label_fn
sum_loss = 0
iter_data = (
tqdm(
train_loader,
total=len(train_loader),
ncols=100,
desc = f"train "
)
)
result_dict = None
for batch_idx, batched_data in enumerate(iter_data):
self.model.train()
self.optimizer.zero_grad()
if pred_fn is not None:
batched_pred = pred_fn(batched_data)
else:
batched_pred = self.pred_fn(batched_data)
if label_fn is not None:
batched_label = label_fn(batched_data)
else:
batched_label = self.label_fn(batched_data)
loss = loss_fn(batched_pred, batched_label)
loss.backward()
self.optimizer.step()
sum_loss += loss.item() * batched_pred.shape[0]
# self.model.eval()
if "eval_fn" in kwargs:
cur_result_dict = kwargs["eval_fn"](pred=batched_pred, label=batched_label)
else:
cur_result_dict = eval_all_metric(config=self.config, pred=batched_pred, label=batched_label)
result_dict = eval_reduce([result_dict, cur_result_dict])
iter_data.set_postfix(loss=loss.item(), **result_dict)
result_dict["loss"] = sum_loss / result_dict["data_size"]
return result_dict
Third Step: Predict and Load Label Information¶
The pred_fn method and label_fn method in TaskTrainer are responsible for predicting and loading label information:
class TaskTrainer(AbstractTrainer):
def __init__(self, config, model):
super(TaskTrainer, self).__init__(config=config, model=model)
def pred_fn(self, batched_data):
batched_pred = self.model(batched_data)
return batched_pred
def label_fn(self, batched_data):
batched_label = batched_data["label"].to(self.device).float().squeeze(0)
if len(batched_label.shape) == 0:
batched_label = batched_label.unsqueeze(0)
return batched_label
Forth Step: Evaluate Metric¶
You can add the evaluation function in satgl.metric and then call it in your trainer.