Model Training
Train from scratch
There are three steps to train a model:
Prepare the
configDict, which contains all the needed arguments.Create the model:
model = XGCN.create_model(config). Theresults_rootdirectory will be automatically created if it does not exist.Start training:
model.fit(). The best model on the validation set and the training information will be save atresults_root.
XGCN provides a simple module - XGCN.main.run_model - to run models from command line.
It has the following contents:
import XGCN
from XGCN.data import io
from XGCN.utils.parse_arguments import parse_arguments
import os.path as osp
def main():
config = parse_arguments()
model = XGCN.create_model(config)
model.fit()
test_results = model.test()
print("test:", test_results)
io.save_json(osp.join(config['results_root'], 'test_results.json'), test_results)
if __name__ == '__main__':
main()
We provide shell scripts to run all the models in script/examples.
For example, script/examples/facebook/run_xGCN.sh:
# set to your own path:
all_data_root='/home/sxr/code/XGCN_and_data/XGCN_data'
config_file_root='/home/sxr/code/XGCN_and_data/XGCN_library/config'
dataset=facebook
model=xGCN
seed=0
device='cuda:0'
emb_table_device=$device
forward_device=$device
out_emb_table_device=$device
data_root=$all_data_root/dataset/instance_$dataset
results_root=$all_data_root/model_output/$dataset/$model/[seed$seed]
# file_pretrained_emb=$all_data_root/model_output/$dataset/Node2vec/[seed$seed]/model/out_emb_table.pt
python -m XGCN.main.run_model --seed $seed \
--config_file $config_file_root/$model-config.yaml \
--data_root $data_root --results_root $results_root \
--val_method one_pos_k_neg \
--file_val_set $data_root/val-one_pos_k_neg.pkl \
--key_score_metric r20 \
--test_method multi_pos_whole_graph \
--file_test_set $data_root/test-multi_pos_whole_graph.pkl \
--emb_table_device $emb_table_device \
--forward_device $forward_device \
--out_emb_table_device $out_emb_table_device \
# --from_pretrained 1 --file_pretrained_emb $file_pretrained_emb \
To run a script, you only need to modify all_data_root and
config_file_root to your own paths.
Once a model is trained, the output data will be saved at results_root:
XGCN_data
└── model_output
└── facebook
└── xGCN
└── [seed0]
├── model (directory) # the best model on the validation set
├── config.yaml # configurations of the running
├── mean_time.json # time consumption information in seconds
├── test_results.json # test results
├── train_record_best.json # validation results of the best epoch
└── train_record.txt # validation results of all the epochs
Load and continue to train
XGCN can also load trained models and continue to train.
In this case, please specify the previously saved config.yaml and call the
XGCN.load_model() function:
config = io.load_yaml(...) # the previously saved config.yaml
config['emb_lr'] = 0.0001 # change some hyper-paramenters
model = XGCN.load_model(config) # load the saved model
model.fit() # training on the new hyper-paramenters
new_resutls = model.test()