Skip to content
Snippets Groups Projects
Commit 85e13e3b authored by Ingmar Schubert's avatar Ingmar Schubert
Browse files

Task 02

parent 1b51ad11
No related branches found
No related tags found
No related merge requests found
This is a pdf version of https://git.tu-berlin.de/lis-public/ai-student-workspace/-/blob/main/02/README.md
To obtain the code for this assignment, you will need to fetch and pull new commits from git@git.tu-berlin.de:lis-public/ai-student-workspace.git
As always, only modify the file `solution_??.py`. And even in `solution_??.py`, only modify what the functions do - don't change the function's names. Don't add any additional files, put all your code in `solution_??.py`.
You can run tests by navigating to the task folder `??`, and then simply typing `python3 -m pytest`. If you haven't yet, you will need to install pytest first: `sudo apt install python3-pytest`.
# Assignment 2: Single-Step and Discrete Decision Making
## 2.1: Single Step, Continuous, Data-based (aka supervised learning)
The following information is available:
- You have to choose an action a from a continuous interval $a \in [l, u] \subset \mathbb{R}$.
- Outcomes $y \in \mathbb{R}$ depend on the decision $a$ in some unknown way: There is an underlying function $f$ unknown to you, and observations are Gauss distributed around this function with $\sigma = 0.1$, i.e. $y \sim P(\cdot\mid a) = \mathcal{N}(f(a)\mid \sigma^2)$. You can assume that $f$ is smooth.
- You have a data set $D=\{(a_i, y_i)_{i=1}^n\}$ of previous decisions $a_i \in \mathbb{R}$ and outcomes $y_i \in \mathbb{R}$.
For this part of the exercise, you will need to modify the following function in `solution_02.py`:
```
regression_based_on_data(l, u, a, y)
```
The inputs of this function are as follows:
- `l` (single float) is the left side of the interval in which the action $a$ can be
- `u` (single float) is the right side of the interval in which the action $a$ can be
- `a` is a `np.array` of length `n`, and contains actions $a$
- `y` is a `np.array` of length `n`, and contains the $y \sim P(\cdot\mid a)$ obtained for the `a` in `A`.
This function should return:
- The action $a_\text{max}$ (single float) from the interval $[l, u]$ that maximizes the expectation $E[y\mid a]$. Since we know that $y \sim P(\cdot\mid a) = \mathcal{N}(f(a)\mid \sigma^2)$, this is the action $a_\text{max}$ that maximizes $f(a)$
To do this, you will first need to perform the following steps:
1. Use the data (`a`,`y`) to find the approximate function $\hat{f}$ (this is your statistical estimator for the true $f$
2. Find the $a_\text{max}$ that maximizes $\hat{f}(a)$ on $a \in [l, u]$.
You are free to use any function approximation to find $\hat{f}$ in the first step, but we suggest the following regression method:
1. We choose $F$ different feature functions $\{\phi_0, \phi_1, ..., \phi_{F-1}\}$. For example, we can use polynomial features: $\phi_k(a)=a^k$.
2. We make the ansatz $\hat{f}(a) = \sum_{k=0}^{F-1} \hat{\alpha_k}\ a^k$.
3. The estimated $\hat{\alpha_k}$ are chosen such that our estimated function $\hat{f}$ maximizes the likelihood of the data (`a`,`y`). As will be discussed in the tutorials, this is equivalent to calculating
$$ \hat{\alpha} = (A^T A)^{-1}A^Ty$$
Here, $\hat{\alpha}$ is a vector with elements $\hat{\alpha_k}$, and $A$ is a $n\times F$ matrix whose elements are as follows:
$$ A_{ij} = \phi_j(a_i) \quad ,$$
where the $a_i$ are the elements of the vector `a`.
The larger you choose $F$, the larger the variance of your estimator $\hat{f}$ becomes. The smaller you choose $F$, the larger the bias of your estimator $\hat{f}$ becomes. There is a sweet spot somewhere in the middle! To get an intuition, it could be useful to plot `a`, `y`, and $\hat{f}$ together!
## 2.2: Single Step, Continuous, Simulation-based (aka active learning)
This is the same setting as in 2.1, but this time you don't have data $D=\{(a_i, y_i)_{i=1}^n\}$. Instead, the agent can query $n$ times from $y \sim P(\cdot\mid a)$ (i.e. sample a resulting $y$ for a given value $n$).
This part of the exercise is not scored, but be prepared to present your results in the tutorial. Write the code for this part of the exercise in the separate file `exercise_022.py`, not in `solution_02.py`. At the top of the file, the function `get_y_given_a` implements sampling from the black-box distribution $P(\cdot\mid a)$.
These are the instructions:
1. Sample a data pair for an action of your choice. Let's denote the data after drawing $k\leq n$ samples as $D_k=\{(a_i, y_i)_{i=1}^k\}$. Start by choosing a couple of actions randomly, then proceed with step 3.
2. Generate $B=10$ bootstrap resamples from the data $D_k$, which we denote as $D_k^b$. In other words: Randomly select $k$ data pairs (with replacement!) from $D_k$, resulting in a new dataset $D_k^b$ each time, and repeat this $B=10$ times. For more details on bootstrapping, read [https://isis.tu-berlin.de/pluginfile.php/2284732/mod_resource/content/1/04-notes.pdf](https://isis.tu-berlin.de/pluginfile.php/2284732/mod_resource/content/1/04-notes.pdf).
3. For each bootstrapped dataset $D_k^b$, perform the regression described in 2.1. This results in an estimated function $\hat{f}_k^b$ for each of the datasets. Now you can calculate
1. The bootstrap estimate of $f$: $\hat{f}_k(a) = \frac{1}{B} \sum_b \hat{f}_k^b(a)$
2. The estimated variance of $\hat{f}_k$: $\hat{\sigma}_k(a) = \frac{1}{B-1} \sum_b (\hat{f}_k^b(a) - \hat{f}_k(a))^2$
4. Create a plot of $\hat{f}_k(a) \pm \sqrt{\hat{\sigma}_k(a)}$ alongside with $D_k$ and the true $f$. You can use `matplotlib.pyplot.fill_between` to visualize the error bars.
5. Choose the next action to sample using UCB with $\beta=2$. Then repeat from step 1.
File added
# %%
"""
Things to experiment with:
1. Number of features (what happens to variance?)
2. Action sampling strategy
"""
import numpy as np
import matplotlib.pyplot as plt
sigma = 0.1
# Define problem
## get lower and upper
l = np.random.rand()*2
u = l + np.random.rand()*2 + 1
## get some parameters for f
peak_position = (u+l)/2 + (u-l) * (np.random.rand()-0.5)/2
width = (u-l)/2 * (1 + np.random.rand())/2
scale = 0.5 + np.random.rand()
def f_of_a(a):
"""
Blackbox implementation of f(a)
"""
## get params of true function
## define f as inline
return scale * np.exp(-(a-peak_position)**2/2/width)
## define sampling process as inline
def get_y_given_a(a):
"""
Blackbox implementation of P(.|a)
"""
if hasattr(a, "__len__"):
return f_of_a(a) + np.random.normal(0, sigma, size=a.shape)
return f_of_a(a) + np.random.normal(0, sigma)
####################################################################
import numpy as np
def regression_based_on_data(l, u, a, y):
"""
Please modify the body of this function according to the description in exercise 2.1
"""
raise NotImplementedError
"""
Test cases for you to check your solutions.
run by typing `pytest` in the task folder.
Do not change the content of this file!
"""
import numpy as np
import pytest
from solution_02 import regression_based_on_data
@pytest.mark.parametrize(
"l,u,a,y,best_action",
[
(
1.376311057991322,
4.326125835094503,
np.array([1.37631106, 1.43651136, 1.49671166, 1.55691196, 1.61711226,
1.67731257, 1.73751287, 1.79771317, 1.85791347, 1.91811377,
1.97831407, 2.03851438, 2.09871468, 2.15891498, 2.21911528,
2.27931558, 2.33951588, 2.39971618, 2.45991649, 2.52011679,
2.58031709, 2.64051739, 2.70071769, 2.76091799, 2.8211183 ,
2.8813186 , 2.9415189 , 3.0017192 , 3.0619195 , 3.1221198 ,
3.18232011, 3.24252041, 3.30272071, 3.36292101, 3.42312131,
3.48332161, 3.54352191, 3.60372222, 3.66392252, 3.72412282,
3.78432312, 3.84452342, 3.90472372, 3.96492403, 4.02512433,
4.08532463, 4.14552493, 4.20572523, 4.26592553, 4.32612584]),
np.array([0.07444134, 0.10072154, 0.17155658, 0.09752348, 0.0857364 ,
0.04541672, 0.14981702, 0.2015575 , 0.12025267, 0.21334166,
0.53762892, 0.26543295, 0.16483137, 0.30528716, 0.44170021,
0.49462482, 0.2701016 , 0.37657232, 0.54416779, 0.65141215,
0.30387195, 0.48041147, 0.45248468, 0.51798842, 0.60359193,
0.47714452, 0.58798299, 0.56155803, 0.58899838, 0.68905315,
0.68169771, 0.68908075, 0.55080193, 0.74452369, 0.50126603,
0.88421286, 0.53759901, 0.56359476, 0.62104628, 0.82497294,
0.4704031 , 0.82829816, 0.90355641, 0.66547357, 0.6676321 ,
0.44229733, 0.61052731, 0.45220587, 0.67128588, 0.54934763]),
3.5152384759845754
),
(
0.7472696582591833,
3.559664643195159,
np.array([0.74726966, 0.80466547, 0.86206129, 0.91945711, 0.97685292,
1.03424874, 1.09164455, 1.14904037, 1.20643619, 1.263832 ,
1.32122782, 1.37862363, 1.43601945, 1.49341527, 1.55081108,
1.6082069 , 1.66560271, 1.72299853, 1.78039435, 1.83779016,
1.89518598, 1.95258179, 2.00997761, 2.06737343, 2.12476924,
2.18216506, 2.23956087, 2.29695669, 2.35435251, 2.41174832,
2.46914414, 2.52653995, 2.58393577, 2.64133159, 2.6987274 ,
2.75612322, 2.81351903, 2.87091485, 2.92831067, 2.98570648,
3.0431023 , 3.10049812, 3.15789393, 3.21528975, 3.27268556,
3.33008138, 3.3874772 , 3.44487301, 3.50226883, 3.55966464]),
np.array([0.77378903, 0.80983148, 0.65459869, 0.81758114, 0.85429728,
0.93294482, 1.0803754 , 1.0517089 , 1.12646874, 1.03698352,
1.16632296, 1.37628955, 1.40197123, 1.28053244, 1.26822654,
1.23988955, 1.28571848, 1.30242709, 1.2977698 , 1.57160917,
1.45509417, 1.57017556, 1.4390608 , 1.57415826, 1.55315486,
1.45626521, 1.62429765, 1.39825196, 1.33835364, 1.26044113,
1.20158646, 1.16107052, 1.0927094 , 1.20850408, 1.08274557,
0.91718687, 1.09890939, 0.9861561 , 1.07590623, 0.90693924,
0.65526108, 0.68132896, 0.78334199, 0.64412536, 0.5325601 ,
0.39194022, 0.50304258, 0.52393478, 0.31853307, 0.39118306]),
1.9528523781485918
),
(
1.3306672668627233,
2.4346597220662876,
np.array([1.33066727, 1.35319773, 1.37572818, 1.39825864, 1.4207891 ,
1.44331956, 1.46585002, 1.48838047, 1.51091093, 1.53344139,
1.55597185, 1.57850231, 1.60103277, 1.62356322, 1.64609368,
1.66862414, 1.6911546 , 1.71368506, 1.73621552, 1.75874597,
1.78127643, 1.80380689, 1.82633735, 1.84886781, 1.87139827,
1.89392872, 1.91645918, 1.93898964, 1.9615201 , 1.98405056,
2.00658101, 2.02911147, 2.05164193, 2.07417239, 2.09670285,
2.11923331, 2.14176376, 2.16429422, 2.18682468, 2.20935514,
2.2318856 , 2.25441606, 2.27694651, 2.29947697, 2.32200743,
2.34453789, 2.36706835, 2.38959881, 2.41212926, 2.43465972]),
np.array([1.02321 , 0.98904306, 1.06952685, 1.13167993, 1.053393 ,
1.07563878, 0.96065283, 0.98587522, 1.10846709, 1.02206226,
0.97644355, 0.88877659, 1.15397695, 0.86628782, 1.17649864,
1.23761648, 1.24050706, 1.0966679 , 1.14244526, 0.91848734,
1.00303682, 1.0737494 , 0.9034855 , 1.04208792, 0.82839053,
1.11280388, 1.00182428, 1.01973662, 1.03536875, 0.79787061,
1.07156466, 0.97157251, 1.02069085, 0.85254815, 1.00025544,
0.87981211, 0.78423405, 0.72420883, 0.81507134, 0.76448586,
0.71390331, 0.86245015, 0.82323186, 0.7787344 , 0.6750226 ,
0.71416252, 0.52104852, 0.87272917, 0.67154903, 0.52325259]),
1.652446854419188
),
(
0.38409932218881093,
1.5036602960827787,
np.array([0.38409932, 0.40694751, 0.42979569, 0.45264387, 0.47549205,
0.49834024, 0.52118842, 0.5440366 , 0.56688479, 0.58973297,
0.61258115, 0.63542934, 0.65827752, 0.6811257 , 0.70397389,
0.72682207, 0.74967025, 0.77251844, 0.79536662, 0.8182148 ,
0.84106299, 0.86391117, 0.88675935, 0.90960753, 0.93245572,
0.9553039 , 0.97815208, 1.00100027, 1.02384845, 1.04669663,
1.06954482, 1.092393 , 1.11524118, 1.13808937, 1.16093755,
1.18378573, 1.20663392, 1.2294821 , 1.25233028, 1.27517846,
1.29802665, 1.32087483, 1.34372301, 1.3665712 , 1.38941938,
1.41226756, 1.43511575, 1.45796393, 1.48081211, 1.5036603 ]),
np.array([0.72858483, 0.67359292, 0.82297195, 0.63363984, 0.77135652,
0.76180939, 0.90148338, 0.89768836, 0.86046132, 0.83385383,
0.85758401, 1.08826333, 1.0427012 , 0.95097736, 1.0332322 ,
1.03096033, 0.8713267 , 1.02777476, 1.05348153, 0.93208115,
1.01402438, 1.15750482, 0.97343204, 1.02404006, 1.08062312,
0.9599618 , 0.95471978, 1.10645223, 1.00998633, 1.07231987,
1.18601178, 0.96449302, 1.15962812, 1.03164456, 1.12568525,
1.21294478, 1.33669226, 1.04406794, 1.31982016, 1.17011975,
1.03055596, 1.14435503, 1.0091734 , 1.16668654, 1.04584704,
1.0114161 , 0.9956996 , 1.06870349, 1.08429899, 1.07615535]),
1.1295611710770421
),
(
1.4133807809076466,
4.169346250620812,
np.array([1.41338078, 1.46962497, 1.52586917, 1.58211336, 1.63835755,
1.69460175, 1.75084594, 1.80709013, 1.86333433, 1.91957852,
1.97582271, 2.03206691, 2.0883111 , 2.14455529, 2.20079949,
2.25704368, 2.31328787, 2.36953207, 2.42577626, 2.48202045,
2.53826465, 2.59450884, 2.65075303, 2.70699723, 2.76324142,
2.81948561, 2.87572981, 2.931974 , 2.98821819, 3.04446239,
3.10070658, 3.15695077, 3.21319497, 3.26943916, 3.32568335,
3.38192754, 3.43817174, 3.49441593, 3.55066012, 3.60690432,
3.66314851, 3.7193927 , 3.7756369 , 3.83188109, 3.88812528,
3.94436948, 4.00061367, 4.05685786, 4.11310206, 4.16934625]),
np.array([0.40490359, 0.49346239, 0.56676352, 0.34671098, 0.63900495,
0.57105849, 0.76102817, 0.70949478, 0.59719803, 0.44417089,
0.64972368, 0.75127685, 0.76022228, 0.65559522, 0.78397685,
0.60962922, 0.9643949 , 0.70588868, 0.73050517, 0.91491863,
0.69345186, 0.54832673, 0.59202946, 0.74825668, 0.66918741,
0.67738221, 0.55447525, 0.44530864, 0.60779754, 0.50484359,
0.49748762, 0.25146808, 0.35824415, 0.48302387, 0.44771223,
0.33074683, 0.27129176, 0.32289134, 0.35352856, 0.25542049,
0.35930394, 0.12828533, 0.15433078, 0.10128796, 0.09494405,
0.13920905, 0.13143921, 0.21591585, 0.16965837, 0.16628598]),
2.3615375367860496
),
(
1.2902316540060546,
4.282784505767093,
np.array([1.29023165, 1.35130416, 1.41237667, 1.47344918, 1.53452168,
1.59559419, 1.6566667 , 1.7177392 , 1.77881171, 1.83988422,
1.90095673, 1.96202923, 2.02310174, 2.08417425, 2.14524675,
2.20631926, 2.26739177, 2.32846428, 2.38953678, 2.45060929,
2.5116818 , 2.5727543 , 2.63382681, 2.69489932, 2.75597183,
2.81704433, 2.87811684, 2.93918935, 3.00026186, 3.06133436,
3.12240687, 3.18347938, 3.24455188, 3.30562439, 3.3666969 ,
3.42776941, 3.48884191, 3.54991442, 3.61098693, 3.67205943,
3.73313194, 3.79420445, 3.85527696, 3.91634946, 3.97742197,
4.03849448, 4.09956698, 4.16063949, 4.221712 , 4.28278451]),
np.array([0.29819128, 0.28873117, 0.36489756, 0.32654122, 0.60308445,
0.41869311, 0.49883312, 0.69546354, 0.57889736, 0.57338874,
0.77547897, 0.68984688, 0.69679499, 0.89851936, 1.00112855,
0.81699379, 1.05764859, 1.13539563, 1.05908627, 1.10197011,
1.0461601 , 1.10926668, 1.15476113, 1.33131476, 1.31444651,
1.46318771, 1.36730327, 1.14982189, 1.30415708, 1.49533404,
1.39388937, 1.32538319, 1.10136551, 1.24927986, 1.41624876,
1.24152407, 1.24267505, 1.29650492, 1.41541085, 1.11778596,
0.92560006, 1.23188006, 1.12761123, 1.00538112, 0.90350357,
0.94416536, 0.65771739, 0.85870303, 0.75278262, 0.72614481]),
3.07119827624406
),
(
1.1949830489712425,
2.773843232037444,
np.array([1.19498305, 1.22720469, 1.25942632, 1.29164796, 1.32386959,
1.35609123, 1.38831287, 1.4205345 , 1.45275614, 1.48497778,
1.51719941, 1.54942105, 1.58164269, 1.61386432, 1.64608596,
1.67830759, 1.71052923, 1.74275087, 1.7749725 , 1.80719414,
1.83941578, 1.87163741, 1.90385905, 1.93608069, 1.96830232,
2.00052396, 2.0327456 , 2.06496723, 2.09718887, 2.1294105 ,
2.16163214, 2.19385378, 2.22607541, 2.25829705, 2.29051869,
2.32274032, 2.35496196, 2.3871836 , 2.41940523, 2.45162687,
2.4838485 , 2.51607014, 2.54829178, 2.58051341, 2.61273505,
2.64495669, 2.67717832, 2.70939996, 2.7416216 , 2.77384323]),
np.array([0.3793393 , 0.37542424, 0.46820038, 0.28045481, 0.55225387,
0.52285749, 0.55759825, 0.55541301, 0.54273424, 0.58899492,
0.54873157, 0.67284187, 0.52399589, 0.82977319, 0.66014881,
0.8500501 , 0.52174529, 0.69882554, 0.77476359, 0.8086202 ,
0.57787331, 0.7360081 , 0.91274542, 0.73080065, 0.81392503,
0.81867983, 0.96940677, 0.48918518, 0.78240487, 0.85844144,
0.68085878, 0.80973961, 0.753002 , 0.9047564 , 0.73421552,
0.71009024, 0.73003265, 0.71415818, 0.60599512, 0.6444571 ,
0.69929698, 0.75010118, 0.66448182, 0.75237484, 0.59096271,
0.6745164 , 0.53312755, 0.53416456, 0.64552088, 0.50352999]),
2.08256212117585
),
(
1.4403744802547955,
2.6964897969511976,
np.array([1.44037448, 1.46600949, 1.49164449, 1.5172795 , 1.54291451,
1.56854951, 1.59418452, 1.61981953, 1.64545453, 1.67108954,
1.69672454, 1.72235955, 1.74799456, 1.77362956, 1.79926457,
1.82489958, 1.85053458, 1.87616959, 1.9018046 , 1.9274396 ,
1.95307461, 1.97870962, 2.00434462, 2.02997963, 2.05561464,
2.08124964, 2.10688465, 2.13251965, 2.15815466, 2.18378967,
2.20942467, 2.23505968, 2.26069469, 2.28632969, 2.3119647 ,
2.33759971, 2.36323471, 2.38886972, 2.41450473, 2.44013973,
2.46577474, 2.49140975, 2.51704475, 2.54267976, 2.56831476,
2.59394977, 2.61958478, 2.64521978, 2.67085479, 2.6964898 ]),
np.array([0.47170882, 0.36332812, 0.4591127 , 0.44715866, 0.68120332,
0.71148963, 0.40955849, 0.53708929, 0.53962812, 0.57249441,
0.52033561, 0.6977331 , 0.44933004, 0.54403723, 0.55945913,
0.50184738, 0.52224838, 0.64708654, 0.56455093, 0.53884591,
0.8296299 , 0.61635789, 0.34852081, 0.44902487, 0.72934694,
0.6737639 , 0.59239064, 0.51771012, 0.55030474, 0.50338181,
0.51388977, 0.52323197, 0.53205544, 0.47382758, 0.31454426,
0.46451816, 0.44031419, 0.4877412 , 0.33020709, 0.46063573,
0.42127234, 0.34122936, 0.35103346, 0.21101455, 0.51129454,
0.41320748, 0.35416231, 0.36201248, 0.36462864, 0.15323838]),
1.8675866174351226
),
(
0.3570819947603825,
3.2434619540053555,
np.array([0.35708199, 0.41598771, 0.47489342, 0.53379914, 0.59270485,
0.65161056, 0.71051628, 0.76942199, 0.8283277 , 0.88723342,
0.94613913, 1.00504484, 1.06395056, 1.12285627, 1.18176198,
1.2406677 , 1.29957341, 1.35847912, 1.41738484, 1.47629055,
1.53519626, 1.59410198, 1.65300769, 1.7119134 , 1.77081912,
1.82972483, 1.88863054, 1.94753626, 2.00644197, 2.06534768,
2.1242534 , 2.18315911, 2.24206483, 2.30097054, 2.35987625,
2.41878197, 2.47768768, 2.53659339, 2.59549911, 2.65440482,
2.71331053, 2.77221625, 2.83112196, 2.89002767, 2.94893339,
3.0078391 , 3.06674481, 3.12565053, 3.18455624, 3.24346195]),
np.array([-0.09290566, 0.00982314, 0.13646142, 0.11243038, 0.11072218,
-0.01414125, -0.14944912, 0.09414059, 0.14374925, 0.19779021,
0.03508091, 0.34100795, 0.27020969, 0.15945921, 0.38725372,
0.37486682, 0.22435582, 0.23684771, 0.24665014, 0.18246121,
0.34818399, 0.40012908, 0.33654425, 0.45042927, 0.4939493 ,
0.57791053, 0.51398335, 0.53698197, 0.50616079, 0.56413169,
0.49961171, 0.79000434, 0.73900315, 0.79937715, 0.76483373,
0.57755459, 0.43027561, 0.65268296, 0.51629475, 0.71783169,
0.86789085, 0.48635384, 0.61860143, 0.7174154 , 0.66789158,
0.6342673 , 0.53464439, 0.34855856, 0.47742571, 0.31660777]),
2.466586289050439
),
(
0.4616234662945171,
2.9429380148134685,
np.array([0.46162347, 0.51226254, 0.56290161, 0.61354068, 0.66417976,
0.71481883, 0.7654579 , 0.81609697, 0.86673605, 0.91737512,
0.96801419, 1.01865326, 1.06929234, 1.11993141, 1.17057048,
1.22120955, 1.27184862, 1.3224877 , 1.37312677, 1.42376584,
1.47440491, 1.52504399, 1.57568306, 1.62632213, 1.6769612 ,
1.72760028, 1.77823935, 1.82887842, 1.87951749, 1.93015657,
1.98079564, 2.03143471, 2.08207378, 2.13271286, 2.18335193,
2.233991 , 2.28463007, 2.33526915, 2.38590822, 2.43654729,
2.48718636, 2.53782544, 2.58846451, 2.63910358, 2.68974265,
2.74038173, 2.7910208 , 2.84165987, 2.89229894, 2.94293801]),
np.array([0.71467699, 0.77899176, 1.04415426, 0.93498513, 0.92074146,
0.98773594, 0.95654816, 1.30092409, 1.28703556, 1.15093418,
1.24532464, 1.35484699, 1.29969977, 1.52840986, 1.47977026,
1.60490531, 1.35572502, 1.32088853, 1.5979054 , 1.36126115,
1.51240815, 1.60152903, 1.53489047, 1.26693377, 1.45913383,
1.49810754, 1.20375782, 1.3216294 , 1.43141034, 1.31127146,
1.39324891, 1.39255155, 1.29547 , 1.2077037 , 1.0778348 ,
1.11899009, 0.90777015, 1.06863381, 0.94917329, 0.91206529,
0.79605204, 0.650758 , 0.7124014 , 0.68123873, 0.55190608,
0.52631777, 0.46615908, 0.50541683, 0.23616255, 0.47073758]),
1.525737079953606
)
]
)
def test_regression_based_on_data(l, u, a, y, best_action):
"""
Test cases for exercise 2.1
"""
# restrict relative deviation
assert abs(best_action - regression_based_on_data(l, u, a, y)) < 0.1 * (u-l)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment