""" BinaryTree inter-engine communication class use from bintree_script.py Provides parallel [all]reduce functionality """ from __future__ import print_function import cPickle as pickle import re import socket import uuid import zmq from IPython.parallel.util import disambiguate_url #---------------------------------------------------------------------------- # bintree-related construction/printing helpers #---------------------------------------------------------------------------- def bintree(ids, parent=None): """construct {child:parent} dict representation of a binary tree keys are the nodes in the tree, and values are the parent of each node. The root node has parent `parent`, default: None. >>> tree = bintree(range(7)) >>> tree {0: None, 1: 0, 2: 1, 3: 1, 4: 0, 5: 4, 6: 4} >>> print_bintree(tree) 0 1 2 3 4 5 6 """ parents = {} n = len(ids) if n == 0: return parents root = ids[0] parents[root] = parent if len(ids) == 1: return parents else: ids = ids[1:] n = len(ids) left = bintree(ids[:n/2], parent=root) right = bintree(ids[n/2:], parent=root) parents.update(left) parents.update(right) return parents def reverse_bintree(parents): """construct {parent:[children]} dict from {child:parent} keys are the nodes in the tree, and values are the lists of children of that node in the tree. reverse_tree[None] is the root node >>> tree = bintree(range(7)) >>> reverse_bintree(tree) {None: 0, 0: [1, 4], 4: [5, 6], 1: [2, 3]} """ children = {} for child,parent in parents.iteritems(): if parent is None: children[None] = child continue elif parent not in children: children[parent] = [] children[parent].append(child) return children def depth(n, tree): """get depth of an element in the tree""" d = 0 parent = tree[n] while parent is not None: d += 1 parent = tree[parent] return d def print_bintree(tree, indent=' '): """print a binary tree""" for n in sorted(tree.keys()): print("%s%s" % (indent * depth(n,tree), n)) #---------------------------------------------------------------------------- # Communicator class for a binary-tree map #---------------------------------------------------------------------------- ip_pat = re.compile(r'^\d+\.\d+\.\d+\.\d+$') def disambiguate_dns_url(url, location): """accept either IP address or dns name, and return IP""" if not ip_pat.match(location): location = socket.gethostbyname(location) return disambiguate_url(url, location) class BinaryTreeCommunicator(object): id = None pub = None sub = None downstream = None upstream = None pub_url = None tree_url = None def __init__(self, id, interface='tcp://*', root=False): self.id = id self.root = root # create context and sockets self._ctx = zmq.Context() if root: self.pub = self._ctx.socket(zmq.PUB) else: self.sub = self._ctx.socket(zmq.SUB) self.sub.setsockopt(zmq.SUBSCRIBE, b'') self.downstream = self._ctx.socket(zmq.PULL) self.upstream = self._ctx.socket(zmq.PUSH) # bind to ports interface_f = interface + ":%i" if self.root: pub_port = self.pub.bind_to_random_port(interface) self.pub_url = interface_f % pub_port tree_port = self.downstream.bind_to_random_port(interface) self.tree_url = interface_f % tree_port self.downstream_poller = zmq.Poller() self.downstream_poller.register(self.downstream, zmq.POLLIN) # guess first public IP from socket self.location = socket.gethostbyname_ex(socket.gethostname())[-1][0] def __del__(self): self.downstream.close() self.upstream.close() if self.root: self.pub.close() else: self.sub.close() self._ctx.term() @property def info(self): """return the connection info for this object's sockets.""" return (self.tree_url, self.location) def connect(self, peers, btree, pub_url, root_id=0): """connect to peers. `peers` will be a dict of 4-tuples, keyed by name. {peer : (ident, addr, pub_addr, location)} where peer is the name, ident is the XREP identity, addr,pub_addr are the """ # count the number of children we have self.nchildren = btree.values().count(self.id) if self.root: return # root only binds root_location = peers[root_id][-1] self.sub.connect(disambiguate_dns_url(pub_url, root_location)) parent = btree[self.id] tree_url, location = peers[parent] self.upstream.connect(disambiguate_dns_url(tree_url, location)) def serialize(self, obj): """serialize objects. Must return list of sendable buffers. Can be extended for more efficient/noncopying serialization of numpy arrays, etc. """ return [pickle.dumps(obj)] def unserialize(self, msg): """inverse of serialize""" return pickle.loads(msg[0]) def publish(self, value): assert self.root self.pub.send_multipart(self.serialize(value)) def consume(self): assert not self.root return self.unserialize(self.sub.recv_multipart()) def send_upstream(self, value, flags=0): assert not self.root self.upstream.send_multipart(self.serialize(value), flags=flags|zmq.NOBLOCK) def recv_downstream(self, flags=0, timeout=2000.): # wait for a message, so we won't block if there was a bug self.downstream_poller.poll(timeout) msg = self.downstream.recv_multipart(zmq.NOBLOCK|flags) return self.unserialize(msg) def reduce(self, f, value, flat=True, all=False): """parallel reduce on binary tree if flat: value is an entry in the sequence else: value is a list of entries in the sequence if all: broadcast final result to all nodes else: only root gets final result """ if not flat: value = reduce(f, value) for i in range(self.nchildren): value = f(value, self.recv_downstream()) if not self.root: self.send_upstream(value) if all: if self.root: self.publish(value) else: value = self.consume() return value def allreduce(self, f, value, flat=True): """parallel reduce followed by broadcast of the result""" return self.reduce(f, value, flat=flat, all=True)