Skip to content

Commit

Permalink
Create new IntersectCorrespondingFields operator (#1531)
Browse files Browse the repository at this point in the history
* filter for entity types intro

* code update

* optimalisation

* improv

* remove filter and add its functionality to intersect

* typo

* Created a new type of intersect operator

Signed-off-by: Yoav Katz <[email protected]>

* Updated documentation

Signed-off-by: Yoav Katz <[email protected]>

---------

Signed-off-by: Yoav Katz <[email protected]>
Co-authored-by: Przemysław Klocek <[email protected]>
Co-authored-by: Yoav Katz <[email protected]>
Co-authored-by: Yoav Katz <[email protected]>
  • Loading branch information
4 people authored Jan 30, 2025
1 parent bc65c5c commit 121c268
Show file tree
Hide file tree
Showing 2 changed files with 237 additions and 1 deletion.
109 changes: 108 additions & 1 deletion src/unitxt/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
from .random_utils import new_random_generator
from .settings_utils import get_settings
from .stream import DynamicStream, Stream
from .text_utils import nested_tuple_to_string
from .text_utils import nested_tuple_to_string, to_pretty_string
from .type_utils import isoftype
from .utils import (
LRUCache,
Expand Down Expand Up @@ -1477,6 +1477,113 @@ def process_value(self, value: Any) -> Any:
return [e for e in value if e in self.allowed_values]


class IntersectCorrespondingFields(InstanceOperator):
"""Intersects the value of a field, which must be a list, with a given list , and removes corresponding elements from other list fields.
For example:
Assume the instances contain a field of 'labels' and a field with the labels' corresponding 'positions' in the text.
IntersectCorrespondingFields(field="label",
allowed_values=["b", "f"],
corresponding_fields_to_intersect=["position"])
would keep only "b" and "f" values in 'labels' field and
their respective values in the 'position' field.
(All other fields are not effected)
Given this input:
[
{"label": ["a", "b"],"position": [0,1],"other" : "not"},
{"label": ["a", "c", "d"], "position": [0,1,2], "other" : "relevant"},
{"label": ["a", "b", "f"], "position": [0,1,2], "other" : "field"}
]
So the output would be:
[
{"label": ["b"], "position":[1],"other" : "not"},
{"label": [], "position": [], "other" : "relevant"},
{"label": ["b", "f"],"position": [1,2], "other" : "field"},
]
Args:
field - the field to intersected (must contain list values)
allowed_values (list) - list of values to keep
corresponding_fields_to_intersect (list) - additional list fields from which values
are removed based the corresponding indices of values removed from the 'field'
"""

field: str
allowed_values: List[str]
corresponding_fields_to_intersect: List[str]

def verify(self):
super().verify()

if not isinstance(self.allowed_values, list):
raise ValueError(
f"The allowed_field_values is not a type list but '{type(self.allowed_field_values)}'"
)

def process(
self, instance: Dict[str, Any], stream_name: Optional[str] = None
) -> Dict[str, Any]:
if self.field not in instance:
raise ValueError(
f"Field '{self.field}' is not in provided instance.\n"
+ to_pretty_string(instance)
)

for corresponding_field in self.corresponding_fields_to_intersect:
if corresponding_field not in instance:
raise ValueError(
f"Field '{corresponding_field}' is not in provided instance.\n"
+ to_pretty_string(instance)
)

if not isinstance(instance[self.field], list):
raise ValueError(
f"Value of field '{self.field}' is not a list, so IntersectCorrespondingFields can not intersect with allowed values. Field value:\n"
+ to_pretty_string(instance, keys=[self.field])
)

num_values_in_field = len(instance[self.field])

if set(self.allowed_values) == set(instance[self.field]):
return instance

indices_to_keep = [
i
for i, value in enumerate(instance[self.field])
if value in set(self.allowed_values)
]

result_instance = {}
for field_name, field_value in instance.items():
if (
field_name in self.corresponding_fields_to_intersect
or field_name == self.field
):
if not isinstance(field_value, list):
raise ValueError(
f"Value of field '{field_name}' is not a list, IntersectCorrespondingFields can not intersect with allowed values."
)
if len(field_value) != num_values_in_field:
raise ValueError(
f"Number of elements in field '{field_name}' is not the same as the number of elements in field '{self.field}' so the IntersectCorrespondingFields can not remove corresponding values.\n"
+ to_pretty_string(instance, keys=[self.field, field_name])
)
result_instance[field_name] = [
value
for index, value in enumerate(field_value)
if index in indices_to_keep
]
else:
result_instance[field_name] = field_value
return result_instance


class RemoveValues(FieldOperator):
"""Removes elements in a field, which must be a list, using a given list of unallowed.
Expand Down
129 changes: 129 additions & 0 deletions tests/library/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
FromIterables,
IndexOf,
Intersect,
IntersectCorrespondingFields,
IterableSource,
JoinStr,
LengthBalancer,
Expand Down Expand Up @@ -658,6 +659,134 @@ def test_intersect(self):
tester=self,
)

def test_intersect_corresponding_fields(self):
inputs = [
{"label": ["a", "b"], "position": [0, 1], "other": "not"},
{"label": ["a", "c", "d"], "position": [0, 1, 2], "other": "relevant"},
{"label": ["a", "b", "f"], "position": [0, 1, 2], "other": "field"},
]

targets = [
{"label": ["b"], "position": [1], "other": "not"},
{"label": [], "position": [], "other": "relevant"},
{"label": ["b", "f"], "position": [1, 2], "other": "field"},
]

check_operator(
operator=IntersectCorrespondingFields(
field="label",
allowed_values=["b", "f"],
corresponding_fields_to_intersect=["position"],
),
inputs=inputs,
targets=targets,
tester=self,
)

exception_texts = [
"Error processing instance '0' from stream 'test' in IntersectCorrespondingFields due to the exception above.",
"""Field 'acme_field' is not in provided instance.
label (list):
[0] (str):
a
[1] (str):
b
position (list):
[0] (int):
0
[1] (int):
1
other (str):
not
""",
]
check_operator_exception(
operator=IntersectCorrespondingFields(
field="acme_field",
allowed_values=["b", "f"],
corresponding_fields_to_intersect=["other"],
),
inputs=inputs,
exception_texts=exception_texts,
tester=self,
)

exception_texts = [
"Error processing instance '0' from stream 'test' in IntersectCorrespondingFields due to the exception above.",
"""Field 'acme_field' is not in provided instance.
label (list):
[0] (str):
a
[1] (str):
b
position (list):
[0] (int):
0
[1] (int):
1
other (str):
not
""",
]
check_operator_exception(
operator=IntersectCorrespondingFields(
field="label",
allowed_values=["b", "f"],
corresponding_fields_to_intersect=["acme_field"],
),
inputs=inputs,
exception_texts=exception_texts,
tester=self,
)

exception_texts = [
"Error processing instance '0' from stream 'test' in IntersectCorrespondingFields due to the exception above.",
"Value of field 'other' is not a list, so IntersectCorrespondingFields can not intersect with allowed values. Field value:\nother (str):\n not\n",
]
check_operator_exception(
operator=IntersectCorrespondingFields(
field="other",
allowed_values=["b", "f"],
corresponding_fields_to_intersect=["other"],
),
inputs=inputs,
exception_texts=exception_texts,
tester=self,
)

inputs = [
{"label": ["a", "b"], "position": [0, 1, 2], "other": "not"},
{"label": ["a", "c", "d"], "position": [0, 1, 2], "other": "relevant"},
{"label": ["a", "b", "f"], "position": [0, 1, 2], "other": "field"},
]
exception_texts = [
"Error processing instance '0' from stream 'test' in IntersectCorrespondingFields due to the exception above.",
"""Number of elements in field 'position' is not the same as the number of elements in field 'label' so the IntersectCorrespondingFields can not remove corresponding values.
label (list):
[0] (str):
a
[1] (str):
b
position (list):
[0] (int):
0
[1] (int):
1
[2] (int):
2
""",
]
check_operator_exception(
operator=IntersectCorrespondingFields(
field="label",
allowed_values=["b", "f"],
corresponding_fields_to_intersect=["position"],
),
inputs=inputs,
exception_texts=exception_texts,
tester=self,
)

def test_remove_none(self):
inputs = [
{"references": [["none"], ["none"]]},
Expand Down

0 comments on commit 121c268

Please sign in to comment.