From 348c9636399535c566d20e8ebff2b7aa0775f136 Mon Sep 17 00:00:00 2001 From: Burcin Bozkaya Date: Thu, 6 Jul 2023 13:12:01 -0400 Subject: [PATCH] adding unit test for end-to-end example (#669) * adding unit test for multi-gpu example * added test for notebook 03 * fixed formatting * update * update * Update 01-ETL-with-NVTabular.ipynb day of week is between 0 and 6; it must be scaled with a max value of 6 to produce correct values from the 0-1 range. If we do col+1 and scale with 7, then a section of the 0-2pi range (for Sine purposes) will not be represented. * Update 01-ETL-with-NVTabular.ipynb Reversed the previous edit for weekday scaling. It is correct that it should be scaled between 0-7, because day 0 (unused/nonapplicable after +1 added) overlaps with day 7 for Sine purposes. Monday should scale to 1/7, Sunday should scale to 7/7 to achieve even distribution of days along the sinus curve. * reduce num_rows * Update test_end_to_end_session_based.py * Update 01-ETL-with-NVTabular.ipynb * updated test script and notebook * updated file * removed nb3 test due to multi-gpu freezing issue * revised notebooks, added back nb3 test * fixed test file with black * update test py * update test py * Use `python -m torch.distributed.run` instead of `torchrun` The `torchrun` script installed in the system is a python script with a shebang line starting with `#!/usr/bin/python3` This picks up the wrong version of python when running in a virtualenv like our tox test environment. If instead this were `#!/usr/bin/env python3` it would work ok in a tox environment to call `torchrun`. However, until either the pytorch package is updated for this to happen or we update our CI image for this to take place. Running the python command directly is more reliable. --------- Co-authored-by: rnyak Co-authored-by: edknv <109497216+edknv@users.noreply.github.com> Co-authored-by: rnyak <16246900+rnyak@users.noreply.github.com> Co-authored-by: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com> --- .../01-ETL-with-NVTabular.ipynb | 160 +++++-- ...end-session-based-with-Yoochoose-PyT.ipynb | 432 +++++++----------- ...ased-Yoochoose-multigpu-training-PyT.ipynb | 283 +++++++++++- .../test_end_to_end_session_based.py | 85 ++++ 4 files changed, 631 insertions(+), 329 deletions(-) create mode 100644 tests/integration/notebooks/test_end_to_end_session_based.py diff --git a/examples/end-to-end-session-based/01-ETL-with-NVTabular.ipynb b/examples/end-to-end-session-based/01-ETL-with-NVTabular.ipynb index 6168597d34..bf127e8146 100644 --- a/examples/end-to-end-session-based/01-ETL-with-NVTabular.ipynb +++ b/examples/end-to-end-session-based/01-ETL-with-NVTabular.ipynb @@ -61,6 +61,8 @@ "\n", "The dataset is available on [Kaggle](https://www.kaggle.com/chadgostopp/recsys-challenge-2015). You need to download it and copy to the `DATA_FOLDER` path. Note that we are only using the `yoochoose-clicks.dat` file.\n", "\n", + "Alternatively, you can generate a synthetic dataset with the same columns and dtypes as the `YOOCHOOSE` dataset and a default date range of 5 days. If the environment variable `USE_SYNTHETIC` is set to `True`, the code below will execute the function `generate_synthetic_data` and the rest of the notebook will run on a synthetic dataset.\n", + "\n", "First, let's start by importing several libraries:" ] }, @@ -75,9 +77,7 @@ "output_type": "stream", "text": [ "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", - " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n", - "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" + " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n" ] } ], @@ -85,7 +85,10 @@ "import os\n", "import glob\n", "import numpy as np\n", + "import pandas as pd\n", "import gc\n", + "import calendar\n", + "import datetime\n", "\n", "import cudf\n", "import cupy\n", @@ -128,12 +131,14 @@ "metadata": {}, "outputs": [], "source": [ - "DATA_FOLDER = \"/workspace/data/\"\n", + "DATA_FOLDER = os.environ.get(\"DATA_FOLDER\", \"/workspace/data\")\n", "FILENAME_PATTERN = 'yoochoose-clicks.dat'\n", "DATA_PATH = os.path.join(DATA_FOLDER, FILENAME_PATTERN)\n", "\n", "OUTPUT_FOLDER = \"./yoochoose_transformed\"\n", - "OVERWRITE = False" + "OVERWRITE = False\n", + "\n", + "USE_SYNTHETIC = os.environ.get(\"USE_SYNTHETIC\", False)" ] }, { @@ -144,16 +149,89 @@ "## Load and clean raw data" ] }, + { + "cell_type": "markdown", + "id": "3fba8546-668c-4743-960e-ea2aef99ef24", + "metadata": {}, + "source": [ + "Execute the cell below if you would like to work with synthetic data. Otherwise you can skip and continue with the next cell." + ] + }, { "cell_type": "code", "execution_count": 5, + "id": "07d14289-c783-45f0-86e8-e5c1001bfd76", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_synthetic_data(\n", + " start_date: datetime.date, end_date: datetime.date, rows_per_day: int = 10000\n", + ") -> pd.DataFrame:\n", + " assert end_date > start_date, \"end_date must be later than start_date\"\n", + "\n", + " number_of_days = (end_date - start_date).days\n", + " total_number_of_rows = number_of_days * rows_per_day\n", + "\n", + " # Generate a long-tail distribution of item interactions. This simulates that some items are\n", + " # more popular than others.\n", + " long_tailed_item_distribution = np.clip(\n", + " np.random.lognormal(3.0, 1.0, total_number_of_rows).astype(np.int64), 1, 50000\n", + " )\n", + "\n", + " # generate random item interaction features\n", + " df = pd.DataFrame(\n", + " {\n", + " \"session_id\": np.random.randint(70000, 80000, total_number_of_rows),\n", + " \"item_id\": long_tailed_item_distribution,\n", + " },\n", + " )\n", + "\n", + " # generate category mapping for each item-id\n", + " df[\"category\"] = pd.cut(df[\"item_id\"], bins=334, labels=np.arange(1, 335)).astype(\n", + " np.int64\n", + " )\n", + "\n", + " max_session_length = 60 * 60 # 1 hour\n", + "\n", + " def add_timestamp_to_session(session: pd.DataFrame):\n", + " random_start_date_and_time = calendar.timegm(\n", + " (\n", + " start_date\n", + " # Add day offset from start_date\n", + " + datetime.timedelta(days=np.random.randint(0, number_of_days))\n", + " # Add time offset within the random day\n", + " + datetime.timedelta(seconds=np.random.randint(0, 86_400))\n", + " ).timetuple()\n", + " )\n", + " session[\"timestamp\"] = random_start_date_and_time + np.clip(\n", + " np.random.lognormal(3.0, 1.0, len(session)).astype(np.int64),\n", + " 0,\n", + " max_session_length,\n", + " )\n", + " return session\n", + "\n", + " df = df.groupby(\"session_id\").apply(add_timestamp_to_session).reset_index()\n", + "\n", + " return df" + ] + }, + { + "cell_type": "code", + "execution_count": 6, "id": "f35dff52", "metadata": {}, "outputs": [], "source": [ - "interactions_df = cudf.read_csv(DATA_PATH, sep=',', \n", - " names=['session_id','timestamp', 'item_id', 'category'], \n", - " dtype=['int', 'datetime64[s]', 'int', 'int'])" + "if USE_SYNTHETIC:\n", + " START_DATE = os.environ.get(\"START_DATE\", \"2014/4/1\")\n", + " END_DATE = os.environ.get(\"END_DATE\", \"2014/4/5\")\n", + " interactions_df = generate_synthetic_data(datetime.datetime.strptime(START_DATE, '%Y/%m/%d'),\n", + " datetime.datetime.strptime(END_DATE, '%Y/%m/%d'))\n", + " interactions_df = cudf.from_pandas(interactions_df)\n", + "else:\n", + " interactions_df = cudf.read_csv(DATA_PATH, sep=',', \n", + " names=['session_id','timestamp', 'item_id', 'category'], \n", + " dtype=['int', 'datetime64[s]', 'int', 'int'])" ] }, { @@ -166,7 +244,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "22c2df72", "metadata": {}, "outputs": [ @@ -181,13 +259,16 @@ ], "source": [ "print(\"Count with in-session repeated interactions: {}\".format(len(interactions_df)))\n", + "\n", "# Sorts the dataframe by session and timestamp, to remove consecutive repetitions\n", "interactions_df.timestamp = interactions_df.timestamp.astype(int)\n", "interactions_df = interactions_df.sort_values(['session_id', 'timestamp'])\n", "past_ids = interactions_df['item_id'].shift(1).fillna()\n", "session_past_ids = interactions_df['session_id'].shift(1).fillna()\n", + "\n", "# Keeping only no consecutive repeated in session interactions\n", "interactions_df = interactions_df[~((interactions_df['session_id'] == session_past_ids) & (interactions_df['item_id'] == past_ids))]\n", + "\n", "print(\"Count after removed in-session repeated interactions: {}\".format(len(interactions_df)))" ] }, @@ -201,7 +282,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "66a1bd13", "metadata": {}, "outputs": [ @@ -234,17 +315,19 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "a0f908a1", "metadata": {}, "outputs": [], "source": [ + "if os.path.isdir(DATA_FOLDER) == False:\n", + " os.mkdir(DATA_FOLDER)\n", "interactions_merged_df.to_parquet(os.path.join(DATA_FOLDER, 'interactions_merged_df.parquet'))" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "909f87c5-bff5-48c8-b714-cc556a4bc64d", "metadata": { "tags": [] @@ -265,17 +348,17 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "04a3b5b7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "0" + "517" ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -330,7 +413,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "id": "86f80068", "metadata": {}, "outputs": [], @@ -425,7 +508,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "id": "10b5c96c", "metadata": {}, "outputs": [], @@ -447,7 +530,6 @@ "# Truncate sequence features to first interacted 20 items \n", "SESSIONS_MAX_LENGTH = 20 \n", "\n", - "\n", "item_feat = groupby_features['item_id-list'] >> nvt.ops.TagAsItemID()\n", "cont_feats = groupby_features['et_dayofweek_sin-list', 'product_recency_days_log_norm-list'] >> nvt.ops.AddMetadata(tags=[Tags.CONTINUOUS])\n", "\n", @@ -491,7 +573,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "id": "45803886", "metadata": {}, "outputs": [], @@ -513,7 +595,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "id": "4c10efb5-89c5-4458-a634-475eb459a47c", "metadata": { "tags": [] @@ -600,7 +682,7 @@ " \n", " 2\n", " item_id-list\n", - " (Tags.CATEGORICAL, Tags.ITEM, Tags.ID, Tags.LIST)\n", + " (Tags.CATEGORICAL, Tags.ID, Tags.LIST, Tags.ITEM)\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " True\n", @@ -697,10 +779,10 @@ "" ], "text/plain": [ - "[{'name': 'session_id', 'tags': {}, 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}}, 'dtype': DType(name='int32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 336, 'name': 'category'}, 'embedding_sizes': {'cardinality': 337, 'dimension': 42}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'day_index', 'tags': {}, 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}]" + "[{'name': 'session_id', 'tags': {}, 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}}, 'dtype': DType(name='int32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {, , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 336, 'name': 'category'}, 'embedding_sizes': {'cardinality': 337, 'dimension': 42}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'day_index', 'tags': {}, 'properties': {}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}]" ] }, - "execution_count": 14, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -719,7 +801,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "id": "2d035a88-2146-4b9a-96fd-dd42be86e2a1", "metadata": {}, "outputs": [], @@ -747,19 +829,23 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "id": "2b4f5b73-459c-4356-87c8-9afb974cc77d", "metadata": {}, "outputs": [], "source": [ "# read in the processed train dataset\n", "sessions_gdf = cudf.read_parquet(os.path.join(DATA_FOLDER, \"processed_nvt/part_0.parquet\"))\n", - "sessions_gdf = sessions_gdf[sessions_gdf.day_index>=178]" + "if USE_SYNTHETIC:\n", + " THRESHOLD_DAY_INDEX = int(os.environ.get(\"THRESHOLD_DAY_INDEX\", '1'))\n", + " sessions_gdf = sessions_gdf[sessions_gdf.day_index>=THRESHOLD_DAY_INDEX]\n", + "else:\n", + " sessions_gdf = sessions_gdf[sessions_gdf.day_index>=178]" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "id": "e18d9c63", "metadata": {}, "outputs": [ @@ -783,13 +869,13 @@ "6606149 [-0.7818309228245777, -0.7818309228245777] \n", "\n", " product_recency_days_log_norm-list \\\n", - "6606147 [1.5241553, 1.5238751, 1.5239341, 1.5241631, 1... \n", - "6606148 [-0.5330064, 1.521494] \n", - "6606149 [1.5338266, 1.5355074] \n", + "6606147 [1.5241561, 1.523876, 1.523935, 1.5241641, 1.5... \n", + "6606148 [-0.533007, 1.521495] \n", + "6606149 [1.5338274, 1.5355083] \n", "\n", " category-list day_index \n", - "6606147 [4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4, 4] 178 \n", - "6606148 [3, 3] 178 \n", + "6606147 [4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4] 178 \n", + "6606148 [1, 3] 178 \n", "6606149 [8, 8] 180 \n" ] } @@ -800,7 +886,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "id": "5175aeaf", "metadata": {}, "outputs": [ @@ -808,7 +894,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Creating time-based splits: 100%|██████████| 5/5 [00:02<00:00, 2.37it/s]\n" + "Creating time-based splits: 100%|██████████| 5/5 [00:02<00:00, 2.24it/s]\n" ] } ], @@ -823,17 +909,17 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "id": "3bd1bad9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "583" + "748" ] }, - "execution_count": 19, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } diff --git a/examples/end-to-end-session-based/02-End-to-end-session-based-with-Yoochoose-PyT.ipynb b/examples/end-to-end-session-based/02-End-to-end-session-based-with-Yoochoose-PyT.ipynb index 1b83844573..18a4affd14 100644 --- a/examples/end-to-end-session-based/02-End-to-end-session-based-with-Yoochoose-PyT.ipynb +++ b/examples/end-to-end-session-based/02-End-to-end-session-based-with-Yoochoose-PyT.ipynb @@ -98,6 +98,7 @@ "outputs": [], "source": [ "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n", "INPUT_DATA_DIR = os.environ.get(\"INPUT_DATA_DIR\", \"/workspace/data\")\n", "OUTPUT_DIR = os.environ.get(\"OUTPUT_DIR\", f\"{INPUT_DATA_DIR}/preproc_sessions_by_day\")" ] @@ -113,11 +114,7 @@ "output_type": "stream", "text": [ "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", - " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n", - "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n" + " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n" ] } ], @@ -194,7 +191,6 @@ " properties.num_buckets\n", " properties.freq_threshold\n", " properties.max_size\n", - " properties.start_index\n", " properties.cat_path\n", " properties.embedding_sizes.cardinality\n", " properties.embedding_sizes.dimension\n", @@ -209,19 +205,18 @@ " \n", " 0\n", " item_id-list\n", - " (Tags.ITEM_ID, Tags.CATEGORICAL, Tags.ID, Tags...\n", + " (Tags.CATEGORICAL, Tags.LIST, Tags.ID, Tags.ITEM)\n", " DType(name='int64', element_type=<ElementType....\n", " True\n", " True\n", " NaN\n", " 0.0\n", " 0.0\n", - " 1.0\n", " .//categories/unique.item_id.parquet\n", - " 52741.0\n", + " 52742.0\n", " 512.0\n", " 0.0\n", - " 52740.0\n", + " 52741.0\n", " item_id\n", " 0\n", " 20\n", @@ -236,12 +231,11 @@ " NaN\n", " 0.0\n", " 0.0\n", - " 1.0\n", " .//categories/unique.category.parquet\n", - " 336.0\n", + " 337.0\n", " 42.0\n", " 0.0\n", - " 335.0\n", + " 336.0\n", " category\n", " 0\n", " 20\n", @@ -262,7 +256,6 @@ " NaN\n", " NaN\n", " NaN\n", - " NaN\n", " 0\n", " 20\n", " \n", @@ -270,7 +263,7 @@ " 3\n", " et_dayofweek_sin-list\n", " (Tags.LIST, Tags.CONTINUOUS)\n", - " DType(name='float32', element_type=<ElementTyp...\n", + " DType(name='float64', element_type=<ElementTyp...\n", " True\n", " True\n", " NaN\n", @@ -282,7 +275,6 @@ " NaN\n", " NaN\n", " NaN\n", - " NaN\n", " 0\n", " 20\n", " \n", @@ -291,7 +283,7 @@ "" ], "text/plain": [ - "[{'name': 'item_id-list', 'tags': {, , , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 1.0, 'cat_path': './/categories/unique.item_id.parquet', 'embedding_sizes': {'cardinality': 52741.0, 'dimension': 512.0}, 'domain': {'min': 0, 'max': 52740, 'name': 'item_id'}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'start_index': 1.0, 'cat_path': './/categories/unique.category.parquet', 'embedding_sizes': {'cardinality': 336.0, 'dimension': 42.0}, 'domain': {'min': 0, 'max': 335, 'name': 'category'}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}]" + "[{'name': 'item_id-list', 'tags': {, , , }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'cat_path': './/categories/unique.item_id.parquet', 'embedding_sizes': {'cardinality': 52742.0, 'dimension': 512.0}, 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {, }, 'properties': {'num_buckets': None, 'freq_threshold': 0.0, 'max_size': 0.0, 'cat_path': './/categories/unique.category.parquet', 'embedding_sizes': {'cardinality': 337.0, 'dimension': 42.0}, 'domain': {'min': 0, 'max': 336, 'name': 'category'}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {, }, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float64', element_type=, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}]" ] }, "execution_count": 5, @@ -343,6 +335,8 @@ "name": "stderr", "output_type": "stream", "text": [ + "/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", "Projecting inputs of NextItemPredictionTask to'64' As weight tying requires the input dimension '320' to be equal to the item-id embedding dimension '64'\n" ] } @@ -434,14 +428,16 @@ }, "outputs": [], "source": [ + "BATCH_SIZE_TRAIN = int(os.environ.get(\"BATCH_SIZE_TRAIN\", \"512\"))\n", + "BATCH_SIZE_VALID = int(os.environ.get(\"BATCH_SIZE_VALID\", \"256\"))\n", "training_args = tr.trainer.T4RecTrainingArguments(\n", " output_dir=\"./tmp\",\n", " max_sequence_length=20,\n", " data_loader_engine='merlin',\n", " num_train_epochs=10, \n", " dataloader_drop_last=False,\n", - " per_device_train_batch_size = 384,\n", - " per_device_eval_batch_size = 512,\n", + " per_device_train_batch_size = BATCH_SIZE_TRAIN,\n", + " per_device_eval_batch_size = BATCH_SIZE_VALID,\n", " learning_rate=0.0005,\n", " fp16=True,\n", " report_to = [],\n", @@ -494,7 +490,9 @@ "id": "709592a2", "metadata": {}, "source": [ - "In this demo, we will use the `fit_and_evaluate` method that allows us to conduct a time-based finetuning by iteratively training and evaluating using a sliding time window: At each iteration, we use the training data of a specific time index $t$ to train the model; then we evaluate on the validation data of the next index $t + 1$. Particularly, we set start time to 178 and end time to 180." + "In this demo, we will use the `fit_and_evaluate` method that allows us to conduct a time-based finetuning by iteratively training and evaluating using a sliding time window: At each iteration, we use the training data of a specific time index $t$ to train the model; then we evaluate on the validation data of the next index $t + 1$. Particularly, we set start time to 178 and end time to 180.\n", + "\n", + "If you have generated a synthetic dataset in the previous notebook, remember to change the values 178 and 180 accordingly." ] }, { @@ -506,24 +504,24 @@ }, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "***** Running training *****\n", - " Num examples = 28800\n", - " Num Epochs = 10\n", - " Instantaneous batch size per device = 384\n", - " Total train batch size (w. parallel, distributed & accumulation) = 384\n", - " Gradient Accumulation steps = 1\n", - " Total optimization steps = 750\n" + "\n", + "***** Launch training for day 178: *****\n" ] }, { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "\n", - "***** Launch training for day 178: *****\n" + "***** Running training *****\n", + " Num examples = 28672\n", + " Num Epochs = 10\n", + " Instantaneous batch size per device = 512\n", + " Total train batch size (w. parallel, distributed & accumulation) = 512\n", + " Gradient Accumulation steps = 1\n", + " Total optimization steps = 560\n" ] }, { @@ -532,8 +530,8 @@ "\n", "
\n", " \n", - " \n", - " [750/750 00:20, Epoch 10/10]\n", + " \n", + " [560/560 00:25, Epoch 10/10]\n", "
\n", " \n", " \n", @@ -545,15 +543,11 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", "
2007.6811007.585600
4006.656800
6006.3720006.608700

" @@ -584,8 +578,8 @@ "\n", "

\n", " \n", - " \n", - " [6/6 00:27]\n", + " \n", + " [11/11 00:33]\n", "
\n", " " ], @@ -601,12 +595,12 @@ "output_type": "stream", "text": [ "***** Running training *****\n", - " Num examples = 20736\n", + " Num examples = 20480\n", " Num Epochs = 10\n", - " Instantaneous batch size per device = 384\n", - " Total train batch size (w. parallel, distributed & accumulation) = 384\n", + " Instantaneous batch size per device = 512\n", + " Total train batch size (w. parallel, distributed & accumulation) = 512\n", " Gradient Accumulation steps = 1\n", - " Total optimization steps = 540\n" + " Total optimization steps = 400\n" ] }, { @@ -616,12 +610,12 @@ "\n", "***** Evaluation results for day 179:*****\n", "\n", - " eval_/next-item/avg_precision@10 = 0.08119537681341171\n", - " eval_/next-item/avg_precision@20 = 0.0857219472527504\n", - " eval_/next-item/ndcg@10 = 0.11199340969324112\n", - " eval_/next-item/ndcg@20 = 0.12857995927333832\n", - " eval_/next-item/recall@10 = 0.20809248089790344\n", - " eval_/next-item/recall@20 = 0.27398842573165894\n", + " eval_/next-item/avg_precision@10 = 0.07277625054121017\n", + " eval_/next-item/avg_precision@20 = 0.077287457883358\n", + " eval_/next-item/ndcg@10 = 0.1008271649479866\n", + " eval_/next-item/ndcg@20 = 0.11763089150190353\n", + " eval_/next-item/recall@10 = 0.18959537148475647\n", + " eval_/next-item/recall@20 = 0.25549131631851196\n", "\n", "***** Launch training for day 179: *****\n" ] @@ -632,8 +626,8 @@ "\n", "
\n", " \n", - " \n", - " [540/540 00:14, Epoch 10/10]\n", + " \n", + " [400/400 00:17, Epoch 10/10]\n", "
\n", " \n", " \n", @@ -645,11 +639,11 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", "
2006.8059006.838400
4006.2503006.304600

" @@ -665,8 +659,6 @@ "name": "stderr", "output_type": "stream", "text": [ - "Saving model checkpoint to ./tmp/checkpoint-500\n", - "Trainer.model is not a `PreTrainedModel`, only saving its state dict.\n", "\n", "\n", "Training completed. Do not forget to share your model on huggingface.co/models =)\n", @@ -675,10 +667,10 @@ "***** Running training *****\n", " Num examples = 16896\n", " Num Epochs = 10\n", - " Instantaneous batch size per device = 384\n", - " Total train batch size (w. parallel, distributed & accumulation) = 384\n", + " Instantaneous batch size per device = 512\n", + " Total train batch size (w. parallel, distributed & accumulation) = 512\n", " Gradient Accumulation steps = 1\n", - " Total optimization steps = 440\n" + " Total optimization steps = 330\n" ] }, { @@ -688,12 +680,12 @@ "\n", "***** Evaluation results for day 180:*****\n", "\n", - " eval_/next-item/avg_precision@10 = 0.06664632260799408\n", - " eval_/next-item/avg_precision@20 = 0.07125943899154663\n", - " eval_/next-item/ndcg@10 = 0.09318044036626816\n", - " eval_/next-item/ndcg@20 = 0.110211081802845\n", - " eval_/next-item/recall@10 = 0.17855477333068848\n", - " eval_/next-item/recall@20 = 0.24568764865398407\n", + " eval_/next-item/avg_precision@10 = 0.059328265488147736\n", + " eval_/next-item/avg_precision@20 = 0.06352042406797409\n", + " eval_/next-item/ndcg@10 = 0.08318208903074265\n", + " eval_/next-item/ndcg@20 = 0.09845318645238876\n", + " eval_/next-item/recall@10 = 0.16083915531635284\n", + " eval_/next-item/recall@20 = 0.2209790199995041\n", "\n", "***** Launch training for day 180: *****\n" ] @@ -704,8 +696,8 @@ "\n", "

