Skip to content
Snippets Groups Projects
Commit 52e775a6 authored by Vincent Wall's avatar Vincent Wall
Browse files

demo for LNDW2022 (shows class images)

parent 814491b7
No related branches found
No related tags found
No related merge requests found
...@@ -32,7 +32,7 @@ from A_record import MODEL_NAME ...@@ -32,7 +32,7 @@ from A_record import MODEL_NAME
# USER SETTINGS # USER SETTINGS
# ================== # ==================
BASE_DIR = "." BASE_DIR = "."
SENSORMODEL_FILENAME = "sensor_model.pkl" SENSORMODEL_FILENAME = "sensor_model_lndw2022_AAS016.pkl"
TEST_SIZE = 0 # percentage of samples left out of training and used for reporting test score TEST_SIZE = 0 # percentage of samples left out of training and used for reporting test score
SHOW_PLOTS = True SHOW_PLOTS = True
# ================== # ==================
...@@ -122,6 +122,14 @@ def plot_spectra(spectra, labels): ...@@ -122,6 +122,14 @@ def plot_spectra(spectra, labels):
fig.show() fig.show()
def plot_cm(y_true, y_pred, classes, title="Confusion Matrix"):
import os, sys
sys.path.insert(0, os.path.realpath("../Acoustic_sensing/"))
from plot import plot_confusion_matrix
from sklearn.metrics import confusion_matrix
plot_confusion_matrix(confusion_matrix(y_true, y_pred), classes, title)
def main(): def main():
print("Running for model '{}'".format(MODEL_NAME)) print("Running for model '{}'".format(MODEL_NAME))
global DATA_DIR global DATA_DIR
...@@ -146,9 +154,14 @@ def main(): ...@@ -146,9 +154,14 @@ def main():
print("Fitted sensor model to data!") print("Fitted sensor model to data!")
print("Training score: {:.2f}".format(train_score)) print("Training score: {:.2f}".format(train_score))
if SHOW_PLOTS:
plot_cm(y_train, clf.predict(X_train), clf.classes_, "CM Train")
if TEST_SIZE > 0: if TEST_SIZE > 0:
test_score = clf.score(X_test, y_test) test_score = clf.score(X_test, y_test)
print("Test score: {:.2f}".format(test_score)) print("Test score: {:.2f}".format(test_score))
if SHOW_PLOTS:
plot_cm(y_test, clf.predict(X_test), clf.classes_, "CM Test")
save_sensor_model(DATA_DIR, clf, SENSORMODEL_FILENAME) save_sensor_model(DATA_DIR, clf, SENSORMODEL_FILENAME)
print("\nSaved model to '{}'".format(os.path.join(DATA_DIR, SENSORMODEL_FILENAME))) print("\nSaved model to '{}'".format(os.path.join(DATA_DIR, SENSORMODEL_FILENAME)))
......
...@@ -28,6 +28,7 @@ from sklearn.neighbors import KNeighborsClassifier ...@@ -28,6 +28,7 @@ from sklearn.neighbors import KNeighborsClassifier
from jacktools.jacksignal import JackSignal from jacktools.jacksignal import JackSignal
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.widgets import Button from matplotlib.widgets import Button
import matplotlib.image as mpimg
from A_record import MODEL_NAME from A_record import MODEL_NAME
from B_train import SENSORMODEL_FILENAME from B_train import SENSORMODEL_FILENAME
...@@ -38,8 +39,18 @@ from B_train import sound_to_spectrum, sound_to_spectrum_stft ...@@ -38,8 +39,18 @@ from B_train import sound_to_spectrum, sound_to_spectrum_stft
# ================== # ==================
BASE_DIR = "." BASE_DIR = "."
CONTINUOUSLY = True # chose between continuous sensing or manually triggered CONTINUOUSLY = True # chose between continuous sensing or manually triggered
SHOW_CLASS_IMAGES = True
# ================== # ==================
label_renamer_lndw2022 = {
"tip": "Vorne",
"middle": "Mitte",
"base": "Basis",
"back": "Rückseite",
"none": "Ohne"
}
label_renamer = label_renamer_lndw2022
CHANNELS = 1 CHANNELS = 1
SR = 48000 SR = 48000
...@@ -86,6 +97,10 @@ class LiveAcousticSensor(object): ...@@ -86,6 +97,10 @@ class LiveAcousticSensor(object):
self.clf = pickle.load(f) self.clf = pickle.load(f)
print(self.clf.classes_) print(self.clf.classes_)
self.class_imgs = {}
for cl in self.clf.classes_:
self.class_imgs[cl] = mpimg.imread(f"img/{cl}.jpg")
def setup_window(self): def setup_window(self):
f = plt.figure(1) f = plt.figure(1)
...@@ -103,13 +118,21 @@ class LiveAcousticSensor(object): ...@@ -103,13 +118,21 @@ class LiveAcousticSensor(object):
self.spectrumlines, = ax2.plot(sound_to_spectrum_stft(self.Ains[0])) self.spectrumlines, = ax2.plot(sound_to_spectrum_stft(self.Ains[0]))
ax2.set_ylim([0, 250]) ax2.set_ylim([0, 250])
ax3 = f.add_subplot(2, 1, 2) ax3a = f.add_subplot(2, 2, 3)
ax3.text(0.0, 0.8, "Sensing result:", dict(size=40)) ax3a.text(0.0, 0.8, "Gemessener Kontakt:", dict(size=40))
self.predictiontext = ax3.text(0.25, 0.25, "", dict(size=70)) self.predictiontext = ax3a.text(0.25, 0.25, "", dict(size=70))
ax3.set_xticklabels([]) ax3a.set_xticklabels([])
ax3.set_yticklabels([]) ax3a.set_yticklabels([])
# ax3.set_title("Contact location") # ax3.set_title("Contact location")
ax3.axis('off') ax3a.axis('off')
ax3b = f.add_subplot(2, 2, 4)
ax3b.set_xticklabels([])
ax3b.set_yticklabels([])
ax3b.axis('off')
self.class_img = ax3b.imshow(numpy.zeros((217, 169), float))
ax_pause = plt.axes([0.91, 0.025, 0.05, 0.075]) ax_pause = plt.axes([0.91, 0.025, 0.05, 0.075])
self.b_pause = Button(ax_pause, '[P]ause') self.b_pause = Button(ax_pause, '[P]ause')
...@@ -128,7 +151,13 @@ class LiveAcousticSensor(object): ...@@ -128,7 +151,13 @@ class LiveAcousticSensor(object):
self.wavelines.set_ydata(self.Ains[0].reshape(-1)) self.wavelines.set_ydata(self.Ains[0].reshape(-1))
self.spectrumlines.set_ydata(spectrum) self.spectrumlines.set_ydata(spectrum)
self.predictiontext.set_text(prediction[0]) if prediction[0] in label_renamer:
self.predictiontext.set_text(label_renamer[prediction[0]])
else:
self.predictiontext.set_text(prediction[0])
if SHOW_CLASS_IMAGES:
self.class_img.set_data(self.class_imgs[prediction[0]])
plt.draw() plt.draw()
plt.pause(0.00001) plt.pause(0.00001)
......
img/back.jpg

544 KiB

img/base.jpg

475 KiB

img/middle.jpg

493 KiB

img/none.jpg

693 KiB

img/tip.jpg

643 KiB

File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment