Task

SATGL uses wrappers to adapt to the dataset, model, trainer, and evaluate metrics invoked by different tasks. Each module is described in detail below.

Dataset Wrapper

satgl.data.wrapper.SATDataWrapper contains two main functions, loading datasets and constructing dataloader. The specific implementation of dataset and dataloader can be referred to Data.

Here is the sample code for the MaxSAT task:

from satgl.data.dataset.sat_dataset import MaxSATDataset
from satgl.data.dataloader.sat_dataloader import MaxSATDataLoader

class SATDataWrapper():
    def __init__(self, config):
        self.config = config
        self.logger = Logger(config, name='sat_data_wrapper')
        self._load_dataset()
        self._build_dataloader()

    def _load_dataset(self):
        if self.config["task"] == "maxsat":
            self._load_maxsat_dataset()
        else:
            raise NotImplementedError(f"task {self.config['task']} not implemented.")

    def _build_dataloader(self):
        if self.config["task"] == "maxsat":
            self._build_maxsat_dataloader()
        else:
            raise NotImplementedError(f"task {self.config['task']} not implemented.")

    def _load_maxsat_dataset(self):
        if self.config["load_split_dataset"] == True:
            train_cnf_dir = os.path.join(self.config["dataset_path"], "train")
            valid_cnf_dir = os.path.join(self.config["dataset_path"], "valid")
            test_cnf_dir = os.path.join(self.config["dataset_path"], "test")
            train_label_path = os.path.join(self.config["dataset_path"], "label", "train.csv")
            valid_label_path = os.path.join(self.config["dataset_path"], "label", "valid.csv")
            test_label_path = os.path.join(self.config["dataset_path"], "label", "test.csv")

            self.logger.info("processing train dataset ...")
            self.train_dataset = MaxSATDataset(self.config, train_cnf_dir, train_label_path)
            self.logger.info("processing valid dataset ...")
            self.valid_dataset = MaxSATDataset(self.config, valid_cnf_dir,
                                                         valid_label_path) if os.path.exists(valid_cnf_dir) else None
            self.logger.info("processing test dataset ...")
            self.test_dataset = MaxSATDataset(self.config, test_cnf_dir, test_label_path) if os.path.exists(
                test_cnf_dir) else None
        else:
            raise NotImplementedError("todo ! Not implemented yet.")

    def _build_maxsat_dataloader(self):
        batch_size = self.config["batch_size"]
        self.train_dataloader = MaxSATDataLoader(self.train_dataset, batch_size, shuffle=True)
        self.valid_dataloader = MaxSATDataLoader(self.valid_dataset, batch_size, shuffle=False) if self.valid_dataset is not None else None
        self.test_dataloader = MaxSATDataLoader(self.test_dataset, batch_size, shuffle=False) if self.test_dataset is not None else None

Model Wrapper

SATGL categorizes model wrappers based on the type of graph constructed. Currently it supports three kinds of LCG, VIG, AIG. Wrapper implements operations such as loading model, initialization, training and pooling. Different tasks can add different operation functions to the wrapper.

The following is an example of building a wrapper for a graph of type AIG.

class AIGWrapper(nn.Module):
    def __init__(self, config):
        super(AIGWrapper, self).__init__()
        self.config = config
        self.feature_type = config["feature_type"]
        self.task = config["task"]
        self.sigmoid = config["model_settings"]["sigmoid"]
        self.pooling = config["model_settings"]["pooling"]
        self.model = config["model_settings"]["model"]
        self.device = config["device"]
        self.hidden_size = config.model_settings["hidden_size"]
        self.num_fc = config.model_settings["num_fc"]

        # embedding init layer, the aig graph node type contrains 3
        self.init_feature_list = nn.ParameterList()
        self.init_embedding_list = nn.ModuleList()
        for node_type in range(3):
            self.init_feature_list.append(nn.Parameter(torch.randn(1, self.hidden_size)))
            self.init_embedding_list.append(nn.Linear(self.hidden_size, self.hidden_size))

        # task specific init
        if self.task == "satisfiability":
            self._satisfiability_init()
        else:
            raise ValueError(f" task not support.")

        # sat model init
        if self.model == "deepsat":
            self.model = DeepSAT(config)
        else:
            raise ValueError(f"{self.model} not support.")

    def _satisfiability_init(self):
        # readout
        self.graph_readout = MLP(self.hidden_size, self.hidden_size, 1, num_layer=self.num_fc)
        self.graph_level_forward = self.graph_pooling

    def get_init_embedding(self, g):
        node_type = g.ndata["node_type"]
        num_nodes = g.number_of_nodes()
        num_classes = 3
        node_embedding = torch.zeros((num_nodes, self.hidden_size)).to(self.device)
        for i in range(num_classes):
            node_type_idx = (node_type == i).nonzero().squeeze().to(self.device)
            init_embedding = self.init_embedding_list[i](self.init_feature_list[i].to(self.device))
            node_embedding[node_type_idx] = init_embedding.repeat(node_type_idx.shape[0], 1)
        return node_embedding

    def graph_pooling(self, node_embedding, data):
        g = data["g"]
        out_node_index = (g.ndata["backward_node_level"] == 0).nonzero().squeeze()
        graph_embedding = self.graph_readout(node_embedding[out_node_index]).squeeze()
        if self.sigmoid:
            graph_embedding = torch.sigmoid(graph_embedding)
        return graph_embedding

    def forward(self, data):
        g = data["g"].to(self.device)
        node_embedding = self.get_init_embedding(g)
        node_embedding = self.model(g, node_embedding)

        # readout
        if self.task == "satisfiability":
            return self.graph_level_forward(node_embedding, data)

Trainer

The SATGL toolkit provides a comprehensive trainer designed for training graph-based models tailored to various tasks. This trainer, embodied in the satgl.trainer.trainer.AbstractTrainer class, caters to various training strategies, allowing users to seamlessly manage the training process.

This trainer encapsulates critical functionalities such as optimizer configuration, integration with learning rate schedulers, evaluation metric tracking, and TensorBoard logging for visualizing training progress.

Explore the flexibility of SATGL’s trainer, providing a robust foundation for training and evaluating graph-based models for various tasks. For detailed configuration parameters and training settings, consult the Customize Trainer.

Evaluate Metric

SATGL provides multiple evaluation metrics for different tasks:

Metric

Description

Accuracy (acc)

Evaluates classification accuracy by comparing predicted class with ground truth class. Suitable for classification tasks.

Mean Squared Error (mse)

Measures the average squared difference between predicted and true values. Commonly used for regression tasks.

Mean Absolute Error (mae)

Computes the average absolute difference between predicted and true values, providing a robust measure of prediction accuracy.

Root Mean Squared Error (rmse)

Calculates the square root of the mean squared error, offering a similar metric to mse but in the original scale of the data. Useful for regression tasks.