in src/inf_utils.py [0:0]
def get_model_prediction(bucket, data_loc, inf_endpt, coords=None):
""""""
local_download = "total.csv"
s3.download_file(bucket, data_loc, local_download)
df_bands = pd.read_csv(local_download)
true_labels = df_bands.label
df_bands = df_bands.drop(["label"], axis=1)
if coords is not None:
df_coord = df_bands[coords].copy()
df_bands = df_bands.drop(coords, axis=1)
df_bands.to_csv(local_download, header=None, index=False)
pred_labels = []
with open(local_download, 'r') as f:
for i, row in enumerate(f):
payload = row.rstrip('\n')
x = sm_runtime.invoke_endpoint(EndpointName=inf_endpt,
ContentType="text/csv",
Body=payload)
pred_labels.append(int(x['Body'].read().decode().strip()))
if coords is None:
return true_labels, pred_labels
else:
return true_labels, pred_labels, df_coord