Skip to content

Commit e7a9f1d

Browse files
committed
major update to renewal
1 parent a33c8d2 commit e7a9f1d

File tree

3 files changed

+534
-27
lines changed

3 files changed

+534
-27
lines changed

surpyval/renewal/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .generalized_one_renewal import GeneralizedOneRenewal
2+
from .generalized_renewal import GeneralizedRenewal

surpyval/renewal/generalized_one_renewal.py

+246-4
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,73 @@
1111
DT_WARN = "Small increment encountered, may have trouble reaching T."
1212

1313

14-
class G1Renewal:
14+
class GeneralizedOneRenewal:
15+
"""
16+
A class to handle the generalized renewal process with different Kijima
17+
models.
18+
19+
Since the Generalised One Renewal Process does not have closed form
20+
solutions for the instantaneous intensity function and the cumulative
21+
intensity function these values cannot be calculated directly with this
22+
class. Instead, the model can be used to simulate recurrence data which is
23+
fitted to a ``NonParametricCounting`` model. This model can then be used
24+
to calculate the cumulative intensity function.
25+
26+
Examples
27+
--------
28+
>>> from surpyval import GeneralizedOneRenewal, Weibull
29+
>>> import numpy as np
30+
>>>
31+
>>> x = np.array([1, 2, 3, 4, 4.5, 5, 5.5, 5.7, 6])
32+
>>>
33+
>>> model = GeneralizedOneRenewal.fit(x, dist=Weibull)
34+
>>> model
35+
G1 Renewal SurPyval Model
36+
=========================
37+
Distribution : Weibull
38+
Fitted by : MLE
39+
Restoration Factor : -0.1730184624683848
40+
Parameters :
41+
alpha: 1.3919045967886952
42+
beta: 5.0088611892336115
43+
>>>
44+
>>> np.random.seed(0)
45+
>>> np_model = model.count_terminated_simulation(len(x), 5000)
46+
>>> np_model.mcf(np.array([1, 2, 3, 4, 5, 6]))
47+
array([0.1696 , 1.181 , 2.287 , 3.6694 , 5.58237925,
48+
8.54474531])
49+
"""
50+
1551
def __init__(self, model, q):
1652
self.model = model
1753
self.q = q
1854

55+
def __repr__(self):
56+
out = (
57+
"G1 Renewal SurPyval Model"
58+
+ "\n========================="
59+
+ f"\nDistribution : {self.model.dist.name}"
60+
+ "\nFitted by : MLE"
61+
+ f"\nRestoration Factor : {self.q}"
62+
)
63+
64+
param_string = "\n".join(
65+
[
66+
"{:>10}".format(name) + ": " + str(p)
67+
for p, name in zip(
68+
self.model.params, self.model.dist.param_names
69+
)
70+
]
71+
)
72+
73+
out = (
74+
out
75+
+ "\nParameters :\n"
76+
+ "{params}".format(params=param_string)
77+
)
78+
79+
return out
80+
1981
def initialize_simulation(self):
2082
self.us = uniform.rvs(size=100_000).tolist()
2183

@@ -30,6 +92,23 @@ def get_uniform_random_number(self):
3092
return self.us.pop()
3193

3294
def count_terminated_simulation(self, events, items=1):
95+
"""
96+
Simulate count-terminated recurrence data based on the fitted model.
97+
98+
Parameters
99+
----------
100+
101+
events: int
102+
Number of events to simulate.
103+
items: int, optional
104+
Number of items (or sequences) to simulate. Default is 1.
105+
106+
Returns
107+
-------
108+
109+
NonParametricCounting
110+
An NonParametricCounting model built from the simulated data.
111+
"""
33112
life, *scale = self.model.params
34113
q = self.q
35114
self.initialize_simulation()
@@ -61,6 +140,33 @@ def count_terminated_simulation(self, events, items=1):
61140
return model
62141

63142
def time_terminated_simulation(self, T, items=1, tol=1e-5):
143+
"""
144+
Simulate time-terminated recurrence data based on the fitted model.
145+
146+
Parameters
147+
----------
148+
149+
T: float
150+
Time termination value.
151+
items: int, optional
152+
Number of items (or sequences) to simulate. Default is 1.
153+
tol: float, optional
154+
Tolerance for interarrival times to stop an individual sequence.
155+
156+
Returns
157+
-------
158+
159+
NonParametricCounting
160+
An NonParametricCounting model built from the simulated data.
161+
162+
Warnings
163+
--------
164+
165+
If any of the simulated sequences seem to not reach the time
166+
termination value T due to possible asymptote, a warning message will
167+
be printed to notify the user about potential convergence problems in
168+
the simulation.
169+
"""
64170
life, *scale = self.model.params
65171
q = self.q
66172
self.initialize_simulation()
@@ -136,6 +242,50 @@ def negll_func(params):
136242

