from __future__ import absolute_import, division, print_function
from collections import OrderedDict
import gzip
import hashlib
from json.decoder import JSONDecodeError
import pathlib
import rdflib
from json.decoder import JSONDecodeError
# NOTE gw is optional for now
try:
import gwpy
from gwpy.timeseries.timeseries import TimeSeries
from gwpy.spectrogram import Spectrogram
except ModuleNotFoundError:
pass
from .data_products import (NumpyDataProduct,
BinaryData,
BinaryProduct,
ApiCatalog,
GWContoursDataProduct,
PictureProduct,
ODAAstropyTable,
TextLikeProduct)
from oda_api.token import TokenLocation
from builtins import (bytes, str, open, super, range,
zip, round, input, int, pow, object, map, zip)
__author__ = "Andrea Tramacere, Volodymyr Savchenko"
import warnings
import requests
import ast
import json
import re
try:
# compatibility in some remaining environments
import simplejson # type: ignore
except ImportError:
import json as simplejson # type: ignore
import random
import string
import time
import os
import inspect
import sys
from astropy.io import ascii
import copy
import pickle
from . import __version__
from . import custom_formatters
from . import colors as C
from itertools import cycle
import numpy as np
import traceback
from jsonschema import validate as validate_json
from typing import Union, Tuple
import oda_api.token
import oda_api.misc_helpers
import logging
logger = logging.getLogger("oda_api.api")
advice_logger = logging.getLogger("oda_api.advice")
__all__ = ['Request', 'NoTraceBackWithLineNumber',
'NoTraceBackWithLineNumber', 'RemoteException', 'DispatcherAPI']
[docs]
class Request(object):
def __init__(self):
pass
[docs]
class NoTraceBackWithLineNumber(Exception):
def __init__(self, msg):
try:
ln = sys.exc_info()[-1].tb_lineno
except AttributeError:
ln = inspect.currentframe().f_back.f_lineno
self.args = "{0.__name__} (line {1}): {2}".format(type(self), ln, msg),
# sys.exit(self)
class UserError(Exception):
pass
[docs]
class RemoteException(NoTraceBackWithLineNumber):
def __init__(self, message='Remote analysis exception', debug_message=''):
super(RemoteException, self).__init__(message)
self.message = message
self.debug_message = debug_message
def __repr__(self):
return f"RemoteException: {self.message}, {self.debug_message}"
class FailedToFindAnyUsefulResults(RemoteException):
pass
class UnexpectedDispatcherStatusCode(RemoteException):
pass
class DispatcherNotAvailable(RemoteException):
pass
class RequestNotUnderstood(Exception):
def __init__(self, details_json) -> None:
self.details_json = details_json
def __repr__(self) -> str:
return f"[ RequestNotUnderstood: {self.details_json['error']} ]"
def __str__(self) -> str:
return repr(self)
class Unauthorized(RemoteException):
pass
class URLRedirected(Exception):
pass
class DispatcherException(Exception):
def __init__(self, response_json) -> None:
self.response_json = response_json
def __repr__(self) -> str:
return f"[ {self.__class__.__name__}: {self.response_json.get('error_message', '[no error message reported]')} ]"
exception_by_message = {
'failed: get dataserver products ': FailedToFindAnyUsefulResults
}
def safe_run(func):
def func_wrapper(*args, **kwargs):
logger = logging.getLogger("oda_api.api." + func.__name__)
self = args[0] # because it really is
n_tries_left = self.n_max_tries
retry_sleep_s = self.retry_sleep_s
t0 = time.time()
while True:
try:
return func(*args, **kwargs)
except UserError as e:
logger.exception("probably an unfortunate user input: %s", e)
raise
except (Unauthorized, RequestNotUnderstood, UnexpectedDispatcherStatusCode) as e:
logger.exception("something went quite wrong, and we think it's not likely to recover on its own: %s",
e)
raise
except (ConnectionError,
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
DispatcherNotAvailable) as e:
# TODO: these are probably all server or access errors,
# TODO: and they may need to be communicated back to server (if possible)
message = ''
message += '\nunable to complete API call'
message += '\nin ' + str(func) + ' called with:'
message += '\n... ' + ", ".join([str(arg) for arg in args])
message += '\n... ' + \
", ".join([k + ": " + str(v) for k, v in kwargs])
message += '\npossible causes:'
message += '\n- connection error'
message += '\n- error on the remote server'
message += '\n exception message: '
message += '\n\n%s\n' % e
message += traceback.format_exc()
n_tries_left -= 1
if n_tries_left > 0:
logger.debug("problem in API call, %i tries left:\n%s\n sleeping %i seconds until retry",
n_tries_left, message, retry_sleep_s)
logger.warning(
"possibly temporary problem in calling server: %s in %.1f seconds, %i tries left, sleeping %i seconds until retry",
repr(e), time.time() - t0, n_tries_left, retry_sleep_s)
time.sleep(retry_sleep_s)
else:
raise RemoteException(
message=message
)
return func_wrapper
[docs]
class DispatcherAPI:
# allowing token discovery by default changes the user interface in some cases,
# but in desirable way
token_discovery_methods = None
use_local_cache = False
_known_sites_dict = None
@property
def known_sites_dict(self):
if self._known_sites_dict is None:
self._known_sites_dict = {}
G = rdflib.Graph()
G.parse("https://odahub.io/oda-sites.ttl")
for site, url in G.subject_objects(rdflib.URIRef("http://odahub.io/ontology#APIURL")):
self._known_sites_dict[site.split("#")[1]] = str(url)
for alias in G.objects(site, rdflib.URIRef("http://www.w3.org/2000/01/rdf-schema#label")):
self._known_sites_dict[str(alias)] = str(url)
return self._known_sites_dict
def setup_loggers(self):
self.logger = logger.getChild(self.__class__.__name__.lower())
self.progress_logger = self.logger.getChild("progress")
def __init__(self,
instrument='mock',
url=None,
run_analysis_handle='run_analysis',
host=None,
port=None,
cookies=None,
protocol="https",
wait=True,
n_max_tries=200,
session_id=None,
use_local_cache=False,
token=None
):
self.setup_loggers()
if url is None:
if 'unige-production' not in self.known_sites_dict:
url = "https://www.astro.unige.ch/mmoda/dispatch-data"
else:
url = self.known_sites_dict['unige-production']
else:
if not url.startswith("http://") and not url.startswith("https://"):
self.logger.info('url %s is not of http(s) schema, trying to interpretting url as an alias', url)
if url in self.known_sites_dict:
self.logger.info('url %s interpretted an alias for %s', url, self.known_sites_dict[url])
url = self.known_sites_dict[url]
else:
logger.debug(f'url %s does not match http(s) schema and is not one of the aliases (%s)', url, list(self.known_sites_dict))
if host is not None:
msg = '\n'
msg += '----------------------------------------------------------------------------\n'
msg += 'support for the parameter host will end soon \n'
msg += 'please use "url" instead of "host" while providing dispatcher URL \n'
msg += '----------------------------------------------------------------------------\n'
warnings.warn(msg, DeprecationWarning)
self.url = host
# TODO: disregard this, but leave parameter for compatibility
if host.startswith('http'):
self.url = host
else:
if protocol != 'http' and protocol != 'https':
raise UserError('protocol must be either http or https')
else:
self.url = protocol + "://" + host
else:
if not oda_api.misc_helpers.validate_url(url):
raise UserError(f'{url} is not a valid url. \n'
'A valid url should be like `https://www.astro.unige.ch/mmoda/dispatch-data`, '
'you might verify if, for example, a valid schema is provided, '
'i.e. url should start with http:// or https:// .\n'
'Please check it and try to issue again the request')
self.url = url
if session_id is not None:
self._session_id = session_id
self.token = token if token is not None else self.get_token_from_environment()
self._carriage_return_progress = False
self.run_analysis_handle = run_analysis_handle
self.wait = wait
self.strict_parameter_check = False
self.cookies = cookies
self.set_instr(instrument)
self.n_max_tries = n_max_tries
self.retry_sleep_s = 10.
if port is not None:
self.logger.warning(
"please use 'url' to specify entire URL, no need to provide port separately")
self._progress_iter = cycle(['|', '/', '-', '\\'])
self.use_local_cache = use_local_cache
# TODO this should really be just swagger/bravado; or at least derived from resources
self.dispatcher_response_schema = {
'type': 'object',
'properties': {
'exit_status': {
'type': 'object',
'properties': {
'status': {'type': 'number'},
},
},
'query_status': {'type': 'string'},
'job_monitor': {
'type': 'object',
'properties': {
'job_id': {'type': 'string'},
},
},
}
}
def inspect_state(self, job_id=None, group_by_job=False):
params = dict(token=oda_api.token.discover_token(), group_by_job=group_by_job)
if job_id is not None:
params['job_id'] = job_id
r = requests.get(self.url + "/inspect-state", params=params)
if r.status_code == 200:
return r.json()
else:
raise RuntimeError(r.text)
def refresh_token(self,
token_to_refresh=None,
write_token=False,
token_write_methods: Union[Tuple[TokenLocation, ...], TokenLocation] = (TokenLocation.ODA_ENV_VAR,
TokenLocation.FILE_CUR_DIR),
discard_discovered_token=False):
if token_to_refresh is None:
token_to_refresh = oda_api.token.discover_token()
if token_to_refresh is not None and token_to_refresh != '':
params = dict(token=token_to_refresh,
query_status='new')
r = requests.get(os.path.join(self.url, 'refresh_token'),
params=params)
if r.status_code == 200:
refreshed_token = r.text
if write_token:
oda_api.token.rewrite_token(refreshed_token,
old_token=token_to_refresh,
token_write_methods=token_write_methods,
discard_discovered_token=discard_discovered_token)
return refreshed_token
else:
raise RuntimeError(r.text)
else:
raise RuntimeError("unable to refresh the token with any known method")
def set_custom_progress_formatter(self, F):
self.custom_progress_formatter = F
[docs]
@classmethod
def build_from_envs(cls):
cookies_path = os.environ.get('ODA_API_TOKEN')
cookies = dict(_oauth2_proxy=open(cookies_path).read().strip())
host_url = os.environ.get('DISP_URL')
return cls(host=host_url, instrument='mock', cookies=cookies, protocol='http')
[docs]
def generate_session_id(self, size=16):
chars = string.ascii_uppercase + string.digits
return ''.join(random.choice(chars) for _ in range(size))
@classmethod
def calculate_param_dict_id(cls,
par_dict: dict):
ordered_par_dic = OrderedDict({
k: par_dict[k] for k in sorted(par_dict.keys())
if par_dict[k] is not None
})
return oda_api.misc_helpers.make_hash(ordered_par_dic)
@property
def session_id(self):
if not hasattr(self, '_session_id'):
self._session_id = self.generate_session_id()
return self._session_id
[docs]
def set_instr(self, instrument):
self.instrument = instrument
self.custom_progress_formatter = custom_formatters.find_custom_formatter(
instrument)
def _progress_bar(self, info=''):
if self._carriage_return_progress:
c_r = '\x1b[80D' + '\x1b[K' # TODO: this does not really work now
else:
c_r = ''
self.progress_logger.info(
f"{c_r}{C.GREY}\r {next(self._progress_iter)} the job is working remotely, please wait {info}{C.NC}")
def format_custom_progress(self, full_report_dict_list):
F = getattr(self, 'custom_progress_formatter', None)
if F is not None:
return F(full_report_dict_list)
return ""
def note_request_time(self):
self.request_stats = getattr(self, 'request_stats', [])
self.request_stats.append(
self.last_request_t_complete - self.last_request_t0)
@property
def preferred_request_method(self):
return getattr(self, '_preferred_request_method', 'GET')
@preferred_request_method.setter
def preferred_request_method(self, v):
allowed_request_methods = ['POST', 'GET']
if v in allowed_request_methods:
self._preferred_request_method = v
else:
raise RuntimeError(f'unable to set preferred request method to {v}, allowed {allowed_request_methods}')
@property
def selected_request_method(self):
if self.parameters_dict_payload is not None:
request_size = len(json.dumps(self.parameters_dict_payload))
max_get_method_size = getattr(self, 'max_get_method_size', 1000)
self.logger.debug('payload size %s, max for GET is %s', request_size, max_get_method_size)
if request_size > max_get_method_size:
self.logger.debug(
'switching to POST request due to large payload: %s > %s', request_size, max_get_method_size)
return 'POST'
return self.preferred_request_method
def request_to_json(self, verbose=False):
if self.use_local_cache:
try:
return self.load_result()
except Exception as e:
logger.debug('unable to load result from %s: will need to compute', self.unique_response_json_fn)
self.progress_logger.info(
f'- waiting for remote response (since {time.strftime("%Y-%m-%d %H:%M:%S")}), please wait for {self.url}/{self.run_analysis_handle}')
try:
timeout = getattr(self, 'timeout', 120)
self.last_request_t0 = time.time()
url = "%s/%s" % (self.url, self.run_analysis_handle)
if self.selected_request_method == 'GET':
response = requests.get(
url,
params=self.parameters_dict_payload,
cookies=self.cookies,
headers={
'Request-Timeout': str(timeout),
'Connection-Timeout': str(timeout),
},
timeout=timeout,
allow_redirects=False
)
elif self.selected_request_method == 'POST':
response = requests.post(
url,
data=self.parameters_dict_payload,
cookies=self.cookies,
headers={
'Request-Timeout': str(timeout),
'Connection-Timeout': str(timeout),
},
timeout=timeout,
allow_redirects=False
)
else:
NotImplementedError
if response.status_code in (301, 302):
# we can not automatically redirect with POST due to unexpected behavior of requests module
# there is a very strange and mysterious story about this:
# * https://github.com/psf/requests/blob/1e5fad7433772b648fcbc921e2a79de5c4c6be8b/requests/sessions.py#L329-L332
# * https://github.com/psf/requests/issues/1704
# to avoid confusion, we will instruct the user to change the code:
raise URLRedirected(f"the service was moved{' permanently' if response.status_code == 301 else ''}, "
f"please reinitialize DispatcherAPI with \"{response.headers['Location']}\" (you asked for \"{url}\")")
if response.status_code == 403:
try:
response_json = response.json()
except JSONDecodeError:
raise Unauthorized(f"undecodable: {response.text}")
try:
raise Unauthorized(response_json['exit_status']['message'])
except KeyError:
raise Unauthorized(response_json['error'])
if response.status_code == 400:
raise RequestNotUnderstood(
response.json())
if response.status_code in [502, 503, 504]:
raise DispatcherNotAvailable()
if response.status_code == 500:
try:
raise DispatcherException(response.json())
except simplejson.JSONDecodeError:
raise DispatcherException({'error_message': response.text})
if response.status_code != 200:
raise UnexpectedDispatcherStatusCode(
f"status: {response.status_code}, raw: {response.text}")
self.last_request_t_complete = time.time()
self.note_request_time()
response_json = self._decode_res_json(response)
validate_json(response_json, self.dispatcher_response_schema)
self.returned_analysis_parameters = response_json['products'].get('analysis_parameters', None)
if self.use_local_cache and response_json.get('query_status') in ['done', 'failed']:
self.save_result(response_json)
return response_json
except json.decoder.JSONDecodeError as e:
self.logger.error(
f"{C.RED}{C.BOLD}unable to decode json from response:{C.NC}")
self.logger.error(f"{C.RED}{response.text}{C.NC}")
raise
def returned_analysis_parameters_consistency(self):
mismatching_parameters = []
for k in self.parameters_dict.keys():
# these do not correspond to meaning
'''
The dry_run parameter is not actually considered within the oda_api,
but we keep it here for consistency.
As discussed in:
* https://github.com/oda-hub/oda_api/pull/85
* https://github.com/oda-hub/oda_api/issues/84
'''
if k in ['query_status', 'off_line', 'verbose', 'dry_run']:
continue
returned = self.returned_analysis_parameters.get(k, None)
requested = self.parameters_dict.get(k, None)
if str(returned) != str(requested):
mismatching_parameters.append(f"{k}: returned {returned} != requested {requested}")
if mismatching_parameters != []:
raise RuntimeError(f"dispatcher return different parameters: {'; '.join(mismatching_parameters)}")
@property
def parameters_dict(self):
"""
as provided in request, not modified by state changes
"""
return getattr(self, '_parameters_dict', None)
@parameters_dict.setter
def parameters_dict(self, value):
self._parameters_dict = value
self.query_status = 'prepared'
@property
def parameters_dict_payload(self):
if self.parameters_dict is None:
return None
p = {
**self.parameters_dict,
'api': 'True',
'oda_api_version': __version__,
}
for k, v in p.items():
if isinstance(v, (list, dict, set)) and (k not in ['catalog_selected_objects', 'selected_catalog', 'scw_list']):
p[k] = json.dumps(v)
if self.is_submitted:
return {
**p,
'job_id': self.job_id,
'query_status': self.query_status,
}
else:
return p
@parameters_dict_payload.setter
def parameters_dict_payload(self, value):
raise UserError(
"please set parameters_dict and not parameters_dict_payload")
@property
def job_id(self):
return getattr(self, '_job_id', None)
@job_id.setter
def job_id(self, new_job_id):
self._job_id = new_job_id
@property
def query_status(self):
return getattr(self, '_query_status', 'not-prepared')
@query_status.setter
def query_status(self, new_status):
possible_status = [
"not-prepared",
"prepared",
"submitted",
"progress",
"done",
"ready",
"failed",
]
if new_status in possible_status:
self._query_status = new_status
else:
raise RuntimeError(
f"unable to set status to {new_status}, possible values are {possible_status}")
@property
def is_submitted(self):
return self.query_status not in ['prepared', 'not-prepared']
@property
def is_prepared(self):
return self.query_status not in ['not-prepared']
@property
def is_done(self):
return self.query_status in ['done']
@property
def is_complete(self):
return self.query_status in ['done', 'failed']
@property
def is_failed(self):
return self.query_status in ['failed']
@safe_run
def poll(self, verbose=None, silent=None):
"""
Updates status of query at the remote server
Relies on self.parameters_dict to set parameters for request
Relies on self.query_status and self.job_id, which is created as necessary and submitted in paylad
"""
if verbose is not None or silent is not None:
self.logger.warning(
"please set verbosity with standard python \"logging\" module")
self.logger.warning("these option will be removed in the future")
if verbose:
if silent:
self.logger.error(
"can not be verbose and silent at once! ignoring verbose and silent options")
else:
self.logger.warning(
"legacy verbose option: setting oda_api logging level to DEBUG and one stream handler")
logging.getLogger('oda_api').setLevel(logging.DEBUG)
logging.getLogger('oda_api').addHandler(
logging.StreamHandler())
else:
if silent:
self.logger.warning(
"legacy silent option, no special logging config - silent by default")
else:
self.logger.warning(
"legacy verbose but not silet option: setting oda_api logging level to INFO and one stream handler")
logging.getLogger('oda_api').setLevel(logging.INFO)
logging.getLogger('oda_api').addHandler(
logging.StreamHandler())
if not self.is_prepared:
raise UserError(
f"can not poll query before parameters are set with {self}.request")
# >
self.response_json = self.request_to_json()
# <
logger.info("session: %s job: %s", self.response_json['job_monitor']['session_id'],
self.response_json['job_monitor']['job_id'])
if 'query_status' not in self.response_json:
logger.error(json.dumps(self.response_json, indent=4))
raise RuntimeError(
f"request json does not contain query_status: {self.response_json}")
if self.response_json.get('query_status') != self.query_status:
self.logger.info(
f"\n... query status {C.PURPLE}{self.query_status}{C.NC} => {C.PURPLE}{self.response_json.get('query_status')}{C.NC}")
self.query_status = self.response_json.get('query_status')
returned_job_id = self.response_json['job_monitor']['job_id']
if self.job_id is None:
self.job_id = returned_job_id
self.logger.info(
f"... assigned job id: {C.BROWN}{self.job_id}{C.NC}")
else:
if self.response_json['query_status'] != self.query_status:
raise RuntimeError(
f"request returns query_status {self.response_json['query_status']} != recorded query_status {self.query_status}"
f"this should not happen! Server must be misbehaving, or client forgot correct query_status")
if self.job_id != returned_job_id:
raise RuntimeError(f"request returns job_id {returned_job_id} != recorded job_id {self.job_id}"
f"this should not happen! Server must be misbehaving, or client forgot correct job id")
if self.query_status == 'done':
self.logger.info(
f"\033[32mquery COMPLETED SUCCESSFULLY (state {self.query_status})\033[0m")
elif self.query_status == 'failed':
self.logger.info(
f"\033[31mquery COMPLETED with FAILURE (state {self.query_status})\033[0m")
else:
self.show_progress()
if self.is_complete:
# TODO: something raising here does not help
self.logger.debug("poll returing data: complete")
return DataCollection.from_response_json(self.response_json, self.instrument, self.product)
def show_progress(self):
full_report_dict_list = self.response_json['job_monitor'].get(
'full_report_dict_list', [])
info = 'status=%s job_id=%s in %d messages since %d seconds (%.2g/%.2g)' % (
self.query_status,
str(self.job_id)[:8],
len(full_report_dict_list),
time.time() - self.t0,
np.mean(self.request_stats),
np.max(self.request_stats),
)
custom_info = self.format_custom_progress(full_report_dict_list)
if custom_info != "":
info += "; " + custom_info
self._progress_bar(info=info)
def print_parameters(self):
for k, v in self.parameters_dict.items():
self.logger.info(f"- {C.BLUE}{k}: {v}{C.NC}")
[docs]
@safe_run
def request(self, parameters_dict, handle=None, url=None, wait=None, quiet=True):
"""
sets request parameters, optionally polls them in a loop
"""
if wait is not None:
self.logger.warning("overriding wait mode from request")
self.wait = wait
if url is not None:
self.logger.warning("overriding dispatcher URL from request!")
self.url = url
if handle is not None:
self.logger.warning(
"overriding dispatcher handle from request not allowed, ignored!")
self.parameters_dict = parameters_dict
if 'scw_list' in self.parameters_dict.keys():
self.logger.debug(self.parameters_dict['scw_list'])
self.set_instr(self.parameters_dict.get('instrument', self.instrument))
if not quiet:
self.print_parameters()
self.t0 = time.time()
while True:
self.poll()
if not self.wait:
self.logger.info("non-waiting dispatcher: terminating")
return
if self.is_complete:
self.logger.info("query complete: terminating")
return
time.sleep(1)
def process_failure(self):
if self.response_json['exit_status']['status'] != 0:
self.failure_report(self.response_json)
if self.query_status != 'failed':
self.logger('query done succesfully!')
else:
logger.error("exception, message: \"%s\"",
self.response_json['exit_status']['message'])
logger.error("have exception message: keys \"%s\"",
exception_by_message.keys())
raise exception_by_message.get(self.response_json['exit_status']['message'], RemoteException)(
message=self.response_json['exit_status']['message'],
debug_message=self.response_json['exit_status']['error_message']
)
[docs]
def failure_report(self, res_json):
self.logger.error('query failed!')
self.logger.error('Remote server message:-> %s',
res_json['exit_status']['message'])
self.logger.error('Remote server error_message-> %s',
res_json['exit_status']['error_message'])
self.logger.error('Remote server debug_message-> %s',
res_json['exit_status']['debug_message'])
def show_status_comments(self, res_json):
if res_json['exit_status']['comment']:
print(res_json['exit_status']['comment'])
# TODO: warning field is not currently consistently used
# could be enabled in the future (add test then!)
# if res_json['exit_status']['warning']:
# self.logger.warning(res_json['exit_status']['warning'])
[docs]
def dig_list(self, b, only_prod=False):
if isinstance(b, (set, tuple, list)):
for c in b:
self.dig_list(c)
else:
try:
original_b = b
b = ast.literal_eval(str(b)) # uh
except Exception as e:
logger.debug(
"dig_list unable to literal_eval %s; problem %s", b, e)
return str(b)
if isinstance(b, dict):
_s = ''
for k, v in b.items():
if 'query_name' == k or 'instrument' == k and only_prod == False:
self.logger.info('')
self.logger.info('--------------')
_s += '%s' % k + ': ' + v
if 'product_name' == k:
_s += ' %s' % k + ': ' + v
for k in ['name', 'value', 'units']:
if k in b.keys():
_s += ' %s' % k + ': '
if b[k] is not None:
_s += '%s,' % str(b[k])
else:
_s += 'None,'
_s += ' '
if _s != '':
self.logger.info(_s)
else:
self.logger.debug(
'unable to dig list, instance not a dict by %s; object was %s', type(b), b)
if original_b != b:
self.dig_list(b)
@safe_run
def _decode_res_json(self, res):
try:
if hasattr(res, 'content'):
# _js = json.loads(res.content)
# fixed issue with python 3.5
_js = res.json()
res = ast.literal_eval(str(_js).replace('null', 'None'))
else:
res = ast.literal_eval(str(res).replace('null', 'None'))
self.dig_list(res)
return res
except Exception as e:
msg = 'remote/connection error, server response is not valid \n'
msg += f'exception: {e}'
msg += 'possible causes: \n'
msg += '- connection error\n'
msg += '- wrong credentials\n'
msg += '- wrong remote address\n'
msg += '- error on the remote server\n'
msg += "--------------------------------------------------------------\n"
if hasattr(res, 'status_code'):
msg += '--- status code:-> %s\n' % res.status_code
if hasattr(res, 'text'):
msg += '--- response text ---\n %s\n' % res.text
if hasattr(res, 'content'):
msg += '--- res content ---\n %s\n' % res.content
msg += "--------------------------------------------------------------"
raise RemoteException(message=msg)
def get_token_from_environment(self):
token = oda_api.token.discover_token(self.token_discovery_methods)
if token is not None:
logger.info("discovered token in environment")
return token
[docs]
@safe_run
def get_instrument_description(self, instrument=None):
if instrument is None:
instrument = self.instrument
res = requests.get("%s/api/meta-data" % self.url,
params=dict(instrument=instrument, token=self.token), cookies=self.cookies)
if res.status_code != 200:
raise UnexpectedDispatcherStatusCode(
f"status: {res.status_code}, raw: {res.text}")
return self._decode_res_json(res)
[docs]
@safe_run
def get_product_description(self, instrument, product_name):
res = requests.get("%s/api/meta-data" % self.url, params=dict(
instrument=instrument, product_type=product_name, token=self.token), cookies=self.cookies)
if res.status_code != 200:
raise UnexpectedDispatcherStatusCode(
f"status: {res.status_code}, raw: {res.text}")
self.logger.info('--------------')
self.logger.info(
'parameters for product %s and instrument %s', product_name, instrument)
return self._decode_res_json(res)
[docs]
@safe_run
def get_instruments_list(self):
res = requests.get("%s/api/instr-list" % self.url,
params=dict(instrument=self.instrument, token=self.token), cookies=self.cookies)
if res.status_code != 200:
raise UnexpectedDispatcherStatusCode(
f"status: {res.status_code}, raw: {res.text}")
return self._decode_res_json(res)
def report_last_request(self):
self.logger.info(
f"{C.GREY}last request completed in {self.last_request_t_complete - self.last_request_t0} seconds{C.NC}")
[docs]
def get_product(self,
product: str,
instrument: str,
verbose=None,
product_type: str = 'Real',
silent=False,
**kwargs):
"""
submit query, wait (if allowed by self.wait), decode output when found
"""
if not silent:
advice_logger.warning('please beware that by default, in a typical setup, oda_api will not output much. '
'To learn how to increase the verbosity, please refer to the documentation: '
'https://oda-api.readthedocs.io/en/latest/user_guide/ScienceWindowList.html?highlight=logging#Let\'s-get-some-logging . \n'
'To disable this message you can pass `.get_product(..., silent=True)`'
)
self.job_id = None
# TODO: it's confusing when and where these are passed
self.product = product
self.instrument = instrument
kwargs['instrument'] = instrument
kwargs['product_type'] = product
kwargs['query_type'] = product_type
kwargs['off_line'] = False,
kwargs['query_status'] = 'new',
kwargs['verbose'] = verbose,
kwargs['session_id'] = self.session_id
if 'dry_run' in kwargs:
warnings.warn('The dry_run parameter you included is not going to have any effect on the execution.\n'
'However the oda_api will perform a check of the list of valid parameters for your request.')
del kwargs['dry_run']
res = requests.get("%s/api/par-names" % self.url, params=dict(
instrument=instrument, product_type=product), cookies=self.cookies)
if res.status_code != 200:
warnings.warn(
'parameter check not available on remote server, check carefully parameters name')
else:
_ignore_list = ['instrument', 'product_type', 'query_type',
'off_line', 'query_status', 'verbose', 'session_id']
validation_dict = copy.deepcopy(kwargs)
for _i in _ignore_list:
del validation_dict[_i]
valid_names = self._decode_res_json(res)
for n in validation_dict.keys():
if n not in valid_names:
if self.strict_parameter_check:
raise UserError(f'the parameter: {n} is not among the valid ones: {valid_names}'
f'(you can set {self}.strict_parameter_check=False, but beware!')
else:
msg = '\n'
msg += '----------------------------------------------------------------------------\n'
msg += 'the parameter: %s ' % n
msg += ' is not among valid ones:'
msg += '\n'
msg += '%s' % valid_names
msg += '\n'
# msg += 'this will throw an error in a future version \n'
# msg += 'and might break the current request!\n '
msg += '----------------------------------------------------------------------------\n'
warnings.warn(msg)
if kwargs.get('token', None) is None and self.token_discovery_methods is not None:
discovered_token = oda_api.token.discover_token(self.token_discovery_methods)
if discovered_token is not None:
logger.info("discovered token in environment")
kwargs['token'] = discovered_token
# >
self.request(kwargs)
if self.is_failed:
return self.process_failure()
elif self.is_done:
res_json = self.response_json
elif not self.is_complete:
if self.wait:
raise RuntimeError(
"should have waited, but did not - programming error!")
else:
self.logger.info(
f"\n{C.BROWN}query not complete, please poll again later{C.NC}")
return
else:
raise RuntimeError(
"not failed, not, but complete? programming error for client!")
self.show_status_comments(res_json)
d = DataCollection.from_response_json(
res_json, instrument, product)
del (res)
return d
@staticmethod
def set_api_code(query_dict, url="www.astro.unige.ch/mmoda/dispatch-data"):
query_dict = OrderedDict(sorted(query_dict.items()))
_skip_list_ = ['job_id', 'query_status',
'session_id', 'use_resolver[local]', 'use_scws']
_alias_dict = {}
_alias_dict['product_type'] = 'product'
_alias_dict['query_type'] = 'product_type'
_header = f'''from oda_api.api import DispatcherAPI
disp=DispatcherAPI(url='{url}', instrument='mock')'''
_api_dict = {}
for k in query_dict.keys():
if k not in _skip_list_:
if k in _alias_dict.keys():
n = _alias_dict[k]
else:
n = k
if query_dict[k] is not None:
_api_dict[n] = query_dict[k]
python_compatible_par_dict_str = json.dumps(_api_dict, indent=4)
python_compatible_par_dict_str = python_compatible_par_dict_str.replace('false', 'False')
python_compatible_par_dict_str = python_compatible_par_dict_str.replace('true', 'True')
_cmd_ = f'''{_header}
par_dict={python_compatible_par_dict_str}
data_collection = disp.get_product(**par_dict)
'''
return _cmd_
def save_result(self, response_json):
fn = self.unique_response_json_fn
os.makedirs(os.path.dirname(fn), exist_ok=True)
with gzip.open(fn, "wt") as f:
json.dump(response_json, f)
logger.info('saved result in %s', fn)
def load_result(self):
fn = self.unique_response_json_fn
logger.info('trying to load result from %s', fn)
t0 = time.time()
with gzip.open(fn, 'rt') as f:
r = json.load(f)
logger.info('\033[32mmanaged to load result\033[0m from %s in %.2f seconds', fn, time.time() - t0)
return r
@property
def unique_response_json_fn(self):
request_hash = oda_api.misc_helpers.make_hash(self.set_api_code(self.parameters_dict))
return pathlib.Path(os.getenv('ODA_CACHE', pathlib.Path(os.getenv('HOME')) / ".cache/oda-api")) / f"cache/oda_api_data_collection_{request_hash}.json.gz"
def __repr__(self):
return f"[ {self.__class__.__name__}: {self.url} ]"
class DataCollection(object):
def __init__(self, data_list, add_meta_to_name=['src_name', 'product'], instrument=None, product=None, request_job_id=None):
self._p_list = []
self._n_list = []
self.request_job_id = request_job_id
for ID, data in enumerate(data_list):
name = ''
if hasattr(data, 'name'):
name = data.name
if name is None or name.strip() == '':
if product is not None:
name = '%s' % product
elif instrument is not None:
name = '%s' % instrument
else:
name = 'prod'
name = '%s_%d' % (name, ID)
name, var_name = self._build_prod_name(
data, name, add_meta_to_name)
setattr(self, var_name, data)
self._p_list.append(data)
self._n_list.append(var_name)
def show(self):
for ID, prod_name in enumerate(self._n_list):
if hasattr(self._p_list[ID], 'meta_data'):
meta_data = self._p_list[ID].meta_data
else:
meta_data = ''
print('ID=%s prod_name=%s' %
(ID, prod_name), ' meta_data:', meta_data)
print()
def as_list(self):
L = []
for ID, prod_name in enumerate(self._n_list):
if hasattr(self._p_list[ID], 'meta_data'):
meta_data = self._p_list[ID].meta_data
else:
meta_data = ''
L.append({
'ID': ID, 'prod_name': prod_name, 'meta_data:': meta_data
})
return L
def _build_prod_name(self, prod, name, add_meta_to_name):
for kw in add_meta_to_name:
if hasattr(prod, 'meta_data'):
if kw in prod.meta_data:
s = prod.meta_data[kw].replace(' ', '')
if s.strip() != '':
name += '_' + s.strip()
return name, oda_api.misc_helpers.clean_var_name(name)
def save_all_data(self, prenpend_name=None):
for pname, prod in zip(self._n_list, self._p_list):
if prenpend_name is not None:
file_name = prenpend_name + '_' + pname
else:
file_name = pname
file_name = file_name + '.fits'
prod.write_fits_file(file_name)
def save(self, file_name):
pickle.dump(self, open(file_name, 'wb'),
protocol=pickle.HIGHEST_PROTOCOL)
def new_from_metadata(self, key, val):
dc = None
_l = []
for p in self._p_list:
if p.meta_data[key] == val:
_l.append(p)
if _l != []:
dc = DataCollection(_l)
return dc
@classmethod
def from_response_json(cls, res_json, instrument, product):
data = []
if 'numpy_data_product' in res_json['products'].keys():
data.append(NumpyDataProduct.decode(
res_json['products']['numpy_data_product']))
elif 'numpy_data_product_list' in res_json['products'].keys():
data.extend([NumpyDataProduct.decode(d)
for d in res_json['products']['numpy_data_product_list']])
if 'binary_data_product_list' in res_json['products'].keys():
try:
data.extend([BinaryProduct.decode(d)
for d in res_json['products']['binary_data_product_list']])
except:
data.extend([BinaryData().decode(d)
for d in res_json['products']['binary_data_product_list']])
if 'catalog' in res_json['products'].keys():
data.append(ApiCatalog(
res_json['products']['catalog'], name='dispatcher_catalog'))
if 'astropy_table_product_ascii_list' in res_json['products'].keys():
data.extend([ODAAstropyTable.decode(table_text, use_binary=False)
for table_text in res_json['products']['astropy_table_product_ascii_list']])
if 'astropy_table_product_binary_list' in res_json['products'].keys():
data.extend([ODAAstropyTable.decode(table_binary, use_binary=True)
for table_binary in res_json['products']['astropy_table_product_binary_list']])
if 'binary_image_product_list' in res_json['products'].keys():
data.extend([PictureProduct.decode(bin_image_data)
for bin_image_data in res_json['products']['binary_image_product_list']])
if 'text_product_list' in res_json['products'].keys():
try:
data.extend([TextLikeProduct.decode(text_data)
for text_data in res_json['products']['text_product_list']])
except (JSONDecodeError, KeyError):
data.extend([text_data
for text_data in res_json['products']['text_product_list']])
if 'gw_strain_product_list' in res_json['products'].keys():
data.extend([TimeSeries(strain_data['value'],
name=strain_data['name'],
t0=strain_data['t0'],
dt=strain_data['dt'])
for strain_data in res_json['products']['gw_strain_product_list']])
if 'gw_spectrogram_product' in res_json['products'].keys():
sgram = res_json['products']['gw_spectrogram_product']
data.append(Spectrogram(sgram['value'],
name='Spectrogram',
unit='s',
t0=sgram['x0'],
dt=sgram['dx'],
frequencies=sgram['yindex']
)
)
if 'gw_skymap_product' in res_json['products'].keys():
skmap = res_json['products']['gw_skymap_product']
for event in skmap['skymaps'].keys():
data.append(NumpyDataProduct.decode(skmap['skymaps'][event]))
if 'contours' in skmap.keys():
data.append(GWContoursDataProduct(skmap['contours']))
if 'job_id' not in res_json['job_monitor']:
# TODO use the incident-report endpoint from the dispatcher (https://github.com/oda-hub/dispatcher-app/issues/393)
logger.warning(f"job_monitor response json does not contain job_id: {res_json['job_monitor']}")
request_job_id = res_json['job_monitor'].get('job_id', None)
d = cls(data, instrument=instrument, product=product, request_job_id=request_job_id)
for p in d._p_list:
if hasattr(p, 'meta_data') is False and hasattr(p, 'meta') is True:
p.meta_data = p.meta
return d
class ProgressReporter(object):
"""
The class allows to report task progress to end user
"""
def __init__(self):
self._callback = None
callback_file = ".oda_api_callback" # perhaps it would be better to define this constant in a common lib
if not os.path.isfile(callback_file):
return
with open(callback_file, 'r') as file:
self._callback = file.read().strip()
@property
def enabled(self):
return self._callback is not None
def report_progress(self, stage: str=None, progress: float=50., progress_max: float=100., substage: str=None,
subprogress: float=None, subprogress_max: float=100., message:str=None):
"""
Report progress via callback URL
:param stage: current stage description string
:param progress: current stage progress
:param progress_max: maximal progress value
:param substage: current substage description string
:param subprogress: current substage progress
:param subprogress_max: maximal substage progress value
:param message: message to pass
"""
callback_payload = dict(stage=stage, progress=progress, progress_max=progress_max, substage=substage,
subprogress=subprogress, subprogress_max=subprogress_max, message=message)
callback_payload = {k: v for k, v in callback_payload.items() if v is not None}
callback_payload['action'] = 'progress'
if not self.enabled:
logger.info('no callback registered, skipping')
return
logger.info('will perform callback: %s', self._callback)
if re.match('^file://', self._callback):
with open(self._callback.replace('file://', ''), "w") as f:
json.dump(callback_payload, f)
logger.info('stored callback in a file %s', self._callback)
elif re.match('^https?://', self._callback):
r = requests.get(self._callback, params=callback_payload)
logger.info('callback %s returns %s : %s', self._callback, r, r.text)
else:
raise NotImplementedError