Source code for oda_api.api

__author__ = "Andrea Tramacere, Volodymyr Savchenko"
__all__ = ['Request', 'NoTraceBackWithLineNumber',
           'NoTraceBackWithLineNumber', 'RemoteException', 'DispatcherAPI']

import ast
import copy
import gzip
import inspect
import json
import logging
import os
import pathlib
import pickle
import random
import re
import string
import sys
import time
import traceback
import warnings
from collections import OrderedDict
from itertools import cycle
from json.decoder import JSONDecodeError
from typing import Tuple, Union, cast

import numpy as np
import rdflib
import requests
from jsonschema import validate as validate_json

import oda_api.misc_helpers
import oda_api.token
from oda_api.token import TokenLocation

from . import __version__
from . import colors as C
from . import custom_formatters
from .data_products import (ApiCatalog, BinaryData, BinaryProduct, DataProduct,
                            GWContoursDataProduct, NumpyDataProduct,
                            ODAAstropyTable, PictureProduct, TextLikeProduct)

# NOTE gw is optional for now
try:
    from gwpy.spectrogram import Spectrogram
    from gwpy.timeseries.timeseries import TimeSeries
except ModuleNotFoundError:
    pass

logger = logging.getLogger("oda_api.api")
advice_logger = logging.getLogger("oda_api.advice")

