-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathextract.py
43 lines (27 loc) · 1.12 KB
/
extract.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from model.load_model import load_model
from .model.descriptor_utils import DescGroupPoolandNorm
class ReFTDescriptor:
def __init__(self, args):
self.model = load_model(args)
self.pool_and_norm = DescGroupPoolandNorm(args)
def __call__(self, image, kpts):
desc = self.model(image, kpts)
## kpts torch.tensor ([B, K, 2]), desc torch.tensor ([B, K, CG])
k1, d1 = self.pool_and_norm.desc_pool_and_norm_infer(kpts, desc)
return k1, d1
if __name__ == "__main__":
import os, cv2, torch
from torchvision import transforms
from config import get_config
from baselines.extract_GIFT.utils.superpoint_utils import SuperPointWrapper
args = get_config()
extractor = ReFTDescriptor(args)
det = SuperPointWrapper()
image_np = cv2.imread("/home/jongmin/Desktop/temp.jpg")
image = transforms.ToTensor()(image_np)
print(image.shape)
kpts, desc = det(image_np)
k1, d1 = extractor(image.unsqueeze(0).float().cuda(), torch.from_numpy(kpts).unsqueeze(0).float().cuda())
print(kpts.shape, desc.shape)
print(k1.shape, d1.shape)