-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathjoin.py
72 lines (61 loc) · 1.64 KB
/
join.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark import SparkConf
def process(sc, pairs, texts):
pairs = (
sc.read.parquet(pairs)
.select(
F.col("A").alias("doi_a"),
F.col("B").alias("doi_b")
)
)
texts = (
sc.read.parquet(texts)
.persist()
)
return (
pairs.alias("pairs")
.join(
texts.alias("texts").select(F.col("doi"), F.col("content").alias("text_a")),
F.col("pairs.doi_a") == F.col("texts.doi"),
'left'
)
.drop("doi")
.alias("pairs")
.join(
texts.alias("texts").select(F.col("doi"), F.col("content").alias("text_b")),
F.col("pairs.doi_b") == F.col("texts.doi"),
'left'
)
.drop("doi")
)
def run(sc, args):
# args
input_pairs = args[0]
input_texts = args[1]
output_path = args[2]
df = process(sc, input_pairs, input_texts)
df.explain()
df.write.mode("overwrite").parquet(output_path)
if __name__ == '__main__':
# args
PAIR_PATH = "stereo-paired.parquet"
TEXT_PATH = "stereo-filtered.parquet/*"
OUTPUT_PATH = "stereo-joined.parquet"
BATCH = "00"
NUM_PARTITIONS = 50_000
spark = (
SparkSession
.builder
.config(conf=SparkConf())
.getOrCreate()
)
# process
pair_batched = PAIR_PATH + "/part-*" + BATCH + "-*.parquet"
output_batched = OUTPUT_PATH + "/" + BATCH
(
process(spark, pair_batched, TEXT_PATH)
.write
.mode('overwrite')
.parquet(output_batched)
)