Commit 278242b3 authored by Mahdyar Ravanbakhsh's avatar Mahdyar Ravanbakhsh
Browse files

inital commit

parent f6c9d19f
# CHNR - Unsupervised Cross-Modal Hashing Method Robust to Noisy Training Image-Text Correspondences
# CHNR: An Unsupervised Cross-Modal Hashing Method Robust to Noisy Training Image-Text Correspondences in Remote Sensing
CHNR: An Unsupervised Cross-Modal Hashing Method Robust to Noisy Training Image-Text Correspondences in Remote Sensing
\ No newline at end of file
This repository contains code of the paper "An Unsupervised Cross-Modal Hashing Method Robust to Noisy Training Image-Text Correspondences in Remote Sensing". This work has been done at the [Remote Sensing Image Analysis group](https://www.rsim.tu-berlin.de/menue/remote_sensing_image_analysis_group/) by [Georgii Mikriukov](https://www.rsim.tu-berlin.de/index.php?id=217673), [Mahdyar Ravanbakhsh](https://rsim.berlin/team/members/mahdyar-ravanbakhsh) and [Begüm Demir](https://begumdemir.com/).
> Paper placeholder
If you use the code from this repository in your research, please cite the following paper:
```
Bibtex placeholder
```
---
## Structure
![structure.png](images/main_diagram_noisy_3.png)
---
## Requirements
* Python 3.8
* PyTorch 1.8
* Transformers 4.4
Libraries installation:
```
pip install -r requirements.txt
```
---
## Data
[Augmented and non-augmented image and caption features](https://tubcloud.tu-berlin.de/s/DykEC54PxRM93TP) for UCMerced and RSICD datasets encoded with ResNet18 and BERT respectively. Insert them to `./data/` folder.
---
## Configs
`./configs/base_config.py`
Base configuration class (inherited by other configs):
* CUDA device
* seed
* data and dataset paths
`./configs/config.py`
DUCH-NR learning configuration:
* learning perparameters
---
## Learning
```
main.py [-h] [--test] [--bit BIT] [--model MODEL] [--epochs EPOCHS]
[--tag TAG] [--dataset DATASET] [--preset PRESET]
[--alpha ALPHA] [--beta BETA] [--gamma GAMMA]
[--contrastive-weights CONTRASTIVE_WEIGHTS CONTRASTIVE_WEIGHTS CONTRASTIVE_WEIGHTS]
[--img-aug-emb IMG_AUG_EMB] [--txt-aug-emb TXT_AUG_EMB]
[--noise-wrong-caption NOISE_WRONG_CAPTION]
[--clean-captions CLEAN_CAPTIONS]
[--noise-weights {normal,exp,dis,ones}]
[--clean-epochs CLEAN_EPOCHS]
optional arguments:
-h, --help show this help message and exit
--test train or test
--bit BIT hash bit
--model MODEL model type
--epochs EPOCHS training epochs
--tag TAG model tag
--dataset DATASET ucm or rsicd
--preset PRESET data presets, see available in config.py
--alpha ALPHA alpha hyperparameter (La)
--beta BETA beta hyperparameter (Lq)
--gamma GAMMA gamma hyperparameter (Lbb)
--contrastive-weights CONTRASTIVE_WEIGHTS CONTRASTIVE_WEIGHTS CONTRASTIVE_WEIGHTS
contrastive loss component weights: [inter, intra_img,
intra_txt]
--img-aug-emb IMG_AUG_EMB
overrides augmented image embeddings file (u-curve)
--txt-aug-emb TXT_AUG_EMB
overrides augmented text embeddings file (noise)
--noise-wrong-caption NOISE_WRONG_CAPTION
probability of 'wrong caption' noise
--clean-captions CLEAN_CAPTIONS
size of the clean dataset for meta-training captions
in dataset
--noise-weights {normal,exp,dis,ones}
sample weight types: normal, exponential, discrete or
1
--clean-epochs CLEAN_EPOCHS
number of meta-training epochs
```
Examples:
1. Train model for 64 bits hash codes retrieval on UCM data. Clean dataset for meta-training is 30% of training data. 20% of noise will be injected in training data. Use discrete joint-feature weights
```
main.py --dataset ucm --bit 64 --tag my_model --noise-wrong-caption 0.5 --clean-captions 0.3 --noise-weights dis
```
2. Train model for 64 bits hash codes retrieval on UCM data. Clean dataset for meta-training is 20% of training data. 20% of noise will be injected in training data. Use normal joint-feature weights
```
main.py --dataset ucm --bit 64 --tag my_model --noise-wrong-caption 0.2 --clean-captions 0.2 --noise-weights normal
```
---
## License
The code is available under the terms of MIT license:
```
Copyright (c) 2021 Georgii Mikriukov
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```
\ No newline at end of file
import pickle
import os
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
def open_pkl(path):
with open(path, 'rb') as f:
return pickle.load(f)
dataset = 'ucm' # 'ucm', 'rsicd'
clean = 0.2 if dataset == 'rsicd' else 0.3
probs = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
bit = 64
markers = ['s', 'd', 'v', 'o', 'H', 'p', '^']
colors = ['red', 'purple', 'blue', 'green', 'orange', 'brown', 'palevioletred']
models = ['duch', 'duch', 'duch', 'duch', 'duch', 'jdsh', 'djsrh']
tags = ['default', 'default', 'default', 'default', 'default_baseline', 'noisy', 'noisy']
weights = ['normal', 'exp', 'dis', 'ones', 'ones', 'normal', 'normal']
weight_names = ['DUCH-NR-NW', 'DUCH-NR-EW', 'DUCH-NR-DW', 'DUCH-PTC', 'DUCH', 'JDSH', 'DJSRH']
map_names = ['I \u2192 T', 'T \u2192 I'] # ['i2t', 't2i', 'i2i', 't2t'] # ['i2t', 't2i', 'i2i', 't2t', 'avg']
mns = ['i2t', 't2i']
experimnet_names = ['k', 'k', 'k', 'hr', 'hr']
experiment_val = [5, 10, 20, 0, 5]
paths = {'duch': r'/home/george/Code/noisy_captions/checkpoints',
'jdsh': r'/home/george/Code/jdsh_noisy/checkpoints',
'djsrh': r'/home/george/Code/jdsh_noisy/checkpoints'}
data = {}
# read data
for weight, tag, model in zip(weights, tags, models):
for prob in probs:
token = model + weight + str(prob) + tag
if model == 'duch':
folder = '_'.join([dataset, str(bit), tag, str(prob), str(clean), weight])
if model in ['jdsh', 'djsrh']:
folder = '_'.join([model, str(bit), dataset, tag, weight, str(prob)]).upper()
data[token] = open_pkl(os.path.join(paths[model], folder, 'maps_eval.pkl'))
final_table = [['probs'] + probs]
plot_num = ['(a)', '(b)']
ylims = {'ucm': [0.3, 1.0], 'rsicd': [0.3, 1.0]}
for i, map in enumerate(map_names):
fig = plt.figure(figsize=(10, 10))
ax = plt.subplot(1, 1, 1)
final_table.append([map])
for weight, weight_name, tag, model, color, marker in zip(weights, weight_names, tags, models, colors, markers):
y = []
for prob in probs:
token = model + weight + str(prob) + tag
y.append(data[token][2][i])
ax.plot(probs, y, label=weight_name, color=color, marker=marker, markersize=12, mew=1, mec='black', linewidth=3)
final_table.append([weight_name] + y)
ax.legend(fontsize=24, framealpha=0.5)
plt.xticks(rotation=45)
ax.xaxis.set_major_locator(MultipleLocator(0.05))
#ax.xaxis.set_minor_locator(AutoMinorLocator(5))
ax.grid(axis='both', which='major', alpha=0.8, linestyle='-')
ax.grid(axis='both', which='minor', alpha=0.4, linestyle=':')
ax.set_ylim(ylims[dataset])
#ax.set_xlim([0.05, 0.5])
ax.tick_params(axis='both', labelsize=22)
plt.xlabel('Noise injection rate', size=30)
plt.ylabel('mAP@20', size=30)
plt.title(map.upper(), fontsize=30, fontweight='bold')
#plt.suptitle(dataset.upper(), size=20, weight='medium')
plt.tight_layout()
print(os.path.join('plots', '{}_{}_{}_{}.png'.format('noise', dataset, clean, mns[i])))
plt.savefig(os.path.join('plots', '{}_{}_{}_{}.png'.format('noise', dataset, clean, mns[i])))
plt.savefig(os.path.join('plots', '{}_{}_{}_{}.pdf'.format('noise', dataset, clean, mns[i])))
df = pd.DataFrame(final_table)
print('plots/plots_table_{}.csv'.format(dataset))
df.to_csv('plots/plots_table_{}.csv'.format(dataset), index=False, header=False)
\ No newline at end of file
class BaseConfig:
def __init__(self, args):
super(BaseConfig, self).__init__()
self.seed = 42 # random seed
self.cuda_device = 'cuda:0' # CUDA device to use
self.dataset = args.dataset.lower()
self.dataset_json_file = "./data/augmented_{}.json".format(args.dataset.upper())
if args.dataset == 'ucm':
self.dataset_image_folder_path = "/home/george/Dropbox/UCM_Captions/images"
if args.dataset == 'rsicd':
self.dataset_image_folder_path = "/home/george/Dropbox/RSICD/images"
self._print_config()
def _print_config(self):
print('Configuration:', self.__class__.__name__)
for v in self.__dir__():
if not v.startswith('_'):
print('\t{0}: {1}'.format(v, getattr(self, v)))
\ No newline at end of file
import argparse
from .base_config import BaseConfig
parser = argparse.ArgumentParser(description='')
parser.add_argument('--test', default=False, help='train or test', action='store_true')
parser.add_argument('--bit', default=64, help='hash bit', type=int)
parser.add_argument('--model', default='UNHD', help='model type', type=str)
parser.add_argument('--epochs', default=150, help='training epochs', type=int)
parser.add_argument('--tag', default='test', help='model tag', type=str)
parser.add_argument('--dataset', default='ucm', help='ucm or rsicd', type=str)
parser.add_argument('--preset', default='clean', help='data presets, see available in config.py', type=str)
parser.add_argument('--alpha', default=0, help='alpha hyperparameter (La)', type=float)
parser.add_argument('--beta', default=0.001, help='beta hyperparameter (Lq)', type=float)
parser.add_argument('--gamma', default=0, help='gamma hyperparameter (Lbb)', type=float)
parser.add_argument('--contrastive-weights', default=[1.0, 0.0, 0.0], type=float, nargs=3,
help='contrastive loss component weights: [inter, intra_img, intra_txt]')
parser.add_argument('--img-aug-emb', default=None, type=str, help='overrides augmented image embeddings file (u-curve)')
parser.add_argument('--txt-aug-emb', default=None, type=str, help='overrides augmented text embeddings file (noise)')
parser.add_argument('--noise-wrong-caption', default=.5, type=float, help="probability of 'wrong caption' noise")
parser.add_argument('--clean-captions', default=.2, type=float, help="size of the clean dataset for meta-training captions in dataset")
parser.add_argument('--noise-weights', default='normal', type=str, choices=['normal', 'exp', 'dis', 'ones'], help="sample weight types: normal, exponential, discrete or 1")
parser.add_argument('--clean-epochs', default=75, help='number of meta-training epochs', type=int)
args = parser.parse_args()
dataset = args.dataset
preset = args.preset
alpha = args.alpha
beta = args.beta
gamma = args.gamma
contrastive_weights = args.contrastive_weights
wrong_noise_caption_prob = args.noise_wrong_caption
clean_captions = args.clean_captions
noise_weights = args.noise_weights
clean_epochs = args.clean_epochs
class ConfigModel(BaseConfig):
preset = preset.lower()
if preset == 'clean':
# default for texts
image_emb_for_model = "./data/image_emb_{}_aug_center_crop_only.h5".format(dataset.upper())
caption_emb_for_model = "./data/caption_emb_{}_aug.h5".format(dataset.upper())
image_emb_aug_for_model = "./data/image_emb_{}_aug_aug_center.h5".format(dataset.upper())
caption_emb_aug_for_model = "./data/caption_emb_{}_aug.h5".format(dataset.upper())
dataset_json_for_model = "./data/augmented_{}.json".format(dataset.upper())
else:
raise Exception('Nonexistent preset: {}'.format(preset))
if args.img_aug_emb is not None:
image_emb_aug_for_model = args.img_aug_emb
if args.txt_aug_emb is not None:
caption_emb_aug_for_model = args.txt_aug_emb
if dataset == 'ucm':
label_dim = 21
# dataset settings
dataset_file = "../data/dataset_UCM_aug_captions_images.h5" # Resulting dataset file
dataset_train_split = 0.5 # part of all data, that will be used for training
# (1 - dataset_train_split) - evaluation data
dataset_query_split = 0.2 # part of evaluation data, that will be used for query
# (1 - dataset_train_split) * (1 - dataset_query_split) - retrieval data
if dataset == 'rsicd':
label_dim = 31
# dataset settings
dataset_file = "../data/dataset_RSICD_aug_captions_images.h5" # Resulting dataset file
dataset_train_split = 0.5 # part of all data, that will be used for training
# (1 - dataset_train_split) - evaluation data
dataset_query_split = 0.2 # part of evaluation data, that will be used for query
# (1 - dataset_train_split) * (1 - dataset_query_split) - retrieval data
build_plots = False
wrong_noise_caption_prob = wrong_noise_caption_prob
clean_captions = clean_captions
noise_weights = noise_weights
model_type = 'UNHD'
batch_size = 256
image_dim = 512
text_dim = 768
hidden_dim = 1024 * 4
hash_dim = 128
noise_dim = image_dim + text_dim
lr = 0.0001
clean_epochs = clean_epochs
max_epoch = 100
valid = True # validation
valid_freq = 150 # validation frequency (epochs)
alpha = alpha # adv loss
beta = beta # quant loss
gamma = gamma # bb loss
contrastive_weights = contrastive_weights # [inter, intra_img, intra_txt]
retrieval_map_k = 5
tag = 'test'
def __init__(self, args):
super(ConfigModel, self).__init__(args)
self.test = args.test
self.hash_dim = args.bit
self.model_type = args.model
self.max_epoch = args.epochs
self.tag = args.tag
cfg = ConfigModel(args)
This diff is collapsed.
This diff is collapsed.
import random
from torch.utils.data import DataLoader
from configs.config import cfg
from utils import read_hdf5, read_json, get_labels
import numpy as np
class DataHandler:
def __init__(self):
super().__init__()
def load_train_query_db_data(self):
"""
Load and split (train, query, db)
:return: tuples of (images, captions, labels), each element is array
"""
random.seed(cfg.seed)
images, captions, labels = load_dataset()
train, query, db = self.split_data(images, captions, labels)
return train, query, db
@staticmethod
def split_data(images, captions, labels):
"""
Split dataset to get training, query and db subsets
:param: images: image embeddings array
:param: captions: caption embeddings array
:param: labels: labels array
:return: tuples of (images, captions, labels), each element is array
"""
idx_tr, idx_q, idx_db = get_split_idxs(len(images))
idx_tr_cap, idx_q_cap, idx_db_cap = get_caption_idxs(idx_tr, idx_q, idx_db)
train = images[idx_tr], captions[idx_tr_cap], labels[idx_tr], (idx_tr, idx_tr_cap)
query = images[idx_q], captions[idx_q_cap], labels[idx_q], (idx_q, idx_q_cap)
db = images[idx_db], captions[idx_db_cap], labels[idx_db], idx_db, (idx_db, idx_db_cap)
return train, query, db
class DataHandlerAugmentedTxt:
def __init__(self):
super().__init__()
def load_train_query_db_data(self):
"""
Load and split (train, query, db)
:return: tuples of (images, captions, labels), each element is array
"""
random.seed(cfg.seed)
images, captions, labels, captions_aug = load_dataset(txt_aug=True)
train, query, db = self.split_data(images, captions, labels, captions_aug)
return train, query, db
@staticmethod
def split_data(images, captions, labels, captions_aug):
"""
Split dataset to get training, query and db subsets
:param: images: image embeddings array
:param: captions: caption embeddings array
:param: labels: labels array
:param: captions_aug: augmented caption embeddings
:return: tuples of (images, captions, labels), each element is array
"""
idx_tr, idx_q, idx_db = get_split_idxs(len(images))
idx_tr_cap, idx_q_cap, idx_db_cap = get_caption_idxs(idx_tr, idx_q, idx_db)
train = images[idx_tr], captions[idx_tr_cap], labels[idx_tr], (idx_tr, idx_tr_cap), captions_aug[idx_tr_cap]
query = images[idx_q], captions[idx_q_cap], labels[idx_q], (idx_q, idx_q_cap), captions_aug[idx_q_cap]
db = images[idx_db], captions[idx_db_cap], labels[idx_db], (idx_db, idx_db_cap), captions_aug[idx_db_cap]
return train, query, db
class DataHandlerAugmentedTxtImg:
def __init__(self):
super().__init__()
def load_train_query_db_data(self):
"""
Load and split (train, query, db)
:return: tuples of (images, captions, labels), each element is array
"""
random.seed(cfg.seed)
images, captions, labels, captions_aug, images_aug = load_dataset(img_aug=True, txt_aug=True)
train, query, db = self.split_data(images, captions, labels, captions_aug, images_aug)
return train, query, db
@staticmethod
def split_data(images, captions, labels, captions_aug, images_aug):
"""
Split dataset to get training, query and db subsets
:param: images: image embeddings array
:param: captions: caption embeddings array
:param: labels: labels array
:param: captions_aug: augmented caption embeddings
:param: images_aug: augmented image embeddings
:return: tuples of (images, captions, labels), each element is array
"""
idx_tr, idx_q, idx_db = get_split_idxs(len(images))
idx_tr_cap, idx_q_cap, idx_db_cap = get_caption_idxs(idx_tr, idx_q, idx_db)
train = images[idx_tr], captions[idx_tr_cap], labels[idx_tr], (idx_tr, idx_tr_cap), captions_aug[idx_tr_cap], \
images_aug[idx_tr]
query = images[idx_q], captions[idx_q_cap], labels[idx_q], (idx_q, idx_q_cap), captions_aug[idx_q_cap], \
images_aug[idx_q]
db = images[idx_db], captions[idx_db_cap], labels[idx_db], (idx_db, idx_db_cap), captions_aug[idx_db_cap], \
images_aug[idx_db]
return train, query, db
def load_dataset(img_aug=False, txt_aug=False):
"""
Load dataset
:return: images and captions embeddings, labels
"""
images = read_hdf5(cfg.image_emb_for_model, 'image_emb', normalize=True)
captions = read_hdf5(cfg.caption_emb_for_model, 'caption_emb', normalize=True)
labels = np.array(get_labels(read_json(cfg.dataset_json_for_model), suppress_console_info=True))
if img_aug and txt_aug:
captions_aug = read_hdf5(cfg.caption_emb_aug_for_model, 'caption_emb', normalize=True)
images_aug = read_hdf5(cfg.image_emb_aug_for_model, 'image_emb', normalize=True)
return images, captions, labels, captions_aug, images_aug
elif img_aug:
images_aug = read_hdf5(cfg.image_emb_aug_for_model, 'image_emb', normalize=True)
return images, captions, labels, images_aug
elif txt_aug:
captions_aug = read_hdf5(cfg.caption_emb_aug_for_model, 'caption_emb', normalize=True)
return images, captions, labels, captions_aug
else:
return images, captions, labels
def get_split_idxs(arr_len):
"""
Get indexes for training, query and db subsets
:param: arr_len: array length
:return: indexes for training, query and db subsets
"""
idx_all = list(range(arr_len))
idx_train, idx_eval = split_indexes(idx_all, cfg.dataset_train_split)
idx_query, idx_db = split_indexes(idx_eval, cfg.dataset_query_split)
return idx_train, idx_query, idx_db
def split_indexes(idx_all, split):
"""
Splits list in two parts.
:param idx_all: array to split
:param split: portion to split
:return: splitted lists
"""
idx_length = len(idx_all)
selection_length = int(idx_length * split)
idx_selection = sorted(random.sample(idx_all, selection_length))
idx_rest = sorted(list(set(idx_all).difference(set(idx_selection))))
return idx_selection, idx_rest
def get_caption_idxs(idx_train, idx_query, idx_db):
"""
Get caption indexes.
:param: idx_train: train image (and label) indexes
:param: idx_query: query image (and label) indexes
:param: idx_db: db image (and label) indexes
:return: caption indexes for corresponding index sets
"""
idx_train_cap = get_caption_idxs_from_img_idxs(idx_train)
idx_query_cap = get_caption_idxs_from_img_idxs(idx_query)
idx_db_cap = get_caption_idxs_from_img_idxs(idx_db)
return idx_train_cap, idx_query_cap, idx_db_cap
def get_caption_idxs_from_img_idxs(img_idxs):
"""
Get caption indexes. There are 5 captions for each image (and label).
Say, img indexes - [0, 10, 100]
Then, caption indexes - [0, 1, 2, 3, 4, 50, 51, 52, 53, 54, 100, 501, 502, 503, 504]
:param: img_idxs: image (and label) indexes
:return: caption indexes
"""
caption_idxs = []
for idx in img_idxs:
for i in range(5): # each image has 5 captions
caption_idxs.append(idx * 5 + i)
return caption_idxs
def get_dataloaders(data_handler, ds_train, ds_train_clean, ds_query, ds_db):