unip-controller/controller/src/exp_pipeline/api.py

401 lines
16 KiB
Python
Raw Normal View History

2025-01-29 13:13:51 +00:00
# ============================================================
# Система: Единая библиотека, Центр ИИ НИУ ВШЭ
# Модуль: ExperimentPipeline
# Авторы: Полежаев В.А., Хританков А.С.
# Дата создания: 2024 г.
# ============================================================
import logging
import urllib.parse
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException
from httpx import AsyncClient
from jsonschema.validators import Draft202012Validator
from kubernetes_asyncio.client import ApiClient
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.responses import Response
from auth import get_user_id, allowed_to_platform_user
from config import UNIP_DOMAIN
from exp_pipeline.logs import configure_logging
from exp_pipeline.pipeline import start_trial, get_trial, get_pipeline_resource, \
list_pipeline_resources, list_trials, add_condition_to_status, get_trial_user_id, OpenAPISpecCache, \
get_api_resource, set_next_started
from exp_pipeline.schema import CreateTrialRequest, PipelineTrial, \
PipelineVersionResponse, PipelinesResponse, PipelinesLinks, Link, PipelineLinks, EmbeddedPipeline, \
PipelinesEmbedded, PipelineResponse, TrialsResponse, TrialsLinks, TrialsEmbedded, CreateStatusConditionRequest
from exp_pipeline.storage.db import init_db, get_db
from kube_config import async_unip_load_kube_config
_http_client: AsyncClient | None = None
_k8s_api_client: ApiClient | None = None
configure_logging()
logger = logging.getLogger(__name__)
@asynccontextmanager
async def _lifespan(app: FastAPI):
global _http_client, _k8s_api_client
await async_unip_load_kube_config(logger)
_http_client = AsyncClient(timeout=30)
_k8s_api_client = ApiClient()
yield
await _http_client.aclose()
await _k8s_api_client.close()
app = FastAPI(lifespan=_lifespan)
# альтернативный вариант, per-request:
# async def _get_http_client():
# with AsyncClient() as http_client:
# yield http_client
#
# def path_op(http_client: AsyncClient = Depends(_get_http_client))
# ...
@app.on_event("startup")
async def init_db_on_start():
await init_db()
openapi_spec_cache = OpenAPISpecCache.get_spec_cache()
def _validate_create_trial_request(create_trial_request: CreateTrialRequest, openapi_spec: dict):
components = openapi_spec['components']
schemas = components['schemas']
openapi_spec = schemas['CreateTrialRequest']
validator = Draft202012Validator(openapi_spec)
request_obj = create_trial_request.model_dump(exclude_none=True)
if not validator.is_valid(request_obj):
errors = sorted(validator.iter_errors(request_obj), key=lambda e: e.path)
errors_strs = []
errors_model = []
for error in errors:
s, m = list(error.schema_path), error.message
error_struct = {'path': s, 'message': m}
errors_strs.append('{0}, {1}'.format(s, m))
nested_errors = []
for sub_error in sorted(error.context, key=lambda e: e.schema_path):
s, m = list(sub_error.schema_path), sub_error.message
errors_strs.append('{0}, {1}'.format(s, m))
nested_errors.append({'path': s, 'message': m})
error_struct['sub'] = nested_errors
errors_model.append(error_struct)
error_msg = '\n'.join(errors_strs)
logger.info(f'Validation error: {error_msg}')
raise HTTPException(status_code=400, detail=errors_model)
async def start_trial_op(app_name: str,
pipeline_name: str,
create_trial_request: CreateTrialRequest,
db: AsyncSession,
user_id: str):
pipeline_resource = await get_pipeline_resource(api_client=_k8s_api_client,
app_name=app_name,
pipeline_name=pipeline_name)
api_resource = await get_api_resource(api_client=_k8s_api_client,
app_name=app_name,
pipeline_name=pipeline_name)
openapi_spec = await openapi_spec_cache.get_pipeline_api_spec(k8s_api_client=_k8s_api_client,
app_name=app_name,
api_resource=api_resource)
_validate_create_trial_request(create_trial_request, openapi_spec)
trial = await start_trial(db=db,
http_client=_http_client,
api_client=_k8s_api_client,
app_name=app_name,
user_id=user_id,
pipeline_name=pipeline_name,
api_resource=api_resource,
pipeline_resource=pipeline_resource,
trial_inputs=create_trial_request.inputs,
trial_output_vars=create_trial_request.output_vars)
return trial
@app.post("/{app_name}/pipelines/{pipeline_name}/trials", response_model=PipelineTrial)
async def post_trial_op(app_name: str,
pipeline_name: str,
create_trial_request: CreateTrialRequest,
db: AsyncSession = Depends(get_db),
user_id: str = Depends(get_user_id)):
trial = await start_trial_op(app_name=app_name,
pipeline_name=pipeline_name,
create_trial_request=create_trial_request,
db=db,
user_id=user_id)
return trial
@app.post("/{app_name}/pipelines/{pipeline_name}/trials/continue",
response_model=PipelineTrial,
dependencies=[Depends(allowed_to_platform_user)])
async def post_next_trial_op(app_name: str,
pipeline_name: str,
next_pipeline_name: str,
tracking_id: str,
create_trial_request: CreateTrialRequest,
db: AsyncSession = Depends(get_db)):
user_id = await get_trial_user_id(db=db,
app_name=app_name,
pipeline_name=None,
tracking_id=tracking_id)
next_trial = await start_trial_op(app_name=app_name,
pipeline_name=next_pipeline_name,
create_trial_request=create_trial_request,
db=db,
user_id=user_id)
await set_next_started(db=db,
tracking_id=tracking_id,
app_name=app_name,
pipeline_name=pipeline_name,
next_tracking_id=next_trial.tracking_id)
return next_trial
@app.get("/{app_name}/pipelines/{pipeline_name}/check",
dependencies=[Depends(allowed_to_platform_user)])
async def check_op(app_name: str,
pipeline_name: str):
return Response()
@app.get("/{app_name}/trials/{tracking_id}",
response_model=PipelineTrial)
async def get_trial_op(app_name: str,
tracking_id: str,
db: AsyncSession = Depends(get_db),
user_id: str = Depends(get_user_id)):
trial = await get_trial(db=db,
app_name=app_name,
pipeline_name=None,
tracking_id=tracking_id,
user_id=user_id)
return trial
@app.post("/{app_name}/pipelines/{pipeline_name}/trials/{tracking_id}/status/conditions",
dependencies=[Depends(allowed_to_platform_user)])
async def post_trial_status_condition_op(response: Response,
app_name: str,
pipeline_name: str,
tracking_id: str,
create_condition_request: CreateStatusConditionRequest,
db: AsyncSession = Depends(get_db)):
modified = await add_condition_to_status(db=db,
tracking_id=tracking_id,
type_=create_condition_request.type,
message=create_condition_request.message,
reason=create_condition_request.reason,
transition_time=create_condition_request.transition_time,
stage=create_condition_request.stage,
app_name=app_name,
pipeline_name=pipeline_name)
response.status_code = 200 if modified else 204
return
@app.get('/{app_name}/pipelines/{pipeline_name}/version',
response_model=PipelineVersionResponse,
dependencies=[Depends(get_user_id)])
async def get_pipeline_version_op(app_name: str,
pipeline_name: str):
api_resource = await get_api_resource(api_client=_k8s_api_client,
app_name=app_name,
pipeline_name=pipeline_name)
try:
openapi_spec = await openapi_spec_cache.get_pipeline_api_spec(k8s_api_client=_k8s_api_client,
app_name=app_name,
api_resource=api_resource)
except Exception as exc:
logger.exception('OpenAPI specification generation error, '
'specification not available', exc_info=exc)
openapi_spec = None
return PipelineVersionResponse(openapi_spec=openapi_spec,
license=None)
def _get_list_request_self_link(cursor, limit, uri):
query_params = []
if cursor:
query_params.append(f'cursor={cursor}')
if limit:
query_params.append(f'limit={limit}')
self_href = urllib.parse.urljoin(f'https://{UNIP_DOMAIN}', uri)
if query_params:
query = '&'.join(query_params)
self_href += f'?{query}'
self_link = Link(href=self_href)
return self_link
def _get_list_request_next_link(new_cursor, limit_set, uri):
if new_cursor:
query_params = [f'cursor={new_cursor}', f'limit={limit_set}']
query = '&'.join(query_params)
cursor_href = urllib.parse.urljoin(f'https://{UNIP_DOMAIN}', f'{uri}?{query}')
cursor_link = Link(href=cursor_href)
else:
cursor_link = None
return cursor_link
@app.get('/{app_name}/pipelines/{pipeline_name}/trials',
response_model=TrialsResponse,
dependencies=[Depends(get_user_id)])
async def get_pipeline_trials_op(app_name: str,
pipeline_name: str,
db: AsyncSession = Depends(get_db),
user_id: str = Depends(get_user_id),
cursor: int | None = None,
limit: int | None = None):
limit_set = limit or 10
if limit_set > 30:
limit_set = 30
if limit_set < 0:
limit_set = 10
trials, new_cursor, counts = await list_trials(db=db,
app_name=app_name,
pipeline_name=pipeline_name,
user_id=user_id,
cursor=cursor,
limit=limit_set)
total_item_count, remaining_item_count = counts
uri = f'{app_name}/pipelines/{pipeline_name}/trials'
self_link = _get_list_request_self_link(cursor, limit, uri)
cursor_link = _get_list_request_next_link(new_cursor, limit_set, uri)
links = TrialsLinks(self=self_link, next=cursor_link)
te = TrialsEmbedded(trials=trials)
resp = TrialsResponse(_links=links,
total_item_count=total_item_count,
remaining_item_count=remaining_item_count,
_embedded=te)
return resp
@app.get('/{app_name}/trials',
response_model=TrialsResponse,
dependencies=[Depends(get_user_id)])
async def get_trials_op(app_name: str,
db: AsyncSession = Depends(get_db),
user_id: str = Depends(get_user_id),
cursor: int | None = None,
limit: int | None = None):
limit_set = limit or 10
if limit_set > 30:
limit_set = 30
if limit_set < 0:
limit_set = 10
trials, new_cursor, counts = await list_trials(db=db,
app_name=app_name,
user_id=user_id,
cursor=cursor,
limit=limit_set)
total_item_count, remaining_item_count = counts
uri = f'{app_name}/trials'
self_link = _get_list_request_self_link(cursor, limit, uri)
cursor_link = _get_list_request_next_link(new_cursor, limit_set, uri)
links = TrialsLinks(self=self_link, next=cursor_link)
te = TrialsEmbedded(trials=trials)
resp = TrialsResponse(_links=links,
total_item_count=total_item_count,
remaining_item_count=remaining_item_count,
_embedded=te)
return resp
def _get_pipeline_links(app_name, pipeline_name):
self_href = urllib.parse.urljoin(f'https://{UNIP_DOMAIN}', f'{app_name}/pipelines/{pipeline_name}')
self_link = Link(href=self_href)
trials_href = f'{self_href}/trials'
trials_link = Link(href=trials_href)
version_href = f'{self_href}/version'
version_link = Link(href=version_href)
pl = PipelineLinks(self=self_link, trials=trials_link, version=version_link)
return pl
@app.get('/{app_name}/pipelines',
response_model=PipelinesResponse,
dependencies=[Depends(get_user_id)])
async def get_pipelines_op(response: Response,
app_name: str,
cursor: str | None = None,
limit: int | None = None):
limit_set = limit or 10
pipelines, new_cursor, remaining_item_count, expired = await list_pipeline_resources(api_client=_k8s_api_client,
app_name=app_name,
cursor=cursor,
limit=limit_set)
uri = f'{app_name}/pipelines'
self_link = _get_list_request_self_link(cursor, limit, uri)
cursor_link = _get_list_request_next_link(new_cursor, limit_set, uri)
links = PipelinesLinks(self=self_link, next=cursor_link)
embedded_pipelines = []
for p in pipelines:
pipeline_name = p['metadata']['name']
pl = _get_pipeline_links(app_name, pipeline_name)
ep = EmbeddedPipeline(name=pipeline_name, _links=pl)
embedded_pipelines.append(ep)
pe = PipelinesEmbedded(pipelines=embedded_pipelines)
resp = PipelinesResponse(_links=links, remaining_item_count=remaining_item_count, _embedded=pe)
if expired:
response.status_code = 410
return resp
@app.get('/{app_name}/pipelines/{pipeline_name}',
response_model=PipelineResponse,
dependencies=[Depends(get_user_id)])
async def get_pipeline_op(app_name: str,
pipeline_name: str):
pipeline_res = await get_pipeline_resource(api_client=_k8s_api_client,
app_name=app_name,
pipeline_name=pipeline_name)
pipeline_name = pipeline_res['metadata']['name']
pipeline_links = _get_pipeline_links(app_name, pipeline_name)
pipeline_definition = pipeline_res['spec']
resp = PipelineResponse(name=pipeline_name, _links=pipeline_links, definition=pipeline_definition)
return resp