Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add entity similarity endpoint #261

Merged
merged 17 commits into from
Nov 30, 2023
Merged
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,5 +134,6 @@ scratch/
docs/_site
docker/edges.tsv.gz
docker/nodes.tsv.gz
docker/embeddings.tsv.gz
mira/dkg/resources/ncit.obo
docker/epi.sh
7 changes: 5 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,22 @@ RUN apt-get update && \

ARG version=2023-10-19
ARG domain=climate
ARG embeddings_path=/sw/embeddings.tsv.gz
# This latter is used in the code
ENV MIRA_DOMAIN=${domain}
ENV EMBEDDINGS_PATH=${embeddings_path}

# Download graph content and ingest into neo4j
RUN wget -O /sw/nodes.tsv.gz https://askem-mira.s3.amazonaws.com/dkg/$domain/build/$version/nodes.tsv.gz && \
wget -O /sw/edges.tsv.gz https://askem-mira.s3.amazonaws.com/dkg/$domain/build/$version/edges.tsv.gz && \
wget -O $embeddings_path https://askem-mira.s3.amazonaws.com/dkg/$domain/build/$version/embeddings.tsv.gz && \
sed -i 's/#dbms.default_listen_address/dbms.default_listen_address/' /etc/neo4j/neo4j.conf && \
sed -i 's/#dbms.security.auth_enabled/dbms.security.auth_enabled/' /etc/neo4j/neo4j.conf && \
neo4j-admin import --delimiter='TAB' --skip-duplicate-nodes=true --skip-bad-relationships=true --nodes /sw/nodes.tsv.gz --relationships /sw/edges.tsv.gz

# Python packages
RUN python -m pip install --upgrade pip && \
python -m pip install git+https://github.com/indralab/mira.git@main#egg=mira[web,uvicorn,dkg-client] && \
python -m pip install git+https://github.com/gyorilab/mira.git@main#egg=mira[web,uvicorn,dkg-client] && \
python -m pip uninstall -y flask_bootstrap && \
python -m pip uninstall -y bootstrap_flask && \
python -m pip install bootstrap_flask && \
Expand All @@ -37,7 +40,7 @@ RUN python -m pip install --upgrade pip && \
python -m pip install --no-dependencies --ignore-requires-python sbmlmath

# Copy the example json for reconstructing the ode semantics
RUN wget -O /sw/sir_flux_span.json https://raw.githubusercontent.com/indralab/mira/main/tests/sir_flux_span.json
RUN wget -O /sw/sir_flux_span.json https://raw.githubusercontent.com/gyorilab/mira/main/tests/sir_flux_span.json

COPY startup.sh startup.sh
ENTRYPOINT ["/bin/bash", "/sw/startup.sh"]
5 changes: 4 additions & 1 deletion docker/Dockerfile.local
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@ RUN apt-get update && \
ln -s /usr/bin/python3 /usr/bin/python

ARG branch=main
ARG embeddings_path=/sw/embeddings.tsv.gz
ENV EMBEDDINGS_PATH=${embeddings_path}

# Add graph content
COPY nodes.tsv.gz /sw/nodes.tsv.gz
COPY edges.tsv.gz /sw/edges.tsv.gz
COPY embeddings.tsv.gz ${embeddings_path}

# Ingest graph content into neo4j
RUN sed -i 's/#dbms.default_listen_address/dbms.default_listen_address/' /etc/neo4j/neo4j.conf && \
sed -i 's/#dbms.security.auth_enabled/dbms.security.auth_enabled/' /etc/neo4j/neo4j.conf && \
neo4j-admin import --delimiter='TAB' --skip-duplicate-nodes=true --skip-bad-relationships=true --nodes /sw/nodes.tsv.gz --relationships /sw/edges.tsv.gz

# Python packages
RUN python -m pip install git+https://github.com/indralab/mira.git@$branch#egg=mira[web,uvicorn,dkg-client] && \
RUN python -m pip install git+https://github.com/gyorilab/mira.git@$branch#egg=mira[web,uvicorn,dkg-client] && \
python -m pip uninstall -y flask_bootstrap && \
python -m pip uninstall -y bootstrap_flask && \
python -m pip install bootstrap_flask
Expand Down
11 changes: 5 additions & 6 deletions docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,23 @@ this folder and use:
export DOMAIN=epi
cp ~/.data/mira/$DOMAIN/nodes.tsv.gz nodes.tsv.gz
cp ~/.data/mira/$DOMAIN/edges.tsv.gz edges.tsv.gz
cp ~/.data/mira/$DOMAIN/embeddings.tsv.gz embeddings.tsv.gz

