Skip to content

Commit

Permalink
feat(avro): integrate writer.
Browse files Browse the repository at this point in the history
DO NOT MERGE!
  • Loading branch information
Uinelj committed Apr 7, 2022
1 parent f3e8d0e commit 14fb3f7
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 12 deletions.
6 changes: 6 additions & 0 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,10 @@ pub struct Pipeline {
help = "Optional path to blocklist."
)]
pub blocklist: Option<PathBuf>,
#[structopt(
long = "format",
help = "corpus output format. ('avro' or 'jsonl')",
default_value = "jsonl"
)]
pub format: String,
}
28 changes: 27 additions & 1 deletion src/io/langfiles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ Each language (provided by [crate::lang::LANG]) is given a [self::Writer] wrappe
!*/
use std::{
collections::HashMap,
fs::File,
io::Write,
path::Path,
str::FromStr,
sync::{Arc, Mutex},
Expand All @@ -16,7 +18,7 @@ use crate::io::writer::Writer;
use crate::lang::LANG;
use crate::{error, lang::Lang};

use super::writer::{WriterDoc, WriterTrait};
use super::writer::{DocWriterAvro, WriterDoc, WriterTrait};
/// Holds references to [Writer].
pub struct LangFiles {
writers: HashMap<&'static str, Arc<Mutex<Writer>>>,
Expand All @@ -26,6 +28,30 @@ pub struct LangFilesDoc {
writers: HashMap<Lang, Arc<Mutex<WriterDoc>>>,
}

pub struct LangFilesAvro<'a> {
writers: HashMap<Lang, Arc<Mutex<DocWriterAvro<'a, File>>>>,
}

impl<'a> LangFilesAvro<'a> {
pub fn new(dst: &Path) -> Result<Self, error::Error> {
let mut writers = HashMap::with_capacity(LANG.len());
let mut w;
for lang in LANG.iter() {
let mut dst = dst.to_path_buf();
dst.push(lang);
dst.set_extension("avro");
w = DocWriterAvro::from_file(&dst)?;
let lang = Lang::from_str(lang)?;
writers.insert(lang, Arc::new(Mutex::new(w)));
}

Ok(Self { writers })
}

pub fn writers(&'a self) -> &HashMap<Lang, Arc<Mutex<DocWriterAvro<File>>>> {
&self.writers
}
}
impl LangFiles {
/// Create a new LangFiles. `part_size_bytes` sets an indication of the maximum size
/// by part.
Expand Down
1 change: 1 addition & 0 deletions src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ mod langfiles;
pub mod reader;
pub mod writer;
pub use langfiles::LangFiles;
pub use langfiles::LangFilesAvro;
pub use langfiles::LangFilesDoc;
pub use writer::Writer;
1 change: 1 addition & 0 deletions src/io/writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use metawriter::MetaWriter;
use textwriter::TextWriter;
pub use writer::Writer;
pub use writer_doc::WriterDoc;
pub(crate) use writer_doc_avro::DocWriterAvro;
pub use writertrait::WriterTrait;

// pub enum WriterKind {
Expand Down
41 changes: 36 additions & 5 deletions src/io/writer/writer_doc_avro.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Avro version of [writer_doc::DocWriter].
use std::io::Write;
use std::{fmt::Debug, fs::File, io::Write, path::Path};

use avro_rs::{Codec, Schema, Writer};
use log::{debug, error};
use serde::Serialize;
use structopt::lazy_static::lazy_static;

Expand Down Expand Up @@ -93,7 +94,7 @@ let document_schema = r#"
.clone()
};
}
struct DocWriterAvro<'a, T>
pub struct DocWriterAvro<'a, T>
where
T: Write,
{
Expand All @@ -114,7 +115,16 @@ where
}
}

pub fn append_ser<S: Serialize>(&mut self, val: &S) -> Result<usize, Error> {
pub fn extend_ser<I, U: Serialize>(&mut self, vals: I) -> Result<usize, Error>
where
I: IntoIterator<Item = U>,
{
self.writer.extend_ser(vals).map_err(|e| e.into())
}
pub fn append_ser<S>(&mut self, val: &S) -> Result<usize, Error>
where
S: Serialize,
{
self.writer.append_ser(val).map_err(|e| e.into())
}

Expand All @@ -127,6 +137,17 @@ where
}
}

