From 71f98b8797888f2f206e11b3284840453fd0d9fc Mon Sep 17 00:00:00 2001 From: Eduardo Date: Wed, 15 Jan 2025 23:19:28 -0300 Subject: [PATCH] feat: only whitelisted ips can access to register api. --- neurons/register_api.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/neurons/register_api.py b/neurons/register_api.py index 0daff4e7..709533dc 100644 --- a/neurons/register_api.py +++ b/neurons/register_api.py @@ -56,7 +56,7 @@ # Import FastAPI Libraries import uvicorn from fastapi import ( - FastAPI, + FastAPI, HTTPException, status, Request, WebSocket, @@ -67,10 +67,17 @@ from fastapi.exceptions import RequestValidationError from fastapi.concurrency import run_in_threadpool from pydantic import BaseModel, Field +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.status import HTTP_403_FORBIDDEN +from dotenv import load_dotenv from typing import Optional, Union, List from compute import (TRUSTED_VALIDATORS_HOTKEYS) +# Loads the .env file +load_dotenv() + # Constants +ENABLE_WHITELIST_IPS = False # False for disabling, True for enabling DEFAULT_SSL_MODE = 2 # 1 for client CERT optional, 2 for client CERT_REQUIRED DEFAULT_API_PORT = 8903 # default port for the API DATA_SYNC_PERIOD = 600 # metagraph resync time @@ -82,6 +89,22 @@ PUBLIC_WANDB_NAME = "opencompute" PUBLIC_WANDB_ENTITY = "neuralinternet" +# IP Whitelist middleware +class IPWhitelistMiddleware(BaseHTTPMiddleware): + def __init__(self, app: FastAPI): + super().__init__(app) + self.whitelisted_ips = set(os.getenv("WHITELISTED_IPS", "").split(",")) + + async def dispatch(self, request: Request, call_next): + # Extracts the client's IP address + client_ip = request.client.host + if client_ip not in self.whitelisted_ips: + bt.logging.info(f"Access attempt from IP: {client_ip}") + raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Access forbidden: IP not whitelisted") + + # Process the request and get the response + response = await call_next(request) + return response class UserConfig(BaseModel): netuid: str = Field(default="15") @@ -266,6 +289,8 @@ def __init__( load_dotenv() self._setup_routes() + if ENABLE_WHITELIST_IPS: + self.app.add_middleware(IPWhitelistMiddleware) self.process = None self.websocket_connection = None self.allocation_table = []