Overall

First Step: Dataset Wrapper

Depending on the particular task, add functions to load the dataset and build the dataloader in satgl.data.wrapper.SATDataWrapper.

Example of loading dataset in MaxSAT task:

def _load_dataset(self):
    if self.config["task"] == "maxsat":
        self._load_maxsat_dataset()
    else:
        raise NotImplementedError(f"task {self.config['task']} not implemented.")

def _load_maxsat_dataset(self):
    if self.config["load_split_dataset"] == True:
        train_cnf_dir = os.path.join(self.config["dataset_path"], "train")
        valid_cnf_dir = os.path.join(self.config["dataset_path"], "valid")
        test_cnf_dir = os.path.join(self.config["dataset_path"], "test")
        train_label_path = os.path.join(self.config["dataset_path"], "label", "train.csv")
        valid_label_path = os.path.join(self.config["dataset_path"], "label", "valid.csv")
        test_label_path = os.path.join(self.config["dataset_path"], "label", "test.csv")

        self.logger.info("processing train dataset ...")
        self.train_dataset = MaxSATDataset(self.config, train_cnf_dir, train_label_path)
        self.logger.info("processing valid dataset ...")
        self.valid_dataset = MaxSATDataset(self.config, valid_cnf_dir,
                                                     valid_label_path) if os.path.exists(valid_cnf_dir) else None
        self.logger.info("processing test dataset ...")
        self.test_dataset = MaxSATDataset(self.config, test_cnf_dir, test_label_path) if os.path.exists(
            test_cnf_dir) else None
    else:
        raise NotImplementedError("todo ! Not implemented yet.")

Example of building the dataloader in MaxSAT task:

def _build_dataloader(self):
    if self.config["task"] == "maxsat":
        self._build_maxsat_dataloader()
    else:
        raise NotImplementedError(f"task {self.config['task']} not implemented.")

def _build_maxsat_dataloader(self):
    batch_size = self.config["batch_size"]
    self.train_dataloader = MaxSATDataLoader(self.train_dataset, batch_size, shuffle=True)
    self.valid_dataloader = MaxSATDataLoader(self.valid_dataset, batch_size, shuffle=False) if self.valid_dataset is not None else None
    self.test_dataloader = MaxSATDataLoader(self.test_dataset, batch_size, shuffle=False) if self.test_dataset is not None else None

Complete Code

from satgl.data.dataset.sat_dataset import MaxSATDataset
from satgl.data.dataloader.sat_dataloader import MaxSATDataLoader

class SATDataWrapper():
    def __init__(self, config):
        self.config = config
        self.logger = Logger(config, name='sat_data_wrapper')
        self._load_dataset()
        self._build_dataloader()

    def _load_dataset(self):
        if self.config["task"] == "maxsat":
            self._load_maxsat_dataset()
        else:
            raise NotImplementedError(f"task {self.config['task']} not implemented.")

    def _build_dataloader(self):
        if self.config["task"] == "maxsat":
            self._build_maxsat_dataloader()
        else:
            raise NotImplementedError(f"task {self.config['task']} not implemented.")

    def _load_maxsat_dataset(self):
        if self.config["load_split_dataset"] == True:
            train_cnf_dir = os.path.join(self.config["dataset_path"], "train")
            valid_cnf_dir = os.path.join(self.config["dataset_path"], "valid")
            test_cnf_dir = os.path.join(self.config["dataset_path"], "test")
            train_label_path = os.path.join(self.config["dataset_path"], "label", "train.csv")
            valid_label_path = os.path.join(self.config["dataset_path"], "label", "valid.csv")
            test_label_path = os.path.join(self.config["dataset_path"], "label", "test.csv")

            self.logger.info("processing train dataset ...")
            self.train_dataset = MaxSATDataset(self.config, train_cnf_dir, train_label_path)
            self.logger.info("processing valid dataset ...")
            self.valid_dataset = MaxSATDataset(self.config, valid_cnf_dir,
                                                         valid_label_path) if os.path.exists(valid_cnf_dir) else None
            self.logger.info("processing test dataset ...")
            self.test_dataset = MaxSATDataset(self.config, test_cnf_dir, test_label_path) if os.path.exists(
                test_cnf_dir) else None
        else:
            raise NotImplementedError("todo ! Not implemented yet.")

    def _build_maxsat_dataloader(self):
        batch_size = self.config["batch_size"]
        self.train_dataloader = MaxSATDataLoader(self.train_dataset, batch_size, shuffle=True)
        self.valid_dataloader = MaxSATDataLoader(self.valid_dataset, batch_size, shuffle=False) if self.valid_dataset is not None else None
        self.test_dataloader = MaxSATDataLoader(self.test_dataset, batch_size, shuffle=False) if self.test_dataset is not None else None