# Build docker
docker build --file Dockerfile.local --tag mira_$DOMAIN_dkg:latest .
docker build --file Dockerfile.local --tag mira:latest .
```

Once the build finished, you can run the container locally as:

```shell
# Option 1: run in the background
docker run --detach -p 8771:8771 -e MIRA_NEO4J_URL=bolt://0.0.0.0:7687 --name mira_$DOMAIN_dkg mira_$DOMAIN_dkg:latest
docker run --detach -p 8771:8771 -p 7687:7687 -e MIRA_NEO4J_URL=bolt://0.0.0.0:7687 --name mira mira:latest

# Option 2: run ephemerally
docker run -p 8771:8771 -e MIRA_NEO4J_URL=bolt://0.0.0.0:7687 mira_$DOMAIN_dkg:latest
docker run -p 8771:8771 -p 7687:7687 -e MIRA_NEO4J_URL=bolt://0.0.0.0:7687 mira:latest
```

This exposes a REST API at `http://localhost:8771`. Note that the `--detach` flag
runs the container in the background. If you want to expose Neo4j's bolt port, also
add `-p 7687:7687`. Note that
This exposes a REST API at `http://localhost:8771`. This also exposes Neo4j's bolt port at port 7687.

## MIRA Metaregistry

Expand Down
72 changes: 70 additions & 2 deletions mira/dkg/api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""API endpoints."""

import itertools as itt
from typing import Any, List, Mapping, Optional, Union

import pydantic
from fastapi import APIRouter, Body, Path, Query, Request, HTTPException
from fastapi import APIRouter, Body, HTTPException, Path, Query, Request
from neo4j.graph import Relationship
from pydantic import BaseModel, Field
from scipy.spatial import distance
from typing_extensions import Literal

from mira.dkg.client import Entity, AskemEntity
from mira.dkg.client import AskemEntity, Entity
from mira.dkg.utils import DKG_REFINER_RELS

__all__ = [
Expand Down Expand Up @@ -444,3 +446,69 @@ def common_parent(
entity = request.app.state.client.get_common_parents(query.curie1,
query.curie2)
return entity


class Distance(BaseModel):
"""Represents the distance between two entities."""

source: str = Field(..., title="source CURIE")
target: str = Field(..., title="target CURIE")
distance: float = Field(..., title="cosine distance")


@api_blueprint.post(
"/entity_similarity", response_model=List[Distance], tags=["entities"]
)
def entity_similarity(
request: Request,
sources: List[str] = Body(
...,
title="source CURIEs",
examples=[["ido:0000511", "ido:0000592", "ido:0000597", "ido:0000514"]],
),
targets: Optional[List[str]] = Body(
default=None,
title="target CURIEs",
description="If not given, source queries used for all-by-all comparison",
examples=[["ido:0000566", "ido:0000567"]],
),
):
"""Get the pairwise similarities between elements referenced by CURIEs in the first list and second list."""
"""Test locally with:

import requests

def main():
curies = ["probonto:k0000000", "probonto:k0000007", "probonto:k0000008"]
res = requests.post(
"http://0.0.0.0:8771/api/entity_similarity",
json={"sources": curies, "targets": curies},
)
res.raise_for_status()
print(res.json())

