Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#190] Implement custom_query() method in DAS API #198

Merged
merged 8 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
[#136] Implement methods in the DAS API to create indexes in the database
[#BUGFIX] Fix Mock in unit tests
[#90] OpenFaas is not serializing/deserializing query answers
[#190] Implement custom_query() method in DAS API
57 changes: 48 additions & 9 deletions hyperon_das/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ class BaseLinksIterator(QueryAnswerIterator, ABC):
def __init__(self, source: ListIterator, **kwargs) -> None:
super().__init__(source)
if not self.source.is_empty():
self.backend = kwargs.get('backend')
if not hasattr(self, 'backend'):
self.backend = kwargs.get('backend')
self.chunk_size = kwargs.get('chunk_size', 1000)
self.cursor = kwargs.get('cursor', 0)
self.buffer_queue = deque()
Expand All @@ -161,19 +162,21 @@ def __init__(self, source: ListIterator, **kwargs) -> None:
def __next__(self) -> Any:
if self.iterator:
try:
self.get_next_value()
return self.get_next_value()
except StopIteration as e:
self.current_value = None
self.iterator = None
if self.fetch_data_thread.is_alive():
self.fetch_data_thread.join()
self.iterator = None
if self.cursor == 0 and len(self.buffer_queue) == 0:
self.current_value = None
raise e
self._refresh_iterator()
self.fetch_data_thread = Thread(target=self._fetch_data)
if self.cursor != 0:
self.fetch_data_thread.start()
return self.get()
return self.__next__()
raise StopIteration

def _fetch_data(self) -> None:
kwargs = self.get_fetch_data_kwargs()
Expand Down Expand Up @@ -201,7 +204,7 @@ def is_empty(self) -> bool:
return not self.iterator

@abstractmethod
def get_next_value(self) -> None:
def get_next_value(self) -> Any:
raise NotImplementedError("Subclasses must implement get_next_value method")

@abstractmethod
Expand All @@ -223,13 +226,14 @@ def __init__(self, source: ListIterator, **kwargs) -> None:
self.targets_document = kwargs.get('targets_document', False)
super().__init__(source, **kwargs)

def get_next_value(self) -> None:
def get_next_value(self) -> Any:
if not self.is_empty() and self.backend:
link_handle = next(self.iterator)
link_document = self.backend.get_atom(
link_handle, targets_document=self.targets_document
)
self.current_value = link_document
return self.current_value

def get_current_value(self) -> Any:
if self.backend:
Expand All @@ -255,7 +259,7 @@ def __init__(self, source: ListIterator, **kwargs) -> None:
self.returned_handles = set()
super().__init__(source, **kwargs)

def get_next_value(self) -> None:
def get_next_value(self) -> Any:
if not self.is_empty():
while True:
link_document = next(self.iterator)
Expand All @@ -267,6 +271,7 @@ def get_next_value(self) -> None:
self.returned_handles.add(handle)
self.current_value = link_document
break
return self.current_value

def get_current_value(self) -> Any:
try:
Expand Down Expand Up @@ -294,10 +299,11 @@ def __init__(self, source: ListIterator, **kwargs) -> None:
self.toplevel_only = kwargs.get('toplevel_only')
super().__init__(source, **kwargs)

def get_next_value(self) -> None:
def get_next_value(self) -> Any:
if not self.is_empty() and self.backend:
value = next(self.iterator)
self.current_value = self.backend._to_link_dict_list([value])[0]
return self.current_value

def get_current_value(self) -> Any:
if self.backend:
Expand Down Expand Up @@ -330,13 +336,14 @@ def __init__(self, source: ListIterator, **kwargs) -> None:
self.returned_handles = set()
super().__init__(source, **kwargs)

def get_next_value(self) -> None:
def get_next_value(self) -> Any:
if not self.is_empty():
value = next(self.iterator)
handle = value.get('handle')
if handle not in self.returned_handles:
self.returned_handles.add(handle)
self.current_value = value
return self.current_value

def get_current_value(self) -> Any:
try:
Expand All @@ -358,6 +365,38 @@ def get_fetch_data(self, **kwargs) -> tuple:
)


class CustomQuery(BaseLinksIterator):
def __init__(self, source: ListIterator, **kwargs) -> None:
self.index_id = kwargs.pop('index_id', None)
self.backend = kwargs.pop('backend', None)
self.is_remote = kwargs.pop('is_remote', False)
self.kwargs = kwargs
super().__init__(source, **kwargs)

def get_next_value(self) -> Any:
if not self.is_empty():
self.current_value = next(self.iterator)
return self.current_value

def get_current_value(self) -> Any:
try:
return self.source.get()
except StopIteration:
return None

def get_fetch_data_kwargs(self) -> Dict[str, Any]:
kwargs = self.kwargs
kwargs.update({'cursor': self.cursor, 'chunk_size': self.chunk_size})
return kwargs

def get_fetch_data(self, **kwargs) -> tuple:
if self.backend:
if self.is_remote:
return self.backend.custom_query(self.index_id, **kwargs)
else:
return self.backend.get_atoms_by_index(self.index_id, **kwargs)


class TraverseLinksIterator(QueryAnswerIterator):
def __init__(self, source: Union[LocalIncomingLinks, RemoteIncomingLinks], **kwargs) -> None:
super().__init__(source)
Expand Down
19 changes: 17 additions & 2 deletions hyperon_das/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import contextlib
import pickle
from typing import Any, Dict, List, Optional, Tuple, Union

from hyperon_das_atomdb import AtomDoesNotExist, LinkDoesNotExist, NodeDoesNotExist
from requests import exceptions, sessions

from hyperon_das.utils import serialize, deserialize
from hyperon_das.exceptions import ConnectionError, HTTPError, RequestError, TimeoutError
from hyperon_das.logger import logger
from hyperon_das.utils import deserialize, serialize


class FunctionsClient:
Expand Down Expand Up @@ -152,13 +153,27 @@ def get_incoming_links(
return None, [] if kwargs.get('cursor') is not None else []
return response

def create_field_index(self, atom_type: str, field: str, type: str = None):
def create_field_index(
self,
atom_type: str,
field: str,
type: Optional[str] = None,
composite_type: Optional[List[Any]] = None,
) -> str:
payload = {
'action': 'create_field_index',
'input': {
'atom_type': atom_type,
'field': field,
'type': type,
'composite_type': composite_type,
},
}
return self._send_request(payload)

def custom_query(self, index_id: str, **kwargs) -> List[Dict[str, Any]]:
payload = {
'action': 'custom_query',
'input': {'index_id': index_id, 'kwargs': kwargs},
}
return self._send_request(payload)
51 changes: 46 additions & 5 deletions hyperon_das/das.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

from hyperon_das_atomdb import AtomDB, AtomDoesNotExist
from hyperon_das_atomdb.adapters import InMemoryDB, RedisMongoDB
Expand Down Expand Up @@ -364,6 +364,32 @@ def query(
"""
return self.query_engine.query(query, parameters)

def custom_query(self, index_id: str, **kwargs) -> Union[Iterator, List[Dict[str, Any]]]:
"""
Perform a custom query on the knowledge base using a custom index id and return an iterator.
If no_iterator is set to True, the method returns a list of dict containing detailed atom information
(But this way only works with Local Das RedisMongo).

Args:
index_id (str): the custom index id to be used in the query

Raises:
NotImplementedError: If the custom_query method is called for the Local DAS in Ram only

Returns:
Union[Iterator, List[Dict[str, Any]]]: An iterator or list of dict containing detailed atom information

Examples:
>>> das.custom_query(index_id='index_123456789', tag='DAS')
>>> das.custom_query(index_id='index_123456789', tag='DAS', no_iterator=True)
"""
if isinstance(self.query_engine, LocalQueryEngine) and isinstance(self.backend, InMemoryDB):
raise NotImplementedError(
"The custom_query method is not implemented for the Local DAS in Ram only"
)

return self.query_engine.custom_query(index_id, **kwargs)

def commit_changes(self):
"""This method applies changes made locally to the remote server"""
self.query_engine.commit()
Expand Down Expand Up @@ -514,23 +540,38 @@ def get_traversal_cursor(self, handle: str, **kwargs) -> TraverseEngine:
except AtomDoesNotExist:
raise GetTraversalCursorException(message="Cannot start Traversal. Atom does not exist")

def create_field_index(self, atom_type: str, field: str, type: str = None) -> str:
def create_field_index(
self,
atom_type: str,
field: str,
type: Optional[str] = None,
composite_type: Optional[List[Any]] = None,
) -> str:
"""Create an index for a field for all Atoms of the specified type

Args:
atom_type (str): Type of the Atom. Could be 'link' or 'node'
field (str): field where the index will be created
type (str, optional): Only atoms of the passed type will be indexed. Defaults to None.
composite_type (List[Any], optional): Only Atoms type of the passed composite type will be indexed. Defaults to None.

Raises:
ValueError: If the type of the Atom is not a string
ValueError: If the type of the Atom is not a string or if both type and composite_type are specified

Returns:
str: The index ID. This ID should be used to make queries that should use the newly created index.

Examples:
>>> index_id = das.create_field_index('link', 'tag', 'Expression')
>>> index_id = das.create_field_index('link', 'tag', type='Expression')
>>> index_id = das.create_field_index('link', 'tag', composite_type=['Expression', 'Symbol', 'Symbol', ['Expression', 'Symbol', 'Symbol', 'Symbol']])
"""

if type and composite_type:
raise ValueError("Only one of 'type' or 'composite_type' can be specified")

if type and not isinstance(type, str):
raise ValueError('The type of the Atom must be a string')
return self.query_engine.create_field_index(atom_type, field, type=type)

return self.query_engine.create_field_index(
atom_type, field, type=type, composite_type=composite_type
)
Loading
Loading