orbit/utils/knots.py (59 lines of code) (raw):
import numpy as np
import pandas as pd
from ..exceptions import IllegalArgument
def get_knot_dates(start_date, knot_idx, freq):
"""
Parameters
----------
start_date : datetime array
knot_idx : ndarray
1D array containing index with `int` type.
freq : date frequency
Returns
-------
list :
list of knot dates with provided start date time and indices
"""
# knot_dates = knot_idx * freq + start_date
# knot_dates = knot_idx * np.timedelta64(1, freq) + start_date
dates_lst = pd.date_range(start=start_date, periods=max(knot_idx) + 1, freq=freq)
knot_dates = dates_lst[knot_idx]
return knot_dates
def get_dates_delta(start_date, end_date, time_delta):
"""return knot index based on date difference normalized with the number of steps by frequency provided
Parameters
----
start_date : numpy datetime
end_date : numpy datetime array
time_delta : time delta between dates
"""
date_diff = end_date - start_date
# can also be deemed as the "knot_idx"
norm_delta = np.round(date_diff / time_delta).astype(int)
return norm_delta
def get_knot_idx_by_dist(num_of_obs, knot_distance):
"""function to calculate the knot idx based on num_of_obs and knot_distance."""
# starts with the the ending point
# use negative values or simply append 0 to the sequence?
knot_idx = np.sort(np.arange(num_of_obs - 1, -1, -knot_distance))
knot_idx = np.round(knot_idx).astype("int")
if 0 not in knot_idx:
# knot_idx = np.sort(np.arange(num_of_obs - 1, -1 - knot_distance, -knot_distance))
knot_idx = np.sort(np.append(knot_idx, 0))
return knot_idx
def get_knot_idx(
num_of_obs=None,
num_of_segments=None,
knot_distance=None,
date_array=None,
knot_dates=None,
):
"""function to calculate and return the knot locations as indices based on
This function will be used in KTRLite and KTRX model.
There are three ways to get the knot index:
1. With number of observations supplied, calculate the knots location and indices based on
number of segments specified and knot indices will be evenly distributed
2. With number of observations supplied, calculate the knots location and indices based on
knot distance specified such that there will be additional knots in the first and end provided
3. With observations date array and knot dates provided, derive knots location directly based on the
implied observation frequency provided.
Parameters
----------
num_of_obs : int
number of observations to derive segments and knots; will be ignored if knot_dates is not None
num_of_segments : int
number of segments, which will be used to calculate the knot distance
knot_distance : int
distance between every two knots
date_array : datetime array
only used when knot_dates is not None
knot_dates : list or array of numpy datetime
list of dates in string format (%Y-%m-%d) or numpy datetime array which will be used as the knot locations
Returns
-------
an array of integers, which are the knot location indices (starts at 0).
"""
if knot_dates is None and num_of_obs is None:
raise IllegalArgument("Either knot_dates or num_of_obs needs to be provided.")
if knot_dates is not None:
if date_array is None:
raise IllegalArgument(
"When knot_dates are supplied, users need to supply date_array as well."
)
knot_dates = np.array(knot_dates, dtype="datetime64")
# filter out
_knot_dates = pd.to_datetime(
[
x
for x in knot_dates
if (x <= date_array.max()) and (x >= date_array.min())
]
)
time_delta = np.diff(date_array).mean()
knot_idx = get_dates_delta(
start_date=date_array[0], end_date=_knot_dates, time_delta=time_delta
)
elif knot_distance is not None:
if not isinstance(knot_distance, int):
raise Exception("knot_distance must be an int.")
knot_idx = get_knot_idx_by_dist(num_of_obs, knot_distance)
elif num_of_segments is not None:
if num_of_segments >= 1:
knot_distance = (num_of_obs - 1) / num_of_segments
knot_idx = get_knot_idx_by_dist(num_of_obs, knot_distance)
else:
# one single knot at the beginning
knot_idx = np.array([0])
else:
raise Exception(
"please specify at least one of the followings to determine the knot locations: "
"knot_dates, knot_distance, or num_of_segments."
)
# knot_idx starts with 0; need to add 1 when calculate the fraction
return knot_idx