Custmize DataLoader

In this section, we’ll first present an overview of the modules in DataLoader and then customize a new one.

1. Overview

In XGCN, the dataloader is called by Trainer during the batch training, and basically it only needs to be an iterable object. To add a new dataloader, you can simply implement it as an iterable object on your own and add it to XGCN.create_DataLoader() (see create_DataLoader() in XGCN/dataloading/create.py), or you can use the infrastructure provided by XGCN.

In this section, we focus on introducing main components of the XGCN dataloader infrastructure. The UML class diagram is shown in the figure below.

UML class diagram of DataLoader

The interface classes are defined in XGCN/dataloading/base.py, they describe a series of interface functions. The BaseDataset class requires three functions: __len__(), __getitem__(), and on_epoch_start(). Note that the __getitem__() function is supposed to return a batch of training sample given the batch index. NodeListDataset further requires the returned data in __getitem__() should include a list of tensors of node IDs. The Sampler is used to generate positive/negative training samples given the sample index. And the BatchSampleIndicesGenerator is used to generate sample indices given the batch index.

To train large-scale message-passing GNNs, mini-graph sampling is often needed. In XGCN, the BlockDataset class utilize the NodeListDataset and DGL’s BlockSampler to conduct mini-graph sampling.

The function XGCN.create_DataLoader() is used to initialize a dataloader. You can refer to the functions in XGCN/dataloading/create.py.

2. Implement a Sampler

In the following, let’s customize a new dataloader and apply it to a GNN model. Suppose we want to sample negative nodes according to their degrees, this can be done by adding a new Sampler (add a XGCN/dataloading/sample/WeightedNeg_Sampler.py):

# XGCN/dataloading/sample/WeightedNeg_Sampler.py

from XGCN.dataloading.base import BaseSampler
from XGCN.utils import io, csr

import torch
import os.path as osp


class WeightedNeg_Sampler(BaseSampler):

    def __init__(self, config, data):
        self.num_neg = config['num_neg']

        data_root = config['data_root']
        indptr = io.load_pickle(osp.join(data_root, 'indptr.pkl'))
        indices = io.load_pickle(osp.join(data_root, 'indices.pkl'))
        indptr, indices = csr.get_undirected(indptr, indices)
        degrees = indptr[1:] - indptr[:-1]

        info = io.load_yaml(osp.join(data_root, 'info.yaml'))
        if info['graph_type'] == 'user-item':
            self.num_neg_total = info['num_items']
            self.offset = info['num_users']
        else:
            self.num_neg_total = info['num_nodes']
            self.offset = 0

        # the probability a node is sampled is proportional to the weights:
        self.weights = torch.FloatTensor(
            degrees[self.offset : self.offset + self.num_neg_total]
        ) ** 0.75

    def __call__(self, pos_sample_data):
        src = pos_sample_data['src']
        neg = torch.multinomial(
            self.weights, num_samples=len(src), replacement=True
        ) + self.offset
        return neg

Also remember to add it to the create_Sampler() function in XGCN/dataloading/create.py, so that XGCN can find the new Sampler.

# in XGCN/dataloading/create.py

from XGCN.sample.WeightedNeg_Sampler import WeightedNeg_Sampler

def create_LinkDataset(config, data):
    pos_sampler = {
        'ObservedEdges_Sampler': ObservedEdges_Sampler,
    }[config['pos_sampler']](config, data)

    neg_sampler = {
        'RandomNeg_Sampler': RandomNeg_Sampler,
        'WeightedNeg_Sampler': WeightedNeg_Sampler,  # <-- add the new Sampler here
    }[config['neg_sampler']](config, data)

    ...

3. Config and run!

Now we have already add a new dataloader to XGCN, you can use it by simply add a --neg_sampler WeightedNeg_Sampler argument. For example, we can modify the script: XGCN/script/model/GraphSAGE/run_GraphSGAE-facebook.sh.

# in XGCN/script/model/GraphSAGE/run_GraphSGAE-facebook.sh

python -m XGCN.main.run_model --seed $seed \
    ...
    --neg_sampler WeightedNeg_Sampler \