Skip to content

Commit

Permalink
Merge pull request #248 from EduardoNicoleit/feature/CSN-258-restrict…
Browse files Browse the repository at this point in the history
…-ips

feat: only whitelisted ips can access to register api.
  • Loading branch information
mo0haned authored Jan 21, 2025
2 parents 6272056 + 71f98b8 commit 407c508
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion neurons/register_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
# Import FastAPI Libraries
import uvicorn
from fastapi import (
FastAPI,
FastAPI, HTTPException,
status,
Request,
WebSocket,
Expand All @@ -67,10 +67,17 @@
from fastapi.exceptions import RequestValidationError
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, Field
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.status import HTTP_403_FORBIDDEN
from dotenv import load_dotenv
from typing import Optional, Union, List
from compute import (TRUSTED_VALIDATORS_HOTKEYS)

# Loads the .env file
load_dotenv()

# Constants
ENABLE_WHITELIST_IPS = False # False for disabling, True for enabling
DEFAULT_SSL_MODE = 2 # 1 for client CERT optional, 2 for client CERT_REQUIRED
DEFAULT_API_PORT = 8903 # default port for the API
DATA_SYNC_PERIOD = 600 # metagraph resync time
Expand All @@ -82,6 +89,22 @@
PUBLIC_WANDB_NAME = "opencompute"
PUBLIC_WANDB_ENTITY = "neuralinternet"

# IP Whitelist middleware
class IPWhitelistMiddleware(BaseHTTPMiddleware):
def __init__(self, app: FastAPI):
super().__init__(app)
self.whitelisted_ips = set(os.getenv("WHITELISTED_IPS", "").split(","))

async def dispatch(self, request: Request, call_next):
# Extracts the client's IP address
client_ip = request.client.host
if client_ip not in self.whitelisted_ips:
bt.logging.info(f"Access attempt from IP: {client_ip}")
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Access forbidden: IP not whitelisted")

# Process the request and get the response
response = await call_next(request)
return response

class UserConfig(BaseModel):
netuid: str = Field(default="15")
Expand Down Expand Up @@ -266,6 +289,8 @@ def __init__(

load_dotenv()
self._setup_routes()
if ENABLE_WHITELIST_IPS:
self.app.add_middleware(IPWhitelistMiddleware)
self.process = None
self.websocket_connection = None
self.allocation_table = []
Expand Down

0 comments on commit 407c508

Please sign in to comment.