Official documents: https://docs.dgl.ai/en/latest/guide/data.html
Dgl.data implements many common graph datasets, which are subclasses of dgl.data.DGLDataset
DGL officially recommends that you implement your own dataset by inheriting dgl.data.DGLDataset, so that you can more easily load, process and save graph datasets
1.DGLDataset class
The data set processing process of dgl.data.DGLDataset class includes the following steps: downloading, processing, saving to disk and loading from disk, as shown in the following figure
Custom dataset class:
from dgl.data import DGLDataset class MyDataset(DGLDataset): def __init__(self): super().__init__(name='my_dataset', url='https://example.com/path/to/my_dataset.zip') def download(self): # download raw data to local disk pass def save(self): # save processed data to directory `self.save_path` pass def load(self): # load processed data from directory `self.save_path` pass def process(self): # process raw data to graphs, labels, splitting masks pass def has_cache(self): # check whether there are processed data in `self.save_path` pass def __getitem__(self, idx): # get one example by index pass def __len__(self): # number of data examples pass
Where process()__ getitem__ (IDX) and__ len__ () is a method that must be implemented
The purpose of DGLDataset class is to provide a standard and convenient way to load graph data. It can store graphs, features, labels, divisions and other basic information of data sets (such as the number of categories)
The best practices for implementing a centralized approach are described below
1.1 download original data
The DGLDataset.download() method is used to download the original data from the URL specified by self.url and save it to self.raw_dir directory
- DGL provides an auxiliary function dgl.data.utils.download() to download files from the specified URL
- Raw of DGLDataset_ The dir attribute is the original data download directory, if raw is specified in the constructor_ The dir parameter uses the specified directory, or the environment variable dgl if not specified_ DOWNLOAD_ The directory specified by dir. If the environment variable does not exist, it defaults to ~ /. dgl
- Raw of DGLDataset_ The path attribute is os.path.join(self.raw_dir, self.name) (it can also be overridden in subclasses), which can be used for the decompression directory of the original data (if the original data is a zip file)
Example:
def download(self): zip_file_path = os.path.join(self.raw_dir, 'my_dataset.zip') download(self.url, path=zip_file_path) extract_archive(zip_file_path, self.raw_path)
1.2 data processing
The DGLDataset.process() method is used to convert self.raw_dir or self.raw_ The original data in path is processed into DGLGraph format, which generally includes the steps of reading original data, data cleaning, construction drawing, reading vertex features and labels, and dividing data sets. The specific logic depends on the format of the original data (possibly pkl, npz, mat, csv, txt, etc.), which is the main part (and the most troublesome part) that needs to be implemented by itself
The basic framework is as follows:
def process(self): data = _read_raw_data(self.raw_path) data = _clean(data) g = dgl.graph(...) g.ndata['feat'] = ... g.ndata['label'] = ... g.ndata['train_mask'] = ... g.ndata['val_mask'] = ... g.ndata['test_mask'] = ... self.g = g
Note:
- _ read_raw_data() and_ clean() is a logic that needs to be implemented by itself to read raw data and clean data
- This example has only one isomorphic graph. In fact, it may also be heterogeneous, or it may contain multiple graphs (such as graph classification dataset)
1.3 saving and loading data
DGL recommends implementing the save() and load() methods of data sets to cache the preprocessed data to the disk, which can be loaded directly from the disk when it is used next time. There is no need to execute the process() method, has_ The cache () method returns whether there is cached processed data on the disk
DGL provides four functions:
- dgl.save_graphs() and dgl.load_graphs() is used to save / read DGLGraph objects to / from disk
- dgl.data.utils.save_info() and dgl.data.utils.load_info() is used to save / read information about the dataset to / from the disk (actually pickle.dump() and pickle.load())
Save path:
- Save of DGLDataset_ The dir attribute is the directory where the processed data is saved. If save is specified in the constructor_ The dir parameter uses the specified directory, otherwise it defaults to raw_dir
- Save of DGLDataset_ The path attribute is os.path.join(self.save_dir, self.name) (it can also be overridden in subclasses). Generally, the processed data is saved to save_path directory
Typical usage:
(1) Vertex classification dataset (only one graph)
def save(self): # save graphs and labels graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') save_graphs(graph_path, [self.g]) def load(self): # load processed data from directory `self.save_path` graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') graphs, _ = load_graphs(graph_path) self.labels = label_dict['labels'] self.g = graphs[0] def has_cache(self): # check whether there are processed data in `self.save_path` graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') return os.path.exists(graph_path)
(2) Graph classification dataset (containing multiple graphs)
def save(self): # save graphs and labels graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') save_graphs(graph_path, self.graphs, {'labels': self.labels}) # save other information in python dict info_path = os.path.join(self.save_path, self.name + '_info.pkl') save_info(info_path, {'num_classes': self.num_classes}) def load(self): # load processed data from directory `self.save_path` graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') self.graphs, label_dict = load_graphs(graph_path) self.labels = label_dict['labels'] info_path = os.path.join(self.save_path, self.name + '_info.pkl') self.num_classes = load_info(info_path)['num_classes'] def has_cache(self): # check whether there are processed data in `self.save_path` graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') info_path = os.path.join(self.save_path, self.name + '_info.pkl') return os.path.exists(graph_path) and os.path.exists(info_path)
2. Use graph dataset
2.1 figure classification dataset
The graph classification dataset is similar to the traditional machine learning dataset, which contains a set of samples and corresponding labels, but each sample is a dgl.DGLGraph, the label is a tensor, and the features of the samples are saved in different vertex features or edge features
The following is an example of QM7b dataset
Create dataset
>>> from dgl.data import QM7bDataset >>> qm7b = QM7bDataset() # The data set will be downloaded first >>> len(qm7b) 7211 >>> qm7b.num_labels 14 >>> g, label = qm7b[0] >>> g Graph(num_nodes=5, num_edges=25, ndata_schemes={} edata_schemes={'h': Scheme(shape=(1,), dtype=torch.float32)}) >>> g.edata {'h': tensor([[36.8581], [ 2.8961], ... [ 0.5000]])} >>> label tensor([-4.2093e+02, 3.9695e+01, 6.2184e-01, -1.6013e+01, 4.1620e+00, 3.6768e+01, 1.5725e+01, -3.9861e+00, -1.0949e+01, 1.3230e-01, -1.4134e+01, 1.0870e+00, 2.5346e+00, 2.4322e+00])
It can be seen that there are 7211 samples in the data set, and each sample has 14 labels (corresponding to 14 prediction tasks). The first sample graph has 5 vertices, 25 edges, an edge feature named h, and the dimension is 1
Traversal dataset
You can use PyTorch's DataLoader to traverse the dataset
from torch.utils.data import DataLoader # load data dataset = QM7bDataset() num_labels = dataset.num_labels # create collate_fn def _collate_fn(batch): graphs, labels = batch g = dgl.batch(graphs) labels = torch.tensor(labels, dtype=torch.long) return g, labels # create dataloaders dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=_collate_fn) # training for epoch in range(100): for g, labels in dataloader: # your training code here pass
2.2 vertex classification dataset
Vertex classification is usually carried out on only one graph, so this kind of data set has only one graph, and the sample features and labels are saved in the vertex features
Taking the Citeseer dataset as an example, the dataset contains a graph with 3327 vertices and 9228 edges. The mask of feature, label, training set, verification set and test set are respectively in the vertex feature feat, label and train_ mask, val_ mask, test_ In the mask, the vertex features are 3703 dimensions and 6 categories (the label range is [0, 5])
>>> from dgl.data import CiteseerGraphDataset >>> citeseer = CiteseerGraphDataset() >>> len(citeseer) 1 >>> citeseer.num_classes 6 >>> g = citeseer[0] >>> g Graph(num_nodes=3327, num_edges=9228, ndata_schemes={'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'label': Scheme(shape=(), dtype=torch.int64), 'feat': Scheme(shape=(3703,), dtype=torch.float32)} edata_schemes={}) >>> g.ndata['feat'].shape torch.Size([3327, 3703]) >>> g.ndata['label'].shape torch.Size([3327]) >>> g.ndata['label'][:10] tensor([3, 1, 5, 5, 3, 1, 3, 0, 3, 5]) >>> train_idx = torch.nonzero(g.ndata['train_mask']).squeeze() >>> train_set = g.ndata['feat'][train_idx] >>> train_set.shape torch.Size([120, 3703])
2.3 connecting forecast data sets
The connection prediction dataset is similar to the vertex classification dataset and has only one graph, but the training set, verification set and test set are masked in the edge feature. This kind of dataset has several subclasses of dgl.data.KnowledgeGraphDataset
2.4 OGB dataset
Open Graph Benchmark (OGB): https://ogb.stanford.edu/docs/home/