\n", " \n", - " \n", - " [440/440 00:12, Epoch 10/10]\n", + " \n", + " [330/330 00:14, Epoch 10/10]\n", "
\n", " \n", " \n", @@ -717,11 +709,7 @@ " \n", " \n", " \n", - " \n", - " \n", - " \n", - " \n", - " \n", + " \n", " \n", " \n", "
2006.608300
4006.0305006.708000

" @@ -751,18 +739,20 @@ "\n", "***** Evaluation results for day 181:*****\n", "\n", - " eval_/next-item/avg_precision@10 = 0.13680869340896606\n", - " eval_/next-item/avg_precision@20 = 0.14374792575836182\n", - " eval_/next-item/ndcg@10 = 0.18158714473247528\n", - " eval_/next-item/ndcg@20 = 0.2070869356393814\n", - " eval_/next-item/recall@10 = 0.3181818127632141\n", - " eval_/next-item/recall@20 = 0.4202226400375366\n" + " eval_/next-item/avg_precision@10 = 0.12736327946186066\n", + " eval_/next-item/avg_precision@20 = 0.13500627875328064\n", + " eval_/next-item/ndcg@10 = 0.16738776862621307\n", + " eval_/next-item/ndcg@20 = 0.19680777192115784\n", + " eval_/next-item/recall@10 = 0.29406309127807617\n", + " eval_/next-item/recall@20 = 0.41187384724617004\n" ] } ], "source": [ "from transformers4rec.torch.utils.examples_utils import fit_and_evaluate\n", - "OT_results = fit_and_evaluate(recsys_trainer, start_time_index=178, end_time_index=180, input_dir=OUTPUT_DIR)" + "start_time_idx = int(os.environ.get(\"START_TIME_INDEX\", \"178\"))\n", + "end_time_idx = int(os.environ.get(\"END_TIME_INDEX\", \"180\"))\n", + "OT_results = fit_and_evaluate(recsys_trainer, start_time_index=start_time_idx, end_time_index=end_time_idx, input_dir=OUTPUT_DIR)" ] }, { @@ -794,24 +784,24 @@ { "data": { "text/plain": [ - "{'indexed_by_time_eval_/next-item/avg_precision@10': [0.08119537681341171,\n", - " 0.06664632260799408,\n", - " 0.13680869340896606],\n", - " 'indexed_by_time_eval_/next-item/avg_precision@20': [0.0857219472527504,\n", - " 0.07125943899154663,\n", - " 0.14374792575836182],\n", - " 'indexed_by_time_eval_/next-item/ndcg@10': [0.11199340969324112,\n", - " 0.09318044036626816,\n", - " 0.18158714473247528],\n", - " 'indexed_by_time_eval_/next-item/ndcg@20': [0.12857995927333832,\n", - " 0.110211081802845,\n", - " 0.2070869356393814],\n", - " 'indexed_by_time_eval_/next-item/recall@10': [0.20809248089790344,\n", - " 0.17855477333068848,\n", - " 0.3181818127632141],\n", - " 'indexed_by_time_eval_/next-item/recall@20': [0.27398842573165894,\n", - " 0.24568764865398407,\n", - " 0.4202226400375366]}" + "{'indexed_by_time_eval_/next-item/avg_precision@10': [0.07277625054121017,\n", + " 0.059328265488147736,\n", + " 0.12736327946186066],\n", + " 'indexed_by_time_eval_/next-item/avg_precision@20': [0.077287457883358,\n", + " 0.06352042406797409,\n", + " 0.13500627875328064],\n", + " 'indexed_by_time_eval_/next-item/ndcg@10': [0.1008271649479866,\n", + " 0.08318208903074265,\n", + " 0.16738776862621307],\n", + " 'indexed_by_time_eval_/next-item/ndcg@20': [0.11763089150190353,\n", + " 0.09845318645238876,\n", + " 0.19680777192115784],\n", + " 'indexed_by_time_eval_/next-item/recall@10': [0.18959537148475647,\n", + " 0.16083915531635284,\n", + " 0.29406309127807617],\n", + " 'indexed_by_time_eval_/next-item/recall@20': [0.25549131631851196,\n", + " 0.2209790199995041,\n", + " 0.41187384724617004]}" ] }, "execution_count": 11, @@ -835,12 +825,12 @@ "name": "stdout", "output_type": "stream", "text": [ - " indexed_by_time_eval_/next-item/avg_precision@10 = 0.09488346427679062\n", - " indexed_by_time_eval_/next-item/avg_precision@20 = 0.10024310400088628\n", - " indexed_by_time_eval_/next-item/ndcg@10 = 0.12892033159732819\n", - " indexed_by_time_eval_/next-item/ndcg@20 = 0.14862599223852158\n", - " indexed_by_time_eval_/next-item/recall@10 = 0.23494302233060202\n", - " indexed_by_time_eval_/next-item/recall@20 = 0.3132995714743932\n" + " indexed_by_time_eval_/next-item/avg_precision@10 = 0.08648926516373952\n", + " indexed_by_time_eval_/next-item/avg_precision@20 = 0.09193805356820424\n", + " indexed_by_time_eval_/next-item/ndcg@10 = 0.1171323408683141\n", + " indexed_by_time_eval_/next-item/ndcg@20 = 0.13763061662515005\n", + " indexed_by_time_eval_/next-item/recall@10 = 0.21483253935972849\n", + " indexed_by_time_eval_/next-item/recall@20 = 0.2961147278547287\n" ] } ], @@ -897,7 +887,7 @@ "metadata": {}, "outputs": [], "source": [ - "df = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, \"./preproc_sessions_by_day/178/train.parquet\"), columns=model.input_schema.column_names)\n", + "df = cudf.read_parquet(os.path.join(INPUT_DATA_DIR, f\"preproc_sessions_by_day/{start_time_idx}/train.parquet\"), columns=model.input_schema.column_names)\n", "table = TensorTable.from_df(df.iloc[:100])\n", "for column in table.columns:\n", " table[column] = convert_col(table[column], TorchColumn)\n", @@ -919,7 +909,8 @@ "metadata": {}, "outputs": [], "source": [ - "model.top_k = 20" + "topk = 20\n", + "model.top_k = topk" ] }, { @@ -934,9 +925,7 @@ "cell_type": "code", "execution_count": 16, "id": "3158138f", - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "model.eval()\n", @@ -962,47 +951,47 @@ { "data": { "text/plain": [ - "tensor([ 604, 878, 742, 90, 4777, 1583, 3446, 8083, 3446, 4018,\n", - " 742, 4777, 184, 12288, 2065, 430, 6, 30, 6, 157,\n", - " 1987, 2590, 10855, 8217, 4210, 8711, 4242, 81, 112, 4242,\n", - " 5732, 6810, 33, 6, 73, 664, 2312, 7124, 9113, 445,\n", - " 1157, 774, 685, 430, 1945, 475, 597, 289, 166, 29,\n", - " 342, 289, 33, 423, 166, 480, 2772, 288, 962, 4001,\n", - " 2050, 3274, 499, 1219, 395, 1636, 11839, 10714, 11107, 289,\n", - " 166, 650, 1085, 302, 88, 650, 214, 304, 177, 317,\n", - " 423, 6, 3818, 931, 2186, 1085, 206, 687, 3831, 687,\n", - " 202, 20, 43, 20, 2034, 10457, 20, 21, 1126, 2815,\n", - " 4210, 34264, 830, 774, 620, 2050, 1987, 1079, 4713, 1336,\n", - " 661, 289, 430, 863, 4829, 5786, 19156, 17270, 23365, 4209,\n", - " 3651, 1037, 4770, 224, 277, 1020, 650, 166, 1354, 206,\n", - " 1889, 2473, 1697, 997, 480, 774, 3841, 4316, 3841, 1230,\n", - " 3841, 1697, 3809, 475, 981, 804, 313, 613, 1219, 1334,\n", - " 1941, 2888, 2626, 1334, 2689, 804, 475, 981, 313, 804,\n", - " 1219, 206, 651, 429, 605, 101, 413, 1965, 627, 814,\n", - " 627, 814, 4713, 3675, 2789, 3769, 1283, 9540, 1251, 313,\n", - " 685, 2497, 395, 845, 3462, 2713, 5077, 388, 340, 297,\n", - " 388, 9641, 46, 61, 822, 61, 602, 2270, 719, 3274,\n", - " 2556, 2, 61, 24, 96, 61, 423, 61, 475, 3460,\n", - " 2693, 3460, 3044, 2556, 3988, 992, 1603, 122, 2704, 2787,\n", - " 3135, 550, 516, 44, 1551, 2702, 206, 1762, 31, 474,\n", - " 481, 198, 474, 2704, 2393, 1025, 20033, 72, 1334, 224,\n", - " 3460, 4774, 2050, 6485, 15953, 422, 1488, 2346, 4470, 2548,\n", - " 571, 1770, 1324, 453, 837, 123, 638, 4759, 3552, 6825,\n", - " 2740, 5347, 5390, 1169, 4100, 1230, 804, 3588, 2449, 185,\n", - " 16, 643, 274, 686, 18092, 10457, 609, 2969, 3480, 2969,\n", - " 37, 609, 2969, 609, 22, 8312, 257, 37, 22, 5361,\n", - " 7186, 7380, 6052, 7256, 13404, 557, 160, 1664, 4375, 3484,\n", - " 685, 651, 445, 429, 445, 774, 651, 4284, 1738, 3855,\n", - " 225, 210, 7245, 6731, 771, 1987, 157, 804, 442, 804,\n", - " 2091, 1169, 2091, 1169, 3484, 4375, 3484, 445, 429, 430,\n", - " 423, 1697, 1393, 1798, 2753, 206, 1153, 21588, 2189, 3704,\n", - " 4463, 5816, 7557, 507, 1797, 814, 627, 2016, 855, 1889,\n", - " 224, 597, 37, 597, 16533, 10255, 2, 2651, 4028, 2556,\n", - " 2788, 6379, 1830, 1070, 30, 312, 445, 1085, 1569, 2222,\n", - " 2664, 1950, 2098, 1672, 224, 4336, 651, 997, 1157, 830,\n", - " 800, 597, 1085, 12430, 415, 12430, 651, 800, 1756, 1378,\n", - " 1413, 633, 2034, 7932, 6034, 6360, 4662, 576, 4662, 576,\n", - " 4662], device='cuda:0')" + "tensor([ 605, 879, 743, 91, 4778, 1584, 3447, 8084, 3447, 4019,\n", + " 743, 4778, 185, 12289, 2066, 431, 7, 31, 7, 158,\n", + " 1988, 2591, 10856, 8218, 4211, 8712, 4243, 82, 113, 4243,\n", + " 5733, 6811, 34, 7, 74, 665, 2313, 7125, 9114, 446,\n", + " 1158, 775, 686, 431, 1946, 476, 598, 290, 167, 30,\n", + " 343, 290, 34, 424, 167, 481, 2773, 289, 963, 4002,\n", + " 2051, 3275, 500, 1220, 396, 1637, 11840, 10715, 11108, 290,\n", + " 167, 651, 1086, 303, 89, 651, 215, 305, 178, 318,\n", + " 424, 7, 3819, 932, 2187, 1086, 207, 688, 3832, 688,\n", + " 203, 21, 44, 21, 2035, 10458, 21, 22, 1127, 2816,\n", + " 4211, 34265, 831, 775, 621, 2051, 1988, 1080, 4714, 1337,\n", + " 662, 290, 431, 864, 4830, 5787, 19157, 17271, 23366, 4210,\n", + " 3652, 1038, 4771, 225, 278, 1021, 651, 167, 1355, 207,\n", + " 1890, 2474, 1698, 998, 481, 775, 3842, 4317, 3842, 1231,\n", + " 3842, 1698, 3810, 476, 982, 805, 314, 614, 1220, 1335,\n", + " 1942, 2889, 2627, 1335, 2690, 805, 476, 982, 314, 805,\n", + " 1220, 207, 652, 430, 606, 102, 414, 1966, 628, 815,\n", + " 628, 815, 4714, 3676, 2790, 3770, 1284, 9541, 1252, 314,\n", + " 686, 2498, 396, 846, 3463, 2714, 5078, 389, 341, 298,\n", + " 389, 9642, 47, 62, 823, 62, 603, 2271, 720, 3275,\n", + " 2557, 3, 62, 25, 97, 62, 424, 62, 476, 3461,\n", + " 2694, 3461, 3045, 2557, 3989, 993, 1604, 123, 2705, 2788,\n", + " 3136, 551, 517, 45, 1552, 2703, 207, 1763, 32, 475,\n", + " 482, 199, 475, 2705, 2394, 1026, 20034, 73, 1335, 225,\n", + " 3461, 4775, 2051, 6486, 15954, 423, 1489, 2347, 4471, 2549,\n", + " 572, 1771, 1325, 454, 838, 124, 639, 4760, 3553, 6826,\n", + " 2741, 5348, 5391, 1170, 4101, 1231, 805, 3589, 2450, 186,\n", + " 17, 644, 275, 687, 18093, 10458, 610, 2970, 3481, 2970,\n", + " 38, 610, 2970, 610, 23, 8313, 258, 38, 23, 5362,\n", + " 7187, 7381, 6053, 7257, 13405, 558, 161, 1665, 4376, 3485,\n", + " 686, 652, 446, 430, 446, 775, 652, 4285, 1739, 3856,\n", + " 226, 211, 7246, 6732, 772, 1988, 158, 805, 443, 805,\n", + " 2092, 1170, 2092, 1170, 3485, 4376, 3485, 446, 430, 431,\n", + " 424, 1698, 1394, 1799, 2754, 207, 1154, 21589, 2190, 3705,\n", + " 4464, 5817, 7558, 508, 1798, 815, 628, 2017, 856, 1890,\n", + " 225, 598, 38, 598, 16534, 10256, 3, 2652, 4029, 2557,\n", + " 2789, 6380, 1831, 1071, 31, 313, 446, 1086, 1570, 2223,\n", + " 2665, 1951, 2099, 1673, 225, 4337, 652, 998, 1158, 831,\n", + " 801, 598, 1086, 12431, 416, 12431, 652, 801, 1757, 1379,\n", + " 1414, 634, 2035, 7933, 6035, 6361, 4663, 577, 4663, 577,\n", + " 4663], device='cuda:0')" ] }, "execution_count": 17, @@ -1077,7 +1066,9 @@ "cell_type": "code", "execution_count": 20, "id": "55c03a0a", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "workflow = Workflow.load(os.path.join(INPUT_DATA_DIR, \"workflow_etl\"))" @@ -1087,17 +1078,10 @@ "cell_type": "code", "execution_count": 21, "id": "18eaa079", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.8/dist-packages/merlin/schema/tags.py:149: UserWarning: Compound tags like Tags.ITEM_ID have been deprecated and will be removed in a future version. Please use the atomic versions of these tags, like [, ].\n", - " warnings.warn(\n" - ] - } - ], + "metadata": { + "tags": [] + }, + "outputs": [], "source": [ "torch_op = workflow.input_schema.column_names >> TransformWorkflow(workflow) >> PredictPyTorch(\n", " traced_model, model.input_schema, model.output_schema\n", @@ -1118,7 +1102,9 @@ "cell_type": "code", "execution_count": 22, "id": "64178cff", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stderr", @@ -1319,114 +1305,18 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'item_id_scores': array([[ 7.7149506, 7.3862066, 7.3664827, 7.2885814, 6.8977885,\n", - " 6.853054 , 6.6473446, 6.5615396, 6.5239835, 6.504697 ,\n", - " 6.4961066, 6.482874 , 6.478437 , 6.4518657, 6.403399 ,\n", - " 6.3845167, 6.372211 , 6.3541226, 6.2738447, 6.2641015],\n", - " [ 7.8986144, 7.575227 , 6.9020343, 6.8903823, 6.839163 ,\n", - " 6.833202 , 6.761392 , 6.7598944, 6.6728435, 6.6523266,\n", - " 6.640926 , 6.6118965, 6.470592 , 6.419156 , 6.4112453,\n", - " 6.4044685, 6.386283 , 6.3791957, 6.362708 , 6.32642 ],\n", - " [ 3.2538116, 3.1455884, 3.1179721, 3.0379417, 2.9422204,\n", - " 2.9090347, 2.8573742, 2.8552375, 2.8385096, 2.8210163,\n", - " 2.8062954, 2.7926047, 2.7883859, 2.7677355, 2.7579772,\n", - " 2.7548594, 2.7234986, 2.7234783, 2.7092912, 2.7081766],\n", - " [ 7.1264834, 6.9747195, 6.686076 , 6.55559 , 6.554246 ,\n", - " 6.4562297, 6.409179 , 6.303057 , 6.1918006, 6.18303 ,\n", - " 6.1193542, 6.107653 , 6.0896993, 6.0872216, 6.0593266,\n", - " 6.0522294, 5.960239 , 5.953763 , 5.9512544, 5.93457 ],\n", - " [ 8.472156 , 8.323568 , 8.279997 , 8.109349 , 8.0277195,\n", - " 7.8910418, 7.743165 , 7.7279477, 7.692122 , 7.6912303,\n", - " 7.6711187, 7.5660324, 7.5545163, 7.524415 , 7.4610972,\n", - " 7.424488 , 7.4123936, 7.411586 , 7.3956637, 7.334862 ],\n", - " [10.485718 , 10.375805 , 10.131399 , 9.950816 , 9.917899 ,\n", - " 9.840376 , 9.77128 , 9.676795 , 9.599488 , 9.544931 ,\n", - " 9.512772 , 9.488843 , 9.455492 , 9.438552 , 9.430434 ,\n", - " 9.37854 , 9.362637 , 9.329395 , 9.303702 , 9.279214 ],\n", - " [ 5.836771 , 5.828289 , 5.7385616, 5.6870475, 5.5709176,\n", - " 5.5614986, 5.5448956, 5.3459535, 5.344148 , 5.310749 ,\n", - " 5.2611475, 5.2548814, 5.1640043, 5.1203575, 5.0081887,\n", - " 4.987002 , 4.976223 , 4.90389 , 4.8933997, 4.86463 ],\n", - " [ 8.174524 , 7.422059 , 7.3884916, 7.272735 , 7.1977425,\n", - " 7.0858173, 6.645329 , 6.5097084, 6.470001 , 6.453685 ,\n", - " 6.2901287, 6.1282573, 6.1203957, 6.0265145, 6.0052614,\n", - " 5.9081306, 5.880647 , 5.78306 , 5.7507606, 5.7427897],\n", - " [ 6.7380514, 6.735104 , 6.6285214, 6.5611916, 6.527247 ,\n", - " 6.4949493, 6.4821672, 6.474738 , 6.4419727, 6.4345016,\n", - " 6.3947067, 6.352811 , 6.314824 , 6.293556 , 6.288764 ,\n", - " 6.2727513, 6.2707367, 6.2422667, 6.2342215, 6.196748 ],\n", - " [ 5.6076183, 5.472957 , 5.412912 , 5.3116655, 5.177665 ,\n", - " 5.112378 , 5.085245 , 4.9710717, 4.940496 , 4.93859 ,\n", - " 4.8099985, 4.760807 , 4.7229834, 4.7145324, 4.707185 ,\n", - " 4.667753 , 4.6421194, 4.630641 , 4.6156597, 4.5743294],\n", - " [ 8.147796 , 7.6630974, 7.5877905, 7.4908724, 7.4732018,\n", - " 7.404136 , 7.352083 , 7.3242574, 7.3009863, 7.242015 ,\n", - " 7.1836567, 7.1416435, 7.138858 , 7.134639 , 7.111288 ,\n", - " 7.0651016, 7.0538983, 7.0510435, 7.044665 , 7.035279 ],\n", - " [ 3.2990756, 3.1885252, 3.1802373, 3.0592136, 3.045468 ,\n", - " 3.0039759, 2.9578297, 2.9415162, 2.93026 , 2.9297981,\n", - " 2.8965662, 2.8944998, 2.877764 , 2.8689 , 2.8668582,\n", - " 2.8634927, 2.78628 , 2.767284 , 2.7598662, 2.7465074],\n", - " [ 8.852634 , 8.734692 , 8.606397 , 8.578839 , 8.453225 ,\n", - " 8.416724 , 8.407235 , 8.343583 , 8.280216 , 8.269384 ,\n", - " 8.20807 , 8.177954 , 8.15201 , 8.136789 , 8.103194 ,\n", - " 7.873078 , 7.852006 , 7.849255 , 7.8436575, 7.7874575],\n", - " [ 8.121338 , 7.9475265, 7.9030285, 7.8839293, 7.850343 ,\n", - " 7.8212285, 7.7303505, 7.655741 , 7.5204306, 7.4822707,\n", - " 7.4198914, 7.3865814, 7.3747654, 7.3459187, 7.3127384,\n", - " 7.2247725, 7.1183944, 7.111621 , 7.0977235, 7.089139 ],\n", - " [ 8.760972 , 8.427143 , 8.166194 , 8.004948 , 7.831641 ,\n", - " 7.8019366, 7.784935 , 7.6579127, 7.5422807, 7.531015 ,\n", - " 7.48702 , 7.448603 , 7.3950186, 7.355933 , 7.3510966,\n", - " 7.2111845, 7.183003 , 7.166962 , 7.1576095, 7.1405106]],\n", - " dtype=float32), 'item_ids': array([[ 127, 1245, 1271, 4074, 161, 532, 3928, 19290, 11874,\n", - " 1446, 346, 9285, 3452, 3334, 10987, 479, 7555, 3206,\n", - " 14633, 13677],\n", - " [ 2693, 3460, 10401, 2393, 9415, 9285, 14213, 2404, 13750,\n", - " 10987, 8084, 183, 2889, 5110, 16662, 10962, 7865, 14401,\n", - " 4158, 3211],\n", - " [ 4278, 3225, 5889, 953, 7577, 1420, 4185, 4591, 573,\n", - " 10287, 3514, 7923, 5623, 3269, 873, 2433, 15077, 5110,\n", - " 6357, 3593],\n", - " [10401, 573, 3211, 2404, 3452, 2763, 4014, 786, 3296,\n", - " 6030, 4380, 7481, 2067, 10962, 183, 809, 4651, 2806,\n", - " 2575, 8084],\n", - " [10987, 9415, 9285, 479, 7998, 7555, 2475, 2404, 8084,\n", - " 2693, 2661, 1102, 786, 14343, 3460, 161, 13677, 2889,\n", - " 2393, 1600],\n", - " [ 8084, 13677, 14633, 12754, 14213, 3452, 13821, 2984, 11816,\n", - " 2475, 1600, 4074, 13749, 9285, 17882, 10401, 10673, 1964,\n", - " 11874, 9801],\n", - " [ 224, 1453, 1889, 620, 2556, 520, 741, 1219, 633,\n", - " 2050, 6344, 2651, 2852, 1039, 3841, 375, 2473, 3225,\n", - " 4204, 2980],\n", - " [ 620, 633, 1889, 1625, 1453, 1219, 4962, 6344, 224,\n", - " 2034, 1413, 2980, 741, 1908, 597, 1085, 4770, 2305,\n", - " 1647, 1334],\n", - " [ 573, 14505, 3365, 9020, 5932, 7470, 7055, 6344, 10962,\n", - " 11937, 6488, 13821, 19263, 14213, 633, 8553, 10123, 10401,\n", - " 1453, 15439],\n", - " [15148, 13277, 1381, 24893, 11689, 37268, 7454, 1547, 2305,\n", - " 30765, 46450, 18630, 8747, 1085, 1889, 25487, 7506, 620,\n", - " 14220, 8291],\n", - " [19263, 13821, 15289, 11874, 8058, 17882, 7159, 12754, 6473,\n", - " 14213, 16233, 573, 3452, 12059, 1600, 8084, 10401, 13749,\n", - " 15439, 9801],\n", - " [ 8813, 1889, 464, 1855, 520, 127, 185, 4204, 637,\n", - " 31, 784, 2789, 4713, 6461, 11874, 840, 177, 5224,\n", - " 5818, 6473],\n", - " [ 7865, 5932, 7055, 14213, 9020, 4541, 8084, 6488, 8553,\n", - " 14505, 6473, 10962, 4204, 4136, 16233, 13821, 10401, 9459,\n", - " 9801, 15439],\n", - " [10401, 6488, 5932, 7865, 10962, 14213, 8553, 8084, 9020,\n", - " 573, 4204, 13821, 3175, 15439, 4136, 10123, 11874, 3452,\n", - " 17882, 9236],\n", - " [ 2693, 3460, 14213, 7865, 5932, 6488, 4204, 8084, 10962,\n", - " 8553, 1219, 9020, 4541, 13750, 9236, 1453, 4136, 2980,\n", - " 10401, 2651]])}\n" + "ename": "ValueError", + "evalue": "cannot convert NA to integer", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[26], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mmerlin\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01msystems\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mtriton\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m send_triton_request\n\u001b[0;32m----> 2\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[43msend_triton_request\u001b[49m\u001b[43m(\u001b[49m\u001b[43mworkflow\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minput_schema\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfiltered_batch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moutput_schema\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumn_names\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28mprint\u001b[39m(response)\n", + "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/merlin/systems/triton/utils.py:226\u001b[0m, in \u001b[0;36msend_triton_request\u001b[0;34m(schema, inputs, outputs_list, client, endpoint, request_id, triton_model)\u001b[0m\n\u001b[1;32m 224\u001b[0m triton_inputs \u001b[38;5;241m=\u001b[39m triton\u001b[38;5;241m.\u001b[39mconvert_table_to_triton_input(schema, inputs, grpcclient\u001b[38;5;241m.\u001b[39mInferInput)\n\u001b[1;32m 225\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 226\u001b[0m triton_inputs \u001b[38;5;241m=\u001b[39m \u001b[43mtriton\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconvert_df_to_triton_input\u001b[49m\u001b[43m(\u001b[49m\u001b[43mschema\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrpcclient\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mInferInput\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 228\u001b[0m outputs \u001b[38;5;241m=\u001b[39m [grpcclient\u001b[38;5;241m.\u001b[39mInferRequestedOutput(col) \u001b[38;5;28;01mfor\u001b[39;00m col \u001b[38;5;129;01min\u001b[39;00m outputs_list]\n\u001b[1;32m 230\u001b[0m response \u001b[38;5;241m=\u001b[39m client\u001b[38;5;241m.\u001b[39minfer(triton_model, triton_inputs, request_id\u001b[38;5;241m=\u001b[39mrequest_id, outputs\u001b[38;5;241m=\u001b[39moutputs)\n", + "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/merlin/systems/triton/__init__.py:88\u001b[0m, in \u001b[0;36mconvert_df_to_triton_input\u001b[0;34m(schema, batch, input_class, dtype)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mconvert_df_to_triton_input\u001b[39m(schema, batch, input_class\u001b[38;5;241m=\u001b[39mgrpcclient\u001b[38;5;241m.\u001b[39mInferInput, dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mint32\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 69\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 70\u001b[0m \u001b[38;5;124;03m Convert a dataframe to a set of Triton inputs\u001b[39;00m\n\u001b[1;32m 71\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[38;5;124;03m A list of Triton inputs of the requested input class\u001b[39;00m\n\u001b[1;32m 87\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m---> 88\u001b[0m df_dict \u001b[38;5;241m=\u001b[39m \u001b[43m_convert_df_to_dict\u001b[49m\u001b[43m(\u001b[49m\u001b[43mschema\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 89\u001b[0m inputs \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 90\u001b[0m _convert_array_to_triton_input(col_name, col_values, input_class)\n\u001b[1;32m 91\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m col_name, col_values \u001b[38;5;129;01min\u001b[39;00m df_dict\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 92\u001b[0m ]\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m inputs\n", + "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/merlin/systems/triton/__init__.py:183\u001b[0m, in \u001b[0;36m_convert_df_to_dict\u001b[0;34m(schema, batch, dtype)\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 182\u001b[0m values \u001b[38;5;241m=\u001b[39m col\u001b[38;5;241m.\u001b[39mvalues \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(col, pd\u001b[38;5;241m.\u001b[39mSeries) \u001b[38;5;28;01melse\u001b[39;00m col\u001b[38;5;241m.\u001b[39mvalues_host\n\u001b[0;32m--> 183\u001b[0m values \u001b[38;5;241m=\u001b[39m \u001b[43mvalues\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreshape\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mshape\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcol_schema\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_numpy\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 184\u001b[0m df_dict[col_name] \u001b[38;5;241m=\u001b[39m values\n\u001b[1;32m 185\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m df_dict\n", + "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/pandas/core/arrays/masked.py:471\u001b[0m, in \u001b[0;36mBaseMaskedArray.astype\u001b[0;34m(self, dtype, copy)\u001b[0m\n\u001b[1;32m 469\u001b[0m \u001b[38;5;66;03m# to_numpy will also raise, but we get somewhat nicer exception messages here\u001b[39;00m\n\u001b[1;32m 470\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer_dtype(dtype) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_hasna:\n\u001b[0;32m--> 471\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot convert NA to integer\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 472\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_bool_dtype(dtype) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_hasna:\n\u001b[1;32m 473\u001b[0m \u001b[38;5;66;03m# careful: astype_nansafe converts np.nan to True\u001b[39;00m\n\u001b[1;32m 474\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot convert float NaN to bool\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mValueError\u001b[0m: cannot convert NA to integer" ] } ], diff --git a/examples/end-to-end-session-based/03-Session-based-Yoochoose-multigpu-training-PyT.ipynb b/examples/end-to-end-session-based/03-Session-based-Yoochoose-multigpu-training-PyT.ipynb index 4e1560b21d..3bd729412e 100644 --- a/examples/end-to-end-session-based/03-Session-based-Yoochoose-multigpu-training-PyT.ipynb +++ b/examples/end-to-end-session-based/03-Session-based-Yoochoose-multigpu-training-PyT.ipynb @@ -2,9 +2,11 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "0fb18ebc", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "# Copyright 2022 NVIDIA Corporation. All Rights Reserved.\n", @@ -99,18 +101,41 @@ "source": [ "Please note these important points that are relevant to multi-gpu training:\n", "- specifying multiple GPUs: PyTorch distributed launch environment will recognize that we have two GPUs, since the `--nproc_per_node` arg of `torch.distributed.launch` takes care of assigning one GPU per process and performs the training loop on multiple GPUs (2 in this case) using different batches of the data in a data-parallel fashion.\n", - "- data repartitioning: when training on multiple GPUs, data must be re-partitioned into >1 partitions where the number of partitions must be at least equal to the number of GPUs. The torch utility library in Transformers4Rec does this automatically and outputs a UserWarning message. If you would like to avoid this warning message, you may choose to manually re-partition your data files before you launch the training loop or function. See [this document](https://nvidia-merlin.github.io/Transformers4Rec/stable/multi_gpu_train.html#distributeddataparallel) for further information on how to do manual re-partitioning.\n", - "- training and evaluation batch sizes: in the default DistributedDataParallel mode we will be running, keeping the batch size unchanged means each worker will receive the same-size batch despite the fact that you are now using multiple GPUs. If you would like to keep the total batch size constant, you may want to divide the training and evaluation batch sizes by the number of GPUs you are running on, which is expected to reduce time it takes train and evaluate on each batch." + "- data repartitioning: when training on multiple GPUs, data must be re-partitioned into >1 partitions where the number of partitions must be at least equal to the number of GPUs. The torch utility library in Transformers4Rec does this automatically and outputs a UserWarning message. If you would like to avoid this warning message, you may choose to manually re-partition your data files before you launch the training loop or function. See [this document](https://nvidia-merlin.github.io/Transformers4Rec/main/multi_gpu_train.html#distributeddataparallel) for further information on how to do manual re-partitioning.\n", + "- training and evaluation batch sizes: in the default DistributedDataParallel mode we will be running, keeping the batch size unchanged means each worker will receive the same-size batch despite the fact that you are now using multiple GPUs. If you would like to keep the total batch size constant, you may want to divide the training and evaluation batch sizes by the number of GPUs you are running on, which is expected to reduce time it takes train and evaluate on each batch.\n", + "\n", + "Finally, if you have worked with the synthetic data set in notebook 01, remember to change the values 178 and 181 in the next cell accordingly:" ] }, { "cell_type": "code", - "execution_count": null, - "id": "6c45c899-9c88-4235-8402-03f7f3e40841", + "execution_count": 2, + "id": "2f2e0a31-4259-4cd3-9618-01a44d2492d0", "metadata": {}, "outputs": [], "source": [ - "%%writefile './pyt_trainer.py'\n", + "import os\n", + "TRAINER_FILE = os.path.join(os.environ.get(\"INPUT_DATA_DIR\", \"/workspace/data\"), \"pyt_trainer.py\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6c45c899-9c88-4235-8402-03f7f3e40841", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overwriting /workspace/data/pyt_trainer.py\n" + ] + } + ], + "source": [ + "%%writefile {TRAINER_FILE}\n", "\n", "import argparse\n", "import os\n", @@ -125,15 +150,14 @@ "from merlin.schema import Schema\n", "from merlin.io import Dataset\n", "\n", - "\n", "cupy.cuda.Device(int(os.environ[\"LOCAL_RANK\"])).use()\n", "\n", "# define arguments that can be passed to this python script\n", "parser = argparse.ArgumentParser(description='Hyperparameters for model training')\n", "parser.add_argument('--path', type=str, help='Directory with training and validation data')\n", "parser.add_argument('--learning-rate', type=float, default=0.0005, help='Learning rate for training')\n", - "parser.add_argument('--per-device-train-batch-size', type=int, default=384, help='Per device batch size for training')\n", - "parser.add_argument('--per-device-eval-batch-size', type=int, default=512, help='Per device batch size for evaluation')\n", + "parser.add_argument('--per-device-train-batch-size', type=int, default=64, help='Per device batch size for training')\n", + "parser.add_argument('--per-device-eval-batch-size', type=int, default=32, help='Per device batch size for evaluation')\n", "sh_args = parser.parse_args()\n", "\n", "# create the schema object by reading the processed train set generated in the previous 01-ETL-with-NVTabular notebook\n", @@ -198,8 +222,9 @@ "start = time.time()\n", "\n", "# main loop for training\n", - "start_time_window_index = 178\n", - "final_time_window_index = 181\n", + "start_time_window_index = int(os.environ.get(\"START_TIME_INDEX\", \"178\"))\n", + "final_time_window_index = int(os.environ.get(\"END_TIME_INDEX\", \"181\"))\n", + "\n", "# Iterating over days from 178 to 181\n", "for time_index in range(start_time_window_index, final_time_window_index):\n", " # Set data \n", @@ -207,7 +232,7 @@ " time_index_eval = time_index + 1\n", " train_paths = glob.glob(os.path.join(OUTPUT_DIR, f\"{time_index_train}/train.parquet\"))\n", " eval_paths = glob.glob(os.path.join(OUTPUT_DIR, f\"{time_index_eval}/valid.parquet\"))\n", - " \n", + "\n", " # Train on day related to time_index \n", " print('*'*20)\n", " print(\"Launch training for day %s are:\" %time_index)\n", @@ -220,16 +245,23 @@ "\n", " # Evaluate on the following day\n", " recsys_trainer.eval_dataset_or_path = eval_paths\n", - " train_metrics = recsys_trainer.evaluate(metric_key_prefix='eval')\n", + " eval_metrics = recsys_trainer.evaluate(metric_key_prefix='eval')\n", " print('*'*20)\n", " print(\"Eval results for day %s are:\\t\" %time_index_eval)\n", " print('\\n' + '*'*20 + '\\n')\n", - " for key in sorted(train_metrics.keys()):\n", - " print(\" %s = %s\" % (key, str(train_metrics[key]))) \n", + " for key in sorted(eval_metrics.keys()):\n", + " print(\" %s = %s\" % (key, str(eval_metrics[key]))) \n", " wipe_memory()\n", "\n", + "# export evaluation metrics to a file\n", + "import json\n", + "fname = os.path.join(INPUT_DATA_DIR, \"eval_metrics.txt\")\n", + "f = open(fname, \"w\")\n", + "f.write(json.dumps(eval_metrics))\n", + "f.close()\n", + "\n", "end = time.time()\n", - "print('Total training time:',end-start)" + "print('Total training time:', end-start)" ] }, { @@ -256,12 +288,221 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "1e5678d8-71a9-465f-a986-9ec68cf351ff", - "metadata": {}, - "outputs": [], + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:torch.distributed.run:\n", + "*****************************************\n", + "Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. \n", + "*****************************************\n", + "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", + " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n", + "/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n", + " warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n", + "Projecting inputs of NextItemPredictionTask to'64' As weight tying requires the input dimension '320' to be equal to the item-id embedding dimension '64'\n", + "Projecting inputs of NextItemPredictionTask to'64' As weight tying requires the input dimension '320' to be equal to the item-id embedding dimension '64'\n", + "********************\n", + "Launch training for day 178 are:\n", + "********************\n", + "\n", + "********************\n", + "Launch training for day 178 are:\n", + "********************\n", + "\n", + "UserWarning: User is advised to repartition the parquet file before training so npartitions>=global_size. Cudf or pandas can be used for repartitioning eg. pdf.to_parquet('file.parquet',row_group_size=N_ROWS/NPARTITIONS) for pandas or gdf.to_parquet('file.parquet',row_group_size_rows=N_ROWS/NPARTITIONS) for cudf so that npartitions=nr_rows/row_group_size. Also ensure npartitions is divisible by number of GPUs to be used (eg. 2 or 4 partitions, if 2 GPUs will be used).\n", + "UserWarning: User is advised to repartition the parquet file before training so npartitions>=global_size. Cudf or pandas can be used for repartitioning eg. pdf.to_parquet('file.parquet',row_group_size=N_ROWS/NPARTITIONS) for pandas or gdf.to_parquet('file.parquet',row_group_size_rows=N_ROWS/NPARTITIONS) for cudf so that npartitions=nr_rows/row_group_size. Also ensure npartitions is divisible by number of GPUs to be used (eg. 2 or 4 partitions, if 2 GPUs will be used).\n", + "***** Running training *****\n", + " Num examples = 14080\n", + " Num Epochs = 10\n", + " Instantaneous batch size per device = 256\n", + " Total train batch size (w. parallel, distributed & accumulation) = 512\n", + " Gradient Accumulation steps = 1\n", + " Total optimization steps = 550\n", + "{'loss': 7.608, 'learning_rate': 0.0003181818181818182, 'epoch': 3.64} \n", + "{'loss': 6.6604, 'learning_rate': 0.00013636363636363637, 'epoch': 7.27} \n", + " 91%|█████████████████████████████████████▎ | 500/550 [00:27<00:02, 20.71it/s]Saving model checkpoint to ./tmp/checkpoint-500\n", + "Trainer.model is not a `PreTrainedModel`, only saving its state dict.\n", + "100%|████████████████████████████████████████▊| 548/550 [00:29<00:00, 24.82it/s]\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n", + "finished\n", + "{'train_runtime': 29.4835, 'train_samples_per_second': 0.339, 'train_steps_per_second': 18.655, 'train_loss': 6.9447163529829545, 'epoch': 10.0}\n", + "100%|█████████████████████████████████████████| 550/550 [00:29<00:00, 18.66it/s]\n", + "finished\n", + "UserWarning: User is advised to repartition the parquet file before training so npartitions>=global_size. Cudf or pandas can be used for repartitioning eg. pdf.to_parquet('file.parquet',row_group_size=N_ROWS/NPARTITIONS) for pandas or gdf.to_parquet('file.parquet',row_group_size_rows=N_ROWS/NPARTITIONS) for cudf so that npartitions=nr_rows/row_group_size. Also ensure npartitions is divisible by number of GPUs to be used (eg. 2 or 4 partitions, if 2 GPUs will be used).\n", + "UserWarning: User is advised to repartition the parquet file before training so npartitions>=global_size. Cudf or pandas can be used for repartitioning eg. pdf.to_parquet('file.parquet',row_group_size=N_ROWS/NPARTITIONS) for pandas or gdf.to_parquet('file.parquet',row_group_size_rows=N_ROWS/NPARTITIONS) for cudf so that npartitions=nr_rows/row_group_size. Also ensure npartitions is divisible by number of GPUs to be used (eg. 2 or 4 partitions, if 2 GPUs will be used).\n", + "100%|███████████████████████████████████████████| 10/10 [00:00<00:00, 39.28it/s]********************\n", + "Eval results for day 179 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 7.320416450500488\n", + " eval_/next-item/avg_precision_at_10 = 0.07864832133054733\n", + " eval_/next-item/avg_precision_at_20 = 0.08251794427633286\n", + " eval_/next-item/ndcg_at_10 = 0.10605569183826447\n", + " eval_/next-item/ndcg_at_20 = 0.12029790133237839\n", + " eval_/next-item/recall_at_10 = 0.19570313394069672\n", + " eval_/next-item/recall_at_20 = 0.2523437440395355\n", + " eval_runtime = 0.3992\n", + " eval_samples_per_second = 3206.723\n", + " eval_steps_per_second = 12.526\n", + "100%|███████████████████████████████████████████| 10/10 [00:00<00:00, 38.88it/s]\n", + "********************\n", + "Eval results for day 179 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 7.320416450500488\n", + " eval_/next-item/avg_precision_at_10 = 0.07864832133054733\n", + " eval_/next-item/avg_precision_at_20 = 0.08251794427633286\n", + " eval_/next-item/ndcg_at_10 = 0.10605569183826447\n", + " eval_/next-item/ndcg_at_20 = 0.12029790133237839\n", + " eval_/next-item/recall_at_10 = 0.19570313394069672\n", + " eval_/next-item/recall_at_20 = 0.2523437440395355\n", + " eval_runtime = 0.3989\n", + " eval_samples_per_second = 3209.046\n", + " eval_steps_per_second = 12.535\n", + "********************\n", + "Launch training for day 179 are:\n", + "********************\n", + "\n", + "********************\n", + "Launch training for day 179 are:\n", + "********************\n", + "\n", + "UserWarning: User is advised to repartition the parquet file before training so npartitions>=global_size. Cudf or pandas can be used for repartitioning eg. pdf.to_parquet('file.parquet',row_group_size=N_ROWS/NPARTITIONS) for pandas or gdf.to_parquet('file.parquet',row_group_size_rows=N_ROWS/NPARTITIONS) for cudf so that npartitions=nr_rows/row_group_size. Also ensure npartitions is divisible by number of GPUs to be used (eg. 2 or 4 partitions, if 2 GPUs will be used).\n", + "UserWarning: User is advised to repartition the parquet file before training so npartitions>=global_size. Cudf or pandas can be used for repartitioning eg. pdf.to_parquet('file.parquet',row_group_size=N_ROWS/NPARTITIONS) for pandas or gdf.to_parquet('file.parquet',row_group_size_rows=N_ROWS/NPARTITIONS) for cudf so that npartitions=nr_rows/row_group_size. Also ensure npartitions is divisible by number of GPUs to be used (eg. 2 or 4 partitions, if 2 GPUs will be used).\n", + "***** Running training *****\n", + " Num examples = 9984\n", + " Num Epochs = 10\n", + " Instantaneous batch size per device = 256\n", + " Total train batch size (w. parallel, distributed & accumulation) = 512\n", + " Gradient Accumulation steps = 1\n", + " Total optimization steps = 390\n", + "{'loss': 6.9038, 'learning_rate': 0.0002435897435897436, 'epoch': 5.13} \n", + "100%|█████████████████████████████████████████| 390/390 [00:16<00:00, 25.05it/s]\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n", + "finished\n", + "{'train_runtime': 16.8289, 'train_samples_per_second': 0.594, 'train_steps_per_second': 23.174, 'train_loss': 6.640231244991988, 'epoch': 10.0}\n", + "100%|█████████████████████████████████████████| 390/390 [00:16<00:00, 23.18it/s]\n", + "finished\n", + "UserWarning: User is advised to repartition the parquet file before training so npartitions>=global_size. Cudf or pandas can be used for repartitioning eg. pdf.to_parquet('file.parquet',row_group_size=N_ROWS/NPARTITIONS) for pandas or gdf.to_parquet('file.parquet',row_group_size_rows=N_ROWS/NPARTITIONS) for cudf so that npartitions=nr_rows/row_group_size. Also ensure npartitions is divisible by number of GPUs to be used (eg. 2 or 4 partitions, if 2 GPUs will be used).\n", + "UserWarning: User is advised to repartition the parquet file before training so npartitions>=global_size. Cudf or pandas can be used for repartitioning eg. pdf.to_parquet('file.parquet',row_group_size=N_ROWS/NPARTITIONS) for pandas or gdf.to_parquet('file.parquet',row_group_size_rows=N_ROWS/NPARTITIONS) for cudf so that npartitions=nr_rows/row_group_size. Also ensure npartitions is divisible by number of GPUs to be used (eg. 2 or 4 partitions, if 2 GPUs will be used).\n", + " 62%|████████████████████████████▏ | 5/8 [00:00<00:00, 46.46it/s]********************\n", + "Eval results for day 180 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 7.8266496658325195\n", + " eval_/next-item/avg_precision_at_10 = 0.0614391528069973\n", + " eval_/next-item/avg_precision_at_20 = 0.0654601976275444\n", + " eval_/next-item/ndcg_at_10 = 0.08422157913446426\n", + " eval_/next-item/ndcg_at_20 = 0.09865029156208038\n", + " eval_/next-item/recall_at_10 = 0.1591796875\n", + " eval_/next-item/recall_at_20 = 0.2158203125\n", + " eval_runtime = 0.3334\n", + " eval_samples_per_second = 3071.067\n", + " eval_steps_per_second = 11.996\n", + "100%|█████████████████████████████████████████████| 8/8 [00:00<00:00, 41.08it/s]\n", + "********************\n", + "Eval results for day 180 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 7.8266496658325195\n", + " eval_/next-item/avg_precision_at_10 = 0.0614391528069973\n", + " eval_/next-item/avg_precision_at_20 = 0.0654601976275444\n", + " eval_/next-item/ndcg_at_10 = 0.08422157913446426\n", + " eval_/next-item/ndcg_at_20 = 0.09865029156208038\n", + " eval_/next-item/recall_at_10 = 0.1591796875\n", + " eval_/next-item/recall_at_20 = 0.2158203125\n", + " eval_runtime = 0.3325\n", + " eval_samples_per_second = 3080.064\n", + " eval_steps_per_second = 12.031\n", + "********************\n", + "Launch training for day 180 are:\n", + "********************\n", + "\n", + "UserWarning: User is advised to repartition the parquet file before training so npartitions>=global_size. Cudf or pandas can be used for repartitioning eg. pdf.to_parquet('file.parquet',row_group_size=N_ROWS/NPARTITIONS) for pandas or gdf.to_parquet('file.parquet',row_group_size_rows=N_ROWS/NPARTITIONS) for cudf so that npartitions=nr_rows/row_group_size. Also ensure npartitions is divisible by number of GPUs to be used (eg. 2 or 4 partitions, if 2 GPUs will be used).\n", + "********************\n", + "Launch training for day 180 are:\n", + "********************\n", + "\n", + "UserWarning: User is advised to repartition the parquet file before training so npartitions>=global_size. Cudf or pandas can be used for repartitioning eg. pdf.to_parquet('file.parquet',row_group_size=N_ROWS/NPARTITIONS) for pandas or gdf.to_parquet('file.parquet',row_group_size_rows=N_ROWS/NPARTITIONS) for cudf so that npartitions=nr_rows/row_group_size. Also ensure npartitions is divisible by number of GPUs to be used (eg. 2 or 4 partitions, if 2 GPUs will be used).\n", + "***** Running training *****\n", + " Num examples = 8192\n", + " Num Epochs = 10\n", + " Instantaneous batch size per device = 256\n", + " Total train batch size (w. parallel, distributed & accumulation) = 512\n", + " Gradient Accumulation steps = 1\n", + " Total optimization steps = 320\n", + "{'loss': 6.6858, 'learning_rate': 0.0001875, 'epoch': 6.25} \n", + "100%|████████████████████████████████████████▊| 319/320 [00:14<00:00, 24.21it/s]\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n", + "finished\n", + "{'train_runtime': 14.4363, 'train_samples_per_second': 0.693, 'train_steps_per_second': 22.166, 'train_loss': 6.4950298309326175, 'epoch': 10.0}\n", + "100%|█████████████████████████████████████████| 320/320 [00:14<00:00, 22.17it/s]\n", + "finished\n", + "UserWarning: User is advised to repartition the parquet file before training so npartitions>=global_size. Cudf or pandas can be used for repartitioning eg. pdf.to_parquet('file.parquet',row_group_size=N_ROWS/NPARTITIONS) for pandas or gdf.to_parquet('file.parquet',row_group_size_rows=N_ROWS/NPARTITIONS) for cudf so that npartitions=nr_rows/row_group_size. Also ensure npartitions is divisible by number of GPUs to be used (eg. 2 or 4 partitions, if 2 GPUs will be used).\n", + "UserWarning: User is advised to repartition the parquet file before training so npartitions>=global_size. Cudf or pandas can be used for repartitioning eg. pdf.to_parquet('file.parquet',row_group_size=N_ROWS/NPARTITIONS) for pandas or gdf.to_parquet('file.parquet',row_group_size_rows=N_ROWS/NPARTITIONS) for cudf so that npartitions=nr_rows/row_group_size. Also ensure npartitions is divisible by number of GPUs to be used (eg. 2 or 4 partitions, if 2 GPUs will be used).\n", + "100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 36.98it/s]********************\n", + "Eval results for day 181 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 5.795247554779053\n", + " eval_/next-item/avg_precision_at_10 = 0.12972043454647064\n", + " eval_/next-item/avg_precision_at_20 = 0.1380162239074707\n", + " eval_/next-item/ndcg_at_10 = 0.17062446475028992\n", + " eval_/next-item/ndcg_at_20 = 0.200949028134346\n", + " eval_/next-item/recall_at_10 = 0.3046875\n", + " eval_/next-item/recall_at_20 = 0.4248046875\n", + " eval_runtime = 0.2612\n", + " eval_samples_per_second = 1960.129\n", + " eval_steps_per_second = 7.657\n", + "100%|█████████████████████████████████████████████| 4/4 [00:00<00:00, 34.63it/s]\n", + "********************\n", + "Eval results for day 181 are:\t\n", + "\n", + "********************\n", + "\n", + " eval_/loss = 5.795247554779053\n", + " eval_/next-item/avg_precision_at_10 = 0.12972043454647064\n", + " eval_/next-item/avg_precision_at_20 = 0.1380162239074707\n", + " eval_/next-item/ndcg_at_10 = 0.17062446475028992\n", + " eval_/next-item/ndcg_at_20 = 0.200949028134346\n", + " eval_/next-item/recall_at_10 = 0.3046875\n", + " eval_/next-item/recall_at_20 = 0.4248046875\n", + " eval_runtime = 0.259\n", + " eval_samples_per_second = 1976.599\n", + " eval_steps_per_second = 7.721\n", + "Total training time: 66.7806625366211\n", + "Total training time: 66.76544284820557\n" + ] + } + ], "source": [ - "! torchrun --nproc_per_node 2 pyt_trainer.py --path \"/workspace/data/preproc_sessions_by_day\" --learning-rate 0.0005" + "import os\n", + "OUTPUT_DIR = os.environ.get(\"OUTPUT_DIR\", \"/workspace/data/preproc_sessions_by_day\")\n", + "LR = float(os.environ.get(\"LEARNING_RATE\", \"0.0005\"))\n", + "BATCH_SIZE_TRAIN = int(os.environ.get(\"BATCH_SIZE_TRAIN\", \"256\"))\n", + "BATCH_SIZE_VALID = int(os.environ.get(\"BATCH_SIZE_VALID\", \"128\"))\n", + "!python -m torch.distributed.run --nproc_per_node 2 {TRAINER_FILE} --path {OUTPUT_DIR} --learning-rate {LR} --per-device-train-batch-size {BATCH_SIZE_TRAIN} --per-device-eval-batch-size {BATCH_SIZE_VALID}" ] }, { diff --git a/tests/integration/notebooks/test_end_to_end_session_based.py b/tests/integration/notebooks/test_end_to_end_session_based.py new file mode 100644 index 0000000000..3dc27c6dc2 --- /dev/null +++ b/tests/integration/notebooks/test_end_to_end_session_based.py @@ -0,0 +1,85 @@ +import os + +import pytest +from merlin.core.dispatch import HAS_GPU +from testbook import testbook + +from tests.conftest import REPO_ROOT + +pytest.importorskip("transformers") + +# flake8: noqa + + +@pytest.mark.notebook +@pytest.mark.skipif(not HAS_GPU, reason="No GPU available") +def test_func(tmp_path): + with testbook( + REPO_ROOT / "examples" / "end-to-end-session-based" / "01-ETL-with-NVTabular.ipynb", + execute=False, + ) as tb1: + dirname = f"{tmp_path}/data" + os.mkdir(dirname) + tb1.inject( + f""" + import os + os.environ["DATA_FOLDER"] = f"{dirname}" + os.environ["USE_SYNTHETIC"] = "True" + os.environ["START_DATE"] = "2014/4/1" + os.environ["END_DATE"] = "2014/4/5" + os.environ["THRESHOLD_DAY_INDEX"] = "1" + """ + ) + tb1.execute() + assert os.path.isdir(f"{dirname}/processed_nvt") + assert os.path.isdir(f"{dirname}/preproc_sessions_by_day") + assert os.path.isdir(f"{dirname}/workflow_etl") + + with testbook( + REPO_ROOT + / "examples" + / "end-to-end-session-based" + / "02-End-to-end-session-based-with-Yoochoose-PyT.ipynb", + timeout=720, + execute=False, + ) as tb2: + dirname = f"{tmp_path}/data" + tb2.inject( + f""" + import os + os.environ["INPUT_DATA_DIR"] = f"{dirname}" + os.environ["OUTPUT_DIR"] = f"{dirname}/preproc_sessions_by_day" + os.environ["START_TIME_INDEX"] = "1" + os.environ["END_TIME_INDEX"] = "3" + os.environ["BATCH_SIZE_TRAIN"] = "64" + os.environ["BATCH_SIZE_VALID"] = "32" + """ + ) + NUM_OF_CELLS = len(tb2.cells) + tb2.execute_cell(list(range(0, NUM_OF_CELLS - 20))) + assert os.path.isdir(f"{dirname}/models") + assert os.listdir(f"{dirname}/models") + + with testbook( + REPO_ROOT + / "examples" + / "end-to-end-session-based" + / "03-Session-based-Yoochoose-multigpu-training-PyT.ipynb", + timeout=720, + execute=False, + ) as tb3: + dirname = f"{tmp_path}/data" + tb3.inject( + f""" + import os + os.environ["INPUT_DATA_DIR"] = f"{dirname}" + os.environ["OUTPUT_DIR"] = f"{dirname}/preproc_sessions_by_day" + os.environ["START_TIME_INDEX"] = "1" + os.environ["END_TIME_INDEX"] = "4" + os.environ["LEARNING_RATE"] = "0.0005" + os.environ["BATCH_SIZE_TRAIN"] = "64" + os.environ["BATCH_SIZE_VALID"] = "32" + """ + ) + tb3.execute() + assert os.path.isfile(f"{dirname}/eval_metrics.txt")