Skip to content

Commit

Permalink
adding unit test for end-to-end example (NVIDIA-Merlin#669)
Browse files Browse the repository at this point in the history
* adding unit test for multi-gpu example

* added test for notebook 03

* fixed formatting

* update

* update

* Update 01-ETL-with-NVTabular.ipynb

day of week is between 0 and 6; it must be scaled with a max value of 6 to produce correct values from the 0-1 range. If we do col+1 and scale with 7, then a section of the 0-2pi range (for Sine purposes) will not be represented.

* Update 01-ETL-with-NVTabular.ipynb

Reversed the previous edit for weekday scaling. It is correct that it should be scaled between 0-7, because day 0 (unused/nonapplicable after +1 added) overlaps with day 7 for Sine purposes. Monday should scale to 1/7, Sunday should scale to 7/7 to achieve even distribution of days along the sinus curve.

* reduce num_rows

* Update test_end_to_end_session_based.py

* Update 01-ETL-with-NVTabular.ipynb

* updated test script and notebook

* updated file

* removed nb3 test due to multi-gpu freezing issue

* revised notebooks, added back nb3 test

* fixed test file with black

* update test py

* update test py

* Use `python -m torch.distributed.run` instead of `torchrun`

The `torchrun` script installed in the system is a python script with
a shebang line starting with `#!/usr/bin/python3`

This picks up the wrong version of python when running in a virtualenv
like our tox test environment.

If instead this were `#!/usr/bin/env python3` it would work ok in a
tox environment to call `torchrun`.

However, until either the pytorch package is updated for this to
happen or we update our CI image for this to take place. Running the
python command directly is more reliable.

---------

Co-authored-by: rnyak <[email protected]>
Co-authored-by: edknv <[email protected]>
Co-authored-by: rnyak <[email protected]>
Co-authored-by: Oliver Holworthy <[email protected]>
  • Loading branch information
5 people authored Jul 6, 2023
1 parent f3c4d2a commit 348c963
Show file tree
Hide file tree
Showing 4 changed files with 631 additions and 329 deletions.
160 changes: 123 additions & 37 deletions examples/end-to-end-session-based/01-ETL-with-NVTabular.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
"\n",
"The dataset is available on [Kaggle](https://www.kaggle.com/chadgostopp/recsys-challenge-2015). You need to download it and copy to the `DATA_FOLDER` path. Note that we are only using the `yoochoose-clicks.dat` file.\n",
"\n",
"Alternatively, you can generate a synthetic dataset with the same columns and dtypes as the `YOOCHOOSE` dataset and a default date range of 5 days. If the environment variable `USE_SYNTHETIC` is set to `True`, the code below will execute the function `generate_synthetic_data` and the rest of the notebook will run on a synthetic dataset.\n",
"\n",
"First, let's start by importing several libraries:"
]
},
Expand All @@ -75,17 +77,18 @@
"output_type": "stream",
"text": [
"/usr/local/lib/python3.8/dist-packages/merlin/dtypes/mappings/tf.py:52: UserWarning: Tensorflow dtype mappings did not load successfully due to an error: No module named 'tensorflow'\n",
" warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n",
"/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
" warn(f\"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}\")\n"
]
}
],
"source": [
"import os\n",
"import glob\n",
"import numpy as np\n",
"import pandas as pd\n",
"import gc\n",
"import calendar\n",
"import datetime\n",
"\n",
"import cudf\n",
"import cupy\n",
Expand Down Expand Up @@ -128,12 +131,14 @@
"metadata": {},
"outputs": [],
"source": [
"DATA_FOLDER = \"/workspace/data/\"\n",
"DATA_FOLDER = os.environ.get(\"DATA_FOLDER\", \"/workspace/data\")\n",
"FILENAME_PATTERN = 'yoochoose-clicks.dat'\n",
"DATA_PATH = os.path.join(DATA_FOLDER, FILENAME_PATTERN)\n",
"\n",
"OUTPUT_FOLDER = \"./yoochoose_transformed\"\n",
"OVERWRITE = False"
"OVERWRITE = False\n",
"\n",
"USE_SYNTHETIC = os.environ.get(\"USE_SYNTHETIC\", False)"
]
},
{
Expand All @@ -144,16 +149,89 @@
"## Load and clean raw data"
]
},
{
"cell_type": "markdown",
"id": "3fba8546-668c-4743-960e-ea2aef99ef24",
"metadata": {},
"source": [
"Execute the cell below if you would like to work with synthetic data. Otherwise you can skip and continue with the next cell."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "07d14289-c783-45f0-86e8-e5c1001bfd76",
"metadata": {},
"outputs": [],
"source": [
"def generate_synthetic_data(\n",
" start_date: datetime.date, end_date: datetime.date, rows_per_day: int = 10000\n",
") -> pd.DataFrame:\n",
" assert end_date > start_date, \"end_date must be later than start_date\"\n",
"\n",
" number_of_days = (end_date - start_date).days\n",
" total_number_of_rows = number_of_days * rows_per_day\n",
"\n",
" # Generate a long-tail distribution of item interactions. This simulates that some items are\n",
" # more popular than others.\n",
" long_tailed_item_distribution = np.clip(\n",
" np.random.lognormal(3.0, 1.0, total_number_of_rows).astype(np.int64), 1, 50000\n",
" )\n",
"\n",
" # generate random item interaction features\n",
" df = pd.DataFrame(\n",
" {\n",
" \"session_id\": np.random.randint(70000, 80000, total_number_of_rows),\n",
" \"item_id\": long_tailed_item_distribution,\n",
" },\n",
" )\n",
"\n",
" # generate category mapping for each item-id\n",
" df[\"category\"] = pd.cut(df[\"item_id\"], bins=334, labels=np.arange(1, 335)).astype(\n",
" np.int64\n",
" )\n",
"\n",
" max_session_length = 60 * 60 # 1 hour\n",
"\n",
" def add_timestamp_to_session(session: pd.DataFrame):\n",
" random_start_date_and_time = calendar.timegm(\n",
" (\n",
" start_date\n",
" # Add day offset from start_date\n",
" + datetime.timedelta(days=np.random.randint(0, number_of_days))\n",
" # Add time offset within the random day\n",
" + datetime.timedelta(seconds=np.random.randint(0, 86_400))\n",
" ).timetuple()\n",
" )\n",
" session[\"timestamp\"] = random_start_date_and_time + np.clip(\n",
" np.random.lognormal(3.0, 1.0, len(session)).astype(np.int64),\n",
" 0,\n",
" max_session_length,\n",
" )\n",
" return session\n",
"\n",
" df = df.groupby(\"session_id\").apply(add_timestamp_to_session).reset_index()\n",
"\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f35dff52",
"metadata": {},
"outputs": [],
"source": [
"interactions_df = cudf.read_csv(DATA_PATH, sep=',', \n",
" names=['session_id','timestamp', 'item_id', 'category'], \n",
" dtype=['int', 'datetime64[s]', 'int', 'int'])"
"if USE_SYNTHETIC:\n",
" START_DATE = os.environ.get(\"START_DATE\", \"2014/4/1\")\n",
" END_DATE = os.environ.get(\"END_DATE\", \"2014/4/5\")\n",
" interactions_df = generate_synthetic_data(datetime.datetime.strptime(START_DATE, '%Y/%m/%d'),\n",
" datetime.datetime.strptime(END_DATE, '%Y/%m/%d'))\n",
" interactions_df = cudf.from_pandas(interactions_df)\n",
"else:\n",
" interactions_df = cudf.read_csv(DATA_PATH, sep=',', \n",
" names=['session_id','timestamp', 'item_id', 'category'], \n",
" dtype=['int', 'datetime64[s]', 'int', 'int'])"
]
},
{
Expand All @@ -166,7 +244,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "22c2df72",
"metadata": {},
"outputs": [
Expand All @@ -181,13 +259,16 @@
],
"source": [
"print(\"Count with in-session repeated interactions: {}\".format(len(interactions_df)))\n",
"\n",
"# Sorts the dataframe by session and timestamp, to remove consecutive repetitions\n",
"interactions_df.timestamp = interactions_df.timestamp.astype(int)\n",
"interactions_df = interactions_df.sort_values(['session_id', 'timestamp'])\n",
"past_ids = interactions_df['item_id'].shift(1).fillna()\n",
"session_past_ids = interactions_df['session_id'].shift(1).fillna()\n",
"\n",
"# Keeping only no consecutive repeated in session interactions\n",
"interactions_df = interactions_df[~((interactions_df['session_id'] == session_past_ids) & (interactions_df['item_id'] == past_ids))]\n",
"\n",
"print(\"Count after removed in-session repeated interactions: {}\".format(len(interactions_df)))"
]
},
Expand All @@ -201,7 +282,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "66a1bd13",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -234,17 +315,19 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "a0f908a1",
"metadata": {},
"outputs": [],
"source": [
"if os.path.isdir(DATA_FOLDER) == False:\n",
" os.mkdir(DATA_FOLDER)\n",
"interactions_merged_df.to_parquet(os.path.join(DATA_FOLDER, 'interactions_merged_df.parquet'))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "909f87c5-bff5-48c8-b714-cc556a4bc64d",
"metadata": {
"tags": []
Expand All @@ -265,17 +348,17 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "04a3b5b7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
"517"
]
},
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -330,7 +413,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 13,
"id": "86f80068",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -425,7 +508,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 14,
"id": "10b5c96c",
"metadata": {},
"outputs": [],
Expand All @@ -447,7 +530,6 @@
"# Truncate sequence features to first interacted 20 items \n",
"SESSIONS_MAX_LENGTH = 20 \n",
"\n",
"\n",
"item_feat = groupby_features['item_id-list'] >> nvt.ops.TagAsItemID()\n",
"cont_feats = groupby_features['et_dayofweek_sin-list', 'product_recency_days_log_norm-list'] >> nvt.ops.AddMetadata(tags=[Tags.CONTINUOUS])\n",
"\n",
Expand Down Expand Up @@ -491,7 +573,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 15,
"id": "45803886",
"metadata": {},
"outputs": [],
Expand All @@ -513,7 +595,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 16,
"id": "4c10efb5-89c5-4458-a634-475eb459a47c",
"metadata": {
"tags": []
Expand Down Expand Up @@ -600,7 +682,7 @@
" <tr>\n",
" <th>2</th>\n",
" <td>item_id-list</td>\n",
" <td>(Tags.CATEGORICAL, Tags.ITEM, Tags.ID, Tags.LIST)</td>\n",
" <td>(Tags.CATEGORICAL, Tags.ID, Tags.LIST, Tags.ITEM)</td>\n",
" <td>DType(name='int64', element_type=&lt;ElementType....</td>\n",
" <td>True</td>\n",
" <td>True</td>\n",
Expand Down Expand Up @@ -697,10 +779,10 @@
"</div>"
],
"text/plain": [
"[{'name': 'session_id', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}}, 'dtype': DType(name='int32', element_type=<ElementType.Int: 'int'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.ITEM: 'item'>, <Tags.ID: 'id'>, <Tags.LIST: 'list'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {<Tags.CONTINUOUS: 'continuous'>, <Tags.LIST: 'list'>}, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float64', element_type=<ElementType.Float: 'float'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {<Tags.CONTINUOUS: 'continuous'>, <Tags.LIST: 'list'>}, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=<ElementType.Float: 'float'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.LIST: 'list'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 336, 'name': 'category'}, 'embedding_sizes': {'cardinality': 337, 'dimension': 42}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'day_index', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}]"
"[{'name': 'session_id', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-count', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}}, 'dtype': DType(name='int32', element_type=<ElementType.Int: 'int'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}, {'name': 'item_id-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.ID: 'id'>, <Tags.LIST: 'list'>, <Tags.ITEM: 'item'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.item_id.parquet', 'domain': {'min': 0, 'max': 52741, 'name': 'item_id'}, 'embedding_sizes': {'cardinality': 52742, 'dimension': 512}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'et_dayofweek_sin-list', 'tags': {<Tags.CONTINUOUS: 'continuous'>, <Tags.LIST: 'list'>}, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float64', element_type=<ElementType.Float: 'float'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'product_recency_days_log_norm-list', 'tags': {<Tags.CONTINUOUS: 'continuous'>, <Tags.LIST: 'list'>}, 'properties': {'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='float32', element_type=<ElementType.Float: 'float'>, element_size=32, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'category-list', 'tags': {<Tags.CATEGORICAL: 'categorical'>, <Tags.LIST: 'list'>}, 'properties': {'num_buckets': None, 'freq_threshold': 0, 'max_size': 0, 'cat_path': './/categories/unique.category.parquet', 'domain': {'min': 0, 'max': 336, 'name': 'category'}, 'embedding_sizes': {'cardinality': 337, 'dimension': 42}, 'value_count': {'min': 0, 'max': 20}}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=0, max=20)))), 'is_list': True, 'is_ragged': True}, {'name': 'day_index', 'tags': {<Tags.CATEGORICAL: 'categorical'>}, 'properties': {}, 'dtype': DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None),))), 'is_list': False, 'is_ragged': False}]"
]
},
"execution_count": 14,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -719,7 +801,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 17,
"id": "2d035a88-2146-4b9a-96fd-dd42be86e2a1",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -747,19 +829,23 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 18,
"id": "2b4f5b73-459c-4356-87c8-9afb974cc77d",
"metadata": {},
"outputs": [],
"source": [
"# read in the processed train dataset\n",
"sessions_gdf = cudf.read_parquet(os.path.join(DATA_FOLDER, \"processed_nvt/part_0.parquet\"))\n",
"sessions_gdf = sessions_gdf[sessions_gdf.day_index>=178]"
"if USE_SYNTHETIC:\n",
" THRESHOLD_DAY_INDEX = int(os.environ.get(\"THRESHOLD_DAY_INDEX\", '1'))\n",
" sessions_gdf = sessions_gdf[sessions_gdf.day_index>=THRESHOLD_DAY_INDEX]\n",
"else:\n",
" sessions_gdf = sessions_gdf[sessions_gdf.day_index>=178]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 19,
"id": "e18d9c63",
"metadata": {},
"outputs": [
Expand All @@ -783,13 +869,13 @@
"6606149 [-0.7818309228245777, -0.7818309228245777] \n",
"\n",
" product_recency_days_log_norm-list \\\n",
"6606147 [1.5241553, 1.5238751, 1.5239341, 1.5241631, 1... \n",
"6606148 [-0.5330064, 1.521494] \n",
"6606149 [1.5338266, 1.5355074] \n",
"6606147 [1.5241561, 1.523876, 1.523935, 1.5241641, 1.5... \n",
"6606148 [-0.533007, 1.521495] \n",
"6606149 [1.5338274, 1.5355083] \n",
"\n",
" category-list day_index \n",
"6606147 [4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 4, 4] 178 \n",
"6606148 [3, 3] 178 \n",
"6606147 [4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4] 178 \n",
"6606148 [1, 3] 178 \n",
"6606149 [8, 8] 180 \n"
]
}
Expand All @@ -800,15 +886,15 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 20,
"id": "5175aeaf",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Creating time-based splits: 100%|██████████| 5/5 [00:02<00:00, 2.37it/s]\n"
"Creating time-based splits: 100%|██████████| 5/5 [00:02<00:00, 2.24it/s]\n"
]
}
],
Expand All @@ -823,17 +909,17 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 21,
"id": "3bd1bad9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"583"
"748"
]
},
"execution_count": 19,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
Expand Down
Loading

0 comments on commit 348c963

Please sign in to comment.