Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add underlying models #2

Merged
merged 6 commits into from
Dec 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ _That's it_! You have ObjectCut running on port 80 routing traffic using _traefi

### Change underlying model

This project was built using [BASNet](https://github.com/NathanUA/BASNet) as the model for inferring the Salient Object Detection. However, in order to test other ones we added the support to select also [U^2-Net](https://github.com/NathanUA/U-2-Net), also implemented by [Xuebin Qin](https://github.com/NathanUA), in the Inference container specifying it as a environment variable called `MODEL`. You can do that setting your model name at [docker-compose.yml](docker-compose.yml):
This project was built using [BASNet](https://github.com/NathanUA/BASNet) as the model for inferring the Salient Object Detection. However, in order to test other ones we added the support to select also the different versions of [U^2-Net](https://github.com/NathanUA/U-2-Net) (`U2NET`, `U2NETP` and `U2NETPORTRAIT`), also implemented by [Xuebin Qin](https://github.com/NathanUA), in the Inference container specifying it as a environment variable called `MODEL`. You can do that setting your model name at [docker-compose.yml](docker-compose.yml):

```yaml
inference:
Expand All @@ -130,7 +130,7 @@ inference:
- object_cut
restart: always
environment:
- MODEL=BASNet # Can also be `U2NET`
- MODEL=BASNet # Can also be `U2NET`, `U2NETP` or `U2NETPORTRAIT`
```

### Integrations
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ services:
- object_cut
restart: always
environment:
- MODEL=BASNet # Can also be `U2NET`
- MODEL=U2NETP # Can also be `BASNet`, `U2NET` or `U2NETPORTRAIT`
labels:
- 'traefik.enable=true'
- 'traefik.docker.network=object_cut'
Expand Down
1 change: 1 addition & 0 deletions inference/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ ADD ./requirements.lock ${HOME}/requirements.lock
RUN ${HOME}/gdrive_download.sh 1s52ek_4YTDRt_EOkx1FS53u-vJa0c4nu ${HOME}/data/basnet.pth
RUN ${HOME}/gdrive_download.sh 1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ ${HOME}/data/u2net.pth
RUN ${HOME}/gdrive_download.sh 1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy ${HOME}/data/u2netp.pth
RUN ${HOME}/gdrive_download.sh 1IG3HdpcRiDoWNookbncQjeaPN28t90yW ${HOME}/data/u2netportrait.pth

# Install dependencies
RUN python3 -m pip install pip --upgrade
Expand Down
1 change: 1 addition & 0 deletions inference/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

# Load model
model_name = os.environ.get('MODEL', Model.BASNet.name) # BASNet as default
log.info('Model name: [{}]'.format(model_name))
assert model_name in Model.list()
model_path = os.path.join('data', '{}.pth'.format(model_name.lower()))
log.info('Model path: [{}]'.format(model_path))
Expand Down
6 changes: 4 additions & 2 deletions inference/src/utils/model_enum.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from enum import Enum

from src.u2_net.model import U2NET
from src.u2_net.model import U2NET, U2NETP
from src.bas_net.model import BASNet


class Model(Enum):

U2NET = U2NET # U2NET
U2NETP = U2NETP # U2NETP
U2NETPORTRAIT = U2NET # U2NETPORTRAIT
BASNet = BASNet # BASNet

def __str__(self):
return self.name

@staticmethod
def list():
return [m.name for m in Model]
return [m for m in Model.__members__.keys()]
2 changes: 1 addition & 1 deletion multiplexer/test/api/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class MultiplexerRemoveTest(BaseTestClass):

def setUp(self):
self.secret_access = env.get_secret_access()
self.img_url = 'https://objectcut.com/docs/images/object-cut.png'
self.img_url = 'https://objectcut.com/assets/img/raven.jpg'
self.img_url_wrong = 'https://example.com/not-existing.jpg'
self.img_base64_wrong = 'not-a-base64'
self.img_path = os.path.join('test', 'data', 'person.jpg')
Expand Down