konduktor/pyserve/extensions.py
Илья Глазунов 83cb7d68b0 initial commit
2025-09-01 23:49:50 +03:00

201 lines
7.6 KiB
Python

from abc import ABC, abstractmethod
from typing import Dict, Any, List, Optional
from starlette.requests import Request
from starlette.responses import Response
from .logging_utils import get_logger
logger = get_logger(__name__)
class Extension(ABC):
def __init__(self, config: Dict[str, Any]):
self.config = config
self.enabled = True
@abstractmethod
async def process_request(self, request: Request) -> Optional[Response]:
pass
@abstractmethod
async def process_response(self, request: Request, response: Response) -> Response:
pass
def initialize(self) -> None:
pass
def cleanup(self) -> None:
pass
class RoutingExtension(Extension):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
from .routing import create_router_from_config
regex_locations = config.get("regex_locations", {})
self.router = create_router_from_config(regex_locations)
from .routing import RequestHandler
self.handler = RequestHandler(self.router)
async def process_request(self, request: Request) -> Optional[Response]:
try:
return await self.handler.handle(request)
except Exception as e:
logger.error(f"Ошибка в RoutingExtension: {e}")
return None
async def process_response(self, request: Request, response: Response) -> Response:
return response
class SecurityExtension(Extension):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.allowed_ips = config.get("allowed_ips", [])
self.blocked_ips = config.get("blocked_ips", [])
self.security_headers = config.get("security_headers", {
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block"
})
async def process_request(self, request: Request) -> Optional[Response]:
client_ip = request.client.host if request.client else "unknown"
if self.blocked_ips and client_ip in self.blocked_ips:
logger.warning(f"Заблокирован запрос от IP: {client_ip}")
from starlette.responses import PlainTextResponse
return PlainTextResponse("403 Forbidden", status_code=403)
if self.allowed_ips and client_ip not in self.allowed_ips:
logger.warning(f"Запрещен доступ для IP: {client_ip}")
from starlette.responses import PlainTextResponse
return PlainTextResponse("403 Forbidden", status_code=403)
return None
async def process_response(self, request: Request, response: Response) -> Response:
for header, value in self.security_headers.items():
response.headers[header] = value
return response
class CachingExtension(Extension):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.cache: Dict[str, Any] = {}
self.cache_patterns = config.get("cache_patterns", [])
self.cache_ttl = config.get("cache_ttl", 3600)
async def process_request(self, request: Request) -> Optional[Response]:
# TODO: Реализовать проверку кэша
return None
async def process_response(self, request: Request, response: Response) -> Response:
# TODO: Реализовать кэширование ответов
return response
class MonitoringExtension(Extension):
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self.request_count = 0
self.error_count = 0
self.response_times = []
self.enable_metrics = config.get("enable_metrics", True)
async def process_request(self, request: Request) -> Optional[Response]:
if self.enable_metrics:
self.request_count += 1
request.state.start_time = __import__('time').time()
return None
async def process_response(self, request: Request, response: Response) -> Response:
if self.enable_metrics and hasattr(request.state, 'start_time'):
response_time = __import__('time').time() - request.state.start_time
self.response_times.append(response_time)
if response.status_code >= 400:
self.error_count += 1
logger.info(f"Request: {request.method} {request.url.path} - "
f"Status: {response.status_code} - "
f"Time: {response_time:.3f}s")
return response
def get_metrics(self) -> Dict[str, Any]:
avg_response_time = (sum(self.response_times) / len(self.response_times)
if self.response_times else 0)
return {
"request_count": self.request_count,
"error_count": self.error_count,
"error_rate": self.error_count / max(self.request_count, 1),
"avg_response_time": avg_response_time,
"total_response_times": len(self.response_times)
}
class ExtensionManager:
def __init__(self):
self.extensions: List[Extension] = []
self.extension_registry = {
"routing": RoutingExtension,
"security": SecurityExtension,
"caching": CachingExtension,
"monitoring": MonitoringExtension
}
def register_extension_type(self, name: str, extension_class: type) -> None:
self.extension_registry[name] = extension_class
def load_extension(self, extension_type: str, config: Dict[str, Any]) -> None:
if extension_type not in self.extension_registry:
logger.error(f"Неизвестный тип расширения: {extension_type}")
return
try:
extension_class = self.extension_registry[extension_type]
extension = extension_class(config)
extension.initialize()
self.extensions.append(extension)
logger.info(f"Загружено расширение: {extension_type}")
except Exception as e:
logger.error(f"Ошибка загрузки расширения {extension_type}: {e}")
async def process_request(self, request: Request) -> Optional[Response]:
for extension in self.extensions:
if not extension.enabled:
continue
try:
response = await extension.process_request(request)
if response is not None:
return response
except Exception as e:
logger.error(f"Ошибка в расширении {type(extension).__name__}: {e}")
return None
async def process_response(self, request: Request, response: Response) -> Response:
for extension in self.extensions:
if not extension.enabled:
continue
try:
response = await extension.process_response(request, response)
except Exception as e:
logger.error(f"Ошибка в расширении {type(extension).__name__}: {e}")
return response
def cleanup(self) -> None:
for extension in self.extensions:
try:
extension.cleanup()
except Exception as e:
logger.error(f"Ошибка при очистке расширения {type(extension).__name__}: {e}")
self.extensions.clear()