Skip to content

Commit

Permalink
Merge pull request #9 from EbookFoundation/dev-david
Browse files Browse the repository at this point in the history
Added blipLocal DescEngine
  • Loading branch information
xxmistacruzxx authored Feb 2, 2024
2 parents 14493d7 + b0a4d94 commit bda65d9
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 12 deletions.
29 changes: 28 additions & 1 deletion src/alttext/descengine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from abc import ABC, abstractmethod
import base64
import os
import shutil
import subprocess
import uuid

import replicate
import vertexai
Expand All @@ -24,6 +27,7 @@ def genDesc(self, imgData: bytes, src: str, context: str = None) -> str:
pass


### IMPLEMENTATIONS
REPLICATE_MODELS = {
"blip": "salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746",
"clip_prefix_caption": "rmokady/clip_prefix_caption:9a34a6339872a03f45236f114321fb51fc7aa8269d38ae0ce5334969981e4cd8",
Expand All @@ -34,7 +38,6 @@ def genDesc(self, imgData: bytes, src: str, context: str = None) -> str:
}


### IMPLEMENTATIONS
class ReplicateAPI(DescEngine):
def __init__(self, key: str, model: str = "blip") -> None:
self.__setKey(key)
Expand Down Expand Up @@ -73,6 +76,30 @@ def genDesc(self, imgData: bytes, src: str, context: str = None) -> str:
return output


class BlipLocal(DescEngine):
def __init__(self, path: str) -> None:
self.__setPath(path)
return None

def __setPath(self, path: str) -> str:
self.path = path
return self.path

def genDesc(self, imgData: bytes, src: str, context: str = None) -> str:
folderName = uuid.uuid4()
ext = src.split(".")[-1]
os.makedirs(f"{self.path}/{folderName}")
open(f"{self.path}/{folderName}/image.{ext}", "wb+").write(imgData)
subprocess.call(
f"python {self.path}/inference.py -i ./{folderName} --batch 1 --gpu 0",
cwd=f"{self.path}",
)
desc = open(f"{self.path}/{folderName}/0_captions.txt", "r").read()
shutil.rmtree(f"{self.path}/{folderName}")
desc = desc.split(",")
return desc[1]


class GoogleVertexAPI(DescEngine):
def __init__(self, project_id: str, location: str, gac_path: str) -> None:
self.project_id = project_id
Expand Down
31 changes: 20 additions & 11 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,33 @@
def testHTML():
print("TESTING HTML")

# alt: alttext.AltTextHTML = alttext.AltTextHTML(
# # descengine.ReplicateAPI(keys.ReplicateEricKey(), "blip"),
# # ocrengine.Tesseract(),
# # langengine.PrivateGPT(HOST1),
# )

# alt: alttext.AltTextHTML = alttext.AltTextHTML(
# descengine.BlipLocal("C:/Users/dacru/Desktop/Codebase/ALT/image-captioning"),
# options={"version": 1},
# )

alt: alttext.AltTextHTML = alttext.AltTextHTML(
# descengine.GoogleVertexAPI(
# keys.VertexProject(), keys.VertexRegion(), keys.VertexGAC()
# ),
descengine.ReplicateAPI(keys.ReplicateEricKey(), "blip"),
# ocrengine.Tesseract(),
# langengine.PrivateGPT(HOST1),
options={"version": 1},
descengine.BlipLocal("C:/Users/dacru/Desktop/Codebase/ALT/image-captioning"),
ocrengine.Tesseract(),
langengine.PrivateGPT(HOST1),
)

alt.parseFile(HTML_HUNTING)
imgs = alt.getAllImgs()
# src = imgs[5].attrs["src"]
# print(src)
src = imgs[4].attrs["src"]
print(src)
print(alt.genAltText(src))

# desc = alt.genDesc(alt.getImgData(src), src)
# print(desc)
associations = alt.genAltAssociations(imgs)
print(associations)
# associations = alt.genAltAssociations(imgs)
# print(associations)


if __name__ == "__main__":
Expand Down

0 comments on commit bda65d9

Please sign in to comment.