# ============================================================ # Система: Единая библиотека, Центр ИИ НИУ ВШЭ # Модуль: ExperimentPipeline # Авторы: Полежаев В.А., Хританков А.С. # Дата создания: 2024 г. # ============================================================ import logging import urllib.parse from contextlib import asynccontextmanager from fastapi import FastAPI, Depends, HTTPException from httpx import AsyncClient from jsonschema.validators import Draft202012Validator from kubernetes_asyncio.client import ApiClient from sqlalchemy.ext.asyncio import AsyncSession from starlette.responses import Response from auth import get_user_id, allowed_to_platform_user from config import UNIP_DOMAIN from exp_pipeline.logs import configure_logging from exp_pipeline.pipeline import start_trial, get_trial, get_pipeline_resource, \ list_pipeline_resources, list_trials, add_condition_to_status, get_trial_user_id, OpenAPISpecCache, \ get_api_resource, set_next_started from exp_pipeline.schema import CreateTrialRequest, PipelineTrial, \ PipelineVersionResponse, PipelinesResponse, PipelinesLinks, Link, PipelineLinks, EmbeddedPipeline, \ PipelinesEmbedded, PipelineResponse, TrialsResponse, TrialsLinks, TrialsEmbedded, CreateStatusConditionRequest from exp_pipeline.storage.db import init_db, get_db from kube_config import async_unip_load_kube_config _http_client: AsyncClient | None = None _k8s_api_client: ApiClient | None = None configure_logging() logger = logging.getLogger(__name__) @asynccontextmanager async def _lifespan(app: FastAPI): global _http_client, _k8s_api_client await async_unip_load_kube_config(logger) _http_client = AsyncClient(timeout=30) _k8s_api_client = ApiClient() yield await _http_client.aclose() await _k8s_api_client.close() app = FastAPI(lifespan=_lifespan) # альтернативный вариант, per-request: # async def _get_http_client(): # with AsyncClient() as http_client: # yield http_client # # def path_op(http_client: AsyncClient = Depends(_get_http_client)) # ... @app.on_event("startup") async def init_db_on_start(): await init_db() openapi_spec_cache = OpenAPISpecCache.get_spec_cache() def _validate_create_trial_request(create_trial_request: CreateTrialRequest, openapi_spec: dict): components = openapi_spec['components'] schemas = components['schemas'] openapi_spec = schemas['CreateTrialRequest'] validator = Draft202012Validator(openapi_spec) request_obj = create_trial_request.model_dump(exclude_none=True) if not validator.is_valid(request_obj): errors = sorted(validator.iter_errors(request_obj), key=lambda e: e.path) errors_strs = [] errors_model = [] for error in errors: s, m = list(error.schema_path), error.message error_struct = {'path': s, 'message': m} errors_strs.append('{0}, {1}'.format(s, m)) nested_errors = [] for sub_error in sorted(error.context, key=lambda e: e.schema_path): s, m = list(sub_error.schema_path), sub_error.message errors_strs.append('{0}, {1}'.format(s, m)) nested_errors.append({'path': s, 'message': m}) error_struct['sub'] = nested_errors errors_model.append(error_struct) error_msg = '\n'.join(errors_strs) logger.info(f'Validation error: {error_msg}') raise HTTPException(status_code=400, detail=errors_model) async def start_trial_op(app_name: str, pipeline_name: str, create_trial_request: CreateTrialRequest, db: AsyncSession, user_id: str): pipeline_resource = await get_pipeline_resource(api_client=_k8s_api_client, app_name=app_name, pipeline_name=pipeline_name) api_resource = await get_api_resource(api_client=_k8s_api_client, app_name=app_name, pipeline_name=pipeline_name) openapi_spec = await openapi_spec_cache.get_pipeline_api_spec(k8s_api_client=_k8s_api_client, app_name=app_name, api_resource=api_resource) _validate_create_trial_request(create_trial_request, openapi_spec) trial = await start_trial(db=db, http_client=_http_client, api_client=_k8s_api_client, app_name=app_name, user_id=user_id, pipeline_name=pipeline_name, api_resource=api_resource, pipeline_resource=pipeline_resource, trial_inputs=create_trial_request.inputs, trial_output_vars=create_trial_request.output_vars) return trial @app.post("/{app_name}/pipelines/{pipeline_name}/trials", response_model=PipelineTrial) async def post_trial_op(app_name: str, pipeline_name: str, create_trial_request: CreateTrialRequest, db: AsyncSession = Depends(get_db), user_id: str = Depends(get_user_id)): trial = await start_trial_op(app_name=app_name, pipeline_name=pipeline_name, create_trial_request=create_trial_request, db=db, user_id=user_id) return trial @app.post("/{app_name}/pipelines/{pipeline_name}/trials/continue", response_model=PipelineTrial, dependencies=[Depends(allowed_to_platform_user)]) async def post_next_trial_op(app_name: str, pipeline_name: str, next_pipeline_name: str, tracking_id: str, create_trial_request: CreateTrialRequest, db: AsyncSession = Depends(get_db)): user_id = await get_trial_user_id(db=db, app_name=app_name, pipeline_name=None, tracking_id=tracking_id) next_trial = await start_trial_op(app_name=app_name, pipeline_name=next_pipeline_name, create_trial_request=create_trial_request, db=db, user_id=user_id) await set_next_started(db=db, tracking_id=tracking_id, app_name=app_name, pipeline_name=pipeline_name, next_tracking_id=next_trial.tracking_id) return next_trial @app.get("/{app_name}/pipelines/{pipeline_name}/check", dependencies=[Depends(allowed_to_platform_user)]) async def check_op(app_name: str, pipeline_name: str): return Response() @app.get("/{app_name}/trials/{tracking_id}", response_model=PipelineTrial) async def get_trial_op(app_name: str, tracking_id: str, db: AsyncSession = Depends(get_db), user_id: str = Depends(get_user_id)): trial = await get_trial(db=db, app_name=app_name, pipeline_name=None, tracking_id=tracking_id, user_id=user_id) return trial @app.post("/{app_name}/pipelines/{pipeline_name}/trials/{tracking_id}/status/conditions", dependencies=[Depends(allowed_to_platform_user)]) async def post_trial_status_condition_op(response: Response, app_name: str, pipeline_name: str, tracking_id: str, create_condition_request: CreateStatusConditionRequest, db: AsyncSession = Depends(get_db)): modified = await add_condition_to_status(db=db, tracking_id=tracking_id, type_=create_condition_request.type, message=create_condition_request.message, reason=create_condition_request.reason, transition_time=create_condition_request.transition_time, stage=create_condition_request.stage, app_name=app_name, pipeline_name=pipeline_name) response.status_code = 200 if modified else 204 return @app.get('/{app_name}/pipelines/{pipeline_name}/version', response_model=PipelineVersionResponse, dependencies=[Depends(get_user_id)]) async def get_pipeline_version_op(app_name: str, pipeline_name: str): api_resource = await get_api_resource(api_client=_k8s_api_client, app_name=app_name, pipeline_name=pipeline_name) try: openapi_spec = await openapi_spec_cache.get_pipeline_api_spec(k8s_api_client=_k8s_api_client, app_name=app_name, api_resource=api_resource) except Exception as exc: logger.exception('OpenAPI specification generation error, ' 'specification not available', exc_info=exc) openapi_spec = None return PipelineVersionResponse(openapi_spec=openapi_spec, license=None) def _get_list_request_self_link(cursor, limit, uri): query_params = [] if cursor: query_params.append(f'cursor={cursor}') if limit: query_params.append(f'limit={limit}') self_href = urllib.parse.urljoin(f'https://{UNIP_DOMAIN}', uri) if query_params: query = '&'.join(query_params) self_href += f'?{query}' self_link = Link(href=self_href) return self_link def _get_list_request_next_link(new_cursor, limit_set, uri): if new_cursor: query_params = [f'cursor={new_cursor}', f'limit={limit_set}'] query = '&'.join(query_params) cursor_href = urllib.parse.urljoin(f'https://{UNIP_DOMAIN}', f'{uri}?{query}') cursor_link = Link(href=cursor_href) else: cursor_link = None return cursor_link @app.get('/{app_name}/pipelines/{pipeline_name}/trials', response_model=TrialsResponse, dependencies=[Depends(get_user_id)]) async def get_pipeline_trials_op(app_name: str, pipeline_name: str, db: AsyncSession = Depends(get_db), user_id: str = Depends(get_user_id), cursor: int | None = None, limit: int | None = None): limit_set = limit or 10 if limit_set > 30: limit_set = 30 if limit_set < 0: limit_set = 10 trials, new_cursor, counts = await list_trials(db=db, app_name=app_name, pipeline_name=pipeline_name, user_id=user_id, cursor=cursor, limit=limit_set) total_item_count, remaining_item_count = counts uri = f'{app_name}/pipelines/{pipeline_name}/trials' self_link = _get_list_request_self_link(cursor, limit, uri) cursor_link = _get_list_request_next_link(new_cursor, limit_set, uri) links = TrialsLinks(self=self_link, next=cursor_link) te = TrialsEmbedded(trials=trials) resp = TrialsResponse(_links=links, total_item_count=total_item_count, remaining_item_count=remaining_item_count, _embedded=te) return resp @app.get('/{app_name}/trials', response_model=TrialsResponse, dependencies=[Depends(get_user_id)]) async def get_trials_op(app_name: str, db: AsyncSession = Depends(get_db), user_id: str = Depends(get_user_id), cursor: int | None = None, limit: int | None = None): limit_set = limit or 10 if limit_set > 30: limit_set = 30 if limit_set < 0: limit_set = 10 trials, new_cursor, counts = await list_trials(db=db, app_name=app_name, user_id=user_id, cursor=cursor, limit=limit_set) total_item_count, remaining_item_count = counts uri = f'{app_name}/trials' self_link = _get_list_request_self_link(cursor, limit, uri) cursor_link = _get_list_request_next_link(new_cursor, limit_set, uri) links = TrialsLinks(self=self_link, next=cursor_link) te = TrialsEmbedded(trials=trials) resp = TrialsResponse(_links=links, total_item_count=total_item_count, remaining_item_count=remaining_item_count, _embedded=te) return resp def _get_pipeline_links(app_name, pipeline_name): self_href = urllib.parse.urljoin(f'https://{UNIP_DOMAIN}', f'{app_name}/pipelines/{pipeline_name}') self_link = Link(href=self_href) trials_href = f'{self_href}/trials' trials_link = Link(href=trials_href) version_href = f'{self_href}/version' version_link = Link(href=version_href) pl = PipelineLinks(self=self_link, trials=trials_link, version=version_link) return pl @app.get('/{app_name}/pipelines', response_model=PipelinesResponse, dependencies=[Depends(get_user_id)]) async def get_pipelines_op(response: Response, app_name: str, cursor: str | None = None, limit: int | None = None): limit_set = limit or 10 pipelines, new_cursor, remaining_item_count, expired = await list_pipeline_resources(api_client=_k8s_api_client, app_name=app_name, cursor=cursor, limit=limit_set) uri = f'{app_name}/pipelines' self_link = _get_list_request_self_link(cursor, limit, uri) cursor_link = _get_list_request_next_link(new_cursor, limit_set, uri) links = PipelinesLinks(self=self_link, next=cursor_link) embedded_pipelines = [] for p in pipelines: pipeline_name = p['metadata']['name'] pl = _get_pipeline_links(app_name, pipeline_name) ep = EmbeddedPipeline(name=pipeline_name, _links=pl) embedded_pipelines.append(ep) pe = PipelinesEmbedded(pipelines=embedded_pipelines) resp = PipelinesResponse(_links=links, remaining_item_count=remaining_item_count, _embedded=pe) if expired: response.status_code = 410 return resp @app.get('/{app_name}/pipelines/{pipeline_name}', response_model=PipelineResponse, dependencies=[Depends(get_user_id)]) async def get_pipeline_op(app_name: str, pipeline_name: str): pipeline_res = await get_pipeline_resource(api_client=_k8s_api_client, app_name=app_name, pipeline_name=pipeline_name) pipeline_name = pipeline_res['metadata']['name'] pipeline_links = _get_pipeline_links(app_name, pipeline_name) pipeline_definition = pipeline_res['spec'] resp = PipelineResponse(name=pipeline_name, _links=pipeline_links, definition=pipeline_definition) return resp