forked from alibaba/GraphScope
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feat] Implement the pipeline for loading vineyard graph as graphlear…
…n_torch dataset and training on a single machine (alibaba#3156) This pr introduces the following changes: 1. A new graphlearn_torch API for session. 2. Add the script to launch the graphlearn_torch server with handle and config. 3. Include an example of graphsage node classification with the ogbn-arxiv dataset in GraphScope on a single machine. alibaba#3157
- Loading branch information
1 parent
5fb62ef
commit 0a9f03b
Showing
2 changed files
with
157 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import time | ||
|
||
import graphscope as gs | ||
import graphscope.learning.graphlearn_torch as glt | ||
import torch | ||
import torch.nn.functional as F | ||
from graphscope.dataset import load_ogbn_arxiv | ||
from graphscope.learning.graphlearn_torch.typing import Split | ||
from ogb.nodeproppred import Evaluator | ||
from torch_geometric.nn import GraphSAGE | ||
|
||
|
||
@torch.no_grad() | ||
def test(model, test_loader, dataset_name): | ||
evaluator = Evaluator(name=dataset_name) | ||
model.eval() | ||
xs = [] | ||
y_true = [] | ||
for i, batch in enumerate(test_loader): | ||
if i == 0: | ||
device = batch.x.device | ||
batch.x = batch.x.to(torch.float32) # TODO | ||
x = model(batch.x, batch.edge_index)[: batch.batch_size] | ||
xs.append(x.cpu()) | ||
y_true.append(batch.y[: batch.batch_size].cpu()) | ||
del batch | ||
|
||
xs = [t.to(device) for t in xs] | ||
y_true = [t.to(device) for t in y_true] | ||
y_pred = torch.cat(xs, dim=0).argmax(dim=-1, keepdim=True) | ||
y_true = torch.cat(y_true, dim=0).unsqueeze(-1) | ||
test_acc = evaluator.eval( | ||
{ | ||
"y_true": y_true, | ||
"y_pred": y_pred, | ||
} | ||
)["acc"] | ||
return test_acc | ||
|
||
|
||
gs.set_option(show_log=True) | ||
|
||
# load the ogbn_arxiv graph as an example. | ||
g = load_ogbn_arxiv() | ||
glt_graph = gs.graphlearn_torch( | ||
g, | ||
edges=[ | ||
("paper", "citation", "paper"), | ||
], | ||
node_features={ | ||
"paper": [f"feat_{i}" for i in range(128)], | ||
}, | ||
node_labels={ | ||
"paper": "label", | ||
}, | ||
edge_dir="out", | ||
random_node_split={ | ||
"num_val": 0.1, | ||
"num_test": 0.1, | ||
}, | ||
) | ||
|
||
epochs = 10 | ||
dataset_name = "ogbn-arxiv" | ||
|
||
print("-- Initializing client ...") | ||
glt.distributed.init_client( | ||
num_servers=1, | ||
num_clients=1, | ||
client_rank=0, | ||
master_addr=glt_graph.master_addr, | ||
master_port=glt_graph.server_client_master_port, | ||
num_rpc_threads=4, | ||
) | ||
|
||
device = torch.device("cpu") | ||
|
||
# Create distributed neighbor loader on remote server for training. | ||
print("-- Creating training dataloader ...") | ||
train_loader = glt.distributed.DistNeighborLoader( | ||
data=None, | ||
num_neighbors=[15, 10, 5], | ||
input_nodes=Split.train, | ||
batch_size=512, | ||
shuffle=True, | ||
collect_features=True, | ||
to_device=device, | ||
worker_options=glt.distributed.RemoteDistSamplingWorkerOptions( | ||
server_rank=[0], | ||
num_workers=1, | ||
worker_devices=[torch.device("cpu")], | ||
worker_concurrency=1, | ||
master_addr=glt_graph.master_addr, | ||
master_port=glt_graph.train_loader_master_port, | ||
buffer_size="1GB", | ||
prefetch_size=1, | ||
worker_key=str(glt_graph.train_loader_master_port), | ||
), | ||
) | ||
|
||
# Create distributed neighbor loader on remote server for testing. | ||
print("-- Creating testing dataloader ...") | ||
test_loader = glt.distributed.DistNeighborLoader( | ||
data=None, | ||
num_neighbors=[15, 10, 5], | ||
input_nodes=Split.test, | ||
batch_size=512, | ||
shuffle=False, | ||
collect_features=True, | ||
to_device=device, | ||
worker_options=glt.distributed.RemoteDistSamplingWorkerOptions( | ||
server_rank=[0], | ||
num_workers=1, | ||
worker_devices=[torch.device("cpu")], | ||
worker_concurrency=1, | ||
master_addr=glt_graph.master_addr, | ||
master_port=glt_graph.test_loader_master_port, | ||
buffer_size="1GB", | ||
prefetch_size=1, | ||
worker_key=str(glt_graph.test_loader_master_port), | ||
), | ||
) | ||
|
||
# Define model and optimizer. | ||
print("-- Initializing model and optimizer ...") | ||
model = GraphSAGE( | ||
in_channels=128, | ||
hidden_channels=256, | ||
num_layers=3, | ||
out_channels=47, | ||
).to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | ||
|
||
# Train and test. | ||
print("-- Start training and testing ...") | ||
for epoch in range(0, 10): | ||
model.train() | ||
start = time.time() | ||
for batch in train_loader: | ||
optimizer.zero_grad() | ||
batch.x = batch.x.to(torch.float32) # TODO | ||
out = model(batch.x, batch.edge_index)[: batch.batch_size].log_softmax(dim=-1) | ||
loss = F.nll_loss(out, batch.y[: batch.batch_size]) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
end = time.time() | ||
print(f"-- Epoch: {epoch:03d}, Loss: {loss:.4f}, Epoch Time: {end - start}") | ||
# Test accuracy. | ||
if epoch == 0 or epoch > (epochs // 2): | ||
test_acc = test(model, test_loader, dataset_name) | ||
print(f"-- Test Accuracy: {test_acc:.4f}") | ||
|
||
print("-- Shutdowning ...") | ||
glt.distributed.shutdown_client() | ||
|
||
print("-- Exited ...") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters