in trl/environment/base_environment.py [0:0]
def run(self, queries, **rewards_kwargs):
"""
Run the environment on a list of queries.
Args:
queries (list[str]): A list of queries to run the model in the environment on.
"""
turns = 0
queries = [self.prompt + task for task in queries]
queries_tokens = [
self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device)
for query in queries
]
histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)]
while any(not history.completed for history in histories) and turns < self.max_turns:
histories = self.generate(histories)
histories = self.tasks_end_check(histories)
# TODO: make this parallel rather than for-loop
for i in range(len(histories)):
histories[i] = self.step(histories[i])
histories = self.tasks_end_check(histories, model_turn=False)
turns += 1
self.compute_reward(histories, **rewards_kwargs)
# convert a list of (q, r, m) tuples to lists of all qs, rs, and ms respectively
queries, responses, masks = map(list, zip(*[history.split_query_response_tokens() for history in histories]))
rewards = [history.reward for history in histories]
return queries, responses, masks, rewards, histories