[docs] class Request(object): def __init__(self): pass
[docs] class NoTraceBackWithLineNumber(Exception): def __init__(self, msg): try: ln = sys.exc_info()[-1].tb_lineno # pyright: ignore[reportOptionalMemberAccess] except AttributeError: ln = inspect.currentframe().f_back.f_lineno # pyright: ignore[reportOptionalMemberAccess] self.args = "{0.__name__} (line {1}): {2}".format(type(self), ln, msg),
# sys.exit(self) class UserError(Exception): pass
[docs] class RemoteException(NoTraceBackWithLineNumber): def __init__(self, message='Remote analysis exception', debug_message=''): super(RemoteException, self).__init__(message) self.message = message self.debug_message = debug_message def __repr__(self): return f"RemoteException: {self.message}, {self.debug_message}"
class FailedToFindAnyUsefulResults(RemoteException): pass class UnexpectedDispatcherStatusCode(RemoteException): pass class DispatcherNotAvailable(RemoteException): pass class RequestNotUnderstood(Exception): def __init__(self, details_json) -> None: self.details_json = details_json def __repr__(self) -> str: return f"[ RequestNotUnderstood: {self.details_json['error']} ]" def __str__(self) -> str: return repr(self) class Unauthorized(RemoteException): pass class URLRedirected(Exception): pass class DispatcherException(Exception): def __init__(self, response_json) -> None: self.response_json = response_json def __repr__(self) -> str: return f"[ {self.__class__.__name__}: {self.response_json.get('error_message', '[no error message reported]')} ]" exception_by_message = { 'failed: get dataserver products ': FailedToFindAnyUsefulResults } def safe_run(func): def func_wrapper(*args, **kwargs): logger = logging.getLogger("oda_api.api." + func.__name__) self = args[0] # because it really is n_tries_left = self.n_max_tries retry_sleep_s = self.retry_sleep_s t0 = time.time() while True: try: return func(*args, **kwargs) except UserError as e: logger.exception("probably an unfortunate user input: %s", e) raise except (Unauthorized, RequestNotUnderstood, UnexpectedDispatcherStatusCode) as e: logger.exception("something went quite wrong, and we think it's not likely to recover on its own: %s", e) raise except (ConnectionError, requests.exceptions.ConnectionError, requests.exceptions.Timeout, DispatcherNotAvailable) as e: # TODO: these are probably all server or access errors, # TODO: and they may need to be communicated back to server (if possible) message = '' message += '\nunable to complete API call' message += '\nin ' + str(func) + ' called with:' message += '\n... ' + ", ".join([str(arg) for arg in args]) message += '\n... ' + \ ", ".join([k + ": " + str(v) for k, v in kwargs]) message += '\npossible causes:' message += '\n- connection error' message += '\n- error on the remote server' message += '\n exception message: ' message += '\n\n%s\n' % e message += traceback.format_exc() n_tries_left -= 1 if n_tries_left > 0: logger.debug("problem in API call, %i tries left:\n%s\n sleeping %i seconds until retry", n_tries_left, message, retry_sleep_s) logger.warning( "possibly temporary problem in calling server: %s in %.1f seconds, %i tries left, sleeping %i seconds until retry", repr(e), time.time() - t0, n_tries_left, retry_sleep_s) time.sleep(retry_sleep_s) else: raise RemoteException( message=message ) from e return func_wrapper
[docs] class DispatcherAPI: # allowing token discovery by default changes the user interface in some cases, # but in desirable way token_discovery_methods = None use_local_cache = False _known_sites_dict = None @property def known_sites_dict(self): if self._known_sites_dict is None: self._known_sites_dict = {} G = rdflib.Graph() G.parse("https://odahub.io/oda-sites.ttl") for site, url in G.subject_objects(rdflib.URIRef("http://odahub.io/ontology#APIURL")): self._known_sites_dict[site.split("#")[1]] = str(url) # pyright: ignore[reportAttributeAccessIssue] for alias in G.objects(site, rdflib.URIRef("http://www.w3.org/2000/01/rdf-schema#label")): self._known_sites_dict[str(alias)] = str(url) return self._known_sites_dict def setup_loggers(self): self.logger = logger.getChild(self.__class__.__name__.lower()) self.progress_logger = self.logger.getChild("progress") def __init__(self, instrument='mock', url=None, run_analysis_handle='run_analysis', host=None, port=None, cookies=None, protocol="https", wait=True, n_max_tries=200, session_id=None, use_local_cache=False, token=None ): self.setup_loggers() if url is None: if 'unige-production' not in self.known_sites_dict: url = "https://www.astro.unige.ch/mmoda/dispatch-data" else: url = self.known_sites_dict['unige-production'] else: if not url.startswith("http://") and not url.startswith("https://"): self.logger.info('url %s is not of http(s) schema, trying to interpretting url as an alias', url) if url in self.known_sites_dict: self.logger.info('url %s interpretted an alias for %s', url, self.known_sites_dict[url]) url = self.known_sites_dict[url] else: logger.debug(f'url {url} does not match http(s) schema and is not one of the aliases (%s) {list(self.known_sites_dict)}') if host is not None: msg = '\n' msg += '----------------------------------------------------------------------------\n' msg += 'support for the parameter host will end soon \n' msg += 'please use "url" instead of "host" while providing dispatcher URL \n' msg += '----------------------------------------------------------------------------\n' warnings.warn(msg, DeprecationWarning, stacklevel=1) self.url = host # TODO: disregard this, but leave parameter for compatibility if host.startswith('http'): self.url = host else: if protocol != 'http' and protocol != 'https': raise UserError('protocol must be either http or https') else: self.url = protocol + "://" + host else: if not oda_api.misc_helpers.validate_url(url): raise UserError(f'{url} is not a valid url. \n' 'A valid url should be like `https://www.astro.unige.ch/mmoda/dispatch-data`, ' 'you might verify if, for example, a valid schema is provided, ' 'i.e. url should start with http:// or https:// .\n' 'Please check it and try to issue again the request') self.url = url if session_id is not None: self._session_id = session_id self.token = token if token is not None else self.get_token_from_environment() self._carriage_return_progress = False self.run_analysis_handle = run_analysis_handle self.wait = wait self.strict_parameter_check = False self.cookies = cookies self.set_instr(instrument) self.n_max_tries = n_max_tries self.retry_sleep_s = 10. if port is not None: self.logger.warning( "please use 'url' to specify entire URL, no need to provide port separately") self._progress_iter = cycle(['|', '/', '-', '\\']) self.use_local_cache = use_local_cache # TODO this should really be just swagger/bravado; or at least derived from resources self.dispatcher_response_schema = { 'type': 'object', 'properties': { 'exit_status': { 'type': 'object', 'properties': { 'status': {'type': 'number'}, }, }, 'query_status': {'type': 'string'}, 'job_monitor': { 'type': 'object', 'properties': { 'job_id': {'type': 'string'}, }, }, } } def inspect_state(self, job_id=None, group_by_job=False): params = dict(token=oda_api.token.discover_token(), group_by_job=group_by_job) if job_id is not None: params['job_id'] = job_id r = requests.get(self.url + "/inspect-state", params=params) if r.status_code == 200: return r.json() else: raise RuntimeError(r.text) def refresh_token(self, token_to_refresh=None, write_token=False, token_write_methods: Union[Tuple[TokenLocation, ...], TokenLocation] = (TokenLocation.ODA_ENV_VAR, TokenLocation.FILE_CUR_DIR), discard_discovered_token=False): if token_to_refresh is None: token_to_refresh = oda_api.token.discover_token() if token_to_refresh is not None and token_to_refresh != '': params = dict(token=token_to_refresh, query_status='new') r = requests.get(os.path.join(self.url, 'refresh_token'), params=params) if r.status_code == 200: refreshed_token = r.text if write_token: oda_api.token.rewrite_token(refreshed_token, old_token=token_to_refresh, token_write_methods=token_write_methods, discard_discovered_token=discard_discovered_token) return refreshed_token else: raise RuntimeError(r.text) else: raise RuntimeError("unable to refresh the token with any known method") def disable_email_token(self, token_to_update=None, write_token=False, token_write_methods: Union[Tuple[TokenLocation, ...], TokenLocation] = (TokenLocation.ODA_ENV_VAR, TokenLocation.FILE_CUR_DIR), discard_discovered_token=False): if token_to_update is None: token_to_update = oda_api.token.discover_token() if token_to_update is not None and token_to_update != '': params = dict(token=token_to_update, msfail=False, mssub=False, msdone=False, query_status='new') r = requests.get(os.path.join(self.url, 'update_token_email_options'), params=params) if r.status_code == 200: refreshed_token = r.text if write_token: oda_api.token.rewrite_token(refreshed_token, old_token=token_to_update, token_write_methods=token_write_methods, discard_discovered_token=discard_discovered_token, force_rewrite=True) return refreshed_token else: raise RuntimeError(r.text) else: raise RuntimeError("unable to refresh the token with any known method") def set_custom_progress_formatter(self, F): self.custom_progress_formatter = F
[docs] @classmethod def build_from_envs(cls): cookies_path = os.environ.get('ODA_API_TOKEN') if cookies_path is None: raise RuntimeError("ODA_API_TOKEN environment variable is not set") cookies = dict(_oauth2_proxy=open(cookies_path).read().strip()) host_url = os.environ.get('DISP_URL') return cls(host=host_url, instrument='mock', cookies=cookies, protocol='http')
[docs] def generate_session_id(self, size=16): chars = string.ascii_uppercase + string.digits return ''.join(random.choice(chars) for _ in range(size))
@classmethod def calculate_param_dict_id(cls, par_dict: dict): ordered_par_dic = OrderedDict({ k: par_dict[k] for k in sorted(par_dict.keys()) if par_dict[k] is not None }) return oda_api.misc_helpers.make_hash(ordered_par_dic) @property def session_id(self): if not hasattr(self, '_session_id'): self._session_id = self.generate_session_id() return self._session_id
[docs] def set_instr(self, instrument): self.instrument = instrument self.custom_progress_formatter = custom_formatters.find_custom_formatter( instrument)
def _progress_bar(self, info=''): if self._carriage_return_progress: c_r = '\x1b[80D' + '\x1b[K' # TODO: this does not really work now else: c_r = '' self.progress_logger.info( f"{c_r}{C.GREY}\r {next(self._progress_iter)} the job is working remotely, please wait {info}{C.NC}") def format_custom_progress(self, full_report_dict_list): F = getattr(self, 'custom_progress_formatter', None) if F is not None: return F(full_report_dict_list) return "" def note_request_time(self): self.request_stats = getattr(self, 'request_stats', []) self.request_stats.append( self.last_request_t_complete - self.last_request_t0) @property def preferred_request_method(self): return getattr(self, '_preferred_request_method', 'GET') @preferred_request_method.setter def preferred_request_method(self, v): allowed_request_methods = ['POST', 'GET'] if v in allowed_request_methods: self._preferred_request_method = v else: raise RuntimeError(f'unable to set preferred request method to {v}, allowed {allowed_request_methods}') @property def selected_request_method(self): if self.parameters_dict_payload is not None: request_size = len(json.dumps(self.parameters_dict_payload)) max_get_method_size = getattr(self, 'max_get_method_size', 1000) self.logger.debug('payload size %s, max for GET is %s', request_size, max_get_method_size) if request_size > max_get_method_size: self.logger.debug( 'switching to POST request due to large payload: %s > %s', request_size, max_get_method_size) return 'POST' return self.preferred_request_method def request_to_json(self, verbose=False): if self.use_local_cache: try: return self.load_result() except Exception: logger.debug('unable to load result from %s: will need to compute', self.unique_response_json_fn) self.progress_logger.info( f'- waiting for remote response (since {time.strftime("%Y-%m-%d %H:%M:%S")}), please wait for {self.url}/{self.run_analysis_handle}') response = None try: timeout = getattr(self, 'timeout', 120) self.last_request_t0 = time.time() url = "%s/%s" % (self.url, self.run_analysis_handle) if self.selected_request_method == 'GET': response = requests.get( url, params=self.parameters_dict_payload, cookies=self.cookies, headers={ 'Request-Timeout': str(timeout), 'Connection-Timeout': str(timeout), }, timeout=timeout, allow_redirects=False ) elif self.selected_request_method == 'POST': response = requests.post( url, data=self.parameters_dict_payload, cookies=self.cookies, headers={ 'Request-Timeout': str(timeout), 'Connection-Timeout': str(timeout), }, timeout=timeout, allow_redirects=False ) else: raise NotImplementedError if response.status_code in (301, 302): # we can not automatically redirect with POST due to unexpected behavior of requests module # there is a very strange and mysterious story about this: # * https://github.com/psf/requests/blob/1e5fad7433772b648fcbc921e2a79de5c4c6be8b/requests/sessions.py#L329-L332 # * https://github.com/psf/requests/issues/1704 # to avoid confusion, we will instruct the user to change the code: raise URLRedirected(f"the service was moved{' permanently' if response.status_code == 301 else ''}, " f"please reinitialize DispatcherAPI with \"{response.headers['Location']}\" (you asked for \"{url}\")") if response.status_code == 403: try: response_json = response.json() except JSONDecodeError as e: raise Unauthorized(f"undecodable: {response.text}") from e try: raise Unauthorized(response_json['exit_status']['message']) except KeyError as e: raise Unauthorized(response_json['error']) from e if response.status_code == 400: raise RequestNotUnderstood( response.json()) if response.status_code in [502, 503, 504]: raise DispatcherNotAvailable() if response.status_code == 500: try: raise DispatcherException(response.json()) except JSONDecodeError as e: raise DispatcherException({'error_message': response.text}) from e if response.status_code != 200: raise UnexpectedDispatcherStatusCode( f"status: {response.status_code}, raw: {response.text}") self.last_request_t_complete = time.time() self.note_request_time() response_json = self._decode_res_json(response) validate_json(response_json, self.dispatcher_response_schema) self.returned_analysis_parameters = response_json['products'].get('analysis_parameters', None) if self.use_local_cache and response_json.get('query_status') in ['done', 'failed']: self.save_result(response_json) return response_json except json.decoder.JSONDecodeError: self.logger.error( f"{C.RED}{C.BOLD}unable to decode json from response:{C.NC}") if response is not None: self.logger.error(f"{C.RED}{response.text}{C.NC}") raise def returned_analysis_parameters_consistency(self): mismatching_parameters = [] for k in self.parameters_dict.keys(): # these do not correspond to meaning ''' The dry_run parameter is not actually considered within the oda_api, but we keep it here for consistency. As discussed in: * https://github.com/oda-hub/oda_api/pull/85 * https://github.com/oda-hub/oda_api/issues/84 ''' if k in ['query_status', 'off_line', 'verbose', 'dry_run']: continue returned = self.returned_analysis_parameters.get(k, None) requested = self.parameters_dict.get(k, None) if str(returned) != str(requested): mismatching_parameters.append(f"{k}: returned {returned} != requested {requested}") if mismatching_parameters != []: raise RuntimeError(f"dispatcher return different parameters: {'; '.join(mismatching_parameters)}") @property def parameters_dict(self): """ as provided in request, not modified by state changes """ return getattr(self, '_parameters_dict', {}) @parameters_dict.setter def parameters_dict(self, value): self._parameters_dict = value self.query_status = 'prepared' @property def parameters_dict_payload(self): if self.parameters_dict is None: return None p = { **self.parameters_dict, 'api': 'True', 'oda_api_version': __version__, } for k, v in p.items(): if isinstance(v, (list, dict, set)) and (k not in ['catalog_selected_objects', 'selected_catalog', 'scw_list']): p[k] = json.dumps(v) if v is None and k != 'token': p[k] = '\x00' if self.is_submitted: return { **p, 'job_id': self.job_id, 'query_status': self.query_status, } else: return p @parameters_dict_payload.setter def parameters_dict_payload(self, value): raise UserError( "please set parameters_dict and not parameters_dict_payload") @property def job_id(self): return getattr(self, '_job_id', None) @job_id.setter def job_id(self, new_job_id): self._job_id = new_job_id @property def query_status(self): return getattr(self, '_query_status', 'not-prepared') @query_status.setter def query_status(self, new_status): possible_status = [ "not-prepared", "prepared", "submitted", "progress", "done", "ready", "failed", ] if new_status in possible_status: self._query_status = new_status else: raise RuntimeError( f"unable to set status to {new_status}, possible values are {possible_status}") @property def is_submitted(self): return self.query_status not in ['prepared', 'not-prepared'] @property def is_prepared(self): return self.query_status not in ['not-prepared'] @property def is_done(self): return self.query_status in ['done'] @property def is_complete(self): return self.query_status in ['done', 'failed'] @property def is_failed(self): return self.query_status in ['failed'] @safe_run def poll(self, verbose=None, silent=None): """ Updates status of query at the remote server Relies on self.parameters_dict to set parameters for request Relies on self.query_status and self.job_id, which is created as necessary and submitted in paylad """ if verbose is not None or silent is not None: self.logger.warning( "please set verbosity with standard python \"logging\" module") self.logger.warning("these option will be removed in the future") if verbose: if silent: self.logger.error( "can not be verbose and silent at once! ignoring verbose and silent options") else: self.logger.warning( "legacy verbose option: setting oda_api logging level to DEBUG and one stream handler") logging.getLogger('oda_api').setLevel(logging.DEBUG) logging.getLogger('oda_api').addHandler( logging.StreamHandler()) else: if silent: self.logger.warning( "legacy silent option, no special logging config - silent by default") else: self.logger.warning( "legacy verbose but not silet option: setting oda_api logging level to INFO and one stream handler") logging.getLogger('oda_api').setLevel(logging.INFO) logging.getLogger('oda_api').addHandler( logging.StreamHandler()) if not self.is_prepared: raise UserError( f"can not poll query before parameters are set with {self}.request") # > self.response_json = self.request_to_json() # < logger.info("session: %s job: %s", self.response_json['job_monitor']['session_id'], self.response_json['job_monitor']['job_id']) if 'query_status' not in self.response_json: logger.error(json.dumps(self.response_json, indent=4)) raise RuntimeError( f"request json does not contain query_status: {self.response_json}") if self.response_json.get('query_status') != self.query_status: self.logger.info( f"\n... query status {C.PURPLE}{self.query_status}{C.NC} => {C.PURPLE}{self.response_json.get('query_status')}{C.NC}") self.query_status = self.response_json.get('query_status') returned_job_id = self.response_json['job_monitor']['job_id'] if self.job_id is None: self.job_id = returned_job_id self.logger.info( f"... assigned job id: {C.BROWN}{self.job_id}{C.NC}") else: if self.response_json['query_status'] != self.query_status: raise RuntimeError( f"request returns query_status {self.response_json['query_status']} != recorded query_status {self.query_status}" f"this should not happen! Server must be misbehaving, or client forgot correct query_status") if self.job_id != returned_job_id: raise RuntimeError(f"request returns job_id {returned_job_id} != recorded job_id {self.job_id}" f"this should not happen! Server must be misbehaving, or client forgot correct job id") if self.query_status == 'done': self.logger.info( f"\033[32mquery COMPLETED SUCCESSFULLY (state {self.query_status})\033[0m") elif self.query_status == 'failed': self.logger.info( f"\033[31mquery COMPLETED with FAILURE (state {self.query_status})\033[0m") else: self.show_progress() if self.is_complete: # TODO: something raising here does not help self.logger.debug("poll returing data: complete") return DataCollection.from_response_json(self.response_json, self.instrument, self.product) def show_progress(self): full_report_dict_list = self.response_json['job_monitor'].get( 'full_report_dict_list', []) info = 'status=%s job_id=%s in %d messages since %d seconds (%.2g/%.2g)' % ( self.query_status, str(self.job_id)[:8], len(full_report_dict_list), time.time() - self.t0, np.mean(self.request_stats), np.max(self.request_stats), ) custom_info = self.format_custom_progress(full_report_dict_list) if custom_info != "": info += "; " + custom_info self._progress_bar(info=info) def print_parameters(self): for k, v in self.parameters_dict.items(): self.logger.info(f"- {C.BLUE}{k}: {v}{C.NC}")
[docs] @safe_run def request(self, parameters_dict, handle=None, url=None, wait=None, quiet=True): """ sets request parameters, optionally polls them in a loop """ if wait is not None: self.logger.warning("overriding wait mode from request") self.wait = wait if url is not None: self.logger.warning("overriding dispatcher URL from request!") self.url = url if handle is not None: self.logger.warning( "overriding dispatcher handle from request not allowed, ignored!") self.parameters_dict = parameters_dict if 'scw_list' in self.parameters_dict.keys(): self.logger.debug(self.parameters_dict['scw_list']) self.set_instr(self.parameters_dict.get('instrument', self.instrument)) if not quiet: self.print_parameters() self.t0 = time.time() while True: self.poll() if not self.wait: self.logger.info("non-waiting dispatcher: terminating") return if self.is_complete: self.logger.info("query complete: terminating") return time.sleep(1)
def process_failure(self): if self.response_json['exit_status']['status'] != 0: self.failure_report(self.response_json) if self.query_status != 'failed': self.logger.info('query done succesfully!') else: logger.error("exception, message: \"%s\"", self.response_json['exit_status']['message']) logger.error("have exception message: keys \"%s\"", exception_by_message.keys()) raise exception_by_message.get(self.response_json['exit_status']['message'], RemoteException)( message=self.response_json['exit_status']['message'], debug_message=self.response_json['exit_status']['error_message'] )
[docs] def failure_report(self, res_json): self.logger.error('query failed!') self.logger.error('Remote server message:-> %s', res_json['exit_status']['message']) self.logger.error('Remote server error_message-> %s', res_json['exit_status']['error_message']) self.logger.error('Remote server debug_message-> %s', res_json['exit_status']['debug_message'])
def show_status_comments(self, res_json): if res_json['exit_status']['comment']: print(res_json['exit_status']['comment']) # TODO: warning field is not currently consistently used # could be enabled in the future (add test then!) # if res_json['exit_status']['warning']: # self.logger.warning(res_json['exit_status']['warning'])
[docs] def dig_list(self, b, only_prod=False): if isinstance(b, (set, tuple, list)): for c in b: self.dig_list(c) else: try: original_b = b b = ast.literal_eval(str(b)) # uh except Exception as e: logger.debug( "dig_list unable to literal_eval %s; problem %s", b, e) return str(b) if isinstance(b, dict): _s = '' for k, v in b.items(): if 'query_name' == k or 'instrument' == k and not only_prod: self.logger.info('') self.logger.info('--------------') _s += '%s' % k + ': ' + v if 'product_name' == k: _s += ' %s' % k + ': ' + v for k in ['name', 'value', 'units']: if k in b.keys(): _s += ' %s' % k + ': ' if b[k] is not None: _s += '%s,' % str(b[k]) else: _s += 'None,' _s += ' ' if _s != '': self.logger.info(_s) else: self.logger.debug( 'unable to dig list, instance not a dict by %s; object was %s', type(b), b) if original_b != b: self.dig_list(b)
@safe_run def _decode_res_json(self, res): try: if hasattr(res, 'content'): # _js = json.loads(res.content) # fixed issue with python 3.5 _js = res.json() res = ast.literal_eval(str(_js).replace('null', 'None')) else: res = ast.literal_eval(str(res).replace('null', 'None')) # what is it for? self.dig_list(res) return res except Exception as e: msg = 'remote/connection error, server response is not valid \n' msg += f'exception: {e}' msg += 'possible causes: \n' msg += '- connection error\n' msg += '- wrong credentials\n' msg += '- wrong remote address\n' msg += '- error on the remote server\n' msg += "--------------------------------------------------------------\n" if hasattr(res, 'status_code'): msg += '--- status code:-> %s\n' % res.status_code if hasattr(res, 'text'): msg += '--- response text ---\n %s\n' % res.text if hasattr(res, 'content'): msg += '--- res content ---\n %s\n' % res.content msg += "--------------------------------------------------------------" raise RemoteException(message=msg) from e def get_token_from_environment(self): token = oda_api.token.discover_token(allow_invalid=False, token_discovery_methods=self.token_discovery_methods) if token is not None: logger.info("discovered token in environment") return token
[docs] @safe_run def get_instrument_description(self, instrument=None): if instrument is None: instrument = self.instrument res = requests.get("%s/api/meta-data" % self.url, params=dict(instrument=instrument, token=self.token), cookies=self.cookies) if res.status_code != 200: raise UnexpectedDispatcherStatusCode( f"status: {res.status_code}, raw: {res.text}") return self._decode_res_json(res)
[docs] @safe_run def get_product_description(self, instrument, product_name): res = requests.get("%s/api/meta-data" % self.url, params=dict( instrument=instrument, product_type=product_name, token=self.token), cookies=self.cookies) if res.status_code != 200: raise UnexpectedDispatcherStatusCode( f"status: {res.status_code}, raw: {res.text}") self.logger.info('--------------') self.logger.info( 'parameters for product %s and instrument %s', product_name, instrument) return self._decode_res_json(res)
[docs] @safe_run def get_instruments_list(self): res = requests.get("%s/api/instr-list" % self.url, params=dict(instrument=self.instrument, token=self.token), cookies=self.cookies) if res.status_code != 200: raise UnexpectedDispatcherStatusCode( f"status: {res.status_code}, raw: {res.text}") return self._decode_res_json(res)
def report_last_request(self): self.logger.info( f"{C.GREY}last request completed in {self.last_request_t_complete - self.last_request_t0} seconds{C.NC}")
[docs] def get_product(self, product: str, instrument: str, verbose=None, product_type: str = 'Real', silent=False, **kwargs): """ submit query, wait (if allowed by self.wait), decode output when found """ if not silent: advice_logger.warning('please beware that by default, in a typical setup, oda_api will not output much. ' 'To learn how to increase the verbosity, please refer to the documentation: ' 'https://oda-api.readthedocs.io/en/latest/user_guide/ScienceWindowList.html?highlight=logging#Let\'s-get-some-logging . \n' 'To disable this message you can pass `.get_product(..., silent=True)`' ) self.job_id = None # TODO: it's confusing when and where these are passed self.product = product self.instrument = instrument kwargs['instrument'] = instrument kwargs['product_type'] = product kwargs['query_type'] = product_type kwargs['off_line'] = False, kwargs['query_status'] = 'new', kwargs['verbose'] = verbose, kwargs['session_id'] = self.session_id if 'dry_run' in kwargs: warnings.warn('The dry_run parameter you included is not going to have any effect on the execution.\n' 'However the oda_api will perform a check of the list of valid parameters for your request.', stacklevel = 1) del kwargs['dry_run'] res = requests.get("%s/api/par-names" % self.url, params=dict( instrument=instrument, product_type=product), cookies=self.cookies) if res.status_code != 200: warnings.warn( 'parameter check not available on remote server, check carefully parameters name', stacklevel = 1) else: _ignore_list = ['instrument', 'product_type', 'query_type', 'off_line', 'query_status', 'verbose', 'session_id'] validation_dict = copy.deepcopy(kwargs) for _i in _ignore_list: del validation_dict[_i] valid_names = self._decode_res_json(res) for n in validation_dict.keys(): if n not in valid_names: if self.strict_parameter_check: raise UserError(f'the parameter: {n} is not among the valid ones: {valid_names}' f'(you can set {self}.strict_parameter_check=False, but beware!') else: msg = '\n' msg += '----------------------------------------------------------------------------\n' msg += 'the parameter: %s ' % n msg += ' is not among valid ones:' msg += '\n' msg += '%s' % valid_names msg += '\n' # msg += 'this will throw an error in a future version \n' # msg += 'and might break the current request!\n ' msg += '----------------------------------------------------------------------------\n' warnings.warn(msg, stacklevel = 1) if kwargs.get('token', None) is None and self.token_discovery_methods is not None: discovered_token = oda_api.token.discover_token(self.token_discovery_methods) if discovered_token is not None: logger.info("discovered token in environment") kwargs['token'] = discovered_token # > self.request(kwargs) if self.is_failed: return self.process_failure() elif self.is_done: res_json = self.response_json elif not self.is_complete: if self.wait: raise RuntimeError( "should have waited, but did not - programming error!") else: self.logger.info( f"\n{C.BROWN}query not complete, please poll again later{C.NC}") return else: raise RuntimeError( "not failed, not, but complete? programming error for client!") self.show_status_comments(res_json) d = DataCollection.from_response_json( res_json, instrument, product) del (res) return d
@staticmethod def set_api_code(query_dict, url="www.astro.unige.ch/mmoda/dispatch-data"): query_dict = OrderedDict(sorted(query_dict.items())) _skip_list_ = ['job_id', 'query_status', 'session_id', 'use_resolver[local]', 'use_scws'] _alias_dict = {} _alias_dict['product_type'] = 'product' _alias_dict['query_type'] = 'product_type' _header = f'''from oda_api.api import DispatcherAPI disp=DispatcherAPI(url='{url}', instrument='mock')''' _api_dict = {} for k in query_dict.keys(): if k not in _skip_list_: if k in _alias_dict.keys(): n = _alias_dict[k] else: n = k if query_dict[k] is not None: _api_dict[n] = query_dict[k] python_compatible_par_dict_str = json.dumps(_api_dict, indent=4) python_compatible_par_dict_str = python_compatible_par_dict_str.replace('false', 'False') python_compatible_par_dict_str = python_compatible_par_dict_str.replace('true', 'True') _cmd_ = f'''{_header} par_dict={python_compatible_par_dict_str} data_collection = disp.get_product(**par_dict) ''' return _cmd_ def save_result(self, response_json): fn = self.unique_response_json_fn os.makedirs(os.path.dirname(fn), exist_ok=True) with gzip.open(fn, "wt") as f: json.dump(response_json, f) logger.info('saved result in %s', fn) def load_result(self): fn = self.unique_response_json_fn logger.info('trying to load result from %s', fn) t0 = time.time() with gzip.open(fn, 'rt') as f: r = json.load(f) logger.info('\033[32mmanaged to load result\033[0m from %s in %.2f seconds', fn, time.time() - t0) return r @property def unique_response_json_fn(self): request_hash = oda_api.misc_helpers.make_hash(self.set_api_code(self.parameters_dict)) return (pathlib.Path(os.getenv('ODA_CACHE', pathlib.Path(os.getenv('HOME', '/tmp')) / ".cache/oda-api")) / f"cache/oda_api_data_collection_{request_hash}.json.gz") def __repr__(self): return f"[ {self.__class__.__name__}: {self.url} ]"
class DataCollection(object): def __init__(self, data_list: list[DataProduct], add_meta_to_name: list[str] | None = None, instrument: str | None = None, product: str | None =None, request_job_id: str | None=None ): if add_meta_to_name is None: add_meta_to_name = ['src_name', 'product'] self._p_list: list[DataProduct] = [] self._n_list: list[str] = [] self.request_job_id = request_job_id for ID, data in enumerate(data_list): name = '' if hasattr(data, 'name'): name = data.name # type: ignore[assignment] if name is None or name.strip() == '': if product is not None: name = '%s' % product elif instrument is not None: name = '%s' % instrument else: name = 'prod' name = '%s_%d' % (name, ID) name, var_name = self._build_prod_name( data, name, add_meta_to_name) setattr(self, var_name, data) self._p_list.append(data) self._n_list.append(var_name) def show(self): for ID, prod_name in enumerate(self._n_list): if hasattr(self._p_list[ID], 'meta_data'): meta_data = self._p_list[ID].meta_data else: meta_data = '' print('ID=%s prod_name=%s' % (ID, prod_name), ' meta_data:', meta_data) print() def as_list(self): L = [] for ID, prod_name in enumerate(self._n_list): if hasattr(self._p_list[ID], 'meta_data'): meta_data = self._p_list[ID].meta_data else: meta_data = '' L.append({ 'ID': ID, 'prod_name': prod_name, 'meta_data:': meta_data }) return L def _build_prod_name(self, prod, name, add_meta_to_name): for kw in add_meta_to_name: if hasattr(prod, 'meta_data'): if kw in prod.meta_data: s = prod.meta_data[kw].replace(' ', '') if s.strip() != '': name += '_' + s.strip() return name, oda_api.misc_helpers.clean_var_name(name) def save_all_data(self, prenpend_name = None, overwrite=True): # NOTE: prepend_name also determines file path for pname, prod in zip(self._n_list, self._p_list, strict=False): if not isinstance(prod, DataProduct): logger.warning(f"Writing on disk is not implemented for product {pname} of type {pname.__class__.__name__}, skipping.") continue if prenpend_name is not None: file_name = prenpend_name + '_' + pname else: file_name = pname fn_extension = prod.suggest_fn_extension() file_name = f"{file_name}.{fn_extension}" prod.write_file(file_name, overwrite=overwrite) def save(self, file_name): pickle.dump(self, open(file_name, 'wb'), protocol=pickle.HIGHEST_PROTOCOL) def new_from_metadata(self, key, val): dc = None _l = [] for p in self._p_list: if p.meta_data[key] == val: _l.append(p) if _l != []: dc = DataCollection(_l) return dc @classmethod def from_response_json(cls, res_json, instrument, product): data = [] if 'numpy_data_product' in res_json['products'].keys(): data.append(NumpyDataProduct.decode( res_json['products']['numpy_data_product'])) elif 'numpy_data_product_list' in res_json['products'].keys(): data.extend([NumpyDataProduct.decode(d) for d in res_json['products']['numpy_data_product_list']]) if 'binary_data_product_list' in res_json['products'].keys(): try: data.extend([BinaryProduct.decode(d) for d in res_json['products']['binary_data_product_list']]) except Exception: data.extend([BinaryData().decode(d) for d in res_json['products']['binary_data_product_list']]) if 'catalog' in res_json['products'].keys(): data.append(ApiCatalog( res_json['products']['catalog'], name='dispatcher_catalog')) if 'astropy_table_product_ascii_list' in res_json['products'].keys(): data.extend([ODAAstropyTable.decode(table_text, use_binary=False) for table_text in res_json['products']['astropy_table_product_ascii_list']]) if 'astropy_table_product_binary_list' in res_json['products'].keys(): data.extend([ODAAstropyTable.decode(table_binary, use_binary=True) for table_binary in res_json['products']['astropy_table_product_binary_list']]) if 'binary_image_product_list' in res_json['products'].keys(): data.extend([PictureProduct.decode(bin_image_data) for bin_image_data in res_json['products']['binary_image_product_list']]) if 'text_product_list' in res_json['products'].keys(): try: data.extend([TextLikeProduct.decode(text_data) for text_data in res_json['products']['text_product_list']]) except (JSONDecodeError, KeyError): data.extend([text_data for text_data in res_json['products']['text_product_list']]) if 'gw_strain_product_list' in res_json['products'].keys(): data.extend([TimeSeries(strain_data['value'], # pyright: ignore[reportPossiblyUnboundVariable] name=strain_data['name'], t0=strain_data['t0'], dt=strain_data['dt']) for strain_data in res_json['products']['gw_strain_product_list']]) if 'gw_spectrogram_product' in res_json['products'].keys(): sgram = res_json['products']['gw_spectrogram_product'] data.append(Spectrogram(sgram['value'], # pyright: ignore[reportPossiblyUnboundVariable] name='Spectrogram', unit='s', t0=sgram['x0'], dt=sgram['dx'], frequencies=sgram['yindex'] ) ) if 'gw_skymap_product' in res_json['products'].keys(): skmap = res_json['products']['gw_skymap_product'] for event in skmap['skymaps'].keys(): data.append(NumpyDataProduct.decode(skmap['skymaps'][event])) if 'contours' in skmap.keys(): data.append(GWContoursDataProduct(skmap['contours'])) if 'job_id' not in res_json['job_monitor']: # TODO use the incident-report endpoint from the dispatcher (https://github.com/oda-hub/dispatcher-app/issues/393) logger.warning(f"job_monitor response json does not contain job_id: {res_json['job_monitor']}") request_job_id = res_json['job_monitor'].get('job_id', None) d = cls(data, instrument=instrument, product=product, request_job_id=request_job_id) for p in d._p_list: if hasattr(p, 'meta_data') is False and hasattr(p, 'meta') is True: p.meta_data = p.meta # type:ignore return d def get_context(): """ load context from file .oda_api_context in the notebook dir """ from oda_api import context_file if not os.path.isfile(context_file): return {} with open(context_file, 'r') as file: context = json.load(file) return context class ProgressReporter(object): """ The class allows to report task progress to end user """ def __init__(self): callback = get_context().get('callback', None) if callback: callback = callback.strip() else: # backward compatibility callback_file = ".oda_api_callback" if os.path.isfile(callback_file): logger.warning(f'reading callback from the deprecated location: {callback_file}') with open(callback_file, 'r') as file: callback = file.read().strip() self._callback = callback @property def enabled(self): return self._callback is not None def report_progress(self, stage: str | None = None, progress: float=50., progress_max: float=100., substage: str | None = None, subprogress: float | None = None, subprogress_max: float=100., message: str | None = None): """ Report progress via callback URL :param stage: current stage description string :param progress: current stage progress :param progress_max: maximal progress value :param substage: current substage description string :param subprogress: current substage progress :param subprogress_max: maximal substage progress value :param message: message to pass """ callback_payload = dict(stage=stage, progress=progress, progress_max=progress_max, substage=substage, subprogress=subprogress, subprogress_max=subprogress_max, message=message) callback_payload = {k: v for k, v in callback_payload.items() if v is not None} callback_payload['action'] = 'progress' if not self.enabled: logger.info('no callback registered, skipping') return self._callback = cast(str, self._callback) logger.info('will perform callback: %s', self._callback) if re.match(r'^file://', self._callback): with open(self._callback.replace('file://', ''), "w") as f: json.dump(callback_payload, f) logger.info('stored callback in a file %s', self._callback) elif re.match(r'^https?://', self._callback): r = requests.get(self._callback, params=callback_payload) logger.info('callback %s returns %s : %s', self._callback, r, r.text) else: raise NotImplementedError