Skip to content

Commit

Permalink
[Feat] support random_node_split for graphlearn_torch
Browse files Browse the repository at this point in the history
Committed-by: Hongyi ZHANG from Dev container
  • Loading branch information
Zhanghyi committed Sep 7, 2023
1 parent fce7562 commit 7ec0a9d
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 193 deletions.
6 changes: 5 additions & 1 deletion coordinator/gscoordinator/launch_graphlearn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,16 @@ def run_server_proc(proc_rank, handle, config, server_rank, dataset):
def launch_graphlearn_torch_server(handle, config, server_rank):
logger.info(f"-- [Server {server_rank}] Initializing server ...")

dataset = glt.distributed.DistDataset()
edge_dir = config.pop("edge_dir")
random_node_split = config.pop("random_node_split")
dataset = glt.distributed.DistDataset(edge_dir=edge_dir)
dataset.load_vineyard(
vineyard_id=str(handle["vineyard_id"]),
vineyard_socket=handle["vineyard_socket"],
**config,
)
if random_node_split is not None:
dataset.random_node_split(**random_node_split)
logger.info(f"-- [Server {server_rank}] Initializing server ...")

torch.multiprocessing.spawn(
Expand Down
7 changes: 6 additions & 1 deletion coordinator/gscoordinator/local_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ def _create_graphlearn_torch_instance(self, object_id, handle, config):
server_client_master_port = get_free_port("localhost")
handle["server_client_master_port"] = server_client_master_port

server_list = [f"localhost:{server_client_master_port}"]
# for train, val and test
for _ in range(3):
server_list.append("localhost:" + str(get_free_port("localhost")))

handle = base64.b64encode(pickle.dumps(handle))

# launch the server
Expand Down Expand Up @@ -374,7 +379,7 @@ def _create_graphlearn_torch_instance(self, object_id, handle, config):
)
setattr(proc, "stdout_watcher", stdout_watcher)
self._learning_instance_processes[object_id].append(proc)
return [f"localhost:{server_client_master_port}"]
return server_list

def close_analytical_instance(self):
self._stop_subprocess(self._analytical_engine_process, kill=True)
Expand Down
302 changes: 117 additions & 185 deletions examples/local_glt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# python3 local_glt.py --dataset=ogbn-arxiv --epochs=10
import argparse
import os.path as osp
import time
from typing import List

import graphscope as gs
import graphscope.learning.graphlearn_torch as glt
import torch
import torch.nn.functional as F
from graphscope.dataset import load_ogbn_arxiv
from graphscope.learning.graphlearn_torch.typing import Split
from ogb.nodeproppred import Evaluator
from torch_geometric.nn import GraphSAGE

Expand Down Expand Up @@ -41,185 +38,120 @@ def test(model, test_loader, dataset_name):
return test_acc


