Skip to content
Snippets Groups Projects
Commit 97084b0e authored by Gencer Sumbul's avatar Gencer Sumbul
Browse files

fix

parent 1849a325
No related branches found
No related tags found
No related merge requests found
......@@ -115,7 +115,7 @@ def select_triplets_with_threshold(label_distances, positive_threshold, negative
# create_positive_mask(positive_indices, batch_size),
# create_negative_mask(negative_indices, batch_size))
def select_random_triplets(label_distances, num_elements, positive_threshold=0.8, negative_threshold=0.2):
def select_random_triplets(label_distances, num_elements, positive_threshold=0.7, negative_threshold=0.3):
batch_size = len(label_distances)
base_mask = np.ndarray((batch_size, batch_size, batch_size), dtype=bool)
......@@ -150,9 +150,9 @@ def select_random_triplets(label_distances, num_elements, positive_threshold=0.8
# negative_mask_temp[a, :, n] = label_distances[a, n] >= negative_threshold
if label_distances[a, n] > negative_threshold:
neg_indices.append(n)
import ipdb; ipdb.set_trace()
positive_mask[a, np.random.choice(pos_indices, size=num_elements, replace=False), :] = True
negative_mask[a, :, np.random.choice(neg_indices, size=num_elements, replace=False)] = True
# import ipdb; ipdb.set_trace()
positive_mask[a, np.random.choice(pos_indices, size=num_elements, replace=True), :] = True
negative_mask[a, :, np.random.choice(neg_indices, size=num_elements, replace=True)] = True
return tf.cast(tf.constant(np.logical_and(positive_mask, negative_mask)), tf.bool)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment