unip-controller/controller/src/exp_pipeline/var.py
2025-04-15 20:56:15 +03:00

508 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ============================================================
# Система: Единая библиотека, Центр ИИ НИУ ВШЭ
# Модуль: ExperimentPipeline
# Авторы: Полежаев В.А., Хританков А.С.
# Дата создания: 2024 г.
# ============================================================
import json
import logging
import posixpath
import urllib.parse
import uuid
from functools import reduce
from operator import mul
from fastapi import HTTPException
from httpx import AsyncClient
from exp_pipeline.box import get_path_inside_box, get_default_box
from exp_pipeline.config import UNIP_FILES_API
from exp_pipeline.schema import ConnectedBox, TrialInput, TrialOutputVar, TrialStartVar, StartVarBoxPath, \
DATATYPE_TO_CONTENT_TYPE
from location import parse_location
UNIP_FILES_USER_PASSWORD = '**'
logger = logging.getLogger(__name__)
def _reshape(lst: list, shape: list[int]):
if len(shape) == 1:
return lst
n = reduce(mul, shape[1:])
return [_reshape(lst[i * n:(i + 1) * n], shape[1:]) for i in range(len(lst) // n)]
async def _save_data_to_file(http_client: AsyncClient,
app_name: str,
user_id: str,
s3_box_name: str,
input_data: object,
shape: list[int],
content_type: str) -> str:
data = input_data
if isinstance(data, list):
data = _reshape(data, shape)
# в текущей реализации все сохраняется в текст в формате JSON
str_data = json.dumps(data)
content_type = 'application/json'
# создание файла в заданной файловой группе,
# если файловой группы нет, то files API создаст ее
url_prefix = f'{app_name}/files/{s3_box_name}/'
create_file_url = urllib.parse.urljoin(UNIP_FILES_API,
url_prefix)
res = await http_client.post(create_file_url, auth=(user_id, UNIP_FILES_USER_PASSWORD))
if res.status_code != 201:
logger.error(f'Files API request error (create file); PUT {create_file_url}; '
f'Returned code: {res.status_code}')
logger.error(f'Response content {res.content}')
raise HTTPException(status_code=503, detail="Service unavailable")
res_data = res.json()
put_url = res_data['presigned_put_url']
new_file_name: str = res_data['name']
if new_file_name.startswith(url_prefix):
new_file_name = new_file_name[len(url_prefix):]
path = new_file_name
binary_data = str_data.encode('utf-8')
data_len = len(str_data)
headers = {'Content-Type': content_type, 'Content-Length': str(data_len)}
res = await http_client.put(put_url, content=binary_data, headers=headers,
auth=(user_id, UNIP_FILES_USER_PASSWORD))
if res.status_code != 201 and res.status_code != 200:
logger.error(f'S3 request error (put file content); PUT {put_url}; '
f'Returned code: {res.status_code}')
logger.error(f'Response content {res.content}')
raise HTTPException(status_code=503, detail="Service unavailable")
# await asyncio.sleep(5)
return path
def _get_var_box(user_id: str,
pipeline_name: str,
connected_boxes: list[ConnectedBox],
box_section: dict | None
) -> ConnectedBox:
connected_boxes_by_name = {cb.name: cb for cb in connected_boxes}
if box_section and 'name' in box_section:
name = box_section['name']
if name not in connected_boxes_by_name:
logger.error(f'Connected box {name} not specified in pipeline {pipeline_name}')
raise HTTPException(status_code=503, detail="Service unavailable")
return connected_boxes_by_name[name]
return get_default_box(user_id, pipeline_name, connected_boxes)
async def _get_input_var_location(http_client: AsyncClient,
app_name: str,
user_id: str,
var_box: ConnectedBox,
passed_var_input: TrialInput) -> tuple[str, str]:
if passed_var_input.datatype in {"FILE", "WEBSITE"}:
if not isinstance(passed_var_input.data, str):
raise HTTPException(status_code=400, detail=f"Passed input var {passed_var_input.name} "
f"datatype is {passed_var_input.datatype},"
f" but data is not instance of 'str'")
if passed_var_input.datatype == 'FILE' and not passed_var_input.content_type:
raise HTTPException(status_code=400, detail=f"Passed input var {passed_var_input.name} "
f"datatype is FILE, content_type is required")
if passed_var_input.datatype == 'WEBSITE':
passed_var_input.content_type = DATATYPE_TO_CONTENT_TYPE['WEBSITE']
file_group_name, file_name = parse_location(passed_var_input.data)
if not file_group_name:
raise HTTPException(status_code=400,
detail=f"At least file group must be passed for input of var {passed_var_input.name}, "
f"given: {passed_var_input.data}")
if not file_name:
# parse_location возвращает значения, которые не начинаются и не заканчиваются символом '/';
# если передано имя файловой группы, то путь должен оканчиваться '/' -
# монтируется папка;
path = file_group_name + '/'
else:
# если передано имя файловой группы и имя файла, то объединяем их -
# монтируется файл;
path = posixpath.join(file_group_name, file_name)
return path, passed_var_input.content_type
else:
if not passed_var_input.data:
raise HTTPException(status_code=400, detail=f"Passed input var {passed_var_input.name} "
f"datatype is {passed_var_input.datatype}, but data not passed")
# создается временный файл во временной файловой группе, путь к нему возвращается -
# монтируется файл;
path = await _save_data_to_file(http_client=http_client,
app_name=app_name,
user_id=user_id,
s3_box_name=var_box.s3_box_name,
input_data=passed_var_input.data,
shape=passed_var_input.shape,
content_type=passed_var_input.content_type)
return path, 'application/json'
def _get_input_var_datatype(passed_var_input: TrialInput):
if passed_var_input.datatype == 'WEBSITE':
return 'WEBSITE'
return 'FILE'
def _get_input_var_in_box_path(user_id: str,
input_location: str | None) -> str:
in_box_path = get_path_inside_box(user_id, input_location)
return in_box_path
def _get_var_mount_path(var_section: dict):
return var_section.get('path', None)
async def _construct_input_var(http_client: AsyncClient,
app_name: str,
user_id: str,
pipeline_name: str,
connected_boxes: list[ConnectedBox],
var_section: dict,
var_input: TrialInput) -> TrialStartVar:
var_name = var_section['name']
box_section = None
if 'mountFrom' in var_section:
mount_from_section = var_section['mountFrom']
if 'box' in mount_from_section:
box_section = mount_from_section['box']
var_box = _get_var_box(user_id=user_id,
pipeline_name=pipeline_name,
connected_boxes=connected_boxes,
box_section=box_section)
if var_box.dataset_ref_box_name:
logger.info(f'Input var {var_name} connected box {var_box.name} cannot be dataset')
raise HTTPException(status_code=503, detail="Service unavailable")
input_location, content_type = await _get_input_var_location(http_client=http_client,
app_name=app_name,
user_id=user_id,
var_box=var_box,
passed_var_input=var_input)
datatype = _get_input_var_datatype(passed_var_input=var_input)
logger.info(f'Input var: {var_name}; input location: {input_location}')
in_box_path = _get_input_var_in_box_path(user_id=user_id, input_location=input_location)
mount_path = _get_var_mount_path(var_section=var_section)
var_box_path = StartVarBoxPath(in_box_path=in_box_path, box=var_box, mount_path=mount_path)
sv = TrialStartVar(name=var_name, box_path=var_box_path, files_location=input_location,
datatype=datatype, content_type=content_type, shape=var_input.shape)
return sv
async def _create_file_group(http_client: AsyncClient,
app_name: str,
user_id: str,
box_name: str,
location: str):
url_prefix = f'/{app_name}/files/{box_name}'
url_prefix = posixpath.join(url_prefix, location)
create_fg_url = urllib.parse.urljoin(UNIP_FILES_API, url_prefix)
res = await http_client.put(create_fg_url, auth=(user_id, UNIP_FILES_USER_PASSWORD))
if res.status_code != 200:
logger.error(f'Files API request error (create file-group); PUT {create_fg_url}; '
f'Returned code: {res.status_code}')
logger.error(f'Response content {res.content}')
raise HTTPException(status_code=503, detail="Service unavailable")
async def _get_output_var_location(http_client: AsyncClient,
app_name: str,
user_id: str,
var_box: ConnectedBox,
passed_var: TrialOutputVar | None):
passed = passed_var.data if passed_var else None
file_group_name, file_name = parse_location(passed)
if file_name:
# выходные переменные всегда должны быть файловыми группами
raise HTTPException(status_code=400,
detail=f'File name must NOT be specified for output var {passed_var.name}; '
f'given location {passed}')
if passed_var.datatype == 'WEBSITE':
passed_var.content_type = DATATYPE_TO_CONTENT_TYPE['WEBSITE']
if not file_group_name:
# если передана пустая локация, то генерируется уникальное имя файловой группы,
# Files API не поддерживает генерацию имен при создании пустых файловых групп;
path = str(uuid.uuid4()) + '/'
else:
# если передана файловая группа, то нужно дописать '/',
# посколькуо parse_location всегда возвращает значения, которые не начинаются
# и не оканчиваются '/'
path = file_group_name + '/'
await _create_file_group(http_client, app_name, user_id, var_box.s3_box_name, path)
return path, passed_var.content_type
def _get_output_var_in_box_path(
user_id: str,
output_location: str | None
) -> str:
in_box_path = get_path_inside_box(user_id, output_location)
return in_box_path
def _get_output_var_datatype(passed_var: TrialOutputVar):
if not passed_var.datatype:
return 'FILE'
return passed_var.datatype
async def _construct_output_var(http_client: AsyncClient,
app_name: str,
user_id: str,
pipeline_name: str,
connected_boxes: list[ConnectedBox],
var_section: dict,
trial_output_var: TrialOutputVar) -> TrialStartVar:
var_name = var_section['name']
box_section = None
if 'mountFrom' in var_section:
mount_from_section = var_section['mountFrom']
if 'box' in mount_from_section:
box_section = mount_from_section['box']
var_box = _get_var_box(user_id=user_id,
pipeline_name=pipeline_name,
connected_boxes=connected_boxes,
box_section=box_section)
if var_box.dataset_ref_box_name:
logger.info(f'Output var {var_name} connected box {var_box.name} cannot be dataset')
raise HTTPException(status_code=503, detail="Service unavailable")
output_location, content_type = await _get_output_var_location(http_client=http_client,
app_name=app_name,
user_id=user_id,
var_box=var_box,
passed_var=trial_output_var)
datatype = _get_output_var_datatype(passed_var=trial_output_var)
logger.info(f'Output var: {var_name}; output location: {output_location}')
in_box_path = _get_output_var_in_box_path(user_id=user_id, output_location=output_location)
mount_path = _get_var_mount_path(var_section=var_section)
var_box_path = StartVarBoxPath(in_box_path=in_box_path, box=var_box, mount_path=mount_path)
sv = TrialStartVar(name=var_name, box_path=var_box_path, files_location=output_location,
datatype=datatype, content_type=content_type)
return sv
async def _get_internal_var_location(http_client: AsyncClient,
app_name: str,
user_id: str,
var_box: ConnectedBox,
var_section: dict):
var_name = var_section['name']
if var_box.dataset_ref_box_name:
logger.info(f'For internal var {var_name} dataset is used')
return None
box_section = None
if 'mountFrom' in var_section:
mount_from_section = var_section['mountFrom']
if 'box' in mount_from_section:
box_section = mount_from_section['box']
# если задан boxPath, то локация не создается
if box_section and ('boxPath' in box_section):
logger.info(f'For internal var {var_name} boxPath is used')
return None
# иначе нужно создать пустую локацию; генерируется уникальное имя файловой группы,
# Files API не поддерживает генерацию имен при создании пустых файловых групп;
path = str(uuid.uuid4()) + '/'
await _create_file_group(http_client, app_name, user_id, var_box.s3_box_name, path)
logger.info(f'Internal var: {var_name}; internal location: {path}')
return path
def _get_internal_var_in_box_path(user_id: str,
var_box: ConnectedBox,
var_section: dict,
internal_location: str | None):
var_name = var_section['name']
box_section = None
if 'mountFrom' in var_section:
mount_from_section = var_section['mountFrom']
if 'box' in mount_from_section:
box_section = mount_from_section['box']
static_box_path = var_box.dataset_ref_box_name or (box_section and ('boxPath' in box_section))
if static_box_path and internal_location:
logger.error(f'Bug. Cannot use internal var {var_name} location'
f' if boxPath or mountDataset is defined in pipeline specification')
raise HTTPException(status_code=500,
detail=f'Internal server error')
if not static_box_path and not internal_location:
logger.error(f'Bug. Internal var location not passed and bot boxPath, mountDataset not specified '
f'for internal var {var_name}.')
raise HTTPException(status_code=500,
detail=f'Internal server error')
in_box_path = None
# либо boxPath и mountDataset не указаны в спецификации, и создается пустая файловая группа,
# internal_location содержит ее значение;
if internal_location:
in_box_path = get_path_inside_box(user_id, internal_location)
# либо boxPath или mountDataset указаны в спецификации;
if static_box_path:
if var_box.dataset_ref_box_name:
in_box_path = ""
elif box_section and ('boxPath' in box_section):
in_box_path = box_section['boxPath']
if in_box_path is not None and in_box_path.startswith('/'):
in_box_path = in_box_path.lstrip('/')
return in_box_path
def _get_internal_var_datatype():
return 'FILE'
async def _construct_internal_var(http_client: AsyncClient,
app_name: str,
user_id: str,
pipeline_name: str,
connected_boxes: list[ConnectedBox],
var_section: dict) -> TrialStartVar:
var_name = var_section['name']
box_section = None
if 'mountFrom' in var_section:
mount_from_section = var_section['mountFrom']
if 'box' in mount_from_section:
box_section = mount_from_section['box']
var_box = _get_var_box(user_id=user_id,
pipeline_name=pipeline_name,
connected_boxes=connected_boxes,
box_section=box_section)
internal_location = await _get_internal_var_location(http_client=http_client,
app_name=app_name,
user_id=user_id,
var_box=var_box,
var_section=var_section)
datatype = _get_internal_var_datatype()
in_box_path = _get_internal_var_in_box_path(user_id=user_id,
var_box=var_box,
var_section=var_section,
internal_location=internal_location)
mount_path = _get_var_mount_path(var_section=var_section)
var_box_path = StartVarBoxPath(in_box_path=in_box_path, box=var_box, mount_path=mount_path)
sv = TrialStartVar(name=var_name, box_path=var_box_path, files_location=internal_location,
datatype=datatype)
return sv
def _determine_start_vars_spec(pipeline_api_spec, pipeline_spec):
result = input_vars_specs, output_vars_specs, internal_vars_specs = {}, {}, {}
if 'apiSpec' not in pipeline_api_spec:
return result
pipeline_vars = pipeline_spec['vars']
api_spec = pipeline_api_spec['apiSpec']
inputs_names = {i['name'] for i in api_spec['inputs']} if 'inputs' in api_spec else set()
input_vars_specs.update({v['name']: v for v in pipeline_vars if v['name'] in inputs_names})
outputs_names = {o['name'] for o in api_spec['outputs']} if 'outputs' in api_spec else set()
output_vars_specs.update({v['name']: v for v in pipeline_vars if v['name'] in outputs_names})
internal_vars_specs.update({v['name']: v for v in pipeline_vars if v['name'] not in (inputs_names | outputs_names)})
return result
async def construct_start_vars(http_client: AsyncClient,
app_name: str,
user_id: str,
pipeline_name: str,
pipeline_api_spec: dict,
pipeline_spec: dict,
connected_boxes: list[ConnectedBox],
passed_trial_inputs: list[TrialInput],
passed_trial_output_vars: list[TrialOutputVar]) \
-> tuple[list[TrialStartVar], list[TrialStartVar], list[TrialStartVar]]:
input_vars_specs, output_vars_specs, internal_vars_specs = _determine_start_vars_spec(pipeline_api_spec,
pipeline_spec)
input_vars = []
# входные данные прошли валидацию, поэтому считаются корректными;
# поэтому для каждого переданного входа отыскивается спецификация в ресурсе пайплайна;
for i in passed_trial_inputs:
# исключение - когда за пайплайном следует другой пайплайн, и переданные данные
# относятся к нему;
if i.name not in input_vars_specs:
continue
input_var_spec = input_vars_specs[i.name]
input_var = await _construct_input_var(http_client=http_client,
app_name=app_name,
user_id=user_id,
pipeline_name=pipeline_name,
connected_boxes=connected_boxes,
var_section=input_var_spec,
var_input=i)
input_vars.append(input_var)
output_vars = []
# аналогично
for o in passed_trial_output_vars:
if o.name not in output_vars_specs:
continue
output_var_spec = output_vars_specs[o.name]
output_var = await _construct_output_var(http_client=http_client,
app_name=app_name,
user_id=user_id,
pipeline_name=pipeline_name,
connected_boxes=connected_boxes,
var_section=output_var_spec,
trial_output_var=o)
output_vars.append(output_var)
# а в данном случае берутся все внутренние переменные,
# которые специфицированы в пайплайне (входные значения для них не передаются)
internal_vars = [await _construct_internal_var(http_client=http_client,
app_name=app_name,
user_id=user_id,
pipeline_name=pipeline_name,
connected_boxes=connected_boxes,
var_section=v)
for v in internal_vars_specs.values()]
return input_vars, output_vars, internal_vars