Note

You should first write the corresponding dataset and dataloader code based on the type of task, which can be found in Customize Data.

Second Step: Model Wrapper

You should choose model wrapper based on the type of graph being constructed and add task-specific initialization methods to it. Next, just make minor changes to model wrapper.

For example, if the composition mode is AIG and the task is a satisfiability task, you can add an initialization mode as follows:

def _satisfiability_init(self):
    # readout
    self.graph_readout = MLP(self.hidden_size, self.hidden_size, 1, num_layer=self.num_fc)
    self.graph_level_forward = self.graph_pooling

Then make minor changes to the wrapper, and the complete code looks like this

class AIGWrapper(nn.Module):
    def __init__(self, config):
        super(AIGWrapper, self).__init__()
        self.config = config
        self.feature_type = config["feature_type"]
        self.task = config["task"]
        self.sigmoid = config["model_settings"]["sigmoid"]
        self.pooling = config["model_settings"]["pooling"]
        self.model = config["model_settings"]["model"]
        self.device = config["device"]
        self.hidden_size = config.model_settings["hidden_size"]
        self.num_fc = config.model_settings["num_fc"]

        # embedding init layer, the aig graph node type contrains 3
        self.init_feature_list = nn.ParameterList()
        self.init_embedding_list = nn.ModuleList()
        for node_type in range(3):
            self.init_feature_list.append(nn.Parameter(torch.randn(1, self.hidden_size)))
            self.init_embedding_list.append(nn.Linear(self.hidden_size, self.hidden_size))

        # task specific init
        if self.task == "satisfiability":
            self._satisfiability_init()
        else:
            raise ValueError(f" task not support.")

        # sat model init
        if self.model == "deepsat":
            self.model = DeepSAT(config)
        else:
            raise ValueError(f"{self.model} not support.")

    def _satisfiability_init(self):
        # readout
        self.graph_readout = MLP(self.hidden_size, self.hidden_size, 1, num_layer=self.num_fc)
        self.graph_level_forward = self.graph_pooling

    def get_init_embedding(self, g):
        node_type = g.ndata["node_type"]
        num_nodes = g.number_of_nodes()
        num_classes = 3
        node_embedding = torch.zeros((num_nodes, self.hidden_size)).to(self.device)
        for i in range(num_classes):
            node_type_idx = (node_type == i).nonzero().squeeze().to(self.device)
            init_embedding = self.init_embedding_list[i](self.init_feature_list[i].to(self.device))
            node_embedding[node_type_idx] = init_embedding.repeat(node_type_idx.shape[0], 1)
        return node_embedding

    def graph_pooling(self, node_embedding, data):
        g = data["g"]
        out_node_index = (g.ndata["backward_node_level"] == 0).nonzero().squeeze()
        graph_embedding = self.graph_readout(node_embedding[out_node_index]).squeeze()
        if self.sigmoid:
            graph_embedding = torch.sigmoid(graph_embedding)
        return graph_embedding

    def forward(self, data):
        g = data["g"].to(self.device)
        node_embedding = self.get_init_embedding(g)
        node_embedding = self.model(g, node_embedding)

        # readout
        if self.task == "satisfiability":
            return self.graph_level_forward(node_embedding, data)

Note

You should first write the modeling code, which can be found in Customize Model.

Third Step: Trainer

For detailed configuration parameters and training settings, consult the Customize Trainer.