Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device argument to PyTorch Hub models #3104

Merged
merged 3 commits into from
May 16, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch


def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
"""Creates a specified YOLOv5 model

Arguments:
Expand All @@ -18,6 +18,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
classes (int): number of model classes
autoshape (bool): apply YOLOv5 .autoshape() wrapper to model
verbose (bool): print all information to screen
device (str, torch.device, None): device to use for model parameters

Returns:
YOLOv5 pytorch model
Expand Down Expand Up @@ -50,7 +51,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
model.names = ckpt['model'].names # set class names attribute
if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
device = select_device('0' if torch.cuda.is_available() else 'cpu') # default to GPU if available
device = select_device('0' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device)
return model.to(device)

except Exception as e:
Expand All @@ -59,49 +60,49 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
raise Exception(s) from e


def custom(path='path/to/model.pt', autoshape=True, verbose=True):
def custom(path='path/to/model.pt', autoshape=True, verbose=True, device=None):
# YOLOv5 custom or local model
return _create(path, autoshape=autoshape, verbose=verbose)
return _create(path, autoshape=autoshape, verbose=verbose, device=device)


def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
def yolov5s(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-small model https://github.com/ultralytics/yolov5
return _create('yolov5s', pretrained, channels, classes, autoshape, verbose)
return _create('yolov5s', pretrained, channels, classes, autoshape, verbose, device)


def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
def yolov5m(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-medium model https://github.com/ultralytics/yolov5
return _create('yolov5m', pretrained, channels, classes, autoshape, verbose)
return _create('yolov5m', pretrained, channels, classes, autoshape, verbose, device)


def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
def yolov5l(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-large model https://github.com/ultralytics/yolov5
return _create('yolov5l', pretrained, channels, classes, autoshape, verbose)
return _create('yolov5l', pretrained, channels, classes, autoshape, verbose, device)


def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
def yolov5x(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-xlarge model https://github.com/ultralytics/yolov5
return _create('yolov5x', pretrained, channels, classes, autoshape, verbose)
return _create('yolov5x', pretrained, channels, classes, autoshape, verbose, device)


def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
def yolov5s6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-small-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5s6', pretrained, channels, classes, autoshape, verbose)
return _create('yolov5s6', pretrained, channels, classes, autoshape, verbose, device)


def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
def yolov5m6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-medium-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5m6', pretrained, channels, classes, autoshape, verbose)
return _create('yolov5m6', pretrained, channels, classes, autoshape, verbose, device)


def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
def yolov5l6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-large-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5l6', pretrained, channels, classes, autoshape, verbose)
return _create('yolov5l6', pretrained, channels, classes, autoshape, verbose, device)


def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True):
def yolov5x6(pretrained=True, channels=3, classes=80, autoshape=True, verbose=True, device=None):
# YOLOv5-xlarge-P6 model https://github.com/ultralytics/yolov5
return _create('yolov5x6', pretrained, channels, classes, autoshape, verbose)
return _create('yolov5x6', pretrained, channels, classes, autoshape, verbose, device)


if __name__ == '__main__':
Expand Down