400 lines
16 KiB
Python
400 lines
16 KiB
Python
# ============================================================
|
||
# Система: Единая библиотека, Центр ИИ НИУ ВШЭ
|
||
# Модуль: ExperimentPipeline
|
||
# Авторы: Полежаев В.А., Хританков А.С.
|
||
# Дата создания: 2024 г.
|
||
# ============================================================
|
||
import logging
|
||
import urllib.parse
|
||
from contextlib import asynccontextmanager
|
||
|
||
from fastapi import FastAPI, Depends, HTTPException
|
||
from httpx import AsyncClient
|
||
from jsonschema.validators import Draft202012Validator
|
||
from kubernetes_asyncio.client import ApiClient
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from starlette.responses import Response
|
||
|
||
from auth import get_user_id, allowed_to_platform_user
|
||
from config import UNIP_DOMAIN
|
||
from exp_pipeline.logs import configure_logging
|
||
from exp_pipeline.pipeline import start_trial, get_trial, get_pipeline_resource, \
|
||
list_pipeline_resources, list_trials, add_condition_to_status, get_trial_user_id, OpenAPISpecCache, \
|
||
get_api_resource, set_next_started
|
||
from exp_pipeline.schema import CreateTrialRequest, PipelineTrial, \
|
||
PipelineVersionResponse, PipelinesResponse, PipelinesLinks, Link, PipelineLinks, EmbeddedPipeline, \
|
||
PipelinesEmbedded, PipelineResponse, TrialsResponse, TrialsLinks, TrialsEmbedded, CreateStatusConditionRequest
|
||
from exp_pipeline.storage.db import init_db, get_db
|
||
from kube_config import async_unip_load_kube_config
|
||
|
||
_http_client: AsyncClient | None = None
|
||
_k8s_api_client: ApiClient | None = None
|
||
|
||
configure_logging()
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@asynccontextmanager
|
||
async def _lifespan(app: FastAPI):
|
||
global _http_client, _k8s_api_client
|
||
await async_unip_load_kube_config(logger)
|
||
_http_client = AsyncClient(timeout=30)
|
||
_k8s_api_client = ApiClient()
|
||
yield
|
||
await _http_client.aclose()
|
||
await _k8s_api_client.close()
|
||
|
||
|
||
app = FastAPI(lifespan=_lifespan)
|
||
|
||
|
||
# альтернативный вариант, per-request:
|
||
# async def _get_http_client():
|
||
# with AsyncClient() as http_client:
|
||
# yield http_client
|
||
#
|
||
# def path_op(http_client: AsyncClient = Depends(_get_http_client))
|
||
# ...
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def init_db_on_start():
|
||
await init_db()
|
||
|
||
|
||
openapi_spec_cache = OpenAPISpecCache.get_spec_cache()
|
||
|
||
|
||
def _validate_create_trial_request(create_trial_request: CreateTrialRequest, openapi_spec: dict):
|
||
components = openapi_spec['components']
|
||
schemas = components['schemas']
|
||
openapi_spec = schemas['CreateTrialRequest']
|
||
|
||
validator = Draft202012Validator(openapi_spec)
|
||
request_obj = create_trial_request.model_dump(exclude_none=True)
|
||
|
||
if not validator.is_valid(request_obj):
|
||
errors = sorted(validator.iter_errors(request_obj), key=lambda e: e.path)
|
||
errors_strs = []
|
||
errors_model = []
|
||
for error in errors:
|
||
s, m = list(error.schema_path), error.message
|
||
error_struct = {'path': s, 'message': m}
|
||
errors_strs.append('{0}, {1}'.format(s, m))
|
||
nested_errors = []
|
||
for sub_error in sorted(error.context, key=lambda e: e.schema_path):
|
||
s, m = list(sub_error.schema_path), sub_error.message
|
||
errors_strs.append('{0}, {1}'.format(s, m))
|
||
nested_errors.append({'path': s, 'message': m})
|
||
error_struct['sub'] = nested_errors
|
||
errors_model.append(error_struct)
|
||
error_msg = '\n'.join(errors_strs)
|
||
logger.info(f'Validation error: {error_msg}')
|
||
raise HTTPException(status_code=400, detail=errors_model)
|
||
|
||
|
||
async def start_trial_op(app_name: str,
|
||
pipeline_name: str,
|
||
create_trial_request: CreateTrialRequest,
|
||
db: AsyncSession,
|
||
user_id: str):
|
||
pipeline_resource = await get_pipeline_resource(api_client=_k8s_api_client,
|
||
app_name=app_name,
|
||
pipeline_name=pipeline_name)
|
||
api_resource = await get_api_resource(api_client=_k8s_api_client,
|
||
app_name=app_name,
|
||
pipeline_name=pipeline_name)
|
||
|
||
openapi_spec = await openapi_spec_cache.get_pipeline_api_spec(k8s_api_client=_k8s_api_client,
|
||
app_name=app_name,
|
||
api_resource=api_resource)
|
||
|
||
_validate_create_trial_request(create_trial_request, openapi_spec)
|
||
|
||
trial = await start_trial(db=db,
|
||
http_client=_http_client,
|
||
api_client=_k8s_api_client,
|
||
app_name=app_name,
|
||
user_id=user_id,
|
||
pipeline_name=pipeline_name,
|
||
api_resource=api_resource,
|
||
pipeline_resource=pipeline_resource,
|
||
trial_inputs=create_trial_request.inputs,
|
||
trial_output_vars=create_trial_request.output_vars)
|
||
|
||
return trial
|
||
|
||
|
||
@app.post("/{app_name}/pipelines/{pipeline_name}/trials", response_model=PipelineTrial)
|
||
async def post_trial_op(app_name: str,
|
||
pipeline_name: str,
|
||
create_trial_request: CreateTrialRequest,
|
||
db: AsyncSession = Depends(get_db),
|
||
user_id: str = Depends(get_user_id)):
|
||
trial = await start_trial_op(app_name=app_name,
|
||
pipeline_name=pipeline_name,
|
||
create_trial_request=create_trial_request,
|
||
db=db,
|
||
user_id=user_id)
|
||
|
||
return trial
|
||
|
||
|
||
@app.post("/{app_name}/pipelines/{pipeline_name}/trials/continue",
|
||
response_model=PipelineTrial,
|
||
dependencies=[Depends(allowed_to_platform_user)])
|
||
async def post_next_trial_op(app_name: str,
|
||
pipeline_name: str,
|
||
next_pipeline_name: str,
|
||
tracking_id: str,
|
||
create_trial_request: CreateTrialRequest,
|
||
db: AsyncSession = Depends(get_db)):
|
||
user_id = await get_trial_user_id(db=db,
|
||
app_name=app_name,
|
||
pipeline_name=None,
|
||
tracking_id=tracking_id)
|
||
|
||
next_trial = await start_trial_op(app_name=app_name,
|
||
pipeline_name=next_pipeline_name,
|
||
create_trial_request=create_trial_request,
|
||
db=db,
|
||
user_id=user_id)
|
||
|
||
await set_next_started(db=db,
|
||
tracking_id=tracking_id,
|
||
app_name=app_name,
|
||
pipeline_name=pipeline_name,
|
||
next_tracking_id=next_trial.tracking_id)
|
||
|
||
return next_trial
|
||
|
||
|
||
@app.get("/{app_name}/pipelines/{pipeline_name}/check",
|
||
dependencies=[Depends(allowed_to_platform_user)])
|
||
async def check_op(app_name: str,
|
||
pipeline_name: str):
|
||
return Response()
|
||
|
||
|
||
@app.get("/{app_name}/trials/{tracking_id}",
|
||
response_model=PipelineTrial)
|
||
async def get_trial_op(app_name: str,
|
||
tracking_id: str,
|
||
db: AsyncSession = Depends(get_db),
|
||
user_id: str = Depends(get_user_id)):
|
||
trial = await get_trial(db=db,
|
||
app_name=app_name,
|
||
pipeline_name=None,
|
||
tracking_id=tracking_id,
|
||
user_id=user_id)
|
||
|
||
return trial
|
||
|
||
|
||
@app.post("/{app_name}/pipelines/{pipeline_name}/trials/{tracking_id}/status/conditions",
|
||
dependencies=[Depends(allowed_to_platform_user)])
|
||
async def post_trial_status_condition_op(response: Response,
|
||
app_name: str,
|
||
pipeline_name: str,
|
||
tracking_id: str,
|
||
create_condition_request: CreateStatusConditionRequest,
|
||
db: AsyncSession = Depends(get_db)):
|
||
modified = await add_condition_to_status(db=db,
|
||
tracking_id=tracking_id,
|
||
type_=create_condition_request.type,
|
||
message=create_condition_request.message,
|
||
reason=create_condition_request.reason,
|
||
transition_time=create_condition_request.transition_time,
|
||
stage=create_condition_request.stage,
|
||
app_name=app_name,
|
||
pipeline_name=pipeline_name)
|
||
response.status_code = 200 if modified else 204
|
||
return
|
||
|
||
|
||
@app.get('/{app_name}/pipelines/{pipeline_name}/version',
|
||
response_model=PipelineVersionResponse,
|
||
dependencies=[Depends(get_user_id)])
|
||
async def get_pipeline_version_op(app_name: str,
|
||
pipeline_name: str):
|
||
api_resource = await get_api_resource(api_client=_k8s_api_client,
|
||
app_name=app_name,
|
||
pipeline_name=pipeline_name)
|
||
|
||
try:
|
||
openapi_spec = await openapi_spec_cache.get_pipeline_api_spec(k8s_api_client=_k8s_api_client,
|
||
app_name=app_name,
|
||
api_resource=api_resource)
|
||
except Exception as exc:
|
||
logger.exception('OpenAPI specification generation error, '
|
||
'specification not available', exc_info=exc)
|
||
openapi_spec = None
|
||
|
||
return PipelineVersionResponse(openapi_spec=openapi_spec,
|
||
license=None)
|
||
|
||
|
||
def _get_list_request_self_link(cursor, limit, uri):
|
||
query_params = []
|
||
if cursor:
|
||
query_params.append(f'cursor={cursor}')
|
||
if limit:
|
||
query_params.append(f'limit={limit}')
|
||
self_href = urllib.parse.urljoin(f'https://{UNIP_DOMAIN}', uri)
|
||
if query_params:
|
||
query = '&'.join(query_params)
|
||
self_href += f'?{query}'
|
||
|
||
self_link = Link(href=self_href)
|
||
return self_link
|
||
|
||
|
||
def _get_list_request_next_link(new_cursor, limit_set, uri):
|
||
if new_cursor:
|
||
query_params = [f'cursor={new_cursor}', f'limit={limit_set}']
|
||
query = '&'.join(query_params)
|
||
cursor_href = urllib.parse.urljoin(f'https://{UNIP_DOMAIN}', f'{uri}?{query}')
|
||
cursor_link = Link(href=cursor_href)
|
||
else:
|
||
cursor_link = None
|
||
return cursor_link
|
||
|
||
|
||
@app.get('/{app_name}/pipelines/{pipeline_name}/trials',
|
||
response_model=TrialsResponse,
|
||
dependencies=[Depends(get_user_id)])
|
||
async def get_pipeline_trials_op(app_name: str,
|
||
pipeline_name: str,
|
||
db: AsyncSession = Depends(get_db),
|
||
user_id: str = Depends(get_user_id),
|
||
cursor: int | None = None,
|
||
limit: int | None = None):
|
||
limit_set = limit or 10
|
||
if limit_set > 30:
|
||
limit_set = 30
|
||
if limit_set < 0:
|
||
limit_set = 10
|
||
trials, new_cursor, counts = await list_trials(db=db,
|
||
app_name=app_name,
|
||
pipeline_name=pipeline_name,
|
||
user_id=user_id,
|
||
cursor=cursor,
|
||
limit=limit_set)
|
||
|
||
total_item_count, remaining_item_count = counts
|
||
|
||
uri = f'{app_name}/pipelines/{pipeline_name}/trials'
|
||
self_link = _get_list_request_self_link(cursor, limit, uri)
|
||
cursor_link = _get_list_request_next_link(new_cursor, limit_set, uri)
|
||
|
||
links = TrialsLinks(self=self_link, next=cursor_link)
|
||
|
||
te = TrialsEmbedded(trials=trials)
|
||
|
||
resp = TrialsResponse(_links=links,
|
||
total_item_count=total_item_count,
|
||
remaining_item_count=remaining_item_count,
|
||
_embedded=te)
|
||
|
||
return resp
|
||
|
||
|
||
@app.get('/{app_name}/trials',
|
||
response_model=TrialsResponse,
|
||
dependencies=[Depends(get_user_id)])
|
||
async def get_trials_op(app_name: str,
|
||
db: AsyncSession = Depends(get_db),
|
||
user_id: str = Depends(get_user_id),
|
||
cursor: int | None = None,
|
||
limit: int | None = None):
|
||
limit_set = limit or 10
|
||
if limit_set > 30:
|
||
limit_set = 30
|
||
if limit_set < 0:
|
||
limit_set = 10
|
||
trials, new_cursor, counts = await list_trials(db=db,
|
||
app_name=app_name,
|
||
user_id=user_id,
|
||
cursor=cursor,
|
||
limit=limit_set)
|
||
|
||
total_item_count, remaining_item_count = counts
|
||
|
||
uri = f'{app_name}/trials'
|
||
self_link = _get_list_request_self_link(cursor, limit, uri)
|
||
cursor_link = _get_list_request_next_link(new_cursor, limit_set, uri)
|
||
|
||
links = TrialsLinks(self=self_link, next=cursor_link)
|
||
|
||
te = TrialsEmbedded(trials=trials)
|
||
|
||
resp = TrialsResponse(_links=links,
|
||
total_item_count=total_item_count,
|
||
remaining_item_count=remaining_item_count,
|
||
_embedded=te)
|
||
|
||
return resp
|
||
|
||
|
||
def _get_pipeline_links(app_name, pipeline_name):
|
||
self_href = urllib.parse.urljoin(f'https://{UNIP_DOMAIN}', f'{app_name}/pipelines/{pipeline_name}')
|
||
self_link = Link(href=self_href)
|
||
trials_href = f'{self_href}/trials'
|
||
trials_link = Link(href=trials_href)
|
||
version_href = f'{self_href}/version'
|
||
version_link = Link(href=version_href)
|
||
pl = PipelineLinks(self=self_link, trials=trials_link, version=version_link)
|
||
return pl
|
||
|
||
|
||
@app.get('/{app_name}/pipelines',
|
||
response_model=PipelinesResponse,
|
||
dependencies=[Depends(get_user_id)])
|
||
async def get_pipelines_op(response: Response,
|
||
app_name: str,
|
||
cursor: str | None = None,
|
||
limit: int | None = None):
|
||
limit_set = limit or 10
|
||
pipelines, new_cursor, remaining_item_count, expired = await list_pipeline_resources(api_client=_k8s_api_client,
|
||
app_name=app_name,
|
||
cursor=cursor,
|
||
limit=limit_set)
|
||
|
||
uri = f'{app_name}/pipelines'
|
||
self_link = _get_list_request_self_link(cursor, limit, uri)
|
||
cursor_link = _get_list_request_next_link(new_cursor, limit_set, uri)
|
||
|
||
links = PipelinesLinks(self=self_link, next=cursor_link)
|
||
|
||
embedded_pipelines = []
|
||
|
||
for p in pipelines:
|
||
pipeline_name = p['metadata']['name']
|
||
pl = _get_pipeline_links(app_name, pipeline_name)
|
||
ep = EmbeddedPipeline(name=pipeline_name, _links=pl)
|
||
embedded_pipelines.append(ep)
|
||
|
||
pe = PipelinesEmbedded(pipelines=embedded_pipelines)
|
||
|
||
resp = PipelinesResponse(_links=links, remaining_item_count=remaining_item_count, _embedded=pe)
|
||
if expired:
|
||
response.status_code = 410
|
||
return resp
|
||
|
||
|
||
@app.get('/{app_name}/pipelines/{pipeline_name}',
|
||
response_model=PipelineResponse,
|
||
dependencies=[Depends(get_user_id)])
|
||
async def get_pipeline_op(app_name: str,
|
||
pipeline_name: str):
|
||
pipeline_res = await get_pipeline_resource(api_client=_k8s_api_client,
|
||
app_name=app_name,
|
||
pipeline_name=pipeline_name)
|
||
pipeline_name = pipeline_res['metadata']['name']
|
||
pipeline_links = _get_pipeline_links(app_name, pipeline_name)
|
||
pipeline_definition = pipeline_res['spec']
|
||
|
||
resp = PipelineResponse(name=pipeline_name, _links=pipeline_links, definition=pipeline_definition)
|
||
|
||
return resp
|