Task¶
SATGL uses wrappers to adapt to the dataset, model, trainer, and evaluate metrics invoked by different tasks. Each module is described in detail below.
Dataset Wrapper¶
satgl.data.wrapper.SATDataWrapper contains two main functions, loading datasets and constructing dataloader.
The specific implementation of dataset and dataloader can be referred to Data.
Here is the sample code for the MaxSAT task:
from satgl.data.dataset.sat_dataset import MaxSATDataset
from satgl.data.dataloader.sat_dataloader import MaxSATDataLoader
class SATDataWrapper():
def __init__(self, config):
self.config = config
self.logger = Logger(config, name='sat_data_wrapper')
self._load_dataset()
self._build_dataloader()
def _load_dataset(self):
if self.config["task"] == "maxsat":
self._load_maxsat_dataset()
else:
raise NotImplementedError(f"task {self.config['task']} not implemented.")
def _build_dataloader(self):
if self.config["task"] == "maxsat":
self._build_maxsat_dataloader()
else:
raise NotImplementedError(f"task {self.config['task']} not implemented.")
def _load_maxsat_dataset(self):
if self.config["load_split_dataset"] == True:
train_cnf_dir = os.path.join(self.config["dataset_path"], "train")
valid_cnf_dir = os.path.join(self.config["dataset_path"], "valid")
test_cnf_dir = os.path.join(self.config["dataset_path"], "test")
train_label_path = os.path.join(self.config["dataset_path"], "label", "train.csv")
valid_label_path = os.path.join(self.config["dataset_path"], "label", "valid.csv")
test_label_path = os.path.join(self.config["dataset_path"], "label", "test.csv")
self.logger.info("processing train dataset ...")
self.train_dataset = MaxSATDataset(self.config, train_cnf_dir, train_label_path)
self.logger.info("processing valid dataset ...")
self.valid_dataset = MaxSATDataset(self.config, valid_cnf_dir,
valid_label_path) if os.path.exists(valid_cnf_dir) else None
self.logger.info("processing test dataset ...")
self.test_dataset = MaxSATDataset(self.config, test_cnf_dir, test_label_path) if os.path.exists(
test_cnf_dir) else None
else:
raise NotImplementedError("todo ! Not implemented yet.")
def _build_maxsat_dataloader(self):
batch_size = self.config["batch_size"]
self.train_dataloader = MaxSATDataLoader(self.train_dataset, batch_size, shuffle=True)
self.valid_dataloader = MaxSATDataLoader(self.valid_dataset, batch_size, shuffle=False) if self.valid_dataset is not None else None
self.test_dataloader = MaxSATDataLoader(self.test_dataset, batch_size, shuffle=False) if self.test_dataset is not None else None
Model Wrapper¶
SATGL categorizes model wrappers based on the type of graph constructed. Currently it supports three kinds of LCG, VIG, AIG. Wrapper implements operations such as loading model, initialization, training and pooling. Different tasks can add different operation functions to the wrapper.
The following is an example of building a wrapper for a graph of type AIG.
class AIGWrapper(nn.Module):
def __init__(self, config):
super(AIGWrapper, self).__init__()
self.config = config
self.feature_type = config["feature_type"]
self.task = config["task"]
self.sigmoid = config["model_settings"]["sigmoid"]
self.pooling = config["model_settings"]["pooling"]
self.model = config["model_settings"]["model"]
self.device = config["device"]
self.hidden_size = config.model_settings["hidden_size"]
self.num_fc = config.model_settings["num_fc"]
# embedding init layer, the aig graph node type contrains 3
self.init_feature_list = nn.ParameterList()
self.init_embedding_list = nn.ModuleList()
for node_type in range(3):
self.init_feature_list.append(nn.Parameter(torch.randn(1, self.hidden_size)))
self.init_embedding_list.append(nn.Linear(self.hidden_size, self.hidden_size))
# task specific init
if self.task == "satisfiability":
self._satisfiability_init()
else:
raise ValueError(f" task not support.")
# sat model init
if self.model == "deepsat":
self.model = DeepSAT(config)
else:
raise ValueError(f"{self.model} not support.")
def _satisfiability_init(self):
# readout
self.graph_readout = MLP(self.hidden_size, self.hidden_size, 1, num_layer=self.num_fc)
self.graph_level_forward = self.graph_pooling
def get_init_embedding(self, g):
node_type = g.ndata["node_type"]
num_nodes = g.number_of_nodes()
num_classes = 3
node_embedding = torch.zeros((num_nodes, self.hidden_size)).to(self.device)
for i in range(num_classes):
node_type_idx = (node_type == i).nonzero().squeeze().to(self.device)
init_embedding = self.init_embedding_list[i](self.init_feature_list[i].to(self.device))
node_embedding[node_type_idx] = init_embedding.repeat(node_type_idx.shape[0], 1)
return node_embedding
def graph_pooling(self, node_embedding, data):
g = data["g"]
out_node_index = (g.ndata["backward_node_level"] == 0).nonzero().squeeze()
graph_embedding = self.graph_readout(node_embedding[out_node_index]).squeeze()
if self.sigmoid:
graph_embedding = torch.sigmoid(graph_embedding)
return graph_embedding
def forward(self, data):
g = data["g"].to(self.device)
node_embedding = self.get_init_embedding(g)
node_embedding = self.model(g, node_embedding)
# readout
if self.task == "satisfiability":
return self.graph_level_forward(node_embedding, data)
Trainer¶
The SATGL toolkit provides a comprehensive trainer designed for training graph-based models tailored to various tasks. This trainer, embodied in the satgl.trainer.trainer.AbstractTrainer class, caters to various training strategies, allowing users to seamlessly manage the training process.
This trainer encapsulates critical functionalities such as optimizer configuration, integration with learning rate schedulers, evaluation metric tracking, and TensorBoard logging for visualizing training progress.
Explore the flexibility of SATGL’s trainer, providing a robust foundation for training and evaluating graph-based models for various tasks. For detailed configuration parameters and training settings, consult the Customize Trainer.
Evaluate Metric¶
SATGL provides multiple evaluation metrics for different tasks:
Metric |
Description |
|---|---|
Accuracy (acc) |
Evaluates classification accuracy by comparing predicted class with ground truth class. Suitable for classification tasks. |
Mean Squared Error (mse) |
Measures the average squared difference between predicted and true values. Commonly used for regression tasks. |
Mean Absolute Error (mae) |
Computes the average absolute difference between predicted and true values, providing a robust measure of prediction accuracy. |
Root Mean Squared Error (rmse) |
Calculates the square root of the mean squared error, offering a similar metric to mse but in the original scale of the data. Useful for regression tasks. |