Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: only whitelisted ips can access to register api. #248

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion neurons/register_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
# Import FastAPI Libraries
import uvicorn
from fastapi import (
FastAPI,
FastAPI, HTTPException,
status,
Request,
WebSocket,
Expand All @@ -67,10 +67,17 @@
from fastapi.exceptions import RequestValidationError
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, Field
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.status import HTTP_403_FORBIDDEN
from dotenv import load_dotenv
from typing import Optional, Union, List
from compute import (TRUSTED_VALIDATORS_HOTKEYS)

# Loads the .env file
load_dotenv()

# Constants
ENABLE_WHITELIST_IPS = False # False for disabling, True for enabling
DEFAULT_SSL_MODE = 2 # 1 for client CERT optional, 2 for client CERT_REQUIRED
DEFAULT_API_PORT = 8903 # default port for the API
DATA_SYNC_PERIOD = 600 # metagraph resync time
Expand All @@ -82,6 +89,22 @@
PUBLIC_WANDB_NAME = "opencompute"
PUBLIC_WANDB_ENTITY = "neuralinternet"

# IP Whitelist middleware
class IPWhitelistMiddleware(BaseHTTPMiddleware):
def __init__(self, app: FastAPI):
super().__init__(app)
self.whitelisted_ips = set(os.getenv("WHITELISTED_IPS", "").split(","))

async def dispatch(self, request: Request, call_next):
# Extracts the client's IP address
client_ip = request.client.host
if client_ip not in self.whitelisted_ips:
bt.logging.info(f"Access attempt from IP: {client_ip}")
raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Access forbidden: IP not whitelisted")

# Process the request and get the response
response = await call_next(request)
return response

class UserConfig(BaseModel):
netuid: str = Field(default="15")
Expand Down Expand Up @@ -266,6 +289,8 @@ def __init__(

load_dotenv()
self._setup_routes()
if ENABLE_WHITELIST_IPS:
self.app.add_middleware(IPWhitelistMiddleware)
self.process = None
self.websocket_connection = None
self.allocation_table = []
Expand Down
Loading