GraphRNN
GraphRNN copied to clipboard
questions in main.py
hi, I have some questions in the main.py.
-
you save the whole graphs as training data ad test test data?
To get train and test set, after loading you need to manually slice
save_graph_list(graphs, args.graph_save_path + args.fname_train + '0.dat') save_graph_list(graphs, args.graph_save_path + args.fname_test + '0.dat') print('train and test graphs saved at: ', args.graph_save_path + args.fname_test + '0.dat')
-
and here should I assign the args.max_prev_node?
dataset initialization
if 'nobfs' in args.note:
print('nobfs')
dataset = Graph_sequence_sampler_pytorch_nobfs(graphs_train, max_num_node=args.max_num_node)
args.max_prev_node = args.max_num_node-1
if 'barabasi_noise' in args.graph_type:
print('barabasi_noise')
dataset = Graph_sequence_sampler_pytorch_canonical(graphs_train,max_prev_node=args.max_prev_node)
args.max_prev_node = args.max_num_node - 1
else:
dataset = Graph_sequence_sampler_pytorch(graphs_train,max_prev_node=args.max_prev_node,max_num_node=args.max_num_node)
sample_strategy = torch.utils.data.sampler.WeightedRandomSampler([1.0 / len(dataset) for i in range(len(dataset))],
num_samples=args.batch_size*args.batch_ratio, replacement=True)
or I would get an error like "TypeError: new(): argument 'size' must be tuple of ints, but found element of type NoneType at pos 2" here
elif 'GraphRNN_RNN' in args.note:
rnn = GRU_plain(input_size=args.max_prev_node, embedding_size=args.embedding_size_rnn,
hidden_size=args.hidden_size_rnn, num_layers=args.num_layers, has_input=True,
has_output=True, output_size=args.hidden_size_rnn_output)
output = GRU_plain(input_size=1, embedding_size=args.embedding_size_rnn_output,
hidden_size=args.hidden_size_rnn_output, num_layers=args.num_layers, has_input=True,
has_output=True, output_size=1)
looking forward to you reply =]