unip-controller/controller/src/exp_pipeline/api.py
2025-04-15 20:56:15 +03:00

400 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ============================================================
# Система: Единая библиотека, Центр ИИ НИУ ВШЭ
# Модуль: ExperimentPipeline
# Авторы: Полежаев В.А., Хританков А.С.
# Дата создания: 2024 г.
# ============================================================
import logging
import urllib.parse
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException
from httpx import AsyncClient
from jsonschema.validators import Draft202012Validator
from kubernetes_asyncio.client import ApiClient
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.responses import Response
from auth import get_user_id, allowed_to_platform_user
from config import UNIP_DOMAIN
from exp_pipeline.logs import configure_logging
from exp_pipeline.pipeline import start_trial, get_trial, get_pipeline_resource, \
list_pipeline_resources, list_trials, add_condition_to_status, get_trial_user_id, OpenAPISpecCache, \
get_api_resource, set_next_started
from exp_pipeline.schema import CreateTrialRequest, PipelineTrial, \
PipelineVersionResponse, PipelinesResponse, PipelinesLinks, Link, PipelineLinks, EmbeddedPipeline, \
PipelinesEmbedded, PipelineResponse, TrialsResponse, TrialsLinks, TrialsEmbedded, CreateStatusConditionRequest
from exp_pipeline.storage.db import init_db, get_db
from kube_config import async_unip_load_kube_config
_http_client: AsyncClient | None = None
_k8s_api_client: ApiClient | None = None
configure_logging()
logger = logging.getLogger(__name__)
@asynccontextmanager
async def _lifespan(app: FastAPI):
global _http_client, _k8s_api_client
await async_unip_load_kube_config(logger)
_http_client = AsyncClient(timeout=30)
_k8s_api_client = ApiClient()
yield
await _http_client.aclose()
await _k8s_api_client.close()
app = FastAPI(lifespan=_lifespan)
# альтернативный вариант, per-request:
# async def _get_http_client():
# with AsyncClient() as http_client:
# yield http_client
#
# def path_op(http_client: AsyncClient = Depends(_get_http_client))
# ...
@app.on_event("startup")
async def init_db_on_start():
await init_db()
openapi_spec_cache = OpenAPISpecCache.get_spec_cache()
def _validate_create_trial_request(create_trial_request: CreateTrialRequest, openapi_spec: dict):
components = openapi_spec['components']
schemas = components['schemas']
openapi_spec = schemas['CreateTrialRequest']
validator = Draft202012Validator(openapi_spec)
request_obj = create_trial_request.model_dump(exclude_none=True)
if not validator.is_valid(request_obj):
errors = sorted(validator.iter_errors(request_obj), key=lambda e: e.path)
errors_strs = []
errors_model = []
for error in errors:
s, m = list(error.schema_path), error.message
error_struct = {'path': s, 'message': m}
errors_strs.append('{0}, {1}'.format(s, m))
nested_errors = []
for sub_error in sorted(error.context, key=lambda e: e.schema_path):
s, m = list(sub_error.schema_path), sub_error.message
errors_strs.append('{0}, {1}'.format(s, m))
nested_errors.append({'path': s, 'message': m})
error_struct['sub'] = nested_errors
errors_model.append(error_struct)
error_msg = '\n'.join(errors_strs)
logger.info(f'Validation error: {error_msg}')
raise HTTPException(status_code=400, detail=errors_model)
async def start_trial_op(app_name: str,
pipeline_name: str,
create_trial_request: CreateTrialRequest,
db: AsyncSession,
user_id: str):
pipeline_resource = await get_pipeline_resource(api_client=_k8s_api_client,
app_name=app_name,
pipeline_name=pipeline_name)
api_resource = await get_api_resource(api_client=_k8s_api_client,
app_name=app_name,
pipeline_name=pipeline_name)
openapi_spec = await openapi_spec_cache.get_pipeline_api_spec(k8s_api_client=_k8s_api_client,
app_name=app_name,
api_resource=api_resource)
_validate_create_trial_request(create_trial_request, openapi_spec)
trial = await start_trial(db=db,
http_client=_http_client,
api_client=_k8s_api_client,
app_name=app_name,
user_id=user_id,
pipeline_name=pipeline_name,
api_resource=api_resource,
pipeline_resource=pipeline_resource,
trial_inputs=create_trial_request.inputs,
trial_output_vars=create_trial_request.output_vars)
return trial
@app.post("/{app_name}/pipelines/{pipeline_name}/trials", response_model=PipelineTrial)
async def post_trial_op(app_name: str,
pipeline_name: str,
create_trial_request: CreateTrialRequest,
db: AsyncSession = Depends(get_db),
user_id: str = Depends(get_user_id)):
trial = await start_trial_op(app_name=app_name,
pipeline_name=pipeline_name,
create_trial_request=create_trial_request,
db=db,
user_id=user_id)
return trial
@app.post("/{app_name}/pipelines/{pipeline_name}/trials/continue",
response_model=PipelineTrial,
dependencies=[Depends(allowed_to_platform_user)])
async def post_next_trial_op(app_name: str,
pipeline_name: str,
next_pipeline_name: str,
tracking_id: str,
create_trial_request: CreateTrialRequest,
db: AsyncSession = Depends(get_db)):
user_id = await get_trial_user_id(db=db,
app_name=app_name,
pipeline_name=None,
tracking_id=tracking_id)
next_trial = await start_trial_op(app_name=app_name,
pipeline_name=next_pipeline_name,
create_trial_request=create_trial_request,
db=db,
user_id=user_id)
await set_next_started(db=db,
tracking_id=tracking_id,
app_name=app_name,
pipeline_name=pipeline_name,
next_tracking_id=next_trial.tracking_id)
return next_trial
@app.get("/{app_name}/pipelines/{pipeline_name}/check",
dependencies=[Depends(allowed_to_platform_user)])
async def check_op(app_name: str,
pipeline_name: str):
return Response()
@app.get("/{app_name}/trials/{tracking_id}",
response_model=PipelineTrial)
async def get_trial_op(app_name: str,
tracking_id: str,
db: AsyncSession = Depends(get_db),
user_id: str = Depends(get_user_id)):
trial = await get_trial(db=db,
app_name=app_name,
pipeline_name=None,
tracking_id=tracking_id,
user_id=user_id)
return trial
@app.post("/{app_name}/pipelines/{pipeline_name}/trials/{tracking_id}/status/conditions",
dependencies=[Depends(allowed_to_platform_user)])
async def post_trial_status_condition_op(response: Response,
app_name: str,
pipeline_name: str,
tracking_id: str,
create_condition_request: CreateStatusConditionRequest,
db: AsyncSession = Depends(get_db)):
modified = await add_condition_to_status(db=db,
tracking_id=tracking_id,
type_=create_condition_request.type,
message=create_condition_request.message,
reason=create_condition_request.reason,
transition_time=create_condition_request.transition_time,
stage=create_condition_request.stage,
app_name=app_name,
pipeline_name=pipeline_name)
response.status_code = 200 if modified else 204
return
@app.get('/{app_name}/pipelines/{pipeline_name}/version',
response_model=PipelineVersionResponse,
dependencies=[Depends(get_user_id)])
async def get_pipeline_version_op(app_name: str,
pipeline_name: str):
api_resource = await get_api_resource(api_client=_k8s_api_client,
app_name=app_name,
pipeline_name=pipeline_name)
try:
openapi_spec = await openapi_spec_cache.get_pipeline_api_spec(k8s_api_client=_k8s_api_client,
app_name=app_name,
api_resource=api_resource)
except Exception as exc:
logger.exception('OpenAPI specification generation error, '
'specification not available', exc_info=exc)
openapi_spec = None
return PipelineVersionResponse(openapi_spec=openapi_spec,
license=None)
def _get_list_request_self_link(cursor, limit, uri):
query_params = []
if cursor:
query_params.append(f'cursor={cursor}')
if limit:
query_params.append(f'limit={limit}')
self_href = urllib.parse.urljoin(f'https://{UNIP_DOMAIN}', uri)
if query_params:
query = '&'.join(query_params)
self_href += f'?{query}'
self_link = Link(href=self_href)
return self_link
def _get_list_request_next_link(new_cursor, limit_set, uri):
if new_cursor:
query_params = [f'cursor={new_cursor}', f'limit={limit_set}']
query = '&'.join(query_params)
cursor_href = urllib.parse.urljoin(f'https://{UNIP_DOMAIN}', f'{uri}?{query}')
cursor_link = Link(href=cursor_href)
else:
cursor_link = None
return cursor_link
@app.get('/{app_name}/pipelines/{pipeline_name}/trials',
response_model=TrialsResponse,
dependencies=[Depends(get_user_id)])
async def get_pipeline_trials_op(app_name: str,
pipeline_name: str,
db: AsyncSession = Depends(get_db),
user_id: str = Depends(get_user_id),
cursor: int | None = None,
limit: int | None = None):
limit_set = limit or 10
if limit_set > 30:
limit_set = 30
if limit_set < 0:
limit_set = 10
trials, new_cursor, counts = await list_trials(db=db,
app_name=app_name,
pipeline_name=pipeline_name,
user_id=user_id,
cursor=cursor,
limit=limit_set)
total_item_count, remaining_item_count = counts
uri = f'{app_name}/pipelines/{pipeline_name}/trials'
self_link = _get_list_request_self_link(cursor, limit, uri)
cursor_link = _get_list_request_next_link(new_cursor, limit_set, uri)
links = TrialsLinks(self=self_link, next=cursor_link)
te = TrialsEmbedded(trials=trials)
resp = TrialsResponse(_links=links,
total_item_count=total_item_count,
remaining_item_count=remaining_item_count,
_embedded=te)
return resp
@app.get('/{app_name}/trials',
response_model=TrialsResponse,
dependencies=[Depends(get_user_id)])
async def get_trials_op(app_name: str,
db: AsyncSession = Depends(get_db),
user_id: str = Depends(get_user_id),
cursor: int | None = None,
limit: int | None = None):
limit_set = limit or 10
if limit_set > 30:
limit_set = 30
if limit_set < 0:
limit_set = 10
trials, new_cursor, counts = await list_trials(db=db,
app_name=app_name,
user_id=user_id,
cursor=cursor,
limit=limit_set)
total_item_count, remaining_item_count = counts
uri = f'{app_name}/trials'
self_link = _get_list_request_self_link(cursor, limit, uri)
cursor_link = _get_list_request_next_link(new_cursor, limit_set, uri)
links = TrialsLinks(self=self_link, next=cursor_link)
te = TrialsEmbedded(trials=trials)
resp = TrialsResponse(_links=links,
total_item_count=total_item_count,
remaining_item_count=remaining_item_count,
_embedded=te)
return resp
def _get_pipeline_links(app_name, pipeline_name):
self_href = urllib.parse.urljoin(f'https://{UNIP_DOMAIN}', f'{app_name}/pipelines/{pipeline_name}')
self_link = Link(href=self_href)
trials_href = f'{self_href}/trials'
trials_link = Link(href=trials_href)
version_href = f'{self_href}/version'
version_link = Link(href=version_href)
pl = PipelineLinks(self=self_link, trials=trials_link, version=version_link)
return pl
@app.get('/{app_name}/pipelines',
response_model=PipelinesResponse,
dependencies=[Depends(get_user_id)])
async def get_pipelines_op(response: Response,
app_name: str,
cursor: str | None = None,
limit: int | None = None):
limit_set = limit or 10
pipelines, new_cursor, remaining_item_count, expired = await list_pipeline_resources(api_client=_k8s_api_client,
app_name=app_name,
cursor=cursor,
limit=limit_set)
uri = f'{app_name}/pipelines'
self_link = _get_list_request_self_link(cursor, limit, uri)
cursor_link = _get_list_request_next_link(new_cursor, limit_set, uri)
links = PipelinesLinks(self=self_link, next=cursor_link)
embedded_pipelines = []
for p in pipelines:
pipeline_name = p['metadata']['name']
pl = _get_pipeline_links(app_name, pipeline_name)
ep = EmbeddedPipeline(name=pipeline_name, _links=pl)
embedded_pipelines.append(ep)
pe = PipelinesEmbedded(pipelines=embedded_pipelines)
resp = PipelinesResponse(_links=links, remaining_item_count=remaining_item_count, _embedded=pe)
if expired:
response.status_code = 410
return resp
@app.get('/{app_name}/pipelines/{pipeline_name}',
response_model=PipelineResponse,
dependencies=[Depends(get_user_id)])
async def get_pipeline_op(app_name: str,
pipeline_name: str):
pipeline_res = await get_pipeline_resource(api_client=_k8s_api_client,
app_name=app_name,
pipeline_name=pipeline_name)
pipeline_name = pipeline_res['metadata']['name']
pipeline_links = _get_pipeline_links(app_name, pipeline_name)
pipeline_definition = pipeline_res['spec']
resp = PipelineResponse(name=pipeline_name, _links=pipeline_links, definition=pipeline_definition)
return resp