def run(
glt_graph,
dataset_name: str,
train_path_list: List[str],
test_path_list: List[str],
epochs: int,
batch_size: int,
train_loader_master_port: int,
test_loader_master_port: int,
):
print(f"-- Initializing client ...")
glt.distributed.init_client(
num_servers=1,
num_clients=1,
client_rank=0,
gs.set_option(show_log=True)

# load the ogbn_arxiv graph as an example.
g = load_ogbn_arxiv()
glt_graph = gs.graphlearn_torch(
g,
edges=[
("paper", "citation", "paper"),
],
node_features={
"paper": [f"feat_{i}" for i in range(128)],
},
node_labels={
"paper": "label",
},
edge_dir="out",
random_node_split={
"num_val": 0.1,
"num_test": 0.1,
},
)

epochs = 10
dataset_name = "ogbn-arxiv"

print("-- Initializing client ...")
glt.distributed.init_client(
num_servers=1,
num_clients=1,
client_rank=0,
master_addr=glt_graph.master_addr,
master_port=glt_graph.server_client_master_port,
num_rpc_threads=4,
)

device = torch.device("cpu")

# Create distributed neighbor loader on remote server for training.
print("-- Creating training dataloader ...")
train_loader = glt.distributed.DistNeighborLoader(
data=None,
num_neighbors=[15, 10, 5],
input_nodes=Split.train,
batch_size=512,
shuffle=True,
collect_features=True,
to_device=device,
worker_options=glt.distributed.RemoteDistSamplingWorkerOptions(
server_rank=[0],
num_workers=1,
worker_devices=[torch.device("cpu")],
worker_concurrency=1,
master_addr=glt_graph.master_addr,
master_port=glt_graph.server_client_master_port,
num_rpc_threads=4,
)

device = torch.device("cpu")

# Create distributed neighbor loader on remote server for training.
print(f"-- Creating training dataloader ...")
train_loader = glt.distributed.DistNeighborLoader(
data=None,
num_neighbors=[15, 10, 5],
input_nodes=train_path_list,
batch_size=batch_size,
shuffle=True,
collect_features=True,
to_device=device,
worker_options=glt.distributed.RemoteDistSamplingWorkerOptions(
server_rank=[0],
num_workers=1,
worker_devices=[torch.device("cpu")],
worker_concurrency=1,
master_addr=glt_graph.master_addr,
master_port=train_loader_master_port,
buffer_size="1GB",
prefetch_size=1,
worker_key="train",
),
)

# Create distributed neighbor loader on remote server for testing.
print(f"-- Creating testing dataloader ...")
test_loader = glt.distributed.DistNeighborLoader(
data=None,
num_neighbors=[15, 10, 5],
input_nodes=test_path_list,
batch_size=batch_size,
shuffle=False,
collect_features=True,
to_device=device,
worker_options=glt.distributed.RemoteDistSamplingWorkerOptions(
server_rank=[0],
num_workers=1,
worker_devices=[torch.device("cpu")],
worker_concurrency=1,
master_addr=glt_graph.master_addr,
master_port=test_loader_master_port,
buffer_size="1GB",
prefetch_size=1,
worker_key="test",
),
)

# Define model and optimizer.
print(f"-- Initializing model and optimizer ...")
model = GraphSAGE(
in_channels=128,
hidden_channels=256,
num_layers=3,
out_channels=47,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Train and test.
print(f"-- Start training and testing ...")
for epoch in range(0, epochs):
model.train()
start = time.time()
for batch in train_loader:
optimizer.zero_grad()
batch.x = batch.x.to(torch.float32) # TODO
out = model(batch.x, batch.edge_index)[: batch.batch_size].log_softmax(
dim=-1
)
loss = F.nll_loss(out, batch.y[: batch.batch_size])
loss.backward()
optimizer.step()

end = time.time()
print(f"-- Epoch: {epoch:03d}, Loss: {loss:.4f}, Epoch Time: {end - start}")
# Test accuracy.
if epoch == 0 or epoch > (epochs // 2):
test_acc = test(model, test_loader, dataset_name)
print(f"-- Test Accuracy: {test_acc:.4f}")

print(f"-- Shutdowning ...")
glt.distributed.shutdown_client()

print(f"-- Exited ...")


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Arguments for distributed training of supervised SAGE with servers."
)
parser.add_argument(
"--dataset",
type=str,
default="ogbn-products",
help="The name of ogbn dataset.",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
help="The number of training epochs. (client option)",
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="Batch size for the training and testing dataloader.",
)
parser.add_argument(
"--train_loader_master_port",
type=int,
default=11112,
help="The port used for RPC initialization across all sampling workers of training loader.",
)
parser.add_argument(
"--test_loader_master_port",
type=int,
default=11113,
help="The port used for RPC initialization across all sampling workers of testing loader.",
)
args = parser.parse_args()

print(f"* dataset: {args.dataset}")
print(f"* epochs: {args.epochs}")
print(f"* batch size: {args.batch_size}")
print(f"* training loader master port: {args.train_loader_master_port}")
print(f"* testing loader master port: {args.test_loader_master_port}")

print("-- Loading training and testing seeds ...")

train_path = "/root/GraphScope/examples/train_seeds.pt"
val_path = "/root/GraphScope/examples/val_seeds.pt"
torch.save(torch.arange(10000), train_path)
torch.save(torch.arange(10000, 169343), val_path)

gs.set_option(show_log=True)

# load the ogbn_arxiv graph as example.
g = load_ogbn_arxiv()
glt_graph = gs.graphlearn_torch(
g,
edges=[
("paper", "citation", "paper"),
],
node_features={
"paper": [f"feat_{i}" for i in range(128)],
},
node_labels={
"paper": "label",
},
)

run(
glt_graph,
args.dataset,
[train_path],
[val_path],
args.epochs,
args.batch_size,
args.train_loader_master_port,
args.test_loader_master_port,
)
master_port=glt_graph.train_loader_master_port,
buffer_size="1GB",
prefetch_size=1,
worker_key="train",
),
)

# Create distributed neighbor loader on remote server for testing.
print("-- Creating testing dataloader ...")
test_loader = glt.distributed.DistNeighborLoader(
data=None,
num_neighbors=[15, 10, 5],
input_nodes=Split.test,
batch_size=512,
shuffle=False,
collect_features=True,
to_device=device,
worker_options=glt.distributed.RemoteDistSamplingWorkerOptions(
server_rank=[0],
num_workers=1,
worker_devices=[torch.device("cpu")],
worker_concurrency=1,
master_addr=glt_graph.master_addr,
master_port=glt_graph.test_loader_master_port,
buffer_size="1GB",
prefetch_size=1,
worker_key="test",
),
)

# Define model and optimizer.
print("-- Initializing model and optimizer ...")
model = GraphSAGE(
in_channels=128,
hidden_channels=256,
num_layers=3,
out_channels=47,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Train and test.
print("-- Start training and testing ...")
for epoch in range(0, 10):
model.train()
start = time.time()
for batch in train_loader:
optimizer.zero_grad()
batch.x = batch.x.to(torch.float32) # TODO
out = model(batch.x, batch.edge_index)[: batch.batch_size].log_softmax(dim=-1)
loss = F.nll_loss(out, batch.y[: batch.batch_size])
loss.backward()
optimizer.step()

end = time.time()
print(f"-- Epoch: {epoch:03d}, Loss: {loss:.4f}, Epoch Time: {end - start}")
# Test accuracy.
if epoch == 0 or epoch > (epochs // 2):
test_acc = test(model, test_loader, dataset_name)
print(f"-- Test Accuracy: {test_acc:.4f}")

print("-- Shutdowning ...")
glt.distributed.shutdown_client()

print("-- Exited ...")
2 changes: 1 addition & 1 deletion learning_engine/graph-learn
2 changes: 1 addition & 1 deletion learning_engine/graphlearn-for-pytorch
Submodule graphlearn-for-pytorch updated 53 files
+4 โˆ’0 CMakeLists.txt
+313 โˆ’0 examples/distributed/dist_sage_unsup/dist_sage_unsup.py
+129 โˆ’0 examples/distributed/dist_sage_unsup/preprocess_template.py
+7 โˆ’7 examples/igbh/dist_train_rgnn.py
+13 โˆ’4 examples/igbh/rgnn.py
+11 โˆ’3 examples/igbh/train_rgnn.py
+12 โˆ’4 graphlearn_torch/csrc/cpu/graph.cc
+1 โˆ’0 graphlearn_torch/csrc/cpu/random_sampler.cc
+2 โˆ’0 graphlearn_torch/csrc/cpu/random_sampler.h
+123 โˆ’107 graphlearn_torch/csrc/cpu/vineyard_utils.cc
+194 โˆ’0 graphlearn_torch/csrc/cpu/weighted_sampler.cc
+63 โˆ’0 graphlearn_torch/csrc/cpu/weighted_sampler.h
+42 โˆ’0 graphlearn_torch/csrc/cuda/weighted_sampler.cuh
+1 โˆ’8 graphlearn_torch/include/common.h
+6 โˆ’1 graphlearn_torch/include/graph.h
+5 โˆ’7 graphlearn_torch/include/vineyard_utils.h
+162 โˆ’10 graphlearn_torch/python/data/dataset.py
+3 โˆ’4 graphlearn_torch/python/data/feature.py
+55 โˆ’25 graphlearn_torch/python/data/graph.py
+6 โˆ’18 graphlearn_torch/python/data/vineyard_utils.py
+98 โˆ’7 graphlearn_torch/python/distributed/dist_dataset.py
+2 โˆ’1 graphlearn_torch/python/distributed/dist_link_neighbor_loader.py
+1 โˆ’0 graphlearn_torch/python/distributed/dist_loader.py
+15 โˆ’8 graphlearn_torch/python/distributed/dist_neighbor_loader.py
+3 โˆ’1 graphlearn_torch/python/distributed/dist_neighbor_sampler.py
+3 โˆ’2 graphlearn_torch/python/distributed/dist_sampling_producer.py
+1 โˆ’1 graphlearn_torch/python/distributed/dist_server.py
+3 โˆ’1 graphlearn_torch/python/distributed/dist_subgraph_loader.py
+3 โˆ’2 graphlearn_torch/python/loader/link_loader.py
+2 โˆ’0 graphlearn_torch/python/loader/link_neighbor_loader.py
+2 โˆ’0 graphlearn_torch/python/loader/neighbor_loader.py
+28 โˆ’6 graphlearn_torch/python/partition/base.py
+3 โˆ’1 graphlearn_torch/python/partition/frequency_partitioner.py
+2 โˆ’1 graphlearn_torch/python/partition/random_partitioner.py
+10 โˆ’8 graphlearn_torch/python/py_export.cc
+31 โˆ’4 graphlearn_torch/python/sampler/base.py
+12 โˆ’4 graphlearn_torch/python/sampler/neighbor_sampler.py
+12 โˆ’1 graphlearn_torch/python/typing.py
+10 โˆ’1 graphlearn_torch/python/utils/build.py
+32 โˆ’15 graphlearn_torch/python/utils/topo.py
+114 โˆ’0 test/cpp/test_vineyard.cc
+57 โˆ’14 test/python/dist_test_utils.py
+22 โˆ’8 test/python/test_dist_neighbor_loader.py
+32 โˆ’11 test/python/test_graph.py
+117 โˆ’60 test/python/test_hetero_neighbor_sampler.py
+22 โˆ’0 test/python/test_neighbor_sampler.py
+27 โˆ’5 test/python/test_partition.py
+168 โˆ’23 test/python/test_vineyard.py
+36 โˆ’0 test/python/vineyard_data/config.json
+5 โˆ’0 test/python/vineyard_data/modern_graph/created.csv
+3 โˆ’0 test/python/vineyard_data/modern_graph/knows.csv
+5 โˆ’0 test/python/vineyard_data/modern_graph/person.csv
+3 โˆ’0 test/python/vineyard_data/modern_graph/software.csv
Loading

0 comments on commit 7ec0a9d

Please sign in to comment.