in examples/baseball.py [0:0]
def main(args):
baseball_dataset = pd.read_csv(DATA_URL, "\t")
train, _, player_names = train_test_split(baseball_dataset)
at_bats, hits = train[:, 0], train[:, 1]
logging.info("Original Dataset:")
logging.info(baseball_dataset)
# (1) Full Pooling Model
# In this model, we illustrate how to use MCMC with general potential_fn.
init_params, potential_fn, transforms, _ = initialize_model(
fully_pooled, model_args=(at_bats, hits), num_chains=args.num_chains,
jit_compile=args.jit, skip_jit_warnings=True)
nuts_kernel = NUTS(potential_fn=potential_fn)
mcmc = MCMC(nuts_kernel,
num_samples=args.num_samples,
warmup_steps=args.warmup_steps,
num_chains=args.num_chains,
initial_params=init_params,
transforms=transforms)
mcmc.run(at_bats, hits)
samples_fully_pooled = mcmc.get_samples()
logging.info("\nModel: Fully Pooled")
logging.info("===================")
logging.info("\nphi:")
logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
sites=["phi"],
player_names=player_names,
diagnostics=True,
group_by_chain=True)["phi"])
num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
sample_posterior_predictive(fully_pooled, samples_fully_pooled, baseball_dataset)
evaluate_pointwise_pred_density(fully_pooled, samples_fully_pooled, baseball_dataset)
# (2) No Pooling Model
nuts_kernel = NUTS(not_pooled, jit_compile=args.jit, ignore_jit_warnings=True)
mcmc = MCMC(nuts_kernel,
num_samples=args.num_samples,
warmup_steps=args.warmup_steps,
num_chains=args.num_chains)
mcmc.run(at_bats, hits)
samples_not_pooled = mcmc.get_samples()
logging.info("\nModel: Not Pooled")
logging.info("=================")
logging.info("\nphi:")
logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
sites=["phi"],
player_names=player_names,
diagnostics=True,
group_by_chain=True)["phi"])
num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
sample_posterior_predictive(not_pooled, samples_not_pooled, baseball_dataset)
evaluate_pointwise_pred_density(not_pooled, samples_not_pooled, baseball_dataset)
# (3) Partially Pooled Model
nuts_kernel = NUTS(partially_pooled, jit_compile=args.jit, ignore_jit_warnings=True)
mcmc = MCMC(nuts_kernel,
num_samples=args.num_samples,
warmup_steps=args.warmup_steps,
num_chains=args.num_chains)
mcmc.run(at_bats, hits)
samples_partially_pooled = mcmc.get_samples()
logging.info("\nModel: Partially Pooled")
logging.info("=======================")
logging.info("\nphi:")
logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
sites=["phi"],
player_names=player_names,
diagnostics=True,
group_by_chain=True)["phi"])
num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
sample_posterior_predictive(partially_pooled, samples_partially_pooled, baseball_dataset)
evaluate_pointwise_pred_density(partially_pooled, samples_partially_pooled, baseball_dataset)
# (4) Partially Pooled with Logit Model
nuts_kernel = NUTS(partially_pooled_with_logit, jit_compile=args.jit, ignore_jit_warnings=True)
mcmc = MCMC(nuts_kernel,
num_samples=args.num_samples,
warmup_steps=args.warmup_steps,
num_chains=args.num_chains)
mcmc.run(at_bats, hits)
samples_partially_pooled_logit = mcmc.get_samples()
logging.info("\nModel: Partially Pooled with Logit")
logging.info("==================================")
logging.info("\nSigmoid(alpha):")
logging.info(get_summary_table(mcmc.get_samples(group_by_chain=True),
sites=["alpha"],
player_names=player_names,
transforms={"alpha": torch.sigmoid},
diagnostics=True,
group_by_chain=True)["alpha"])
num_divergences = sum(map(len, mcmc.diagnostics()["divergences"].values()))
logging.info("\nNumber of divergent transitions: {}\n".format(num_divergences))
sample_posterior_predictive(partially_pooled_with_logit, samples_partially_pooled_logit,
baseball_dataset)
evaluate_pointwise_pred_density(partially_pooled_with_logit, samples_partially_pooled_logit,
baseball_dataset)