Skip to content

Commit

Permalink
Merge pull request #2 from AlbertSuarez/add-underlying-models
Browse files Browse the repository at this point in the history
Add underlying models
  • Loading branch information
AlbertSuarez authored Dec 8, 2020
2 parents 113acf1 + 5fe9ac6 commit 346052b
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 6 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ _That's it_! You have ObjectCut running on port 80 routing traffic using _traefi
### Change underlying model
This project was built using [BASNet](https://github.com/NathanUA/BASNet) as the model for inferring the Salient Object Detection. However, in order to test other ones we added the support to select also [U^2-Net](https://github.com/NathanUA/U-2-Net), also implemented by [Xuebin Qin](https://github.com/NathanUA), in the Inference container specifying it as a environment variable called `MODEL`. You can do that setting your model name at [docker-compose.yml](docker-compose.yml):
This project was built using [BASNet](https://github.com/NathanUA/BASNet) as the model for inferring the Salient Object Detection. However, in order to test other ones we added the support to select also the different versions of [U^2-Net](https://github.com/NathanUA/U-2-Net) (`U2NET`, `U2NETP` and `U2NETPORTRAIT`), also implemented by [Xuebin Qin](https://github.com/NathanUA), in the Inference container specifying it as a environment variable called `MODEL`. You can do that setting your model name at [docker-compose.yml](docker-compose.yml):
```yaml
inference:
Expand All @@ -130,7 +130,7 @@ inference:
- object_cut
restart: always
environment:
- MODEL=BASNet # Can also be `U2NET`
- MODEL=BASNet # Can also be `U2NET`, `U2NETP` or `U2NETPORTRAIT`
```

### Integrations
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ services:
- object_cut
restart: always
environment:
- MODEL=BASNet # Can also be `U2NET`
- MODEL=U2NETP # Can also be `BASNet`, `U2NET` or `U2NETPORTRAIT`
labels:
- 'traefik.enable=true'
- 'traefik.docker.network=object_cut'
Expand Down
1 change: 1 addition & 0 deletions inference/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ ADD ./requirements.lock ${HOME}/requirements.lock
RUN ${HOME}/gdrive_download.sh 1s52ek_4YTDRt_EOkx1FS53u-vJa0c4nu ${HOME}/data/basnet.pth
RUN ${HOME}/gdrive_download.sh 1ao1ovG1Qtx4b7EoskHXmi2E9rp5CHLcZ ${HOME}/data/u2net.pth
RUN ${HOME}/gdrive_download.sh 1rbSTGKAE-MTxBYHd-51l2hMOQPT_7EPy ${HOME}/data/u2netp.pth
RUN ${HOME}/gdrive_download.sh 1IG3HdpcRiDoWNookbncQjeaPN28t90yW ${HOME}/data/u2netportrait.pth

# Install dependencies
RUN python3 -m pip install pip --upgrade
Expand Down
1 change: 1 addition & 0 deletions inference/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

# Load model
model_name = os.environ.get('MODEL', Model.BASNet.name) # BASNet as default
log.info('Model name: [{}]'.format(model_name))
assert model_name in Model.list()
model_path = os.path.join('data', '{}.pth'.format(model_name.lower()))
log.info('Model path: [{}]'.format(model_path))
Expand Down
6 changes: 4 additions & 2 deletions inference/src/utils/model_enum.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from enum import Enum

from src.u2_net.model import U2NET
from src.u2_net.model import U2NET, U2NETP
from src.bas_net.model import BASNet


class Model(Enum):

U2NET = U2NET # U2NET
U2NETP = U2NETP # U2NETP
U2NETPORTRAIT = U2NET # U2NETPORTRAIT
BASNet = BASNet # BASNet

def __str__(self):
return self.name

@staticmethod
def list():
return [m.name for m in Model]
return [m for m in Model.__members__.keys()]
2 changes: 1 addition & 1 deletion multiplexer/test/api/test_remove.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class MultiplexerRemoveTest(BaseTestClass):

def setUp(self):
self.secret_access = env.get_secret_access()
self.img_url = 'https://objectcut.com/docs/images/object-cut.png'
self.img_url = 'https://objectcut.com/assets/img/raven.jpg'
self.img_url_wrong = 'https://example.com/not-existing.jpg'
self.img_base64_wrong = 'not-a-base64'
self.img_path = os.path.join('test', 'data', 'person.jpg')
Expand Down

0 comments on commit 346052b

Please sign in to comment.