in reagent/ope/estimators/sequential_estimators.py [0:0]
def evaluate(self, input: RLEstimatorInput, **kwargs) -> EstimatorResults:
assert input.value_function is not None
logging.info(f"{self}: start evaluating")
stime = time.process_time()
results = EstimatorResults()
num_resamples = kwargs["num_resamples"] if "num_resamples" in kwargs else 200
loss_threshold = (
kwargs["loss_threshold"] if "loss_threshold" in kwargs else 0.00001
)
lr = kwargs["lr"] if "lr" in kwargs else 0.0001
logging.info(
f" params: num_resamples[{num_resamples}], "
f"loss_threshold[{loss_threshold}], "
f"lr[{lr}]"
)
# Compute MAGIC estimate
n = len(input.log)
horizon = len(reduce(lambda a, b: a if len(a) > len(b) else b, input.log))
ws = self._calc_weights(
n, horizon, zip_longest(*input.log), input.target_policy
)
last_ws = torch.zeros((n, horizon), device=self._device)
last_ws[:, 0] = 1.0 / n
last_ws[:, 1:] = ws[:, :-1]
discount = torch.full((horizon,), input.gamma, device=self._device)
discount[0] = 1.0
discount = discount.cumprod(0)
rs = torch.zeros((n, horizon))
vs = torch.zeros((n, horizon))
qs = torch.zeros((n, horizon))
for ts, j in zip(zip_longest(*input.log), count()):
for t, i in zip(ts, count()):
if t is not None and t.action is not None:
qs[i, j] = input.value_function(t.last_state, t.action)
vs[i, j] = input.value_function(t.last_state)
rs[i, j] = t.reward
vs = vs.to(device=self._device)
qs = qs.to(device=self._device)
rs = rs.to(device=self._device)
wdrs = ((ws * (rs - qs) + last_ws * vs) * discount).cumsum(1)
wdr = wdrs[:, -1].sum(0)
next_vs = torch.zeros((n, horizon), device=self._device)
next_vs[:, :-1] = vs[:, 1:]
gs = wdrs + ws * next_vs * discount
gs_normal = gs.sub(torch.mean(gs, 0))
assert n > 1
omiga = (n / (n - 1.0)) * torch.einsum("ij,ik->jk", gs_normal, gs_normal)
resample_wdrs = torch.zeros((num_resamples,))
for i in range(num_resamples):
samples = random.choices(range(n), k=n)
sws = ws[samples, :]
last_sws = last_ws[samples, :]
srs = rs[samples, :]
svs = vs[samples, :]
sqs = qs[samples, :]
resample_wdrs[i] = (
((sws * (srs - sqs) + last_sws * svs).sum(0) * discount).sum().item()
)
resample_wdrs, _ = resample_wdrs.to(device=self._device).sort(0)
lb = torch.min(wdr, resample_wdrs[int(round(0.05 * num_resamples))])
ub = torch.max(wdr, resample_wdrs[int(round(0.95 * num_resamples)) - 1])
b = torch.tensor(
list(
map(
lambda a: a - ub if a > ub else (a - lb if a < lb else 0.0),
gs.sum(0),
)
),
device=self._device,
)
b.unsqueeze_(0)
bb = b * b.t()
cov = omiga + bb
# x = torch.rand((1, horizon), device=self.device, requires_grad=True)
x = torch.zeros((1, horizon), device=self._device, requires_grad=True)
# using SGD to find min x
optimizer = torch.optim.SGD([x], lr=lr)
last_y = 0.0
for i in range(100):
x = torch.nn.functional.softmax(x, dim=1)
y = torch.mm(torch.mm(x, cov), x.t())
if abs(y.item() - last_y) < loss_threshold:
print(f"{i}: {last_y} -> {y.item()}")
break
last_y = y.item()
optimizer.zero_grad()
y.backward(retain_graph=True)
optimizer.step()
x = torch.nn.functional.softmax(x, dim=1)
estimate = torch.mm(x, gs.sum(0, keepdim=True).t()).cpu().item()
results.append(
EstimatorResult(
self._log_reward(input.gamma, input.log),
estimate,
None
if input.ground_truth is None
else self._estimate_value(input.gamma, input.log, input.ground_truth),
)
)
logging.info(
f"{self}: finishing evaluating["
f"process_time={time.process_time() - stime}]"
)
return results