diff --git a/tests/test_tinydb.py b/tests/test_tinydb.py index bcd71a33..655027a5 100644 --- a/tests/test_tinydb.py +++ b/tests/test_tinydb.py @@ -306,6 +306,7 @@ def test_upsert(db: TinyDB): assert db.upsert({'int': 9, 'char': 'x'}, where('char') == 'x') == [4] assert db.count(where('int') == 9) == 1 + def test_upsert_by_id(db: TinyDB): assert len(db) == 3 @@ -313,6 +314,7 @@ def test_upsert_by_id(db: TinyDB): extant_doc = Document({'char': 'v'}, doc_id=1) assert db.upsert(extant_doc) == [1] doc = db.get(where('char') == 'v') + assert isinstance(doc, Document) assert doc is not None assert doc.doc_id == 1 assert len(db) == 3 @@ -321,6 +323,7 @@ def test_upsert_by_id(db: TinyDB): missing_doc = Document({'int': 5, 'char': 'w'}, doc_id=5) assert db.upsert(missing_doc) == [5] doc = db.get(where('char') == 'w') + assert isinstance(doc, Document) assert doc is not None assert doc.doc_id == 5 assert len(db) == 4 @@ -357,6 +360,7 @@ def test_search_no_results_cache(db: TinyDB): def test_get(db: TinyDB): item = db.get(where('char') == 'b') + assert isinstance(item, Document) assert item is not None assert item['char'] == 'b' @@ -366,10 +370,12 @@ def test_get_ids(db: TinyDB): assert db.get(doc_id=el.doc_id) == el assert db.get(doc_id=float('NaN')) is None # type: ignore + def test_get_multiple_ids(db: TinyDB): el = db.all() - assert db.get(doc_id=[x.doc_id for x in el]) == el - + assert db.get(doc_ids=[x.doc_id for x in el]) == el + + def test_get_invalid(db: TinyDB): with pytest.raises(RuntimeError): db.get() diff --git a/tinydb/table.py b/tinydb/table.py index ea325a2c..60a8798f 100644 --- a/tinydb/table.py +++ b/tinydb/table.py @@ -279,13 +279,13 @@ def search(self, cond: QueryLike) -> List[Document]: def get( self, cond: Optional[QueryLike] = None, - doc_id: Optional[Union[int , List]] = None, + doc_id: Optional[int] = None, doc_ids: Optional[List] = None - ) -> Optional[Union[Document , List[Document]]]: + ) -> Optional[Union[Document, List[Document]]]: """ Get exactly one document specified by a query or a document ID. - However if muliple document IDs are given then returns all docu- - ments in a list. + However, if multiple document IDs are given then returns all + documents in a list. Returns ``None`` if the document doesn't exist. @@ -294,8 +294,9 @@ def get( :param doc_ids: the document's IDs(multiple) :returns: the document(s) or ``None`` - """ + """ table = self._read_table() + if doc_id is not None: # Retrieve a document specified by its ID raw_doc = table.get(str(doc_id), None) @@ -305,17 +306,22 @@ def get( # Convert the raw data to the document class return self.document_class(raw_doc, doc_id) + elif doc_ids is not None: - # Filter the table by extracting out all those documents which have doc id - # specified in the doc_id list. - set_doc_id = set(doc_ids) # Since Doc Ids will be unique, making it a set to make sure constant lookup - raw_docs = dict(filter(lambda item: int(item[0]) in set_doc_id, table.items())) - if raw_docs is None: - return None + # Filter the table by extracting out all those documents which + # have doc id specified in the doc_id list. + + # Since document IDs will be unique, we make it a set to ensure + # constant time lookup + doc_ids_set = set(str(doc_id) for doc_id in doc_ids) + + # Now return the filtered documents in form of list + return [ + self.document_class(doc, self.document_id_class(doc_id)) + for doc_id, doc in table.items() + if doc_id in doc_ids_set + ] - ## Now return the filtered documents in form of list - return list(map(lambda x:self.document_class(raw_docs[str(x)] , int(x)) , raw_docs.keys())) - elif cond is not None: # Find a document specified by a query # The trailing underscore in doc_id_ is needed so MyPy