Skip to content
Snippets Groups Projects
A_record.py 6.43 KiB
Newer Older
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Example script for the "Acoustic Sensing Starter Kit"
[Zöller, Gabriel, Vincent Wall, and Oliver Brock. “Active Acoustic Contact Sensing for Soft Pneumatic Actuators.” In Proceedings of the IEEE International Conference on Robotics and Automation (ICRA). IEEE, 2020.]

This script _records_ data samples for different classes, e.g. contact locations.

In 'USER SETTINGS' define:
BASE_DIR - path where data should be stored
SOUND_NAME - type of active sound to use. choose from SOUNDS or create your own.
CLASS_LABELS - labels of the different prediction classes, e.g. contact locations.
SAMPLES_PER_CLASS - how many samples to record per class
MODEL_NAME - name of the model. is used as folder name.
SHUFFLE_RECORDING_ORDER - whether or not to randomize the recording order

Before running the script, make sure to start QjackCtl.

@author: Vincent Wall, Gabriel Zöller
@copyright 2020 Robotics and Biology Lab, TU Berlin
@licence: BSD Licence
"""
import numpy
import random
import librosa
import os
from matplotlib import pyplot
from matplotlib.widgets import Button
from jacktools.jacksignal import JackSignal

# ==================
# USER SETTINGS
# ==================
BASE_DIR = "."
SOUND_NAME = "sweep"  # sound to use
CLASS_LABELS = ["back", "left", "right", "none"]  # classes to train
SAMPLES_PER_CLASS = 10
MODEL_NAME = "berlinsummit_plus_sweep_1s"
SHUFFLE_RECORDING_ORDER = False
APPEND_TO_EXISTING_FILES = True
SR = 48000

# Example sounds
RECORDING_DELAY_SILENCE = numpy.zeros(int(SR*0.15), dtype='float32')  # the microphone has about .15 seconds delay in recording the sound
SOUNDS = dict({
    "sweep":            numpy.hstack([librosa.core.chirp(20, 20000, SR, duration=1).astype('float32'),
                                      RECORDING_DELAY_SILENCE]),
    "white_noise":      numpy.hstack([numpy.random.uniform(low=-0.999, high=1.0, size=(SR)).astype('float32'),
                                      RECORDING_DELAY_SILENCE]),
    "silence":          numpy.hstack([numpy.zeros((SR,), dtype='float32'), RECORDING_DELAY_SILENCE]),
    })


def main():
    print("Running for model '{}'".format(MODEL_NAME))
    print("Using sound: {}".format(SOUND_NAME))
    print("and classes: {}".format(CLASS_LABELS))

    # check if data was previously recorded
    # ask if want to load or re-record and overwrite
    global DATA_DIR
    DATA_DIR = mkpath(BASE_DIR, MODEL_NAME)

    setup_experiment()
    setup_jack(SOUND_NAME)
    setup_matplotlib()


def setup_experiment():
    global label_list
    global current_idx

    label_list = CLASS_LABELS * SAMPLES_PER_CLASS
    if SHUFFLE_RECORDING_ORDER:
        random.shuffle(label_list)
    current_idx = 0

    if APPEND_TO_EXISTING_FILES:
        if len(glob(DATA_DIR+"/*.wav")) > 0:
            max_id = max([int(x.split("/")[-1].split("_")[0]) for x in glob(DATA_DIR+"/*.wav")])
            label_list = [""]*max_id + label_list
            current_idx = max_id

def setup_jack(sound_name):
    global J
    global Ains
    J = JackSignal("JS")
    print(J.get_state())
    assert J.get_state() >= 0, "Creating JackSignal failed."
    name, sr, period = J.get_jack_info()

    for i in range(CHANNELS):
        J.create_output(i, "out_{}".format(i))
        J.create_input(i, "in_{}".format(i))
        J.connect_input(i, "system:capture_{}".format(i + 1))
        J.connect_output(i, "system:playback_{}".format(i + 1))
    J.silence()

    sound = SOUNDS[sound_name]
    Aouts = [sound] * CHANNELS
    Ains = [numpy.zeros_like(sound, dtype=numpy.float32) for __ in range(CHANNELS)]
    for i in range(CHANNELS):
        J.set_output_data(i, Aouts[i])
        J.set_input_data(i, Ains[i])

    # store active sound for reference
    sound_file = os.path.join(DATA_DIR, "{}_{}.wav".format(0, sound_name))
    scipy.io.wavfile.write(sound_file, SR, sound)
    return J, Aouts, Ains


def setup_matplotlib():
    global LINES
    global TITLE
    global b_rec
    fig, ax = pyplot.subplots(1)
    ax.set_ylim(-1,1)
    pyplot.subplots_adjust(bottom=0.2)
    LINES, = ax.plot(Ains[0])
    ax_back = pyplot.axes([0.59, 0.05, 0.1, 0.075])
    b_back = Button(ax_back, '[B]ack')
    b_back.on_clicked(back)
    ax_rec = pyplot.axes([0.81, 0.05, 0.1, 0.075])
    b_rec = Button(ax_rec, '[R]ecord')
    b_rec.on_clicked(record)
    cid = fig.canvas.mpl_connect('key_press_event', on_key)
    TITLE = ax.set_title(get_current_title())
    pyplot.show()


def on_key(event):
    if event.key == "r":
        record(event)
    elif event.key == "b":
        back(event)


def l(i):
    try:
        return label_list[i]
    except IndexError:
        # print("current_idx: {}, i: {}".format(current_idx, i))
        return ""


def get_current_title():
    name = "Model: {}".format(MODEL_NAME.replace("_", " "))
    labels = "previous: {}   current: [{}]   next: {}".format(l(current_idx-1), l(current_idx), l(current_idx+1))
    number = "#{}/{}: {}".format(current_idx+1, len(label_list), l(current_idx))
    if current_idx >= len(label_list):
        number += "DONE!"
    title = "{}\n{}\n{}".format(name, labels, number)
    return title


def back(event):
    global current_idx
    # switch to previous
    current_idx = max(0, current_idx-1)
    update()


def record(event):
    global current_idx
    if current_idx >= len(label_list):
        print("current_idx: {}  >= len(label_list): {}".format(current_idx, len(label_list)))
        return

    global J
    global Ains
    # touch object and start sound
    # wait for recording
    # store current sound
    # plot current sound
    # switch to next label
    J.process()
    J.wait()
    LINES.set_ydata(Ains[0].reshape(-1))
    store()
    current_idx += 1
    update()


def store():
    sound_file = os.path.join(DATA_DIR, "{}_{}.wav".format(current_idx+1, l(current_idx)))
    scipy.io.wavfile.write(sound_file, SR, Ains[0])


def mkpath(*args):
    """ Takes parts of a path (dir or file), joins them, creates the directory if it doesn't exist and returns the path.
        figure_path = mkpath(PLOT_DIR, "experiment", "figure.svg")
    """
    path = os.path.join(*args)
    if os.path.splitext(path)[1]:  # if path has file extension
        base_path = os.path.split(path)[0]
    else:
        base_path = path
    if not os.path.exists(base_path):
        os.makedirs(base_path)
    return path


def update():
    TITLE.set_text(get_current_title())
    pyplot.draw()


if __name__ == "__main__":
    main()