impl<'a> DocWriterAvro<'a, File> {
pub fn from_file(path: &Path) -> Result<Self, Error> {
if path.exists() {
error!("{:?} already exists!", path);
Err(std::io::Error::new(std::io::ErrorKind::AlreadyExists, format!("{path:?}")).into())
} else {
let fh = File::create(path)?;
Ok(DocWriterAvro::new(&SCHEMA, fh, Codec::Snappy))
}
}
}
impl<'a, T> WriterTrait for DocWriterAvro<'a, T>
where
T: Write,
Expand All @@ -145,7 +166,8 @@ where
}

fn write(&mut self, vals: Vec<Self::Item>) -> Result<(), crate::error::Error> {
todo!()
self.extend_ser(&vals)?;
Ok(())
}

fn write_single(&mut self, val: &Self::Item) -> Result<(), crate::error::Error> {
Expand Down Expand Up @@ -191,8 +213,16 @@ mod test {
content.push_str(&i.to_string());
let mut headers = HashMap::new();
headers.insert(WarcHeader::ContentType, "conversion".as_bytes().to_owned());
headers.insert(
WarcHeader::Unknown("warc-identified-language".to_string()),
"fr".as_bytes().to_owned(),
);
let default_id = Identification::new(Lang::En, 1.0);
let metadata = Metadata::new(&default_id, &vec![Some(default_id.clone()); 3]);
let mut metadata = Metadata::new(
&default_id,
&vec![Some(default_id.clone()), Some(default_id.clone()), None],
);
metadata.set_annotation("adult".to_string());
let d = Document::new(content, headers, metadata);
documents.push(d);
}
Expand All @@ -212,6 +242,7 @@ mod test {
from_avro.push(deserialized);
}

println!("{from_avro:#?}");
//check equality
assert_eq!(documents, from_avro);
}
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async fn main() -> Result<(), error::Error> {
cli::Ungoliant::Pipeline(p) => {
let mut schema_filepath = p.dst.clone();
// let p = pipeline::OscarMetadata::new(p.src, p.dst, p.lid_path);
let p = pipelines::OscarDoc::new(p.src, p.dst, p.lid_path, p.blocklist);
let p = pipelines::OscarDoc::new(p.src, p.dst, p.lid_path, p.blocklist, p.format);
p.run()?;

schema_filepath.push("metadata_schema.json");
Expand Down
18 changes: 13 additions & 5 deletions src/pipelines/oscardoc/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use ut1_blocklist::Blocklist;
use warc::BufferedBody;
use warc::{Record, WarcHeader};

use crate::io::LangFilesDoc;
use crate::io::{LangFilesAvro, LangFilesDoc};

const DOC_THRESHOLD: f32 = 0.6f32;
pub struct OscarDoc {
Expand All @@ -52,7 +52,13 @@ pub struct OscarDoc {
}

impl OscarDoc {
pub fn new(src: PathBuf, dst: PathBuf, lid_path: PathBuf, blocklist: Option<PathBuf>) -> Self {
pub fn new(
src: PathBuf,
dst: PathBuf,
lid_path: PathBuf,
blocklist: Option<PathBuf>,
format: String,
) -> Self {
if blocklist.is_none() {
warn!("No blocklist folder specified! No adult content tagging will be done.");
}
Expand Down Expand Up @@ -308,7 +314,7 @@ impl OscarDoc {

/// concurrently write documets
fn write_documents<'a>(
langfiles: &LangFilesDoc,
langfiles: &'a LangFilesAvro<'a>,
avrowriters: &'a RebuildWriters<'a, File>,
shard_id: usize,
documents: HashMap<Lang, Vec<(Document, Location)>>,
Expand All @@ -333,10 +339,11 @@ impl OscarDoc {
let sr = ShardResult::new(shard_id as i64, locations, metadata_cloned);

// write docs and rebuild files
writer_lock.write(docs)?;
writer_lock.extend_ser(docs)?;
avrowriter_lock.append_ser(sr)?;

//TODO: not sure that we need the flush
writer_lock.flush()?;
avrowriter_lock.flush()?;

Ok(())
Expand Down Expand Up @@ -385,7 +392,8 @@ impl Pipeline<()> for OscarDoc {
// ourselves.
let results = results.enumerate().par_bridge();

let langfiles = LangFilesDoc::new(&self.dst, None)?;
// let langfiles = LangFilesDoc::new(&self.dst, None)?;
let langfiles = LangFilesAvro::new(&self.dst)?;
let mut dst_rebuild = self.dst.clone();
dst_rebuild.push("rebuild");

Expand Down

0 comments on commit 14fb3f7

Please sign in to comment.