Skip to content

Commit

Permalink
Merge pull request #1 from abesmon/hf-key
Browse files Browse the repository at this point in the history
custom hugging face models
  • Loading branch information
abesmon authored Dec 13, 2022
2 parents ea9caac + 43eb735 commit 924e09b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
4 changes: 3 additions & 1 deletion gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
parser.add_argument('--light_phi', type=float, default=0, help="default GUI light direction in [0, 360), azimuth")
parser.add_argument('--max_spp', type=int, default=1, help="GUI rendering max sample per pixel")

parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")

opt = parser.parse_args()

# default to use -O !!!
Expand All @@ -91,7 +93,7 @@

if opt.guidance == 'stable-diffusion':
from nerf.sd import StableDiffusion
guidance = StableDiffusion(device)
guidance = StableDiffusion(device, opt.hf_key)
elif opt.guidance == 'clip':
from nerf.clip import CLIP
guidance = CLIP(device)
Expand Down
6 changes: 4 additions & 2 deletions nerf/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ def seed_everything(seed):
#torch.backends.cudnn.benchmark = True

class StableDiffusion(nn.Module):
def __init__(self, device, sd_version='2.0'):
def __init__(self, device, sd_version='2.0', hf_key=None):
super().__init__()

self.device = device
self.sd_version = sd_version

print(f'[INFO] loading stable diffusion...')

if self.sd_version == '2.0':
if hf_key:
model_key = hf_key
elifif self.sd_version == '2.0':
model_key = "stabilityai/stable-diffusion-2-base"
elif self.sd_version == '1.5':
model_key = "runwayml/stable-diffusion-v1-5"
Expand Down

0 comments on commit 924e09b

Please sign in to comment.