"""Module for funneling messages from one comm to another."""
import os
import numpy as np
import importlib
import threading
from cis_interface import backwards
from cis_interface.communication import new_comm
from cis_interface.drivers.Driver import Driver
[docs]class ConnectionDriver(Driver):
r"""Class that continuously passes messages from one comm to another.
Args:
name (str): Name that should be used to set names of input/output comms.
icomm_kws (dict, optional): Keyword arguments for the input communicator.
ocomm_kws (dict, optional): Keyword arguments for the output communicator.
translator (str, func, optional): Function or string specifying function
that should be used to translate messages from the input communicator
before passing them to the output communicator. If a string, the
format should be "<package.module>:<function>" so that <function>
can be imported from <package>. Defaults to None and messages are
passed directly.
timeout_send_1st (float, optional): Time in seconds that should be
waited before giving up on the first send. Defaults to self.timeout.
single_use (bool, optional): If True, the driver will be stopped after
one loop. Defaults to False.
**kwargs: Additonal keyword arguments are passed to the parent class.
Attributes:
icomm_kws (dict): Keyword arguments for the input communicator.
ocomm_kws (dict): Keyword arguments for the output communicator.
icomm (CommBase): Input communicator.
ocomm (CommBase): Output communicator.
nrecv (int): Number of messages received.
nproc (int): Number of messages processed.
nsent (int): Number of messages sent.
nskip (int): Number of messages skipped.
state (str): Descriptor of last action taken.
translator (func): Function that will be used to translate messages from
the input communicator before passing them to the output communicator.
timeout_send_1st (float): Time in seconds that should be waited before
giving up on the first send.
single_use (bool): If True, the driver will be stopped after one
loop.
"""
def __init__(self, name, icomm_kws=None, ocomm_kws=None,
translator=None, timeout_send_1st=None, single_use=False,
**kwargs):
super(ConnectionDriver, self).__init__(name, **kwargs)
if icomm_kws is None:
icomm_kws = dict()
if ocomm_kws is None:
ocomm_kws = dict()
# Translator
if isinstance(translator, str):
pkg_mod = translator.split(':')
if len(pkg_mod) == 2:
mod, fun = pkg_mod[:]
else:
raise ValueError("Could not parse translator string: %s" % translator)
modobj = importlib.import_module(mod)
if not hasattr(modobj, fun):
raise AttributeError("Module %s has no funciton %s" % (
modobj, fun))
translator = getattr(modobj, fun)
if (translator is not None) and (not hasattr(translator, '__call__')):
raise ValueError("Translator %s not callable." % translator)
# Input communicator
self.debug("Creating input comm")
icomm_kws['direction'] = 'recv'
icomm_kws['dont_open'] = True
icomm_kws['reverse_names'] = True
icomm_kws['close_on_eof_recv'] = False
icomm_name = icomm_kws.pop('name', name)
self.icomm = new_comm(icomm_name, **icomm_kws)
self.icomm_kws = icomm_kws
self.env[self.icomm.name] = self.icomm.address
# Output communicator
self.debug("Creating output comm")
ocomm_kws['direction'] = 'send'
ocomm_kws['dont_open'] = True
ocomm_kws['reverse_names'] = True
# ocomm_kws['close_on_eof_send'] = False
ocomm_name = ocomm_kws.pop('name', name)
try:
self.ocomm = new_comm(ocomm_name, **ocomm_kws)
except BaseException:
self.icomm.close()
raise
self.ocomm_kws = ocomm_kws
self.env[self.ocomm.name] = self.ocomm.address
# Attributes
self._is_input = ('Input' in str(self.__class__))
self._is_output = ('Output' in str(self.__class__))
self._eof_sent = False
if timeout_send_1st is None:
timeout_send_1st = self.timeout
self.translator = translator
self.single_use = single_use
self.timeout_send_1st = timeout_send_1st
self._first_send_done = False
self._comm_opened = threading.Event()
self._comm_closed = False
self._used = False
self._skip_after_loop = False
self._model_exited = False
self.nrecv = 0
self.nproc = 0
self.nsent = 0
self.nskip = 0
self.state = 'started'
self.close_state = ''
self.debug('')
self.debug(80 * '=')
self.debug('class = %s', self.__class__)
# self.debug(' env: %s', str(self.env))
self.debug(' input: name = %s, address = %s',
self.icomm.name, self.icomm.address)
self.debug(' output: name = %s, address = %s',
self.ocomm.name, self.ocomm.address)
self.debug(80 * '=')
[docs] def wait_for_route(self, timeout=None):
r"""Wait until messages have been routed."""
T = self.start_timeout(timeout)
while ((not T.is_out) and
(self.nrecv != (self.nsent + self.nskip))): # pragma: debug
self.sleep()
self.stop_timeout()
return (self.nrecv == (self.nsent + self.nskip))
@property
def is_valid(self):
r"""bool: Returns True if the connection is open and the parent class
is valid."""
with self.lock:
return (super(ConnectionDriver, self).is_valid and
self.is_comm_open and not (self.single_use and self._used))
@property
def is_comm_open(self):
r"""bool: Returns True if both communicators are open."""
with self.lock:
return (self.icomm.is_open and self.ocomm.is_open and
not self._comm_closed)
@property
def is_comm_closed(self):
r"""bool: Returns True if both communicators are closed."""
with self.lock:
return self.icomm.is_closed and self.ocomm.is_closed
@property
def n_msg(self):
r"""int: Number of messages waiting in input communicator."""
with self.lock:
return self.icomm.n_msg_recv
[docs] def open_comm(self):
r"""Open the communicators."""
self.debug('')
with self.lock:
if self._comm_closed:
self.debug('Aborted as comm closed')
return
try:
self.icomm.open()
self.ocomm.open()
except BaseException:
self.close_comm()
raise
self._comm_opened.set()
self.debug('Returning')
[docs] def close_comm(self):
r"""Close the communicators."""
self.debug('')
with self.lock:
self._comm_closed = True
self._skip_after_loop = True
# Capture errors for both comms
ie = None
oe = None
try:
if getattr(self, 'icomm', None) is not None:
self.icomm.close()
except BaseException as e:
ie = e
try:
if getattr(self, 'ocomm', None) is not None:
self.ocomm.close()
except BaseException as e:
oe = e
if ie:
raise ie
if oe:
raise oe
self.debug('Returning')
[docs] def start(self):
r"""Open connection before running."""
self.open_comm()
Tout = self.start_timeout()
while (not self.is_comm_open) and (not Tout.is_out):
self.sleep()
self.stop_timeout()
if not self.is_comm_open:
raise Exception("Connection never finished opening.")
super(ConnectionDriver, self).start()
[docs] def graceful_stop(self, timeout=None, **kwargs):
r"""Stop the driver, first waiting for the input comm to be empty.
Args:
timeout (float, optional): Max time that should be waited. Defaults
to None and is set to attribute timeout.
**kwargs: Additional keyword arguments are passed to the parent
class's graceful_stop method.
"""
self.debug('')
with self.lock:
self.set_close_state('stop')
self._skip_after_loop = True
self.drain_input(timeout=timeout)
self.wait_for_route(timeout=timeout)
self.drain_output(timeout=timeout)
super(ConnectionDriver, self).graceful_stop()
self.debug('Returning')
[docs] def on_model_exit(self):
r"""Drain input and then close it."""
# self.info("%s: on_model_exit", self.name)
self.debug('')
self.set_close_state('model exit')
if self._is_input:
self.drain_input(timeout=self.timeout)
with self.lock:
self.icomm.close()
self.ocomm.close()
if self._is_output:
self.drain_input(timeout=self.timeout)
with self.lock:
self.icomm.close()
super(ConnectionDriver, self).on_model_exit()
[docs] def do_terminate(self):
r"""Stop the driver by closing the communicators."""
# self.info('%s: do_terminate', self.name)
self.debug('')
self.set_close_state('terminate')
self.close_comm()
super(ConnectionDriver, self).do_terminate()
[docs] def cleanup(self):
r"""Ensure that the communicators are closed."""
self.debug('')
self.close_comm()
super(ConnectionDriver, self).cleanup()
[docs] def printStatus(self, beg_msg='', end_msg=''):
r"""Print information on the status of the ConnectionDriver.
Arguments:
beg_msg (str, optional): Additional message to print at beginning.
end_msg (str, optional): Additional message to print at end.
"""
msg = beg_msg
msg += '%-50s' % (self.__module__.split('.')[-1] + '(' + self.name + '): ')
msg += '%-30s' % ('last action: ' + self.state)
msg += '%-15s' % (str(self.nrecv) + ' received, ')
msg += '%-15s' % (str(self.nproc) + ' processed, ')
msg += '%-15s' % (str(self.nskip) + ' skipped, ')
msg += '%-15s' % (str(self.nsent) + ' sent, ')
msg += '%-20s' % (str(self.icomm.n_msg) + ' ready to recv')
msg += '%-20s' % (str(self.ocomm.n_msg) + ' ready to send')
with self.lock:
if self.close_state:
msg += '%-30s' % ('close state: ' + self.close_state)
msg += end_msg
print(msg)
[docs] def drain_output(self, timeout=None):
r"""Drain messages from output comm."""
T = self.start_timeout(timeout)
while not T.is_out:
with self.lock:
if (not self.ocomm.is_open) or (self.ocomm.n_msg_send_drain == 0):
break
self.sleep()
self.stop_timeout()
[docs] def before_loop(self):
r"""Actions to perform prior to sending messages."""
self.state = 'before loop'
try:
self.open_comm()
self.sleep() # Help ensure senders/receivers connected before messages
self.debug('Running in %s, is_valid = %s', os.getcwd(), str(self.is_valid))
except BaseException: # pragma: debug
self.exception('Could not prep for loop.')
self.close_comm()
self.set_break_flag()
[docs] def after_loop(self):
r"""Actions to perform after sending messages."""
self.state = 'after loop'
self.debug('')
# Close input comm in case loop did not
self.drain_input(timeout=False)
with self.lock:
if self._skip_after_loop:
self.debug("After loop skipped.")
return
self.icomm.close()
# Send EOF in case the model didn't
if not self.single_use:
self.send_eof()
# Close output comm after waiting for output to be processed
# self.drain_output(timeout=False)
# self.ocomm.close()
[docs] def recv_message(self, **kwargs):
r"""Get a new message to send.
Args:
**kwargs: Additional keyword arguments are passed to the appropriate
recv method.
Returns:
str, bool: False if no more messages, message otherwise.
"""
kwargs.setdefault('timeout', 0)
with self.lock:
if self.icomm.is_closed:
return False
flag, msg = self.icomm.recv(**kwargs)
if isinstance(msg, backwards.bytes_type) and (msg == self.icomm.eof_msg):
return self.on_eof()
if flag:
return msg
else:
return flag
[docs] def on_eof(self):
r"""Actions to take when EOF received.
Returns:
str, bool: Value that should be returned by recv_message on EOF.
"""
self.debug('EOF received')
self.state = 'eof'
with self.lock:
self.send_eof()
self.drain_input(timeout=False)
with self.lock:
self.set_close_state('eof')
self.icomm.close()
return False
[docs] def on_message(self, msg):
r"""Process a message.
Args:
msg (bytes, str): Message to be processed.
Returns:
bytes, str: Processed message.
"""
if (self.ocomm._send_serializer):
self.update_serializer(msg)
if self.translator is None:
return msg
else:
return self.translator(msg)
[docs] def update_serializer(self, msg):
r"""Update the serializer for the output comm based on input."""
if self.ocomm.serializer.serializer_type == 0:
old_kwargs = self.ocomm.serializer.serializer_info
del old_kwargs['stype']
self.ocomm.serializer = self.icomm.serializer
self.ocomm.serializer.update_serializer(**old_kwargs)
def _send_message(self, *args, **kwargs):
r"""Send a single message.
Args:
*args: Arguments are passed to the output comm send method.
*kwargs: Keyword arguments are passed to the output comm send method.
Returns:
bool: Success or failure of send.
"""
with self.lock:
if self.ocomm.is_closed:
return False
flag = self.ocomm.send(*args, **kwargs)
return flag
def _send_1st_message(self, *args, **kwargs):
r"""Send the first message, trying multiple times.
Args:
*args: Arguments are passed to the output comm send method.
*kwargs: Keyword arguments are passed to the output comm send method.
Returns:
bool: Success or failure of send.
"""
self.ocomm._multiple_first_send = False
T = self.start_timeout(self.timeout_send_1st)
flag = self._send_message(*args, **kwargs)
self.ocomm.suppress_special_debug = True
if not flag:
self.debug("1st send failed, will keep trying for %f s in silence.",
float(self.timeout_send_1st))
while ((not T.is_out) and (not flag) and
self.ocomm.is_open): # pragma: debug
flag = self._send_message(*args, **kwargs)
if not flag:
self.sleep()
self.stop_timeout()
self.ocomm.suppress_special_debug = False
self._first_send_done = True
if not flag:
self.error("1st send failed.")
else:
self.debug("1st send succeded")
return flag
[docs] def send_eof(self):
r"""Send EOF message.
Returns:
bool: Success or failure of send.
"""
with self.lock:
if self._eof_sent:
return False
self._eof_sent = True
self.debug('')
return self.send_message(self.ocomm.eof_msg)
[docs] def send_message(self, *args, **kwargs):
r"""Send a single message.
Args:
*args: Arguments are passed to the output comm send method.
*kwargs: Keyword arguments are passed to the output comm send method.
Returns:
bool: Success or failure of send.
"""
self.debug('')
with self.lock:
self._used = True
if self._first_send_done:
flag = self._send_message(*args, **kwargs)
else:
flag = self._send_1st_message(*args, **kwargs)
# if self.single_use:
# with self.lock:
# self.info('used')
# self.icomm.drain_messages()
# self.icomm.close()
return flag
[docs] def set_close_state(self, state):
r"""Set the close state if its not already set."""
with self.lock:
if not self.close_state:
self.close_state = state
[docs] def run_loop(self):
r"""Run the driver. Continue looping over messages until there are not
any left or the communication channel is closed.
"""
self.state = 'in loop'
if not self.is_valid:
self.debug("Breaking loop")
self.set_close_state('invalid')
# self.info("%s: breaking loop, input=%s, output=%s", self.name,
# self.icomm.is_open, self.ocomm.is_open)
self.set_break_flag()
return
# Receive a message
self.state = 'receiving'
msg = self.recv_message()
if msg is False:
self.debug('No more messages')
self.set_break_flag()
self.set_close_state('receiving')
return
if ((isinstance(msg, type(self.icomm.serializer.empty_msg)) and
(msg == self.icomm.serializer.empty_msg))):
self.state = 'waiting'
self.verbose_debug(':run: Waiting for next message.')
self.sleep()
return
self.nrecv += 1
self.state = 'received'
if isinstance(msg, backwards.bytes_type):
self.debug('Received message that is %d bytes from %s.',
len(msg), self.icomm.address)
elif isinstance(msg, np.ndarray):
self.debug('Received array with shape %s and data type %s from %s',
msg.shape, msg.dtype, self.icomm.address)
else:
self.debug('Received message of type %s from %s',
type(msg), self.icomm.address)
# Process message
self.state = 'processing'
msg = self.on_message(msg)
if msg is False: # pragma: debug
self.error('Could not process message.')
self.set_break_flag()
self.set_close_state('processing')
return
elif ((isinstance(msg, type(self.ocomm.serializer.empty_msg)) and
(msg == self.ocomm.serializer.empty_msg))):
self.debug('Message skipped.')
self.nskip += 1
return
self.nproc += 1
self.state = 'processed'
self.debug('Processed message.')
# Send a message
self.state = 'sending'
ret = self.send_message(msg)
if ret is False:
self.error('Could not send message.')
self.set_break_flag()
self.set_close_state('sending')
return
self.nsent += 1
self.state = 'sent'
self.debug('Sent message to %s.', self.ocomm.address)