DataLoaderΒΆ
Our DataLoader inherits dgl.dataloading.GraphCollator and dgl.dataloading.GraphDataLoader from DGL.
You should implement the collate_fn function according to Dataset.
An example can be found in the following code:
import dgl
import torch
from dgl.dataloading import GraphCollator, GraphDataLoader
def maxsat_collate_fn(data):
if (len(data) == 0):
return None
elem = data[0]
if isinstance(elem, dict):
return {key: maxsat_collate_fn([d[key] for d in data]) for key in elem}
elif isinstance(elem, dgl.DGLGraph):
return dgl.batch(data)
elif isinstance(elem, torch.Tensor):
if elem.dim() == 0:
return torch.stack(data)
else:
return torch.cat(data, dim=0)
elif isinstance(elem, list):
return torch.cat([maxsat_collate_fn(e) for e in data])
elif isinstance(elem, int):
return torch.tensor(data)
elif isinstance(elem, float):
return torch.tensor(data)
def MaxSATDataLoader(dataset, batch_size, shuffle=False):
return GraphDataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=maxsat_collate_fn,
)