Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic XML support (mostly copy pasted from text) #7250

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion docs/source/nlp_load.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,14 @@ To load remote text files via HTTP, pass the URLs instead:

```py
>>> dataset = load_dataset("text", data_files="https://huggingface.co/datasets/lhoestq/test/resolve/main/some_text.txt")
```
```

To load XML data you can use the "xml" loader, which is equivalent to "text" with sample_by="document":

```py
>>> from datasets import load_dataset
>>> dataset = load_dataset("xml", data_files={"train": ["my_xml_1.xml", "my_xml_2.xml"], "test": "my_xml_file.xml"})

# Load from a directory
>>> dataset = load_dataset("xml", data_dir="path/to/xml/dataset")
```
6 changes: 6 additions & 0 deletions docs/source/package_reference/loading_methods.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ load_dataset("csv", data_dir="path/to/data/dir", sep="\t")

[[autodoc]] datasets.packaged_modules.json.Json

### XML

[[autodoc]] datasets.packaged_modules.xml.XmlConfig

[[autodoc]] datasets.packaged_modules.xml.Xml

### Parquet

[[autodoc]] datasets.packaged_modules.parquet.ParquetConfig
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .sql import sql
from .text import text
from .webdataset import webdataset
from .xml import xml


def _hash_python_lines(lines: List[str]) -> str:
Expand All @@ -41,6 +42,7 @@ def _hash_python_lines(lines: List[str]) -> str:
"imagefolder": (imagefolder.__name__, _hash_python_lines(inspect.getsource(imagefolder).splitlines())),
"audiofolder": (audiofolder.__name__, _hash_python_lines(inspect.getsource(audiofolder).splitlines())),
"webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())),
"xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())),
}

# get importable module names and hash for caching
Expand Down Expand Up @@ -69,6 +71,7 @@ def _hash_python_lines(lines: List[str]) -> str:
".arrow": ("arrow", {}),
".txt": ("text", {}),
".tar": ("webdataset", {}),
".xml": ("xml", {}),
}
_EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
Expand Down
Empty file.
68 changes: 68 additions & 0 deletions src/datasets/packaged_modules/xml/xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import itertools
from dataclasses import dataclass
from typing import Optional

import pyarrow as pa

import datasets
from datasets.features.features import require_storage_cast
from datasets.table import table_cast


logger = datasets.utils.logging.get_logger(__name__)


@dataclass
class XmlConfig(datasets.BuilderConfig):
"""BuilderConfig for xml files."""

features: Optional[datasets.Features] = None
encoding: str = "utf-8"
encoding_errors: Optional[str] = None


class Xml(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = XmlConfig

def _info(self):
return datasets.DatasetInfo(features=self.config.features)

def _split_generators(self, dl_manager):
"""The `data_files` kwarg in load_dataset() can be a str, List[str], Dict[str,str], or Dict[str,List[str]].

If str or List[str], then the dataset returns only the 'train' split.
If dict, then keys should be from the `datasets.Split` enum.
"""
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
dl_manager.download_config.extract_on_the_fly = True
data_files = dl_manager.download_and_extract(self.config.data_files)
splits = []
for split_name, files in data_files.items():
if isinstance(files, str):
files = [files]
files = [dl_manager.iter_files(file) for file in files]
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
if self.config.features is not None:
schema = self.config.features.arrow_schema
if all(not require_storage_cast(feature) for feature in self.config.features.values()):
# cheaper cast
pa_table = pa_table.cast(schema)
else:
# more expensive cast; allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, schema)
return pa_table
else:
return pa_table.cast(pa.schema({"xml": pa.string()}))

def _generate_tables(self, files):
pa_table_names = list(self.config.features) if self.config.features is not None else ["xml"]
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
# open in text mode, by default translates universal newlines ("\n", "\r\n" and "\r") into "\n"
with open(file, encoding=self.config.encoding, errors=self.config.encoding_errors) as f:
xml = f.read()
pa_table = pa.Table.from_arrays([pa.array([xml])], names=pa_table_names)
yield file_idx, self._cast_table(pa_table)
Loading