137243
@classmethod
138244
def fit_from_recurrent_data(cls, data, dist=Weibull, init=None):
245+
"""
246+
Fit the generalized renewal model from recurrent data.
247+
248+
Parameters
249+
----------
250+
251+
data : RecurrentData
252+
Data containing the recurrence details.
253+
dist : Distribution, optional
254+
A surpyval distribution object. Default is Weibull.
255+
kijima : str, optional
256+
Type of Kijima model to use, either "i" or "ii". Default is "i".
257+
init : list, optional
258+
Initial parameters for the optimization algorithm.
259+
260+
Returns
261+
-------
262+
263+
GeneralizedOneRenewal
264+
A fitted GeneralizedOneRenewal object.
265+
266+
Example
267+
-------
268+
269+
>>> from surpyval import GeneralizedOneRenewal, handle_xicn
270+
>>> import numpy as np
271+
>>>
272+
>>> x = np.array([1, 3, 6, 9, 10, 1.4, 3, 6.7, 8.9, 11, 1, 2])
273+
>>> c = np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 , 1])
274+
>>> i = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3])
275+
>>>
276+
>>> rec_data = handle_xicn(x, i, c)
277+
>>>
278+
>>> model = GeneralizedOneRenewal.fit_from_recurrent_data(rec_data)
279+
>>> model
280+
G1 Renewal SurPyval Model
281+
=========================
282+
Distribution : Weibull
283+
Fitted by : MLE
284+
Restoration Factor : 0.4270960618530103
285+
Parameters :
286+
alpha: 1.3494830373118245
287+
beta: 2.7838386997223212
288+
"""
139289
if init is None:
140290
dist_params = dist.fit(
141291
data.interarrival_times, data.c, data.n
@@ -183,8 +333,100 @@ def fit_from_recurrent_data(cls, data, dist=Weibull, init=None):
183333
return out
184334

185335
@classmethod
186-
def fit(cls, x, i, c, n, dist=Weibull, init=None):
187-
# Wrangle data
188-
# Rest of the data assumes values are in ascending order.
336+
def fit(cls, x, i=None, c=None, n=None, dist=Weibull, init=None):
337+
"""
338+
Fit the generalized renewal model.
339+
340+
Parameters
341+
----------
342+
343+
x : array_like
344+
An array of event times.
345+
i : array_like, optional
346+
An array of item indices.
347+
c : array_like, optional
348+
An array of censoring indicators.
349+
n : array_like, optional
350+
An array of counts.
351+
dist : object, optional
352+
A surpyval distribution object. Default is Weibull.
353+
kijima : str, optional
354+
Type of Kijima model to use, either "i" or "ii". Default is "i".
355+
init : list, optional
356+
Initial parameters for the optimization algorithm.
357+
358+
Returns
359+
-------
360+
361+
GeneralizedOneRenewal
362+
A fitted GeneralizedOneRenewal object.
363+
364+
Example
365+
-------
366+
367+
>>> from surpyval import GeneralizedOneRenewal
368+
>>> import numpy as np
369+
>>>
370+
>>> x = np.array([1, 3, 6, 9, 10, 1.4, 3, 6.7, 8.9, 11, 1, 2])
371+
>>> c = np.array([0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0 , 1])
372+
>>> i = np.array([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3])
373+
>>>
374+
>>> model = GeneralizedOneRenewal.fit(x, i, c=c)
375+
>>> model
376+
G1 Renewal SurPyval Model
377+
=========================
378+
Distribution : Weibull
379+
Fitted by : MLE
380+
Restoration Factor : 0.4270960618530103
381+
Parameters :
382+
alpha: 1.3494830373118245
383+
beta: 2.7838386997223212
384+
"""
189385
data = handle_xicn(x, i, c, n, as_recurrent_data=True)
190386
return cls.fit_from_recurrent_data(data, dist=dist, init=init)
387+
388+
@classmethod
389+
def fit_from_parameters(cls, params, q, dist=Weibull):
390+
"""
391+
Fit the generalized renewal model from given parameters.
392+
393+
Parameters
394+
----------
395+
396+
params : list
397+
A list of parameters for the survival analysis distribution.
398+
q : float
399+
Restoration factor used in the Kijima models.
400+
kijima : str, optional
401+
Type of Kijima model to use, either "i" or "ii". Default is "i".
402+
dist : object, optional
403+
A surpyval distribution object. Default is Weibull.
404+
405+
Returns
406+
-------
407+
408+
GeneralizedOneRenewal
409+
A fitted GeneralizedOneRenewal object.
410+
411+
Example
412+
-------
413+
414+
>>> from surpyval import GeneralizedOneRenewal, Normal
415+
>>>
416+
>>> model = GeneralizedOneRenewal.fit_from_parameters(
417+
[10, 2],
418+
0.2,
419+
dist=Normal
420+
)
421+
>>> model
422+
G1 Renewal SurPyval Model
423+
=========================
424+
Distribution : Normal
425+
Fitted by : MLE
426+
Restoration Factor : 0.2
427+
Parameters :
428+
mu: 10
429+
sigma: 2
430+
"""
431+
model = dist.from_params(params)
432+
return cls(model, q)

0 commit comments

Comments
 (0)