if __name__ == "__main__":
main()
"""
vectors = request.app.state.vectors
if not vectors:
raise HTTPException(
status_code=500, detail="No entity vectors available"
)
if targets is None:
targets = sources
rv = []
for source, target in itt.product(sources, targets):
if source == target:
continue
source_vector = vectors.get(source)
if source_vector is None:
continue
target_vector = vectors.get(target)
if target_vector is None:
continue
cosine_distance = distance.cosine(source_vector, target_vector)
rv.append(
Distance(source=source, target=target, distance=cosine_distance)
)
return rv
10 changes: 7 additions & 3 deletions mira/dkg/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,14 +728,18 @@ def askepetrinet_model_comparison(
return resp


flux_span_path = docker_test_file_path if docker_test_file_path.exists() else \
test_file_path
if docker_test_file_path.exists():
flux_span_query_example = json.loads(docker_test_file_path.read_text())
elif test_file_path.exists():
flux_span_query_example = json.loads(test_file_path.read_text())
else:
flux_span_query_example = None


class FluxSpanQuery(BaseModel):
model: Dict[str, Any] = Field(
...,
example=json.load(flux_span_path.open()),
example=flux_span_query_example,
description="The model to recover the ODE-semantics from.",
)

Expand Down
13 changes: 10 additions & 3 deletions mira/dkg/utils.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
"""Utilities and constants for the MIRA app."""

from dataclasses import dataclass
from typing import List
from pathlib import Path
from typing import Dict, List

import numpy as np
from gilda.grounder import Grounder

from mira.dkg.client import Neo4jClient, Entity
from mira.dkg.client import Entity, Neo4jClient
from mira.metamodel import RefinementClosure

__all__ = [
"MiraState",
"PREFIXES",
"DKG_REFINER_RELS",
"DOCKER_FILES_ROOT",
]


@dataclass
class MiraState:
"""All of the state associated with the MIRA app."""
"""Represents the state associated with the MIRA app."""

client: Neo4jClient
grounder: Grounder
refinement_closure: RefinementClosure
lexical_dump: List[Entity]
vectors: Dict[str, np.array]


#: A list of all prefixes used in MIRA
Expand Down Expand Up @@ -69,3 +73,6 @@ class MiraState:

#: A list of all relation types that are considered refinement relations
DKG_REFINER_RELS = ["subclassof", "part_of"]

#: The root path of the MIRA app when running in a container
DOCKER_FILES_ROOT = Path("/sw")
32 changes: 27 additions & 5 deletions mira/dkg/wsgi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Neo4j client module."""

import csv
import gzip
import logging
import os
from pathlib import Path
from textwrap import dedent

import flask
import numpy as np
from fastapi import FastAPI
from fastapi.middleware.wsgi import WSGIMiddleware
from flask_bootstrap import Bootstrap5
Expand All @@ -13,17 +17,19 @@
from mira.dkg.client import Neo4jClient
from mira.dkg.grounding import grounding_blueprint
from mira.dkg.ui import ui_blueprint
from mira.dkg.utils import PREFIXES, MiraState
from mira.dkg.utils import PREFIXES, MiraState, DOCKER_FILES_ROOT
from mira.metamodel import RefinementClosure

logger = logging.getLogger(__name__)


__all__ = [
"flask_app",
"app",
]

logger = logging.getLogger(__name__)

EMBEDDINGS_PATH_DOCKER = Path(
os.getenv("EMBEDDINGS_PATH", DOCKER_FILES_ROOT / "embeddings.tsv.gz")
)
DOMAIN = os.getenv("MIRA_DOMAIN")

tags_metadata = [
Expand All @@ -46,7 +52,7 @@
{
"name": "relations",
"description": "Query relation data",
}
},
]


Expand Down Expand Up @@ -87,6 +93,21 @@ def startup_event():
logger.info("Running app startup function")
Bootstrap5(flask_app)

if not EMBEDDINGS_PATH_DOCKER.is_file():
logger.warning(
f"Embeddings file {EMBEDDINGS_PATH_DOCKER} not found, skipping "
f"loading of embeddings"
)
vectors = {}
else:
with gzip.open(EMBEDDINGS_PATH_DOCKER, "rt") as file:
reader = csv.reader(file, delimiter="\t")
next(reader) # skip header
vectors = {
curie: np.array([float(p) for p in parts])
for curie, *parts in reader
}

# Set MIRA_NEO4J_URL in the environment
# to point this somewhere specific
client = Neo4jClient()
Expand All @@ -95,6 +116,7 @@ def startup_event():
grounder=client.get_grounder(PREFIXES),
refinement_closure=RefinementClosure(client.get_transitive_closure()),
lexical_dump=client.get_lexical(),
vectors=vectors,
)

flask_app.register_blueprint(ui_blueprint)
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ web =
python-libsbml
lxml
bioregistry
scipy
numpy
uvicorn =
uvicorn
gunicorn =
Expand Down
Loading