Customize Data¶
Customize Dataset¶
First Step: Load Dataset¶
First, you should construct function _load_dataset to read labels and the raw file, such as CNF files or AIG files.
Next, construct the _build_graph function based on the type of graph to be built. Also, you can add multiple functions
to build different types of graphs. _build_graph function can also be used directly by inheriting from the satgl.data.dataset.sat_dataset.SATTaskAbstractDataset.
More functions can be found in Raw Input.
Here’s an example function of reading label information stored as a csv and reading CNF files to build graphs:
def _load_dataset(self):
label_df = pd.read_csv(self.label_path, sep=',')
self.data_list = []
for idx, row in tqdm(label_df.iterrows(), total=label_df.shape[0]):
name = row['name']
label = eval(row['maxsat'])
cnf_path = os.path.join(self.cnf_dir, name)
num_variable, num_clause, clause_list = parse_cnf_file(cnf_path)
cnf_graph = self._build_graph(num_variable, num_clause, clause_list, cnf_path)
info = self._get_info(num_variable, num_clause, clause_list)
self.data_list.append({"g": cnf_graph, "label": label, "info": info})
Second Step. Output dataset¶
Since the SATTaskAbstractDataset provided by SATGL inherits from dgl.data.dgl_dataset.DGLDataset, it is necessary to implement two functions,
__getitem__ and __len__, to get the data for each batch during training.
An example can be found in the following code:
def __getitem__(self, idx):
return self.data_list[idx]
def __len__(self):
return len(self.data_list)
Complete Code
class MaxSATDataset(SATTaskAbstractDataset):
def __init__(self, config, cnf_dir, label_path):
super().__init__(config, cnf_dir, label_path, name="maxsat_dataset")
self._load_dataset()
def _load_dataset(self):
label_df = pd.read_csv(self.label_path, sep=',')
self.data_list = []
for idx, row in tqdm(label_df.iterrows(), total=label_df.shape[0]):
name = row['name']
label = eval(row['maxsat'])
cnf_path = os.path.join(self.cnf_dir, name)
num_variable, num_clause, clause_list = parse_cnf_file(cnf_path)
cnf_graph = self._build_graph(num_variable, num_clause, clause_list, cnf_path)
info = self._get_info(num_variable, num_clause, clause_list)
self.data_list.append({"g": cnf_graph, "label": label, "info": info})
def __getitem__(self, idx):
return self.data_list[idx]
def __len__(self):
return len(self.data_list)
Customize DataLoader¶
The dataloader provided by SATGL inherits from dgl.dataloading.GraphDataLoader, so the only thing to do to customize your dataloader
is to design the collate_fn function to pass to dgl.dataloading.GraphDataLoader based on the dataset.
An example can be found in the following code:
def maxsat_collate_fn(data):
if (len(data) == 0):
return None
elem = data[0]
if isinstance(elem, dict):
return {key: maxsat_collate_fn([d[key] for d in data]) for key in elem}
elif isinstance(elem, dgl.DGLGraph):
return dgl.batch(data)
elif isinstance(elem, torch.Tensor):
if elem.dim() == 0:
return torch.stack(data)
else:
return torch.cat(data, dim=0)
elif isinstance(elem, list):
return torch.cat([maxsat_collate_fn(e) for e in data])
elif isinstance(elem, int):
return torch.tensor(data)
elif isinstance(elem, float):
return torch.tensor(data)
def MaxSATDataLoader(dataset, batch_size, shuffle=False):
return GraphDataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=maxsat_collate_fn,
)