in table_bert/dataset.py [0:0]
def __example_worker_process_zmq(tokenizer, db):
context = zmq.Context()
job_receiver = context.socket(zmq.PULL)
# job_receiver.setsockopt(zmq.LINGER, -1)
job_receiver.connect("tcp://127.0.0.1:5557")
controller = context.socket(zmq.SUB)
controller.connect("tcp://127.0.0.1:5558")
controller.setsockopt(zmq.SUBSCRIBE, b"")
poller = zmq.Poller()
poller.register(job_receiver, zmq.POLLIN)
poller.register(controller, zmq.POLLIN)
cache_client = redis.Redis(host='localhost', port=6379, db=0)
buffer_size = 20000
def _add_to_cache():
if buffer:
with db._cur_index.get_lock():
index_end = db._cur_index.value + len(buffer)
db._cur_index.value = index_end
index_start = index_end - len(buffer)
values = {str(i): val for i, val in zip(range(index_start, index_end), buffer)}
cache_client.mset(values)
del buffer[:]
cnt = 0
buffer = []
can_exit = False
while True:
triggered = False
socks = dict(poller.poll(timeout=2000))
if socks.get(job_receiver) == zmq.POLLIN:
triggered = True
job = job_receiver.recv_string()
if job:
cnt += 1
# print(cnt)
example = Example.from_dict(ujson.loads(job), tokenizer, suffix=None)
if TableDatabase.is_valid_example(example):
data = example.serialize()
buffer.append(msgpack.packb(data, use_bin_type=True))
if len(buffer) >= buffer_size:
_add_to_cache()
# else:
# job_receiver.close()
# _add_to_cache()
# break
if socks.get(controller) == zmq.POLLIN:
triggered = True
print(controller.recv_string())
can_exit = True
# timeout
# print(socks)
if not socks and can_exit:
print('Processor exit...')
break
if socks and not triggered:
print(socks)
_add_to_cache()
job_receiver.close()
controller.close()
context.destroy()