From 33743424aeda6204dbc5e111b154c014b1363cbc Mon Sep 17 00:00:00 2001 From: LiSu Date: Mon, 25 Sep 2023 20:05:01 +0800 Subject: [PATCH] feat(learning): Integrate GLTorch into GraphScope (#3230) This PR integrates [GLTorch](https://github.com/alibaba/graphlearn-for-pytorch) into GraphScope, allowing training GNNs using GLTorch in the local mode of GraphScope. --------- Co-authored-by: Hongyi ZHANG <50618951+Zhanghyi@users.noreply.github.com> Co-authored-by: Jia Zhibin <56682441+Jia-zb@users.noreply.github.com> --- .gitmodules | 4 + Makefile | 5 + coordinator/gscoordinator/coordinator.py | 8 +- .../gscoordinator/kubernetes_launcher.py | 4 +- .../{learning.py => launch_graphlearn.py} | 5 +- .../gscoordinator/launch_graphlearn_torch.py | 93 +++++++++ coordinator/gscoordinator/launcher.py | 4 +- coordinator/gscoordinator/local_launcher.py | 79 +++++++- docs/learning_engine/guide_and_examples.md | 29 ++- .../tutorial_node_classification_pyg_local.md | 170 +++++++++++++++++ learning_engine/graphlearn-for-pytorch | 1 + proto/message.proto | 6 + python/graphscope/__init__.py | 1 + python/graphscope/client/rpc.py | 3 +- python/graphscope/client/session.py | 79 +++++++- .../learning/gl_torch_examples/local.py | 154 +++++++++++++++ python/graphscope/learning/gl_torch_graph.py | 179 ++++++++++++++++++ python/graphscope/learning/graphlearn_torch | 1 + python/requirements.txt | 1 + python/setup.cfg | 3 +- python/setup.py | 97 +++++++++- 21 files changed, 906 insertions(+), 20 deletions(-) rename coordinator/gscoordinator/{learning.py => launch_graphlearn.py} (94%) create mode 100644 coordinator/gscoordinator/launch_graphlearn_torch.py create mode 100644 docs/learning_engine/tutorial_node_classification_pyg_local.md create mode 160000 learning_engine/graphlearn-for-pytorch create mode 100644 python/graphscope/learning/gl_torch_examples/local.py create mode 100644 python/graphscope/learning/gl_torch_graph.py create mode 120000 python/graphscope/learning/graphlearn_torch diff --git a/.gitmodules b/.gitmodules index dabe493f60bd..1db7618bdbe0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,3 +5,7 @@ [submodule "flex/grin"] path = flex/grin url = https://github.com/GraphScope/GRIN.git + +[submodule "learning_engine/graphlearn-for-pytorch"] + path = learning_engine/graphlearn-for-pytorch + url = https://github.com/alibaba/graphlearn-for-pytorch.git diff --git a/Makefile b/Makefile index 0892e87515ab..81ecd19aee9b 100644 --- a/Makefile +++ b/Makefile @@ -18,6 +18,8 @@ NETWORKX ?= ON # testing build option BUILD_TEST ?= OFF +# whether to build graphlearn-torch extension (graphlearn is built by default) +WITH_GLTORCH ?= ON # INSTALL_PREFIX is environment variable, but if it is not set, then set default value ifeq ($(INSTALL_PREFIX),) @@ -76,6 +78,9 @@ client: learning python3 -m pip install -r requirements.txt -r requirements-dev.txt --user && \ export PATH=$(PATH):$(HOME)/.local/bin && \ python3 setup.py build_ext --inplace --user && \ + if [ $(WITH_GLTORCH) = ON ]; then \ + python3 setup.py build_gltorch_ext --inplace --user; \ + fi && \ python3 -m pip install --user --no-build-isolation --editable $(CLIENT_DIR) && \ rm -rf $(CLIENT_DIR)/*.egg-info diff --git a/coordinator/gscoordinator/coordinator.py b/coordinator/gscoordinator/coordinator.py index 7799801a5465..17f324418806 100644 --- a/coordinator/gscoordinator/coordinator.py +++ b/coordinator/gscoordinator/coordinator.py @@ -531,10 +531,14 @@ def _match_frontend_endpoint(pattern, lines): def CreateLearningInstance(self, request, context): object_id = request.object_id logger.info("Create learning instance with object id %ld", object_id) - handle, config = request.handle, request.config + handle, config, learning_backend = ( + request.handle, + request.config, + request.learning_backend, + ) try: endpoints = self._launcher.create_learning_instance( - object_id, handle, config + object_id, handle, config, learning_backend ) self._object_manager.put(object_id, LearningInstanceManager(object_id)) except Exception as e: diff --git a/coordinator/gscoordinator/kubernetes_launcher.py b/coordinator/gscoordinator/kubernetes_launcher.py index 6e8be09dabcd..3c48e6a13f3d 100644 --- a/coordinator/gscoordinator/kubernetes_launcher.py +++ b/coordinator/gscoordinator/kubernetes_launcher.py @@ -1291,7 +1291,7 @@ def _distribute_learning_process( self._learning_instance_processes[object_id] = [] for pod_index, pod in enumerate(self._pod_name_list): container = LEARNING_CONTAINER_NAME - sub_cmd = f"python3 -m gscoordinator.learning {handle} {config} {pod_index}" + sub_cmd = f"python3 -m gscoordinator.launch_graphlearn {handle} {config} {pod_index}" cmd = f"kubectl -n {self._namespace} exec -it -c {container} {pod} -- {sub_cmd}" logger.debug("launching learning server: %s", " ".join(cmd)) proc = subprocess.Popen( @@ -1321,7 +1321,7 @@ def _distribute_learning_process( self._api_client, object_id, pod_host_ip_list ) - def create_learning_instance(self, object_id, handle, config): + def create_learning_instance(self, object_id, handle, config, learning_backend): pod_name_list, _, pod_host_ip_list = self._allocate_learning_engine(object_id) if not pod_name_list or not pod_host_ip_list: raise RuntimeError("Failed to allocate learning engine") diff --git a/coordinator/gscoordinator/learning.py b/coordinator/gscoordinator/launch_graphlearn.py similarity index 94% rename from coordinator/gscoordinator/learning.py rename to coordinator/gscoordinator/launch_graphlearn.py index 8acc0493a0a9..39e72daafd03 100644 --- a/coordinator/gscoordinator/learning.py +++ b/coordinator/gscoordinator/launch_graphlearn.py @@ -67,7 +67,10 @@ def launch_server(handle, config, server_index): if __name__ == "__main__": if len(sys.argv) < 3: - print("Usage: ./learning.py ", file=sys.stderr) + print( + "Usage: ./launch_graphlearn.py ", + file=sys.stderr, + ) sys.exit(-1) handle = decode_arg(sys.argv[1]) diff --git a/coordinator/gscoordinator/launch_graphlearn_torch.py b/coordinator/gscoordinator/launch_graphlearn_torch.py new file mode 100644 index 000000000000..4650161aaa08 --- /dev/null +++ b/coordinator/gscoordinator/launch_graphlearn_torch.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import base64 +import json +import logging +import os.path as osp +import sys + +import graphscope.learning.graphlearn_torch as glt +import torch +from graphscope.learning.gl_torch_graph import GLTorchGraph + +logger = logging.getLogger("graphscope") + + +def decode_arg(arg): + if isinstance(arg, dict): + return arg + return json.loads( + base64.b64decode(arg.encode("utf-8", errors="ignore")).decode( + "utf-8", errors="ignore" + ) + ) + + +def run_server_proc(proc_rank, handle, config, server_rank, dataset): + glt.distributed.init_server( + num_servers=handle["num_servers"], + server_rank=server_rank, + dataset=dataset, + master_addr=handle["master_addr"], + master_port=handle["server_client_master_port"], + num_rpc_threads=16, + is_dynamic=True, + ) + logger.info(f"-- [Server {server_rank}] Waiting for exit ...") + glt.distributed.wait_and_shutdown_server() + logger.info(f"-- [Server {server_rank}] Exited ...") + + +def launch_graphlearn_torch_server(handle, config, server_rank): + logger.info(f"-- [Server {server_rank}] Initializing server ...") + + edge_dir = config.pop("edge_dir") + random_node_split = config.pop("random_node_split") + dataset = glt.distributed.DistDataset(edge_dir=edge_dir) + dataset.load_vineyard( + vineyard_id=str(handle["vineyard_id"]), + vineyard_socket=handle["vineyard_socket"], + **config, + ) + if random_node_split is not None: + dataset.random_node_split(**random_node_split) + logger.info(f"-- [Server {server_rank}] Initializing server ...") + + torch.multiprocessing.spawn( + fn=run_server_proc, args=(handle, config, server_rank, dataset), nprocs=1 + ) + + +if __name__ == "__main__": + if len(sys.argv) < 3: + logger.info( + "Usage: ./launch_graphlearn_torch.py ", + file=sys.stderr, + ) + sys.exit(-1) + + handle = decode_arg(sys.argv[1]) + config = decode_arg(sys.argv[2]) + server_index = int(sys.argv[3]) + config = GLTorchGraph.reverse_transform_config(config) + + logger.info( + f"launch_graphlearn_torch_server handle: {handle} config: {config} server_index: {server_index}" + ) + launch_graphlearn_torch_server(handle, config, server_index) diff --git a/coordinator/gscoordinator/launcher.py b/coordinator/gscoordinator/launcher.py index b4f6ac7d2386..09f43f2da860 100644 --- a/coordinator/gscoordinator/launcher.py +++ b/coordinator/gscoordinator/launcher.py @@ -95,7 +95,9 @@ def create_interactive_instance( pass @abstractmethod - def create_learning_instance(self, object_id: int, handle: str, config: str): + def create_learning_instance( + self, object_id: int, handle: str, config: str, learning_backend: int + ): pass @abstractmethod diff --git a/coordinator/gscoordinator/local_launcher.py b/coordinator/gscoordinator/local_launcher.py index 814516cb529f..bd1aeed7140e 100644 --- a/coordinator/gscoordinator/local_launcher.py +++ b/coordinator/gscoordinator/local_launcher.py @@ -33,6 +33,7 @@ from graphscope.framework.utils import get_java_version from graphscope.framework.utils import get_tempdir from graphscope.framework.utils import is_free_port +from graphscope.proto import message_pb2 from graphscope.proto import types_pb2 from gscoordinator.launcher import AbstractLauncher @@ -245,7 +246,19 @@ def _popen_helper(cmd, cwd, env, stdout=None, stderr=None): ) return process - def create_learning_instance(self, object_id, handle, config): + def create_learning_instance(self, object_id, handle, config, learning_backend): + if learning_backend == message_pb2.LearningBackend.GRAPHLEARN: + return self._create_graphlearn_instance( + object_id=object_id, handle=handle, config=config + ) + elif learning_backend == message_pb2.LearningBackend.GRAPHLEARN_TORCH: + return self._create_graphlearn_torch_instance( + object_id=object_id, handle=handle, config=config + ) + else: + raise ValueError("invalid learning backend") + + def _create_graphlearn_instance(self, object_id, handle, config): # prepare argument handle = json.loads( base64.b64decode(handle.encode("utf-8", errors="ignore")).decode( @@ -275,12 +288,12 @@ def create_learning_instance(self, object_id, handle, config): cmd = [ sys.executable, "-m", - "gscoordinator.learning", + "gscoordinator.launch_graphlearn", handle, config, str(index), ] - logger.debug("launching learning server: %s", " ".join(cmd)) + logger.debug("launching graphlearn server: %s", " ".join(cmd)) proc = self._popen_helper(cmd, cwd=None, env=env) stdout_watcher = PipeWatcher(proc.stdout, sys.stdout) @@ -289,6 +302,66 @@ def create_learning_instance(self, object_id, handle, config): self._learning_instance_processes[object_id].append(proc) return server_list + def _create_graphlearn_torch_instance(self, object_id, handle, config): + handle = json.loads( + base64.b64decode(handle.encode("utf-8", errors="ignore")).decode( + "utf-8", errors="ignore" + ) + ) + + server_client_master_port = get_free_port("localhost") + handle["server_client_master_port"] = server_client_master_port + + server_list = [f"localhost:{server_client_master_port}"] + # for train, val and test + for _ in range(3): + server_list.append("localhost:" + str(get_free_port("localhost"))) + + handle = base64.b64encode( + json.dumps(handle).encode("utf-8", errors="ignore") + ).decode("utf-8", errors="ignore") + + # launch the server + env = os.environ.copy() + # set coordinator dir to PYTHONPATH + python_path = ( + env.get("PYTHONPATH", "") + + os.pathsep + + os.path.dirname(os.path.dirname(__file__)) + ) + env["PYTHONPATH"] = python_path + + self._learning_instance_processes[object_id] = [] + for index in range(self._num_workers): + cmd = [ + sys.executable, + "-m", + "gscoordinator.launch_graphlearn_torch", + handle, + config, + str(index), + ] + logger.debug("launching graphlearn_torch server: %s", " ".join(str(cmd))) + + proc = subprocess.Popen( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + encoding="utf-8", + errors="replace", + universal_newlines=True, + bufsize=1, + ) + stdout_watcher = PipeWatcher( + proc.stdout, + sys.stdout, + suppressed=(not logger.isEnabledFor(logging.DEBUG)), + ) + setattr(proc, "stdout_watcher", stdout_watcher) + self._learning_instance_processes[object_id].append(proc) + return server_list + def close_analytical_instance(self): self._stop_subprocess(self._analytical_engine_process, kill=True) self._analytical_engine_endpoint = None diff --git a/docs/learning_engine/guide_and_examples.md b/docs/learning_engine/guide_and_examples.md index cd477e23dfc3..fa97be8bbf53 100644 --- a/docs/learning_engine/guide_and_examples.md +++ b/docs/learning_engine/guide_and_examples.md @@ -14,10 +14,12 @@ tutorial_node_classification_k8s This section contains a guide for the learning engine and a number of examples. ```{tip} -We assume you has read the [getting_started](getting_started.md) section and know how to launch a GraphScope session. +We assume you has read the [getting_started](getting_started.md) section and +know how to launch a GraphScope session. ``` -We present an end-to-end example, demonstrating how GLE trains a node classification model on a citation network using the local mode of GraphScope. +We present an end-to-end example, demonstrating how GLE trains a node +classification model on a citation network using the local mode of GraphScope. ````{panels} :header: text-center @@ -31,7 +33,11 @@ We present an end-to-end example, demonstrating how GLE trains a node classifica Training a Node Classification Model on Your Local Machine. ```` -GraphScope is designed for processing large graphs, which are usually hard to fit in the memory of a single machine. With vineyard as the distributed in-memory data manager, GraphScope supports run on a cluster managed by Kubernetes(k8s). Next, we revisit the example we present in the first tutorial, showing how GraphScope process the node classification task on a Kubernetes cluster. +GraphScope is designed for processing large graphs, which are usually hard to +fit in the memory of a single machine. With vineyard as the distributed +in-memory data manager, GraphScope supports run on a cluster managed by +Kubernetes(k8s). Next, we revisit the example we present in the first tutorial, +showing how GraphScope process the node classification task on a Kubernetes cluster. ````{panels} @@ -45,3 +51,20 @@ GraphScope is designed for processing large graphs, which are usually hard to fi ^^^^^^^^^^^^^^ Training a Node Classification Model on K8s Cluster ```` + + +GraphScope is also compatible with PyG models, the following examples shows +ho2 to train a PyG model using GraphScope on your local machine. + + +````{panels} +:header: text-center +:column: col-lg-12 p-2 + +```{link-button} tutorial_node_classification_pyg_local.html +:text: Tutorial +:classes: btn-block stretched-link +``` +^^^^^^^^^^^^^^ +Training a Node Classification Model(PyG) on Your Local Machine +```` \ No newline at end of file diff --git a/docs/learning_engine/tutorial_node_classification_pyg_local.md b/docs/learning_engine/tutorial_node_classification_pyg_local.md new file mode 100644 index 000000000000..ff85e500f166 --- /dev/null +++ b/docs/learning_engine/tutorial_node_classification_pyg_local.md @@ -0,0 +1,170 @@ +# Tutorial: Training a Node Classification Model (PyG) on Your Local Machine + +This tutorial presents an end-to-end example that illustrates how GraphScope +trains the GraphSAGE model (implemented in PyG) for a node classification task. + +## Load Graph +```python +import time + +import torch +import torch.nn.functional as F +from ogb.nodeproppred import Evaluator +from torch_geometric.nn import GraphSAGE + +import graphscope as gs +import graphscope.learning.graphlearn_torch as glt +from graphscope.dataset import load_ogbn_arxiv +from graphscope.learning.graphlearn_torch.typing import Split + +gs.set_option(show_log=True) + +# load the ogbn_arxiv graph as an example. +g = load_ogbn_arxiv() +``` +## Define the evaluation function +```python +@torch.no_grad() +def test(model, test_loader, dataset_name): + evaluator = Evaluator(name=dataset_name) + model.eval() + xs = [] + y_true = [] + for i, batch in enumerate(test_loader): + if i == 0: + device = batch.x.device + batch.x = batch.x.to(torch.float32) # TODO + x = model(batch.x, batch.edge_index)[: batch.batch_size] + xs.append(x.cpu()) + y_true.append(batch.y[: batch.batch_size].cpu()) + del batch + + xs = [t.to(device) for t in xs] + y_true = [t.to(device) for t in y_true] + y_pred = torch.cat(xs, dim=0).argmax(dim=-1, keepdim=True) + y_true = torch.cat(y_true, dim=0).unsqueeze(-1) + test_acc = evaluator.eval( + { + "y_true": y_true, + "y_pred": y_pred, + } + )["acc"] + return test_acc +``` + +## Launch the Learning Engine +```python +glt_graph = gs.graphlearn_torch( + g, + edges=[ + ("paper", "citation", "paper"), + ], + node_features={ + "paper": [f"feat_{i}" for i in range(128)], + }, + node_labels={ + "paper": "label", + }, + edge_dir="out", + random_node_split={ + "num_val": 0.1, + "num_test": 0.1, + }, +) + +print("-- Initializing client ...") +glt.distributed.init_client( + num_servers=1, + num_clients=1, + client_rank=0, + master_addr=glt_graph.master_addr, + master_port=glt_graph.server_client_master_port, + num_rpc_threads=4, + is_dynamic=True, +) +``` + +## Create neighbor loaderfor training, testing and validation +```python +device = torch.device("cpu") +# Create distributed neighbor loader on remote server for training. +print("-- Creating training dataloader ...") +train_loader = glt.distributed.DistNeighborLoader( + data=None, + num_neighbors=[15, 10, 5], + input_nodes=Split.train, + batch_size=512, + shuffle=True, + collect_features=True, + to_device=device, + worker_options=glt.distributed.RemoteDistSamplingWorkerOptions( + num_workers=1, + worker_devices=[torch.device("cpu")], + worker_concurrency=1, + buffer_size="1GB", + prefetch_size=1, + glt_graph=glt_graph, + workload_type="train", + ), +) + +# Create distributed neighbor loader on remote server for testing. +print("-- Creating testing dataloader ...") +test_loader = glt.distributed.DistNeighborLoader( + data=None, + num_neighbors=[15, 10, 5], + input_nodes=Split.test, + batch_size=512, + shuffle=False, + collect_features=True, + to_device=device, + worker_options=glt.distributed.RemoteDistSamplingWorkerOptions( + num_workers=1, + worker_devices=[torch.device("cpu")], + worker_concurrency=1, + buffer_size="1GB", + prefetch_size=1, + glt_graph=glt_graph, + workload_type="test", + ), +) +``` + +## Define the PyG GraphSage Model and optimizer +```python +print("-- Initializing model and optimizer ...") +model = GraphSAGE( + in_channels=128, + hidden_channels=256, + num_layers=3, + out_channels=47, +).to(device) +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) +``` + +## Train and test +```python +print("-- Start training and testing ...") +epochs = 10 +dataset_name = "ogbn-arxiv" +for epoch in range(0, epochs): + model.train() + start = time.time() + for batch in train_loader: + optimizer.zero_grad() + batch.x = batch.x.to(torch.float32) # TODO + out = model(batch.x, batch.edge_index)[: batch.batch_size].log_softmax(dim=-1) + loss = F.nll_loss(out, batch.y[: batch.batch_size]) + loss.backward() + optimizer.step() + + end = time.time() + print(f"-- Epoch: {epoch:03d}, Loss: {loss:.4f}, Epoch Time: {end - start}") + # Test accuracy. + if epoch == 0 or epoch > (epochs // 2): + test_acc = test(model, test_loader, dataset_name) + print(f"-- Test Accuracy: {test_acc:.4f}") + +print("-- Shutdowning ...") +glt.distributed.shutdown_client() +``` \ No newline at end of file diff --git a/learning_engine/graphlearn-for-pytorch b/learning_engine/graphlearn-for-pytorch new file mode 160000 index 000000000000..7068173fd51f --- /dev/null +++ b/learning_engine/graphlearn-for-pytorch @@ -0,0 +1 @@ +Subproject commit 7068173fd51f4956767e9bfae92b4810a26a07f9 diff --git a/proto/message.proto b/proto/message.proto index 998b314b49a1..2a7bbf6ab54b 100644 --- a/proto/message.proto +++ b/proto/message.proto @@ -194,11 +194,17 @@ message CreateInteractiveInstanceResponse { int64 object_id = 3; }; +enum LearningBackend { + GRAPHLEARN = 0; + GRAPHLEARN_TORCH = 1; +} + message CreateLearningInstanceRequest { string session_id = 1; int64 object_id = 2; string handle = 3; string config = 4; + LearningBackend learning_backend = 5; }; message CreateLearningInstanceResponse { diff --git a/python/graphscope/__init__.py b/python/graphscope/__init__.py index ad1799b3f20f..a3ed3ea76005 100644 --- a/python/graphscope/__init__.py +++ b/python/graphscope/__init__.py @@ -41,6 +41,7 @@ from graphscope.client.session import g from graphscope.client.session import get_default_session from graphscope.client.session import graphlearn +from graphscope.client.session import graphlearn_torch from graphscope.client.session import gremlin from graphscope.client.session import has_default_session from graphscope.client.session import interactive diff --git a/python/graphscope/client/rpc.py b/python/graphscope/client/rpc.py index 8e18f0ce5bb1..cf65ca5932d2 100644 --- a/python/graphscope/client/rpc.py +++ b/python/graphscope/client/rpc.py @@ -219,11 +219,12 @@ def create_interactive_instance(self, object_id, schema_path, params, with_cyphe response = self._stub.CreateInteractiveInstance(request) return response.gremlin_endpoint, response.cypher_endpoint - def create_learning_instance(self, object_id, handle, config): + def create_learning_instance(self, object_id, handle, config, learning_backend): request = message_pb2.CreateLearningInstanceRequest(session_id=self._session_id) request.object_id = object_id request.handle = handle request.config = config + request.learning_backend = learning_backend response = self._stub.CreateLearningInstance(request) return response.handle, response.config, response.endpoints diff --git a/python/graphscope/client/session.py b/python/graphscope/client/session.py index d5ea6d61808d..56e7b8faf26a 100755 --- a/python/graphscope/client/session.py +++ b/python/graphscope/client/session.py @@ -1312,7 +1312,7 @@ def graphlearn(self, graph, nodes=None, edges=None, gen_labels=None): json.dumps(config).encode("utf-8", errors="ignore") ).decode("utf-8", errors="ignore") handle, config, endpoints = self._grpc_client.create_learning_instance( - graph.vineyard_id, handle, config + graph.vineyard_id, handle, config, message_pb2.LearningBackend.GRAPHLEARN ) handle = json.loads(base64.b64decode(handle).decode("utf-8", errors="ignore")) @@ -1325,6 +1325,57 @@ def graphlearn(self, graph, nodes=None, edges=None, gen_labels=None): graph._attach_learning_instance(g) return g + def graphlearn_torch( + self, + graph, + edges, + edge_weights=None, + node_features=None, + edge_features=None, + node_labels=None, + edge_dir="out", + random_node_split=None, + ): + from graphscope.learning.gl_torch_graph import GLTorchGraph + + handle = { + "vineyard_socket": self._engine_config["vineyard_socket"], + "vineyard_id": graph.vineyard_id, + "fragments": graph.fragments, + "master_addr": "localhost", + "num_servers": 1, + "num_clients": 1, + } + + handle = base64.b64encode( + json.dumps(handle).encode("utf-8", errors="ignore") + ).decode("utf-8", errors="ignore") + config = { + "edges": edges, + "edge_weights": edge_weights, + "node_features": node_features, + "edge_features": edge_features, + "node_labels": node_labels, + "edge_dir": edge_dir, + "random_node_split": random_node_split, + } + GLTorchGraph.check_params(graph.schema, config) + config = GLTorchGraph.transform_config(config) + config = base64.b64encode( + json.dumps(config).encode("utf-8", errors="ignore") + ).decode("utf-8", errors="ignore") + handle, config, endpoints = self._grpc_client.create_learning_instance( + graph.vineyard_id, + handle, + config, + message_pb2.LearningBackend.GRAPHLEARN_TORCH, + ) + + g = GLTorchGraph(endpoints) + self._learning_instance_dict[graph.vineyard_id] = g + graph._attach_learning_instance(g) + return g + def nx(self): if not self.eager(): raise RuntimeError( @@ -1611,3 +1662,29 @@ def graphlearn(graph, nodes=None, edges=None, gen_labels=None): return graph._session.graphlearn( graph, nodes, edges, gen_labels ) # pylint: disable=protected-access + + +def graphlearn_torch( + graph, + edges, + edge_weights=None, + node_features=None, + edge_features=None, + node_labels=None, + edge_dir="out", + random_node_split=None, +): + assert graph is not None, "graph cannot be None" + assert ( + graph._session is not None + ), "The graph object is invalid" # pylint: disable=protected-access + return graph._session.graphlearn_torch( + graph, + edges, + edge_weights, + node_features, + edge_features, + node_labels, + edge_dir, + random_node_split, + ) # pylint: disable=protected-access diff --git a/python/graphscope/learning/gl_torch_examples/local.py b/python/graphscope/learning/gl_torch_examples/local.py new file mode 100644 index 000000000000..2f71ffd0cf66 --- /dev/null +++ b/python/graphscope/learning/gl_torch_examples/local.py @@ -0,0 +1,154 @@ +import time + +import torch +import torch.nn.functional as F +from ogb.nodeproppred import Evaluator +from torch_geometric.nn import GraphSAGE + +import graphscope as gs +import graphscope.learning.graphlearn_torch as glt +from graphscope.dataset import load_ogbn_arxiv +from graphscope.learning.graphlearn_torch.typing import Split + + +@torch.no_grad() +def test(model, test_loader, dataset_name): + evaluator = Evaluator(name=dataset_name) + model.eval() + xs = [] + y_true = [] + for i, batch in enumerate(test_loader): + if i == 0: + device = batch.x.device + batch.x = batch.x.to(torch.float32) # TODO + x = model(batch.x, batch.edge_index)[: batch.batch_size] + xs.append(x.cpu()) + y_true.append(batch.y[: batch.batch_size].cpu()) + del batch + + xs = [t.to(device) for t in xs] + y_true = [t.to(device) for t in y_true] + y_pred = torch.cat(xs, dim=0).argmax(dim=-1, keepdim=True) + y_true = torch.cat(y_true, dim=0).unsqueeze(-1) + test_acc = evaluator.eval( + { + "y_true": y_true, + "y_pred": y_pred, + } + )["acc"] + return test_acc + + +gs.set_option(show_log=True) + +# load the ogbn_arxiv graph as an example. +g = load_ogbn_arxiv() +glt_graph = gs.graphlearn_torch( + g, + edges=[ + ("paper", "citation", "paper"), + ], + node_features={ + "paper": [f"feat_{i}" for i in range(128)], + }, + node_labels={ + "paper": "label", + }, + edge_dir="out", + random_node_split={ + "num_val": 0.1, + "num_test": 0.1, + }, +) + +print("-- Initializing client ...") +glt.distributed.init_client( + num_servers=1, + num_clients=1, + client_rank=0, + master_addr=glt_graph.master_addr, + master_port=glt_graph.server_client_master_port, + num_rpc_threads=4, + is_dynamic=True, +) + + +device = torch.device("cpu") +# Create distributed neighbor loader on remote server for training. +print("-- Creating training dataloader ...") +train_loader = glt.distributed.DistNeighborLoader( + data=None, + num_neighbors=[15, 10, 5], + input_nodes=Split.train, + batch_size=512, + shuffle=True, + collect_features=True, + to_device=device, + worker_options=glt.distributed.RemoteDistSamplingWorkerOptions( + num_workers=1, + worker_devices=[torch.device("cpu")], + worker_concurrency=1, + buffer_size="1GB", + prefetch_size=1, + glt_graph=glt_graph, + workload_type="train", + ), +) + +# Create distributed neighbor loader on remote server for testing. +print("-- Creating testing dataloader ...") +test_loader = glt.distributed.DistNeighborLoader( + data=None, + num_neighbors=[15, 10, 5], + input_nodes=Split.test, + batch_size=512, + shuffle=False, + collect_features=True, + to_device=device, + worker_options=glt.distributed.RemoteDistSamplingWorkerOptions( + num_workers=1, + worker_devices=[torch.device("cpu")], + worker_concurrency=1, + buffer_size="1GB", + prefetch_size=1, + glt_graph=glt_graph, + workload_type="test", + ), +) + +# Define model and optimizer. +print("-- Initializing model and optimizer ...") +model = GraphSAGE( + in_channels=128, + hidden_channels=256, + num_layers=3, + out_channels=47, +).to(device) +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + +# Train and test. +print("-- Start training and testing ...") +epochs = 10 +dataset_name = "ogbn-arxiv" +for epoch in range(0, epochs): + model.train() + start = time.time() + for batch in train_loader: + optimizer.zero_grad() + batch.x = batch.x.to(torch.float32) # TODO + out = model(batch.x, batch.edge_index)[: batch.batch_size].log_softmax(dim=-1) + loss = F.nll_loss(out, batch.y[: batch.batch_size]) + loss.backward() + optimizer.step() + + end = time.time() + print(f"-- Epoch: {epoch:03d}, Loss: {loss:.4f}, Epoch Time: {end - start}") + # Test accuracy. + if epoch == 0 or epoch > (epochs // 2): + test_acc = test(model, test_loader, dataset_name) + print(f"-- Test Accuracy: {test_acc:.4f}") + +print("-- Shutdowning ...") +glt.distributed.shutdown_client() + +print("-- Exited ...") diff --git a/python/graphscope/learning/gl_torch_graph.py b/python/graphscope/learning/gl_torch_graph.py new file mode 100644 index 000000000000..f7b1d0191e6a --- /dev/null +++ b/python/graphscope/learning/gl_torch_graph.py @@ -0,0 +1,179 @@ +class GLTorchGraph(object): + def __init__(self, server_list): + assert len(server_list) == 4 + self._master_addr, self._server_client_master_port = server_list[0].split(":") + self._train_master_addr, self._train_loader_master_port = server_list[1].split( + ":" + ) + self._val_master_addr, self._val_loader_master_port = server_list[2].split(":") + self._test_master_addr, self._test_loader_master_port = server_list[3].split( + ":" + ) + assert ( + self._master_addr + == self._train_master_addr + == self._val_master_addr + == self._test_master_addr + ) + + @property + def master_addr(self): + return self._master_addr + + @property + def server_client_master_port(self): + return self._server_client_master_port + + @property + def train_loader_master_port(self): + return self._train_loader_master_port + + @property + def val_loader_master_port(self): + return self._val_loader_master_port + + @property + def test_loader_master_port(self): + return self._test_loader_master_port + + @staticmethod + def check_edge(schema, edge): + if not isinstance(edge, tuple) or len(edge) != 3: + raise ValueError("Each edge should be a tuple of length 3") + for vertex_label in [edge[0], edge[2]]: + if vertex_label not in schema.vertex_labels: + raise ValueError(f"Invalid edge label: {vertex_label}") + if edge[1] not in schema.edge_labels: + raise ValueError(f"Invalid edge label: {edge[1]}") + + @staticmethod + def check_edges(schema, edges): + for edge in edges: + GLTorchGraph.check_edge(schema, edge) + if edge in edges[edges.index(edge) + 1 :]: + raise ValueError(f"Duplicated edge: {edge}") + + @staticmethod + def check_features(feature_names, properties): + data_type = None + property_name = "" + property_dict = {property.name: property for property in properties} + for feature in feature_names: + if feature not in property_dict: + raise ValueError(f"Feature '{feature}' does not exist") + property = property_dict[feature] + if data_type is None: + data_type = property.data_type + property_name = property.name + if data_type != property.data_type: + raise ValueError( + f"Inconsistent DataType: '{data_type}' for {property_name} \ + and '{property.data_type}' for {property.name}" + ) + + @staticmethod + def check_node_features(schema, node_features): + if node_features is None: + return + for label, feature_names in node_features.items(): + if label not in schema.vertex_labels: + raise ValueError(f"Invalid vertex label: {label}") + GLTorchGraph.check_features( + feature_names, schema.get_vertex_properties(label) + ) + + @staticmethod + def check_edge_features(schema, edge_features): + if edge_features is None: + return + for edge, feature_names in edge_features.items(): + GLTorchGraph.check_edge(edge) + GLTorchGraph.check_features( + feature_names, schema.get_edge_properties(edge[1]) + ) + + @staticmethod + def check_node_labels(schema, node_labels): + if node_labels is None: + return + for label, property_name in node_labels.items(): + if label not in schema.vertex_labels: + raise ValueError(f"Invalid vertex label: {label}") + vertex_property_names = [ + property.name for property in schema.get_vertex_properties(label) + ] + if property_name not in vertex_property_names: + raise ValueError( + f"Invalid property name '{property_name}' for vertex label '{label}'" + ) + + @staticmethod + def check_edge_weights(schema, edge_weights): + if edge_weights is None: + return + for edge, property_name in edge_weights.items(): + GLTorchGraph.check_edge(edge) + edge_property_names = [ + property.name for property in schema.get_edge_properties(edge[1]) + ] + if property_name not in edge_property_names: + raise ValueError( + f"Invalid property name '{property_name}' for edge '{edge}'" + ) + + @staticmethod + def check_random_node_split(random_node_split): + if random_node_split is None: + return + if not isinstance(random_node_split, dict): + raise ValueError("Random node split should be a dictionary") + if "num_val" not in random_node_split: + raise ValueError("Missing 'num_val' in random node split") + if "num_test" not in random_node_split: + raise ValueError("Missing 'num_test' in random node split") + if len(random_node_split) != 2: + raise ValueError("Invalid parameters in random node split") + + @staticmethod + def check_params(schema, config): + GLTorchGraph.check_edges(schema, config.get("edges")) + GLTorchGraph.check_node_features(schema, config.get("node_features")) + GLTorchGraph.check_edge_features(schema, config.get("edge_features")) + GLTorchGraph.check_node_labels(schema, config.get("node_labels")) + GLTorchGraph.check_random_node_split(config.get("random_node_split")) + GLTorchGraph.check_edge_weights(schema, config.get("edge_weights")) + + @staticmethod + def transform_config(config): + # transform config to a format that is compatible with json dumps and loads + transformed_config = config.copy() + transformed_config["edges"] = [ + [node for node in edge] for edge in config["edges"] + ] + if config["edge_weights"]: + transformed_config["edge_weights"] = { + config["edges"].index(edge): weights + for edge, weights in config["edge_weights"].items() + } + if config["edge_features"]: + transformed_config["edge_features"] = { + config["edges"].index(edge): features + for edge, features in config["edge_features"].items() + } + return transformed_config + + @staticmethod + def reverse_transform_config(config): + reversed_config = config.copy() + reversed_config["edges"] = [tuple(edge) for edge in config["edges"]] + if config["edge_weights"]: + reversed_config["edge_weights"] = { + reversed_config["edges"][int(index)]: weights + for index, weights in config["edge_weights"].items() + } + if config["edge_features"]: + reversed_config["edge_features"] = { + reversed_config["edges"][int(index)]: features + for index, features in config["edge_features"].items() + } + return reversed_config diff --git a/python/graphscope/learning/graphlearn_torch b/python/graphscope/learning/graphlearn_torch new file mode 120000 index 000000000000..2c5e552172b8 --- /dev/null +++ b/python/graphscope/learning/graphlearn_torch @@ -0,0 +1 @@ +../../../learning_engine/graphlearn-for-pytorch/graphlearn_torch/python \ No newline at end of file diff --git a/python/requirements.txt b/python/requirements.txt index aeecca8269b6..270f49ee39ae 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -24,3 +24,4 @@ tqdm click vineyard>=0.16.3;sys_platform!="win32" simple-parsing +torch==1.13 diff --git a/python/setup.cfg b/python/setup.cfg index 35d63de4b6a3..c8711af74a18 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -16,7 +16,7 @@ upload_dir = docs/_build/html ensure_newline_before_comments = True line_length = 88 force_single_line = True -skip = build/,dist/,jupyter/graphscope/node_modules/,graphscope/learning/examples/,graphscope/learning/graphlearn/ +skip = build/,dist/,jupyter/graphscope/node_modules/,graphscope/learning/examples/,graphscope/learning/graphlearn/,graphscope/learning/graphlearn_torch/ skip_glob = *_pb2.py,*_pb2_grpc.py [flake8] @@ -41,6 +41,7 @@ extend-exclude = graphscope/proto/** graphscope/learning/examples/** graphscope/learning/graphlearn/** + graphscope/learning/graphlearn_torch/** jupyter/graphscope/node_modules/.* per-file-ignores = graphscope/nx/classes/function.py:F405 diff --git a/python/setup.py b/python/setup.py index 1680934ac70f..36aee891f21b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -18,6 +18,7 @@ import os import platform +import shutil import site import subprocess import sys @@ -32,6 +33,12 @@ from setuptools.command.sdist import sdist from wheel.bdist_wheel import bdist_wheel +try: + import torch + import torch.utils.cpp_extension +except ImportError: + torch = None + # Enables --editable install with --user # https://github.com/pypa/pip/issues/7953 site.ENABLE_USER_SITE = "--user" in sys.argv[1:] @@ -45,6 +52,15 @@ else: os.environ["ARCHFLAGS"] = "-arch x86_64" +GL_EXT_NAME = "graphscope.learning.graphlearn.pywrap_graphlearn" +GLTORCH_EXT_NAME = "graphscope.learning.graphlearn_torch.py_graphlearn_torch" +GLTORCH_V6D_EXT_NAME = ( + "graphscope.learning.graphlearn_torch.py_graphlearn_torch_vineyard" +) +glt_root_path = os.path.abspath( + os.path.join(pkg_root, "..", "learning_engine", "graphlearn-for-pytorch") +) + class BuildProto(Command): description = "build protobuf file" @@ -111,12 +127,56 @@ def run(self): build_py.run(self) -class CustomBuildExt(build_ext): +class BuildGLExt(build_ext): def run(self): + self.extensions = [ext for ext in self.extensions if ext.name == GL_EXT_NAME] self.run_command("build_proto") build_ext.run(self) +class BuildGLTorchExt(torch.utils.cpp_extension.BuildExtension if torch else build_ext): + def run(self): + assert ( + torch + ), "Building graphlearn-torch extension requires installing pytorch first. Let WITH_GLTORCH=OFF if you don't need it." + self.extensions = [ + ext + for ext in self.extensions + if ext.name in [GLTORCH_EXT_NAME, GLTORCH_V6D_EXT_NAME] + ] + torch.utils.cpp_extension.BuildExtension.run(self) + + def _get_gcc_use_cxx_abi(self): + if hasattr(self, "_gcc_use_cxx_abi"): + return self._gcc_use_cxx_abi + build_dir = os.path.join(glt_root_path, "cmake-build") + os.makedirs(build_dir, exist_ok=True) + output = subprocess.run( + [shutil.which("cmake"), ".."], + cwd=build_dir, + capture_output=True, + text=True, + ) + import re + + match = re.search(r"GCC_USE_CXX11_ABI: (\d)", str(output)) + if match: + self._gcc_use_cxx_abi = match.group(1) + else: + return None + + return self._gcc_use_cxx_abi + + def _add_gnu_cpp_abi_flag(self, extension): + gcc_use_cxx_abi = ( + self._get_gcc_use_cxx_abi() + if extension.name == GLTORCH_V6D_EXT_NAME + else str(int(torch._C._GLIBCXX_USE_CXX11_ABI)) + ) + print(f"GCC_USE_CXX11_ABI for {extension.name}: {gcc_use_cxx_abi}") + self._add_compile_flag(extension, "-D_GLIBCXX_USE_CXX11_ABI=" + gcc_use_cxx_abi) + + class CustomDevelop(develop): def run(self): develop.run(self) @@ -201,6 +261,32 @@ def parsed_package_data(): def build_learning_engine(): + ext_modules = [graphlearn_ext()] + if torch and os.path.exists(os.path.join(glt_root_path, "graphlearn_torch")): + sys.path.append( + os.path.join(glt_root_path, "graphlearn_torch", "python", "utils") + ) + from build import glt_ext_module + from build import glt_v6d_ext_module + + ext_modules.append( + glt_ext_module( + name=GLTORCH_EXT_NAME, + root_path=glt_root_path, + with_cuda=False, + release=False, + ) + ) + ext_modules.append( + glt_v6d_ext_module( + name=GLTORCH_V6D_EXT_NAME, + root_path=glt_root_path, + ) + ) + return ext_modules + + +def graphlearn_ext(): import numpy ROOT_PATH = os.path.abspath( @@ -252,8 +338,9 @@ def build_learning_engine(): # KNN not enabled # ROOT_PATH + "/graphlearn/python/c/py_contrib.cc", ] - ext = Extension( - "graphscope.learning.graphlearn.pywrap_graphlearn", + + return Extension( + GL_EXT_NAME, sources, extra_compile_args=extra_compile_args, extra_link_args=extra_link_args, @@ -261,7 +348,6 @@ def build_learning_engine(): library_dirs=library_dirs, libraries=libraries, ) - return [ext] def parse_version(root, **kwargs): @@ -318,7 +404,8 @@ def parse_version(root, **kwargs): package_data=parsed_package_data(), ext_modules=build_learning_engine(), cmdclass={ - "build_ext": CustomBuildExt, + "build_ext": BuildGLExt, + "build_gltorch_ext": BuildGLTorchExt, "build_proto": BuildProto, "build_py": CustomBuildPy, "bdist_wheel": CustomBDistWheel,