Custmize DataLoader
In this section, we’ll first present an overview of the modules in DataLoader and then customize a new one.
1. Overview
In XGCN, the dataloader is called by Trainer during the batch training,
and basically it only needs to be an iterable object.
To add a new dataloader, you can simply implement it as an iterable object
on your own and add it to XGCN.create_DataLoader()
(see create_DataLoader() in XGCN/dataloading/create.py),
or you can use the infrastructure provided by XGCN.
In this section, we focus on introducing main components of the XGCN dataloader infrastructure. The UML class diagram is shown in the figure below.
The interface classes are defined in XGCN/dataloading/base.py, they describe
a series of interface functions.
The BaseDataset class requires three functions: __len__(), __getitem__(),
and on_epoch_start().
Note that the __getitem__() function is supposed to return a batch of training sample
given the batch index.
NodeListDataset further requires the returned data in __getitem__()
should include a list of tensors of node IDs.
The Sampler is used to generate positive/negative training samples
given the sample index. And the BatchSampleIndicesGenerator is used to
generate sample indices given the batch index.
To train large-scale message-passing GNNs, mini-graph sampling is often needed.
In XGCN, the BlockDataset class utilize the NodeListDataset
and DGL’s BlockSampler to conduct mini-graph sampling.
The function XGCN.create_DataLoader() is used to initialize a dataloader.
You can refer to the functions in XGCN/dataloading/create.py.
2. Implement a Sampler
In the following, let’s customize a new dataloader and apply it to a GNN model.
Suppose we want to sample negative nodes according to their degrees, this can be
done by adding a new Sampler (add a XGCN/dataloading/sample/WeightedNeg_Sampler.py):
# XGCN/dataloading/sample/WeightedNeg_Sampler.py
from XGCN.dataloading.base import BaseSampler
from XGCN.utils import io, csr
import torch
import os.path as osp
class WeightedNeg_Sampler(BaseSampler):
def __init__(self, config, data):
self.num_neg = config['num_neg']
data_root = config['data_root']
indptr = io.load_pickle(osp.join(data_root, 'indptr.pkl'))
indices = io.load_pickle(osp.join(data_root, 'indices.pkl'))
indptr, indices = csr.get_undirected(indptr, indices)
degrees = indptr[1:] - indptr[:-1]
info = io.load_yaml(osp.join(data_root, 'info.yaml'))
if info['graph_type'] == 'user-item':
self.num_neg_total = info['num_items']
self.offset = info['num_users']
else:
self.num_neg_total = info['num_nodes']
self.offset = 0
# the probability a node is sampled is proportional to the weights:
self.weights = torch.FloatTensor(
degrees[self.offset : self.offset + self.num_neg_total]
) ** 0.75
def __call__(self, pos_sample_data):
src = pos_sample_data['src']
neg = torch.multinomial(
self.weights, num_samples=len(src), replacement=True
) + self.offset
return neg
Also remember to add it to the create_Sampler() function in XGCN/dataloading/create.py,
so that XGCN can find the new Sampler.
# in XGCN/dataloading/create.py
from XGCN.sample.WeightedNeg_Sampler import WeightedNeg_Sampler
def create_LinkDataset(config, data):
pos_sampler = {
'ObservedEdges_Sampler': ObservedEdges_Sampler,
}[config['pos_sampler']](config, data)
neg_sampler = {
'RandomNeg_Sampler': RandomNeg_Sampler,
'WeightedNeg_Sampler': WeightedNeg_Sampler, # <-- add the new Sampler here
}[config['neg_sampler']](config, data)
...
3. Config and run!
Now we have already add a new dataloader to XGCN, you can use it by simply
add a --neg_sampler WeightedNeg_Sampler argument.
For example, we can modify the script: XGCN/script/model/GraphSAGE/run_GraphSGAE-facebook.sh.
# in XGCN/script/model/GraphSAGE/run_GraphSGAE-facebook.sh
python -m XGCN.main.run_model --seed $seed \
...
--neg_sampler WeightedNeg_Sampler \