-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
111 lines (101 loc) · 3.07 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import pymongo
import gradio as gr
from mega import Mega
from dotenv import load_dotenv
import torch.nn.functional as F
import torch
from huggingface_hub import hf_hub_download
from src.core.helper import load_model, read_image
from src.core.load_config import load_config
load_dotenv()
mega = Mega()
mega = mega.login(os.getenv("mega_email"), os.getenv("mega_password"))
def main(img_path):
if img_path:
client = pymongo.MongoClient(os.getenv("MONGOLAB_URI"))
db = client.simsearch
collection = db['feature512']
img = read_image(img_path, target_size=cfg.DATA.IMG_SIZE)
with torch.no_grad():
feat = F.normalize(model(img),1)
vector_query = feat[0].detach().cpu().numpy().tolist()
pipeline = [
{
"$search": {
"index": "feature512",
"knnBeta": {
"vector": vector_query,
"path": "embedding",
"k": 10
}
}
},
{
"$project": {
"embedding": 0,
"_id": 0,
'score': {
'$meta': 'searchScore'
}
}
},
]
res = collection.aggregate(pipeline)
list_retrieval = []
for r in res:
list_retrieval.append((r['image_name'], r['score']))
## download from mega
list_fnames = []
for item in list_retrieval:
cat, fname = item[0].replace('\\','/').split('/')
img_url = hf_hub_download(
token=os.getenv("HUGGINGFACE"),
repo_id="taindp98/fashion-recsys",
filename=fname,
subfolder=cat,
repo_type="dataset",
)
print(f"img_url: {img_url}")
list_fnames.append(img_url)
return list_fnames
if __name__ == '__main__':
model_local_path = hf_hub_download(
token=os.getenv("HUGGINGFACE"),
repo_id="taindp98/siamese-model",
filename="checkpoint.pth",
repo_type="dataset"
)
cfg = load_config("configs/deploy.yaml")
cfg.defrost()
cfg.SAVE_DIR = model_local_path
cfg.freeze()
model = load_model(cfg)
load_dotenv()
inputs_image = [
gr.Image(
type='filepath',
label='Input Image'
)
]
outputs_image = [
gr.Gallery(
label='Search Results',
).style(
columns=[5],
object_fit='contain',
height='auto'
)
]
demo = gr.Interface(
fn = main,
inputs=inputs_image,
outputs=outputs_image,
title='Demo Fashion Image Search',
examples=[
"./resources/blazer.jpg",
"./resources/dress.jpg",
"./resources/trouser.jpg"
]
)
demo.launch()