Skip to content
Snippets Groups Projects
Commit a644680c authored by Christoph Lange's avatar Christoph Lange
Browse files

first version of a tf data handler

parent d8831e90
No related branches found
No related tags found
1 merge request!9Lstm prototype
import functools
import matplotlib.pyplot as plt
import numpy as np
import random
import tensorflow as tf
MIN_INPUT = 10
class DataWindowHandler():
def __init__(
self,
input_width,
label_width,
train_data,
validation_data,
minimal_input_length=MIN_INPUT,
):
# Store the raw data.
self.train = train_data
self.validation_data = validation_data
self.total_length = len(train_data[0])
# Work out the window parameters.
self.input_width = input_width
self.label_width = label_width
self.minimal_input_length = minimal_input_length
def __repr__(self):
return '\n'.join([
f'Total window size: {self.input_width + self.label_width}',
f'Input width: {self.input_width}',
f'Label width: {self.label_width}',
])
def split_window(self, features):
# dimensions (batch, time, features==1)
inputs = features[:, :self.input_width, tf.newaxis]
labels = features[:, self.input_width:(self.input_width + self.label_width), tf.newaxis]
# Slicing doesn't preserve static shape information, so set the shapes
# manually. This way the `tf.data.Datasets` are easier to inspect.
inputs.set_shape([None, self.input_width, 1])
labels.set_shape([None, self.label_width, 1])
return inputs, labels
def plot(self, model=None):
fig, axises = plt.subplots(2, 1, figsize=(25, 15))
for ax, validation_data in zip(axises.flatten(), self.validation_dataset.unbatch().take(2)):
inputs, labels = validation_data
plt.ylabel('voltage [normed]')
ax.plot(range(self.input_width), inputs, label='Inputs', marker='.')
ax.scatter(
range(self.input_width, self.input_width + self.label_width),
labels,
edgecolors='k',
label='Labels',
c='#2ca02c',
)
if model is not None:
ax.scatter(
range(self.input_width, self.input_width + self.label_width),
model(inputs[tf.newaxis, :]),
marker='X',
edgecolors='k',
label='Predictions',
c='#ff7f0e'
)
plt.legend()
plt.xlabel('Time [min]')
def make_dataset(self, data):
return functools.reduce(
tf.data.Dataset.concatenate,
[
tf.keras.preprocessing.timeseries_dataset_from_array(
data=one_series.astype(np.float32),
targets=None,
sequence_length=self.input_width + self.label_width,
sequence_stride=10,
shuffle=True,
batch_size=8,
).map(self.split_window)
for one_series in data
]
)
@property
def train_dataset(self):
return self.make_dataset(self.train)
@property
def train_dataset_randomly(self):
def random_generator():
while True:
cut_off = random.choice(
range(
self.minimal_input_length,
self.total_length - self.minimal_input_length
)
)
for time_series in self.train:
# dimensions (batch, time, features==1)
inputs = tf.convert_to_tensor(
time_series[:cut_off, tf.newaxis],
dtype=tf.float32,
)
labels = tf.convert_to_tensor(
np.concatenate((
time_series[cut_off:self.total_length],
np.zeros(cut_off - self.minimal_input_length)
))[:, np.newaxis],
dtype=tf.float32,
)
# Slicing doesn't preserve static shape information, so set the shapes
# manually. This way the `tf.data.Datasets` are easier to inspect.
# inputs.set_shape([cut_off, 1])
# labels.set_shape([self.total_length - 10, 1])
yield inputs, labels
return tf.data.Dataset.from_generator(
random_generator,
(tf.float32, tf.float32),
(tf.TensorShape([None, 1]), tf.TensorShape([None, 1]))
).padded_batch(len(self.train))
@property
def validation_dataset(self):
return self.make_dataset(self.validation_data)
@property
def validation_dataset_full(self):
def fixed_generator():
for time_series in self.validation_data:
# dimensions (batch, time, features==1)
inputs = time_series[:self.input_width, tf.newaxis]
labels = time_series[self.input_width:, tf.newaxis]
yield inputs, labels
return tf.data.Dataset.from_generator(
fixed_generator,
(tf.float32, tf.float32),
(tf.TensorShape([None, 1]), tf.TensorShape([None, 1]))
).batch(len(self.validation_data))
import numpy as np
from glucose_ts.data import DataWindowHandler
def test_training_dataset_variable_length():
# given
slope = 2
initial_value = 5
test_labels = [
np.array([
slope * time + initial_value for time in np.linspace(0, 5, 21)
]) + np.random.randn((21))
for _ in range(4)
]
data_handler = DataWindowHandler(
input_width=None,
label_width=None,
train_data=test_labels,
validation_data=test_labels,
minimal_input_length=3,
)
# when
inputs, labels = list(data_handler.train_dataset_randomly.take(1))[0]
# then
assert np.allclose(
np.concatenate(
(
inputs.numpy(),
labels.numpy()
),
axis=1,
),
np.concatenate(
(
np.array(test_labels)[:, :, np.newaxis],
np.zeros((len(test_labels), (inputs.shape[1] - 3), 1))
),
axis=1,
),
)
def test_validation_dataset_full_timespan():
# given
slope = 2
initial_value = 5
test_labels = [
np.array([
slope * time + initial_value for time in np.linspace(0, 5, 21)
]) + np.random.randn((21))
for _ in range(4)
]
data_handler = DataWindowHandler(
input_width=10,
label_width=None,
train_data=test_labels,
validation_data=test_labels,
minimal_input_length=3,
)
# when
inputs, labels = list(data_handler.validation_dataset_full.take(1))[0]
# then
assert np.allclose(
np.concatenate(
(
inputs.numpy(),
labels.numpy()
),
axis=1,
),
np.array(test_labels)[:, :, np.newaxis]
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment