Source code for domino_data.training_sets.client

"""Domino TrainingSet client library."""


from typing import List, Mapping, Optional

import json
import os
import re
import shutil
from stat import S_IRGRP, S_IROTH, S_IRUSR, S_IWUSR, S_IXGRP, S_IXOTH, S_IXUSR

import pandas as pd

from training_set_api_client.api.default import (
    delete_training_set_name,
    delete_training_set_name_number,
    get_training_set_name,
    get_training_set_name_number,
    post_find,
    post_training_set_name,
    post_version_find,
    put_training_set_name,
    put_training_set_name_number,
)
from training_set_api_client.models import (
    CreateTrainingSetVersionRequest,
    CreateTrainingSetVersionRequestMeta,
    MonitoringMeta,
    TrainingSet,
    TrainingSetFilter,
    TrainingSetFilterMeta,
    TrainingSetVersion,
    TrainingSetVersionFilter,
    TrainingSetVersionFilterMeta,
    TrainingSetVersionFilterTrainingSetMeta,
    UpdateTrainingSetRequest,
    UpdateTrainingSetRequestMeta,
    UpdateTrainingSetVersionRequest,
    UpdateTrainingSetVersionRequestMeta,
)
from training_set_api_client.types import Response

from ..auth import AuthenticatedClient
from ..training_sets import model

_trainingset_name_pat = re.compile("^[-A-Za-z0-9_]+$")


[docs] class ServerException(Exception): """This exception is raised when the TrainingSet server rejects a request.""" def __init__(self, message: str, server_msg: str): self.message = message self.server_msg = server_msg
[docs] class SchemaMismatchException(Exception): """This exception is raised when the TrainingSet data columns do not match the metadata."""
[docs] def get_training_set(name: str) -> model.TrainingSet: """Get a TrainingSet by name. Args: name: Name of the training set. Returns: The TrainingSet, if found. """ _validate_trainingset_name(name) response = get_training_set_name.sync_detailed( client=_get_client(), training_set_name=name, ) if response.status_code != 200: _raise_response_exn(response, "could not get TrainingSet") return _to_TrainingSet(response.parsed)
[docs] def list_training_sets( meta: Optional[Mapping[str, str]] = None, asc: bool = True, offset: int = 0, limit: int = 10000, ) -> List[model.TrainingSet]: """Query training sets. Args: meta: Metadata key-value pairs to match. asc: Sort order by creation time, 1 for ascending -1 for descending. offset: Offset limit: Limit Returns: A list of matching TrainingSets. """ if meta is None: meta = {} project_id = _get_project_id() response = post_find.sync_detailed( client=_get_client(), json_body=TrainingSetFilter( project_id=project_id, meta=TrainingSetFilterMeta.from_dict(meta), ), offset=offset, limit=limit, asc=asc, ) if response.status_code != 200: _raise_response_exn(response, "could not list TrainingSets") return [_to_TrainingSet(ts) for ts in response.parsed]
[docs] def update_training_set( updated: model.TrainingSet, ) -> model.TrainingSet: """Update a TrainingSet. Args: updated: Updated TrainingSet. Returns: The updated TrainingSet from the server. """ _validate_trainingset_name(updated.name) response = put_training_set_name.sync_detailed( training_set_name=updated.name, client=_get_client(), json_body=UpdateTrainingSetRequest( meta=UpdateTrainingSetRequestMeta.from_dict(updated.meta), description=updated.description, ), ) if response.status_code != 200: _raise_response_exn(response, "could not update TrainingSets") return _to_TrainingSet(response.parsed)
[docs] def delete_training_set(name: str) -> bool: """Delete a TrainingSet. **Note:** This deletes the TrainingSet only if it has no versions. Args: name: Name of the TrainingSet. Returns: True if TrainingSet was deleted. """ _validate_trainingset_name(name) response = delete_training_set_name.sync_detailed(training_set_name=name, client=_get_client()) if response.status_code != 200: _raise_response_exn(response, "could not delete TrainingSet") return True
[docs] def create_training_set_version( training_set_name: str, df: pd.DataFrame, description: Optional[str] = None, key_columns: Optional[List[str]] = None, target_columns: Optional[List[str]] = None, exclude_columns: Optional[List[str]] = None, monitoring_meta: Optional[model.MonitoringMeta] = None, meta: Optional[Mapping[str, str]] = None, **kwargs, ) -> model.TrainingSetVersion: """Create a TrainingSetVersion. Args: training_set_name: Name of the TrainingSet this version belongs to. ``training_set_name`` must be a string containing only alphanumeric characters in the basic Latin alphabet including dash and underscore: `[-A-Za-z_]`. df: A DataFrame holding the data. description: Description of this version. key_columns: Names of columns that represent IDs for retrieving features. target_columns: Target variables for prediction. exclude_columns: Columns to exclude when generating the training DataFrame. monitoring_meta: Monitoring specific metadata. meta: User defined metadata. **kwargs: Arbitrary keyword arguments. Returns: The created TrainingSetVersion """ if key_columns is None: key_columns = [] if target_columns is None: target_columns = [] if exclude_columns is None: exclude_columns = [] if monitoring_meta is None: monitoring_meta = model.MonitoringMeta() if meta is None: meta = {} all_columns = list(df.columns) _validate_trainingset_name(training_set_name) _check_columns( all_columns, key_columns + target_columns + exclude_columns + monitoring_meta.timestamp_columns + monitoring_meta.categorical_columns + monitoring_meta.ordinal_columns, ) project_id = _get_project_id() response = post_training_set_name.sync_detailed( client=_get_client(), training_set_name=training_set_name, json_body=CreateTrainingSetVersionRequest( project_id=project_id, key_columns=key_columns, target_columns=target_columns, exclude_columns=exclude_columns, all_columns=all_columns, monitoring_meta=MonitoringMeta( timestamp_columns=monitoring_meta.timestamp_columns, categorical_columns=monitoring_meta.categorical_columns, ordinal_columns=monitoring_meta.ordinal_columns, ), meta=CreateTrainingSetVersionRequestMeta.from_dict(meta), description=description, ), ) if response.status_code != 200: _raise_response_exn(response, "could not create Training Set version") tsv = _to_TrainingSetVersion(response.parsed) os.makedirs(tsv.absolute_container_path) df.to_parquet(os.path.join(tsv.absolute_container_path, "data.parquet")) os.chmod(tsv.absolute_container_path, S_IRUSR | S_IRGRP | S_IROTH | S_IXUSR | S_IXGRP | S_IXOTH) tsv.pending = False return update_training_set_version(tsv)
[docs] def get_training_set_version(training_set_name: str, number: int) -> model.TrainingSetVersion: """Gets a TrainingSetVersion by version number. Args: training_set_name: Name of the TrainingSet. number: Version number. Returns: The requested TrainingSetVersion. """ _validate_trainingset_name(training_set_name) response = get_training_set_name_number.sync_detailed( client=_get_client(), training_set_name=training_set_name, number=number, ) if response.status_code != 200: _raise_response_exn(response, "could not get TrainingSetVersion") return _to_TrainingSetVersion(response.parsed)
[docs] def update_training_set_version(version: model.TrainingSetVersion) -> model.TrainingSetVersion: """Updates this TrainingSetVersion. Args: version: TrainingSetVersion to update. Returns: The updated TrainingSetVersion from the server. """ _validate_trainingset_name(version.training_set_name) response = put_training_set_name_number.sync_detailed( training_set_name=version.training_set_name, number=version.number, client=_get_client(), json_body=UpdateTrainingSetVersionRequest( key_columns=version.key_columns, target_columns=version.target_columns, exclude_columns=version.exclude_columns, monitoring_meta=MonitoringMeta( timestamp_columns=version.monitoring_meta.timestamp_columns, categorical_columns=version.monitoring_meta.categorical_columns, ordinal_columns=version.monitoring_meta.ordinal_columns, ), meta=UpdateTrainingSetVersionRequestMeta.from_dict(version.meta), pending=version.pending, description=version.description, ), ) if response.status_code != 200: _raise_response_exn(response, "could not update TrainingSetVersion") return _to_TrainingSetVersion(response.parsed)
[docs] def delete_training_set_version(training_set_name: str, number: int) -> bool: """Deletes a TrainingSetVersion. Args: training_set_name: Name of the TrainingSet. number: TrainingSetVersion number. Returns: True if TrainingSetVersion was deleted. """ _validate_trainingset_name(training_set_name) tsv = get_training_set_version(training_set_name, number) response = delete_training_set_name_number.sync_detailed( training_set_name=training_set_name, number=number, client=_get_client(), ) if response.status_code != 200: _raise_response_exn(response, "could not delete TrainingSetVersion") stat = os.stat(tsv.absolute_container_path) os.chmod(tsv.absolute_container_path, stat.st_mode | S_IWUSR) shutil.rmtree(tsv.absolute_container_path) return True
[docs] def list_training_set_versions( meta: Optional[Mapping[str, str]] = None, training_set_name: Optional[str] = None, training_set_meta: Optional[Mapping[str, str]] = None, asc: bool = True, offset: int = 0, limit: int = 10000, ) -> List[model.TrainingSetVersion]: """List training sets. Args: meta: Version metadata. training_set_name: Training set name. training_set_meta: Training set meta data. asc: Sort order by creation time, 1 for ascending -1 for descending. offset: Offset. limit: Limit. Returns: A list of matching TrainingSetVersions. """ if meta is None: meta = {} if training_set_meta is None: training_set_meta = {} project_id = _get_project_id() response = post_version_find.sync_detailed( client=_get_client(), json_body=TrainingSetVersionFilter( training_set_meta=TrainingSetVersionFilterTrainingSetMeta.from_dict( training_set_meta, ), meta=TrainingSetVersionFilterMeta.from_dict(meta), project_id=project_id, training_set_name=training_set_name, ), offset=offset, limit=limit, asc=asc, ) if response.status_code != 200: _raise_response_exn(response, "could not find TrainingSetVersion") return [_to_TrainingSetVersion(tsv) for tsv in response.parsed]
def _get_client() -> AuthenticatedClient: domino_host = os.getenv("DOMINO_API_HOST", os.getenv("DOMINO_USER_HOST")) api_key = os.getenv("DOMINO_USER_API_KEY") token_file = os.getenv("DOMINO_TOKEN_FILE") token_url = os.getenv("DOMINO_API_PROXY") return AuthenticatedClient( base_url=f"{domino_host}/trainingset", api_key=api_key, token_file=token_file, token_url=token_url, ) def _to_TrainingSet(ts: TrainingSet) -> model.TrainingSet: return model.TrainingSet( name=ts.name, description=ts.description, meta=ts.meta.to_dict(), project_id=ts.project_id, ) def _to_TrainingSetVersion(tsv: TrainingSetVersion) -> model.TrainingSetVersion: return model.TrainingSetVersion( training_set_name=tsv.training_set_name, number=tsv.number, description=tsv.description, key_columns=tsv.key_columns, target_columns=tsv.target_columns, exclude_columns=tsv.exclude_columns, all_columns=tsv.all_columns, monitoring_meta=model.MonitoringMeta( timestamp_columns=tsv.monitoring_meta.timestamp_columns, categorical_columns=tsv.monitoring_meta.categorical_columns, ordinal_columns=tsv.monitoring_meta.ordinal_columns, ), meta=tsv.meta.to_dict(), path=tsv.path, container_path=tsv.container_path, pending=tsv.pending, ) def _raise_response_exn(response: Response, msg: str): try: response_json = json.loads(response.content.decode("utf8")) server_msg = response_json.get("errors") except Exception: server_msg = None raise ServerException(msg, server_msg) def _check_columns(all_columns: [str], expected_columns: [str]): diff = set(expected_columns) - set(all_columns) if diff: raise SchemaMismatchException(f"DataFrame missing columns: {diff}") def _get_project_id() -> Optional[str]: return os.getenv("DOMINO_PROJECT_ID") def _validate_trainingset_name(name: str): if _trainingset_name_pat.match(name) is None: raise ValueError(f"bad TrainingSet name '{name}'")