Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logging to the checkpointers #8

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions spark_matcher/table_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
# Stan Leisink
# Frits Hermans

from typing import Optional
import abc
import os
import logging

from pyspark.sql import SparkSession, DataFrame

from spark_matcher.utils import create_logger


class TableCheckpointer(abc.ABC):
"""
Expand Down Expand Up @@ -36,8 +40,12 @@ class HiveCheckpointer(TableCheckpointer):
database: a name of a database or storage system where the tables can be saved
checkpoint_prefix: a prefix of the name that can be used to save tables
"""
def __init__(self, spark_session: SparkSession, database: str, checkpoint_prefix: str = "checkpoint_spark_matcher"):
def __init__(self, spark_session: SparkSession, database: str, checkpoint_prefix: str = "checkpoint_spark_matcher",
logger: Optional[logging.Logger] = None):
super().__init__(spark_session, database, checkpoint_prefix)
self.logger = logger
if not self.logger:
self.logger = create_logger()

def checkpoint_table(self, sdf: DataFrame, checkpoint_name: str):
"""
Expand All @@ -53,6 +61,7 @@ def checkpoint_table(self, sdf: DataFrame, checkpoint_name: str):
the same, unchanged, spark dataframe as the input dataframe. With the only difference that the
dataframe is now read from disk as a checkpoint.
"""
self.logger.debug(f'caching {self.checkpoint_prefix}_{checkpoint_name}')
sdf.write.saveAsTable(f"{self.database}.{self.checkpoint_prefix}_{checkpoint_name}",
mode="overwrite")
sdf = self.spark_session.table(f"{self.database}.{self.checkpoint_prefix}_{checkpoint_name}")
Expand All @@ -67,8 +76,11 @@ class ParquetCheckPointer(TableCheckpointer):
checkpoint_prefix: a prefix of the name that can be used to save tables
"""
def __init__(self, spark_session: SparkSession, checkpoint_dir: str,
checkpoint_prefix: str = "checkpoint_spark_matcher"):
checkpoint_prefix: str = "checkpoint_spark_matcher", logger: Optional[logging.Logger] = None):
super().__init__(spark_session, checkpoint_dir, checkpoint_prefix)
self.logger = logger
if not self.logger:
self.logger = create_logger()

def checkpoint_table(self, sdf: DataFrame, checkpoint_name: str):
"""
Expand All @@ -85,6 +97,7 @@ def checkpoint_table(self, sdf: DataFrame, checkpoint_name: str):
the same, unchanged, spark dataframe as the input dataframe. With the only difference that the
dataframe is now read from disk as a checkpoint.
"""
self.logger.debug(f'caching {self.checkpoint_prefix}_{checkpoint_name}')
file_name = os.path.join(f'{self.database}', f'{self.checkpoint_prefix}_{checkpoint_name}')
sdf.write.parquet(file_name, mode='overwrite')
return self.spark_session.read.parquet(file_name)
23 changes: 22 additions & 1 deletion spark_matcher/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,35 @@
# Frits Hermans

from typing import List

import logging
import numpy as np
import pandas as pd
from pyspark.ml.feature import StopWordsRemover
from pyspark.sql import DataFrame
from pyspark.sql import functions as F


def create_logger() -> logging.Logger:
"""
Creates a logger
"""
logger = logging.getLogger('debug_spark_matcher')
logger.setLevel(logging.DEBUG)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would propose to make the default logging level 'INFO', then all the cached tables are not shown by default


if not (logger.hasHandlers() and len(logger.handlers)):
ch = logging.StreamHandler()
logger.addHandler(ch)
else:
ch = logger.handlers[0]

ch.setLevel(logging.DEBUG)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here


formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)

return logger


def get_most_frequent_words(sdf: DataFrame, col_name: str, min_df=2, top_n_words=1_000) -> pd.DataFrame:
"""
Count word frequencies in a Spark dataframe `sdf` column named `col_name` and return a Pandas dataframe containing
Expand Down