import logging
import os.path as osp
import time
import math
from functools import partial
import numpy as np
import torch
import torch_geometric.transforms as T
from numpy.random import default_rng
from ogb.nodeproppred import PygNodePropPredDataset
from ogb.graphproppred import PygGraphPropPredDataset
from torch_geometric.datasets import (Actor, Planetoid, Amazon, Coauthor, GNNBenchmarkDataset, TUDataset,
WebKB, WikipediaNetwork, ZINC)
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loader import load_pyg, load_ogb, set_dataset_attr
from torch_geometric.graphgym.register import register_loader
from opengt.loader.dataset.aqsol_molecules import AQSOL
from opengt.loader.dataset.coco_superpixels import COCOSuperpixels
from opengt.loader.dataset.malnet_tiny import MalNetTiny
from opengt.loader.dataset.voc_superpixels import VOCSuperpixels
from opengt.loader.split_generator import (prepare_splits,
set_dataset_splits)
from opengt.transform.posenc_stats import compute_posenc_stats
from opengt.transform.transforms import (pre_transform_in_memory,
generate_splits, # not the same as split_generator above.
typecast_x, concat_x_and_pos,
clip_graphs_to_size, move_node_feat_to_x)
from opengt.transform.expander_edges import generate_random_expander
from opengt.transform.dist_transforms import (add_dist_features, add_reverse_edges,
add_self_loops, effective_resistances,
effective_resistance_embedding,
effective_resistances_from_embedding)
from opengt.transform.multihop_prep import generate_multihop_adj
from opengt.transform.graph_partition import GraphPartitionTransform
def log_loaded_dataset(dataset, format, name):
logging.info(f"[*] Loaded dataset '{name}' from '{format}':")
logging.info(f" {dataset.data}")
# logging.info(f" undirected: {dataset[0].is_undirected()}")
logging.info(f" num graphs: {len(dataset)}")
total_num_nodes = 0
if hasattr(dataset.data, 'num_nodes'):
total_num_nodes = dataset.data.num_nodes
elif hasattr(dataset.data, 'x'):
total_num_nodes = dataset.data.x.size(0)
logging.info(f" avg num_nodes/graph: "
f"{total_num_nodes // len(dataset)}")
logging.info(f" num node features: {dataset.num_node_features}")
logging.info(f" num edge features: {dataset.num_edge_features}")
if hasattr(dataset, 'num_tasks'):
logging.info(f" num tasks: {dataset.num_tasks}")
if hasattr(dataset.data, 'y') and dataset.data.y is not None:
if isinstance(dataset.data.y, list):
# A special case for ogbg-code2 dataset.
logging.info(f" num classes: n/a")
elif dataset.data.y.numel() == dataset.data.y.size(0) and \
torch.is_floating_point(dataset.data.y):
logging.info(f" num classes: (appears to be a regression task)")
else:
logging.info(f" num classes: {dataset.num_classes}")
elif hasattr(dataset.data, 'train_edge_label') or hasattr(dataset.data, 'edge_label'):
# Edge/link prediction task.
if hasattr(dataset.data, 'train_edge_label'):
labels = dataset.data.train_edge_label # Transductive link task
else:
labels = dataset.data.edge_label # Inductive link task
if labels.numel() == labels.size(0) and \
torch.is_floating_point(labels):
logging.info(f" num edge classes: (probably a regression task)")
else:
logging.info(f" num edge classes: {len(torch.unique(labels))}")
## Show distribution of graph sizes.
# graph_sizes = [d.num_nodes if hasattr(d, 'num_nodes') else d.x.shape[0]
# for d in dataset]
# hist, bin_edges = np.histogram(np.array(graph_sizes), bins=10)
# logging.info(f' Graph size distribution:')
# logging.info(f' mean: {np.mean(graph_sizes)}')
# for i, (start, end) in enumerate(zip(bin_edges[:-1], bin_edges[1:])):
# logging.info(
# f' bin {i}: [{start:.2f}, {end:.2f}]: '
# f'{hist[i]} ({hist[i] / hist.sum() * 100:.2f}%)'
# )
[docs]
@register_loader('custom_master_loader')
def load_dataset_master(format, name, dataset_dir):
"""
Master loader that controls loading of all datasets, overshadowing execution
of any default GraphGym dataset loader. Default GraphGym dataset loader are
instead called from this function, the format keywords `PyG` and `OGB` are
reserved for these default GraphGym loaders.
Custom transforms and dataset splitting is applied to each loaded dataset.
Args:
format: dataset format name that identifies Dataset class
name: dataset name to select from the class identified by `format`
dataset_dir: path where to store the processed dataset
Returns:
PyG dataset object with applied perturbation transforms and data splits
"""
if format.startswith('PyG-'):
pyg_dataset_id = format.split('-', 1)[1]
dataset_dir = osp.join(dataset_dir, pyg_dataset_id)
if pyg_dataset_id == 'Actor':
if name != 'none':
raise ValueError(f"Actor class provides only one dataset.")
dataset = Actor(dataset_dir)
elif pyg_dataset_id == 'GNNBenchmarkDataset':
dataset = preformat_GNNBenchmarkDataset(dataset_dir, name)
elif pyg_dataset_id == 'MalNetTiny':
dataset = preformat_MalNetTiny(dataset_dir, feature_set=name)
elif pyg_dataset_id == 'Amazon':
dataset = Amazon(dataset_dir, name)
if name == "photo" or name == "computers":
pre_transform_in_memory(dataset, partial(generate_splits, g_split = cfg.dataset.split[0]))
pre_transform_in_memory(dataset, partial(add_reverse_edges))
if cfg.prep.add_self_loops:
pre_transform_in_memory(dataset, partial(add_self_loops))
elif pyg_dataset_id == 'Coauthor':
dataset = Coauthor(dataset_dir, name)
if name == "physics" or name == "cs":
pre_transform_in_memory(dataset, partial(generate_splits, g_split = cfg.dataset.split[0]))
pre_transform_in_memory(dataset, partial(add_reverse_edges))
if cfg.prep.add_self_loops:
pre_transform_in_memory(dataset, partial(add_self_loops))
elif pyg_dataset_id == 'Planetoid':
dataset = Planetoid(dataset_dir, name)
elif pyg_dataset_id == 'TUDataset':
dataset = preformat_TUDataset(dataset_dir, name)
elif pyg_dataset_id == 'WebKB':
dataset = WebKB(dataset_dir, name)
elif pyg_dataset_id == 'VOCSuperpixels':
dataset = preformat_VOCSuperpixels(dataset_dir, name,
cfg.dataset.slic_compactness)
elif pyg_dataset_id == 'COCOSuperpixels':
dataset = preformat_COCOSuperpixels(dataset_dir, name,
cfg.dataset.slic_compactness)
elif pyg_dataset_id == 'WikipediaNetwork':
if name == 'crocodile':
raise NotImplementedError(f"crocodile not implemented yet")
dataset = WikipediaNetwork(dataset_dir, name)
elif pyg_dataset_id == 'ZINC':
dataset = preformat_ZINC(dataset_dir, name)
elif pyg_dataset_id == 'AQSOL':
dataset = preformat_AQSOL(dataset_dir, name)
else:
raise ValueError(f"Unexpected PyG Dataset identifier: {format}")
# GraphGym default loader for Pytorch Geometric datasets
elif format == 'PyG':
dataset = load_pyg(name, dataset_dir)
elif format == 'OGB':
if name.startswith('ogbg'):
dataset = preformat_OGB_Graph(dataset_dir, name.replace('_', '-'))
elif name.startswith('PCQM4Mv2-'):
subset = name.split('-', 1)[1]
dataset = preformat_OGB_PCQM4Mv2(dataset_dir, subset)
elif name.startswith('ogbn'):
dataset = preformat_ogbn(dataset_dir, name)
elif name.startswith('peptides-'):
dataset = preformat_Peptides(dataset_dir, name)
### Link prediction datasets.
elif name.startswith('ogbl-'):
# GraphGym default loader.
dataset = load_ogb(name, dataset_dir)
# OGB link prediction datasets are binary classification tasks,
# however the default loader creates float labels => convert to int.
def convert_to_int(ds, prop):
tmp = getattr(ds.data, prop).int()
set_dataset_attr(ds, prop, tmp, len(tmp))
convert_to_int(dataset, 'train_edge_label')
convert_to_int(dataset, 'val_edge_label')
convert_to_int(dataset, 'test_edge_label')
elif name.startswith('PCQM4Mv2Contact-'):
dataset = preformat_PCQM4Mv2Contact(dataset_dir, name)
else:
raise ValueError(f"Unsupported OGB(-derived) dataset: {name}")
else:
raise ValueError(f"Unknown data format: {format}")
log_loaded_dataset(dataset, format, name)
# Precompute necessary statistics for positional encodings.
pe_enabled_list = []
for key, pecfg in cfg.items():
if key.startswith('posenc_') and pecfg.enable and (not key.startswith('posenc_ER')):
pe_name = key.split('_', 1)[1]
pe_enabled_list.append(pe_name)
if hasattr(pecfg, 'kernel'):
# Generate kernel times if functional snippet is set.
if pecfg.kernel.times_func:
pecfg.kernel.times = list(eval(pecfg.kernel.times_func))
logging.info(f"Parsed {pe_name} PE kernel times / steps: "
f"{pecfg.kernel.times}")
if pe_enabled_list:
start = time.perf_counter()
logging.info(f"Precomputing Positional Encoding statistics: "
f"{pe_enabled_list} for all graphs...")
# Estimate directedness based on 10 graphs to save time.
is_undirected = all(d.is_undirected() for d in dataset[:10])
logging.info(f" ...estimated to be undirected: {is_undirected}")
pre_transform_in_memory(dataset,
partial(compute_posenc_stats,
pe_types=pe_enabled_list,
is_undirected=is_undirected,
cfg=cfg),
show_progress=True
)
elapsed = time.perf_counter() - start
timestr = time.strftime('%H:%M:%S', time.gmtime(elapsed)) \
+ f'{elapsed:.2f}'[-3:]
logging.info(f"Done! Took {timestr}")
# Other preprocessings:
# adding expander edges:
if cfg.prep.exp:
for j in range(cfg.prep.exp_count):
start = time.perf_counter()
logging.info(f"Adding expander edges (round {j}) ...")
pre_transform_in_memory(dataset,
partial(generate_random_expander,
degree = cfg.prep.exp_deg,
algorithm = cfg.prep.exp_algorithm,
rng = None,
max_num_iters = cfg.prep.exp_max_num_iters,
exp_index = j),
show_progress=True
)
elapsed = time.perf_counter() - start
timestr = time.strftime('%H:%M:%S', time.gmtime(elapsed)) \
+ f'{elapsed:.2f}'[-3:]
logging.info(f"Done! Took {timestr}")
# adding shortest path features
if cfg.prep.dist_enable:
start = time.perf_counter()
logging.info(f"Precalculating node distances and shortest paths ...")
is_undirected = dataset[0].is_undirected()
Max_N = max([data.num_nodes for data in dataset])
pre_transform_in_memory(dataset,
partial(add_dist_features,
max_n = Max_N,
is_undirected = is_undirected,
cutoff = cfg.prep.dist_cutoff),
show_progress=True
)
elapsed = time.perf_counter() - start
timestr = time.strftime('%H:%M:%S', time.gmtime(elapsed)) \
+ f'{elapsed:.2f}'[-3:]
logging.info(f"Done! Took {timestr}")
if cfg.prep.rb_order > 1:
start = time.perf_counter()
logging.info(f"Generating multi-hop adjacency matrices ...")
pre_transform_in_memory(dataset,
partial(generate_multihop_adj,
cfg = cfg),
show_progress=True
)
elapsed = time.perf_counter() - start
timestr = time.strftime('%H:%M:%S', time.gmtime(elapsed)) \
+ f'{elapsed:.2f}'[-3:]
logging.info(f"Done! Took {timestr}")
# adding effective resistance features
if cfg.posenc_ERN.enable or cfg.posenc_ERE.enable:
start = time.perf_counter()
logging.info(f"Precalculating effective resistance for graphs ...")
MaxK = max(
[
min(
math.ceil(data.num_nodes//2),
math.ceil(8 * math.log(data.num_edges) / (cfg.posenc_ERN.accuracy**2))
)
for data in dataset
]
)
cfg.posenc_ERN.er_dim = MaxK
logging.info(f"Choosing ER pos enc dim = {MaxK}")
pre_transform_in_memory(dataset,
partial(effective_resistance_embedding,
MaxK = MaxK,
accuracy = cfg.posenc_ERN.accuracy,
which_method = 0),
show_progress=True
)
pre_transform_in_memory(dataset,
partial(effective_resistances_from_embedding,
normalize_per_node = False),
show_progress=True
)
elapsed = time.perf_counter() - start
timestr = time.strftime('%H:%M:%S', time.gmtime(elapsed)) \
+ f'{elapsed:.2f}'[-3:]
logging.info(f"Done! Took {timestr}")
# graph partition transform
if cfg.metis.patches > 0:
start = time.perf_counter()
logging.info(f"Precomputing graph partition transform ...")
pre_transform_in_memory(dataset, GraphPartitionTransform(n_patches=cfg.metis.patches,
metis=cfg.metis.enable,
drop_rate=cfg.metis.drop_rate,
num_hops=cfg.metis.num_hops,
is_directed=False,
patch_rw_dim=cfg.metis.patch_rw_dim,
patch_num_diff=cfg.metis.patch_num_diff),
show_progress=True)
elapsed = time.perf_counter() - start
timestr = time.strftime('%H:%M:%S', time.gmtime(elapsed)) \
+ f'{elapsed:.2f}'[-3:]
logging.info(f"Done! Took {timestr}")
dataset.data['extra_loss'] = torch.Tensor([0.0])
# This could not be done earlier because the training wants 'train_mask' etc.
# Now after using gnn.head: inductive_node this is ok.
if name == 'ogbn-arxiv' or name == 'ogbn-proteins':
return dataset
# Set standard dataset train/val/test splits
if hasattr(dataset, 'split_idxs'):
set_dataset_splits(dataset, dataset.split_idxs)
delattr(dataset, 'split_idxs')
# Verify or generate dataset train/val/test splits
prepare_splits(dataset)
# Precompute in-degree histogram if needed for PNAConv.
if cfg.gt.layer_type.startswith('PNAConv') and len(cfg.gt.pna_degrees) == 0:
cfg.gt.pna_degrees = compute_indegree_histogram(
dataset[dataset.data['train_graph_index']])
return dataset
def compute_indegree_histogram(dataset):
"""Compute histogram of in-degree of nodes needed for PNAConv.
Args:
dataset: PyG Dataset object
Returns:
List where i-th value is the number of nodes with in-degree equal to `i`
"""
from torch_geometric.utils import degree
deg = torch.zeros(1000, dtype=torch.long)
max_degree = 0
for data in dataset:
d = degree(data.edge_index[1],
num_nodes=data.num_nodes, dtype=torch.long)
max_degree = max(max_degree, d.max().item())
deg += torch.bincount(d, minlength=deg.numel())
return deg.numpy().tolist()[:max_degree + 1]
def preformat_GNNBenchmarkDataset(dataset_dir, name):
"""Load and preformat datasets from PyG's GNNBenchmarkDataset.
Args:
dataset_dir: path where to store the cached dataset
name: name of the specific dataset in the TUDataset class
Returns:
PyG dataset object
"""
tf_list = []
if name in ['MNIST', 'CIFAR10']:
tf_list = [concat_x_and_pos] # concat pixel value and pos. coordinate
tf_list.append(partial(typecast_x, type_str='float'))
else:
ValueError(f"Loading dataset '{name}' from "
f"GNNBenchmarkDataset is not supported.")
dataset = join_dataset_splits(
[GNNBenchmarkDataset(root=dataset_dir, name=name, split=split)
for split in ['train', 'val', 'test']]
)
pre_transform_in_memory(dataset, T.Compose(tf_list))
return dataset
def preformat_MalNetTiny(dataset_dir, feature_set):
"""Load and preformat Tiny version (5k graphs) of MalNet
Args:
dataset_dir: path where to store the cached dataset
feature_set: select what node features to precompute as MalNet
originally doesn't have any node nor edge features
Returns:
PyG dataset object
"""
if feature_set in ['none', 'Constant']:
tf = T.Constant()
elif feature_set == 'OneHotDegree':
tf = T.OneHotDegree()
elif feature_set == 'LocalDegreeProfile':
tf = T.LocalDegreeProfile()
else:
raise ValueError(f"Unexpected transform function: {feature_set}")
dataset = MalNetTiny(dataset_dir)
dataset.name = 'MalNetTiny'
logging.info(f'Computing "{feature_set}" node features for MalNetTiny.')
pre_transform_in_memory(dataset, tf)
split_dict = dataset.get_idx_split()
dataset.split_idxs = [split_dict['train'],
split_dict['valid'],
split_dict['test']]
return dataset
def preformat_OGB_Graph(dataset_dir, name):
"""Load and preformat OGB Graph Property Prediction datasets.
Args:
dataset_dir: path where to store the cached dataset
name: name of the specific OGB Graph dataset
Returns:
PyG dataset object
"""
dataset = PygGraphPropPredDataset(name=name, root=dataset_dir)
s_dict = dataset.get_idx_split()
dataset.split_idxs = [s_dict[s] for s in ['train', 'valid', 'test']]
if name == 'ogbg-ppa':
# ogbg-ppa doesn't have any node features, therefore add zeros but do
# so dynamically as a 'transform' and not as a cached 'pre-transform'
# because the dataset is big (~38.5M nodes), already taking ~31GB space
def add_zeros(data):
data.x = torch.zeros(data.num_nodes, dtype=torch.long)
return data
dataset.transform = add_zeros
elif name == 'ogbg-code2':
from opengt.loader.ogbg_code2_utils import idx2vocab, \
get_vocab_mapping, augment_edge, encode_y_to_arr
num_vocab = 5000 # The number of vocabulary used for sequence prediction
max_seq_len = 5 # The maximum sequence length to predict
seq_len_list = np.array([len(seq) for seq in dataset.data.y])
logging.info(f"Target sequences less or equal to {max_seq_len} is "
f"{np.sum(seq_len_list <= max_seq_len) / len(seq_len_list)}")
# Building vocabulary for sequence prediction. Only use training data.
vocab2idx, idx2vocab_local = get_vocab_mapping(
[dataset.data.y[i] for i in s_dict['train']], num_vocab)
logging.info(f"Final size of vocabulary is {len(vocab2idx)}")
idx2vocab.extend(idx2vocab_local) # Set to global variable to later access in CustomLogger
# Set the transform function:
# augment_edge: add next-token edge as well as inverse edges. add edge attributes.
# encode_y_to_arr: add y_arr to PyG data object, indicating the array repres
dataset.transform = T.Compose(
[augment_edge,
lambda data: encode_y_to_arr(data, vocab2idx, max_seq_len)])
# Subset graphs to a maximum size (number of nodes) limit.
pre_transform_in_memory(dataset, partial(clip_graphs_to_size,
size_limit=1000))
return dataset
def preformat_OGB_PCQM4Mv2(dataset_dir, name):
"""Load and preformat PCQM4Mv2 from OGB LSC.
OGB-LSC provides 4 data index splits:
2 with labeled molecules: 'train', 'valid' meant for training and dev
2 unlabeled: 'test-dev', 'test-challenge' for the LSC challenge submission
We will take random 150k from 'train' and make it a validation set and
use the original 'valid' as our testing set.
Note: PygPCQM4Mv2Dataset requires rdkit
Args:
dataset_dir: path where to store the cached dataset
name: select 'subset' or 'full' version of the training set
Returns:
PyG dataset object
"""
try:
# Load locally to avoid RDKit dependency until necessary.
from ogb.lsc import PygPCQM4Mv2Dataset
except Exception as e:
logging.error('ERROR: Failed to import PygPCQM4Mv2Dataset, '
'make sure RDKit is installed.')
raise e
dataset = PygPCQM4Mv2Dataset(root=dataset_dir)
split_idx = dataset.get_idx_split()
rng = default_rng(seed=42)
train_idx = rng.permutation(split_idx['train'].numpy())
train_idx = torch.from_numpy(train_idx)
# Leave out 150k graphs for a new validation set.
valid_idx, train_idx = train_idx[:150000], train_idx[150000:]
if name == 'full':
split_idxs = [train_idx, # Subset of original 'train'.
valid_idx, # Subset of original 'train' as validation set.
split_idx['valid'] # The original 'valid' as testing set.
]
elif name == 'subset':
# Further subset the training set for faster debugging.
subset_ratio = 0.1
subtrain_idx = train_idx[:int(subset_ratio * len(train_idx))]
subvalid_idx = valid_idx[:50000]
subtest_idx = split_idx['valid'] # The original 'valid' as testing set.
dataset = dataset[torch.cat([subtrain_idx, subvalid_idx, subtest_idx])]
n1, n2, n3 = len(subtrain_idx), len(subvalid_idx), len(subtest_idx)
split_idxs = [list(range(n1)),
list(range(n1, n1 + n2)),
list(range(n1 + n2, n1 + n2 + n3))]
else:
raise ValueError(f'Unexpected OGB PCQM4Mv2 subset choice: {name}')
dataset.split_idxs = split_idxs
return dataset
def preformat_PCQM4Mv2Contact(dataset_dir, name):
"""Load PCQM4Mv2-derived molecular contact link prediction dataset.
Note: This dataset requires RDKit dependency!
Args:
dataset_dir: path where to store the cached dataset
name: the type of dataset split: 'shuffle', 'num-atoms'
Returns:
PyG dataset object
"""
try:
# Load locally to avoid RDKit dependency until necessary
from opengt.loader.dataset.pcqm4mv2_contact import \
PygPCQM4Mv2ContactDataset, \
structured_neg_sampling_transform
except Exception as e:
logging.error('ERROR: Failed to import PygPCQM4Mv2ContactDataset, '
'make sure RDKit is installed.')
raise e
split_name = name.split('-', 1)[1]
dataset = PygPCQM4Mv2ContactDataset(dataset_dir, subset='530k')
# Inductive graph-level split (there is no train/test edge split).
s_dict = dataset.get_idx_split(split_name)
dataset.split_idxs = [s_dict[s] for s in ['train', 'val', 'test']]
if cfg.dataset.resample_negative:
dataset.transform = structured_neg_sampling_transform
return dataset
def preformat_Peptides(dataset_dir, name):
"""Load Peptides dataset, functional or structural.
Note: This dataset requires RDKit dependency!
Args:
dataset_dir: path where to store the cached dataset
name: the type of dataset split:
- 'peptides-functional' (10-task classification)
- 'peptides-structural' (11-task regression)
Returns:
PyG dataset object
"""
try:
# Load locally to avoid RDKit dependency until necessary.
from opengt.loader.dataset.peptides_functional import \
PeptidesFunctionalDataset
from opengt.loader.dataset.peptides_structural import \
PeptidesStructuralDataset
except Exception as e:
logging.error('ERROR: Failed to import Peptides dataset class, '
'make sure RDKit is installed.')
raise e
dataset_type = name.split('-', 1)[1]
if dataset_type == 'functional':
dataset = PeptidesFunctionalDataset(dataset_dir)
elif dataset_type == 'structural':
dataset = PeptidesStructuralDataset(dataset_dir)
s_dict = dataset.get_idx_split()
dataset.split_idxs = [s_dict[s] for s in ['train', 'val', 'test']]
return dataset
def preformat_TUDataset(dataset_dir, name):
"""Load and preformat datasets from PyG's TUDataset.
Args:
dataset_dir: path where to store the cached dataset
name: name of the specific dataset in the TUDataset class
Returns:
PyG dataset object
"""
if name in ['DD', 'NCI1', 'ENZYMES', 'PROTEINS']:
func = None
elif name.startswith('IMDB-') or name == "COLLAB":
func = T.Constant()
else:
ValueError(f"Loading dataset '{name}' from TUDataset is not supported.")
dataset = TUDataset(dataset_dir, name, pre_transform=func)
return dataset
def preformat_ogbn(dataset_dir, name):
if name == 'ogbn-arxiv' or name == 'ogbn-proteins':
dataset = PygNodePropPredDataset(name=name)
if name == 'ogbn-arxiv':
pre_transform_in_memory(dataset, partial(add_reverse_edges))
if cfg.prep.add_self_loops:
pre_transform_in_memory(dataset, partial(add_self_loops))
if name == 'ogbn-proteins':
pre_transform_in_memory(dataset, partial(move_node_feat_to_x))
pre_transform_in_memory(dataset, partial(typecast_x, type_str='float'))
split_dict = dataset.get_idx_split()
split_dict['val'] = split_dict.pop('valid')
dataset.split_idx = split_dict
return dataset
# We do not need to store these separately.
# storing separatelymight simplify the duplicated logger code in main.py
# s_dict = dataset.get_idx_split()
# dataset.split_idxs = [s_dict[s] for s in ['train', 'valid', 'test']]
# convert the adjacency list to an edge_index list.
# data = dataset[0]
# coo = data.adj_t.coo()
# data is only a deep copy. Need to write to the dataset object itself.
# dataset[0].edge_index = torch.stack(coo[:2])
# del dataset[0]['adj_t'] # remove the adjacency list after the edge_index is created.
# return dataset
else:
ValueError(f"Unknown ogbn dataset '{name}'.")
def preformat_ZINC(dataset_dir, name):
"""Load and preformat ZINC datasets.
Args:
dataset_dir: path where to store the cached dataset
name: select 'subset' or 'full' version of ZINC
Returns:
PyG dataset object
"""
if name not in ['subset', 'full']:
raise ValueError(f"Unexpected subset choice for ZINC dataset: {name}")
dataset = join_dataset_splits(
[ZINC(root=dataset_dir, subset=(name == 'subset'), split=split)
for split in ['train', 'val', 'test']]
)
return dataset
def preformat_AQSOL(dataset_dir):
"""Load and preformat AQSOL datasets.
Args:
dataset_dir: path where to store the cached dataset
Returns:
PyG dataset object
"""
dataset = join_dataset_splits(
[AQSOL(root=dataset_dir, split=split)
for split in ['train', 'val', 'test']]
)
return dataset
def preformat_VOCSuperpixels(dataset_dir, name, slic_compactness):
"""Load and preformat VOCSuperpixels dataset.
Args:
dataset_dir: path where to store the cached dataset
Returns:
PyG dataset object
"""
dataset = join_dataset_splits(
[VOCSuperpixels(root=dataset_dir, name=name,
slic_compactness=slic_compactness,
split=split)
for split in ['train', 'val', 'test']]
)
return dataset
def preformat_COCOSuperpixels(dataset_dir, name, slic_compactness):
"""Load and preformat COCOSuperpixels dataset.
Args:
dataset_dir: path where to store the cached dataset
Returns:
PyG dataset object
"""
dataset = join_dataset_splits(
[COCOSuperpixels(root=dataset_dir, name=name,
slic_compactness=slic_compactness,
split=split)
for split in ['train', 'val', 'test']]
)
return dataset
def join_dataset_splits(datasets):
"""Join train, val, test datasets into one dataset object.
Args:
datasets: list of 3 PyG datasets to merge
Returns:
joint dataset with `split_idxs` property storing the split indices
"""
assert len(datasets) == 3, "Expecting train, val, test datasets"
n1, n2, n3 = len(datasets[0]), len(datasets[1]), len(datasets[2])
data_list = [datasets[0].get(i) for i in range(n1)] + \
[datasets[1].get(i) for i in range(n2)] + \
[datasets[2].get(i) for i in range(n3)]
datasets[0]._indices = None
datasets[0]._data_list = data_list
datasets[0].data, datasets[0].slices = datasets[0].collate(data_list)
split_idxs = [list(range(n1)),
list(range(n1, n1 + n2)),
list(range(n1 + n2, n1 + n2 + n3))]
datasets[0].split_idxs = split_idxs
return datasets[0]