Commit c0b5bdc5 authored by shoefer's avatar shoefer
Browse files

implemented new iterator for pairwise similarity

parent 7f383994
......@@ -55,21 +55,24 @@ def load_dataset(tr_data, tr_data_url, test_data, test_data_url):
X_valid = npz_train['X_valid']
Y_valid = np.cast['int32'](npz_train['Y_valid'])
#C_valid = npz_train['C_valid']
C_valid = npz_train['C_valid']
X_test = npz_test['X_test']
Y_test = np.cast['int32'](npz_test['Y_test'])
#C_test = npz_train['C_test']
C_test = npz_test['C_test']
return (npz_train, npz_test),\
(X, Y, C, X_valid, Y_valid, X_test, Y_test)
(X, Y, C, X_valid, Y_valid, C_valid, X_test, Y_test, C_test)
def load_direct_context_dataset():
tr_data = "cl_synth_direct_d-50_e-0_n-500_seed-12340.npz"
tr_data_url = ""
test_data = "cl_synth_direct_d-50_e-0_ntest-50000_seed-12340.npz"
test_data_url = ""
return load_dataset(tr_data, tr_data_url, test_data, test_data_url)[1]
_, res = load_dataset(tr_data, tr_data_url, test_data, test_data_url)
(X, Y, C, X_valid, Y_valid, C_valid, X_test, Y_test, C_test) = res
return (X, Y, C, X_valid, Y_valid, X_test, Y_test)
def load_embedding_context_dataset():
tr_data = "cl_synth_embedding_d-50_e-25_n-500_seed-12340.npz"
......@@ -83,7 +86,8 @@ def load_embedding_context_dataset():
# sanity check
assert (np.mean (npz_train['Q'] - npz_test['Q']) == 0.)
return res
(X, Y, C, X_valid, Y_valid, C_valid, X_test, Y_test, C_test) = res
return (X, Y, C, X_valid, Y_valid, X_test, Y_test)
def load_relative_context_dataset():
tr_data = "cl_synth_relative_d-50_e-0_n-500_seed-12340.npz"
......@@ -91,7 +95,7 @@ def load_relative_context_dataset():
test_data = "cl_synth_relative_d-50_e-0_ntest-50000_seed-12340.npz"
test_data_url = ""
(_,_), (X, Y, C, X_valid, Y_valid, X_test, Y_test) \
(_,_), (X, Y, C, X_valid, Y_valid, C_valid, X_test, Y_test, C_test) \
= load_dataset(tr_data, tr_data_url, test_data, test_data_url)
# the context training data C contains stacked "x_j" and "y_ij"
......@@ -102,7 +106,26 @@ def load_relative_context_dataset():
return X, Y, CX, CY, X_valid, Y_valid, X_test, Y_test
def load_similarity_context_dataset():
return load_relative_context_dataset()
tr_data = "cl_synth_relative_d-50_e-0_n-500_seed-12340.npz"
tr_data_url = ""
test_data = "cl_synth_relative_d-50_e-0_ntest-50000_seed-12340.npz"
test_data_url = ""
(_,_), (X, Y, C, X_valid, Y_valid, C_valid, X_test, Y_test, C_test) \
= load_dataset(tr_data, tr_data_url, test_data, test_data_url)
# the context training data C contains stacked "x_j" and "y_ij"
# which are aligned with the x_i in matrix X
CX = C[:, :X.shape[1]]
CY = C[:, X.shape[1]:]
CX_valid = C_valid[:, :X.shape[1]]
CY_valid = C_valid[:, X.shape[1]:]
CX_test = C_test[:, :X.shape[1]]
CY_test = C_test[:, X.shape[1]:]
return X, Y, CX, CY, X_valid, Y_valid, CX_valid, CY_valid, X_test, Y_test, CX_test, CY_test
# ############################# Helper functions #################################
def build_linear_simple(input_layer, n_out, nonlinearity=None, name=None):
......@@ -244,7 +267,9 @@ def iterate_pairwise_transformation_aligned_minibatches(inputs, targets, batchsi
def build_pw_similarity_pattern(input_var, target_var,
context_sim_i, context_sim_j, context_dissim_i, context_dissim_j,
n, m, num_classes):
n, m, num_classes,
sim_weight=0.5, dissim_weight=0.5, min_margin=1.
):
input_layer = lasagne.layers.InputLayer(shape=(None, n),
input_var=input_var)
phi = build_linear_simple( input_layer, m, name="phi")
......@@ -258,11 +283,13 @@ def build_pw_similarity_pattern(input_var, target_var,
context_dissim_i=context_dissim_i,
context_dissim_j=context_dissim_j,
sim_dissim_type='margin',
sim_weight=sim_weight,
dissim_weight=dissim_weight,
min_margin=min_margin,
)
return psp
#def iterate_pairwise_ordered_minibatches(inputs, inputs_for_context, targets, targets_for_context, batchsize, shuffle=False):
def iterate_pairwise_ordered_minibatches(inputs, targets, inputs_for_context, batchsize, shuffle=False):
""" Iterator for pairwise similarity pattern, assuming input_for_context
to contain an ordered sequence of input samples; subsequent samples
......@@ -296,18 +323,70 @@ def iterate_pairwise_ordered_minibatches(inputs, targets, inputs_for_context, ba
excerpt_consec2 = indices_consec2[start_idx:start_idx + batchsize]
yield inputs[excerpt], targets[excerpt], \
inputs_for_context[excerpt_rnd1], inputs_for_context[excerpt_rnd2], \
inputs_for_context[excerpt_consec1], inputs_for_context[excerpt_consec2],
inputs_for_context[excerpt_consec1], inputs_for_context[excerpt_consec2], \
inputs_for_context[excerpt_rnd1], inputs_for_context[excerpt_rnd2],
def iterate_pairwise_binary_target_minibatches(inputs, targets, inputs_for_context, targets_for_context, batchsize, shuffle=False):
""" Iterator for pairwise similarity pattern.
Here we use some binary label information (targets_for_context) for
distinguishing. This can be easily adapted for non-binary targets.
Note that inputs and inputs_for_context can be identical - we separate
them to give more flexibility, in case few labels (targets) are
available.
"""
assert (len(inputs) == len(targets))
indices = np.arange(len(inputs))
assert (len(inputs_for_context) == len(targets_for_context))
indices_lbl0 = np.where(targets_for_context == 0)[0]
indices_lbl1 = np.where(targets_for_context == 1)[0]
indices_consec1 = np.concatenate([indices_lbl0[:np.floor(len(indices_lbl0)/2)],
indices_lbl1[:np.floor(len(indices_lbl1)/2)] ])
indices_consec2 = np.concatenate([indices_lbl0[np.ceil(len(indices_lbl0)/2):],
indices_lbl1[np.ceil(len(indices_lbl1)/2):] ])
# indices_consec1, indices_consec2, = map(np.asarray, zip(*np.random.permutation(
# zip(indices_consec1, indices_consec2))))
np.random.shuffle(indices_lbl0)
np.random.shuffle(indices_lbl1)
len_min = np.min( [ len(indices_lbl0), len(indices_lbl1) ])
indices_rnd1 = indices_lbl0[:len_min]
indices_rnd2 = indices_lbl1[:len_min]
if shuffle:
np.random.shuffle(indices)
for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
excerpt = indices[start_idx:start_idx + batchsize]
excerpt_rnd1 = indices_rnd1[start_idx:start_idx + batchsize]
excerpt_rnd2 = indices_rnd2[start_idx:start_idx + batchsize]
excerpt_consec1 = indices_consec1[start_idx:start_idx + batchsize]
excerpt_consec2 = indices_consec2[start_idx:start_idx + batchsize]
# print ("-----------")
# print (excerpt_rnd1, excerpt_rnd2, excerpt_consec1, excerpt_consec2)
# print ( map(lambda x: targets_for_context[x[0]], (excerpt_rnd1, excerpt_rnd2, excerpt_consec1, excerpt_consec2)))
yield inputs[excerpt], targets[excerpt], \
inputs_for_context[excerpt_consec1], inputs_for_context[excerpt_consec2], \
inputs_for_context[excerpt_rnd1], inputs_for_context[excerpt_rnd2],
# ########################## Main ###############################
#def main(pattern_type, data, num_epochs=500, batchsize=50):
if __name__ == "__main__":
pattern_type="pairwise_similarity"
data='relative'
num_epochs=500
batchsize=50
def main(pattern_type, data, num_epochs=500, batchsize=50):
#if __name__ == "__main__":
# pattern_type="pairwise_similarity"
# data='relative'
# num_epochs=500
# batchsize=50
#theano.config.on_unused_input = 'ignore'
print ("Pattern: %s" % pattern_type)
......@@ -343,12 +422,12 @@ if __name__ == "__main__":
print("Loading embedding data...")
X_train, y_train, C_train, X_val, y_val, X_test, y_test = load_embedding_context_dataset()
# input dimension of X
n = X_train.shape[1]
# dimensionality of C
m = C_train.shape[1]
# dimensionality of intermediate representation S
d = 1
# input dimension of X
n = X_train.shape[1]
# dimensionality of C
m = C_train.shape[1]
# dimensionality of intermediate representation S
d = 1
if pattern_type == "direct":
# d == m
......@@ -397,7 +476,8 @@ if __name__ == "__main__":
elif pattern_type in ['pairwise_similarity']:
# Load the dataset
print("Loading similarity data...")
X_train, y_train, CX_train, Cy_train, X_val, y_val, X_test, y_test = load_similarity_context_dataset()
X_train, y_train, CX_train, Cy_train, X_val, y_val, CX_val, Cy_val,\
X_test, y_test, CX_test, Cy_test = load_similarity_context_dataset()
context_sim_i = T.matrix('context_sim_i')
context_sim_j = T.matrix('context_sim_j')
......@@ -410,14 +490,22 @@ if __name__ == "__main__":
m = Cy_train.shape[1]
pattern = build_pw_similarity_pattern(input_var, target_var, \
context_sim_i, context_sim_j, context_dissim_i, context_dissim_j, n, m, num_classes)
context_sim_i, context_sim_j, context_dissim_i, context_dissim_j,
n, m, num_classes,
sim_weight=.5, dissim_weight=.5,
min_margin=5)
# using temporal coherence
iterate_context_minibatches = iterate_pairwise_ordered_minibatches
iterate_context_minibatches_args = (X_train, y_train, X_test, batchsize, True)
# using labels
# iterate_context_minibatches = iterate_pairwise_binary_target_minibatches
# iterate_context_minibatches_args = (X_train, y_train, X_test, y_test, batchsize, True)
train_fn_inputs = [input_var, target_var, \
context_sim_i, context_sim_j, context_dissim_i, context_dissim_j ]
context_sim_i, context_sim_j, context_dissim_i, context_dissim_j, ]
learning_rate=0.00001
loss_weights = {'target_weight':0.75, 'context_weight':0.25}
learning_rate=0.0001
loss_weights = {'target_weight':0., 'context_weight':1.0}
# ------------------------------------------------------
......@@ -497,19 +585,19 @@ if __name__ == "__main__":
print(" test accuracy:\t\t{:.2f} %".format(
test_acc / test_batches * 100))
# return pattern
return pattern
# ------------------------------------------------------
#if __name__ == '__main__':
# parser = argparse.ArgumentParser()
# parser.add_argument("pattern", type=str, help="which pattern to use",
# default='direct',
# choices=['direct', 'multitask', 'multiview', 'pairwise_transformation', 'pairwise_similarity'])
# parser.add_argument("data", type=str, help="which context data to load",
# default='direct',
# choices=['direct', 'embedding', 'relative', 'similarity'])
# parser.add_argument("--num_epochs", type=int, help="number of epochs for SGD", default=500, required=False)
# parser.add_argument("--batchsize", type=int, help="batch size for SGD", default=50, required=False)
# args = parser.parse_args()
#
# pattern = main(args.pattern, args.data, args.num_epochs, args.batchsize)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("pattern", type=str, help="which pattern to use",
default='direct',
choices=['direct', 'multitask', 'multiview', 'pairwise_transformation', 'pairwise_similarity'])
parser.add_argument("data", type=str, help="which context data to load",
default='direct',
choices=['direct', 'embedding', 'relative', 'similarity'])
parser.add_argument("--num_epochs", type=int, help="number of epochs for SGD", default=500, required=False)
parser.add_argument("--batchsize", type=int, help="batch size for SGD", default=50, required=False)
args = parser.parse_args()
pattern = main(args.pattern, args.data, args.num_epochs, args.batchsize)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment