Commit 4511f5a4 authored by Gencer Sumbul's avatar Gencer Sumbul
Browse files

final changes

parent 11a4a36a
......@@ -7,7 +7,5 @@
"embed_model": "densenet169",
"decoder_lr": 0.001,
"batch_size": 384,
"print_freq": 1,
"data_mean": [101.77617734389132, 104.53789821808728, 94.61747324265178],
"data_std": [38.472136790400675, 35.23932449267246, 34.110112414874365]
"print_freq": 1
}
\ No newline at end of file
......@@ -5,7 +5,7 @@ import json
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from helpers import (
ConfigFactory, Vocabulary, StateFactory, Tester, SummarizationProvider
ConfigFactory, Vocabulary, StateFactory, Tester
)
......@@ -40,14 +40,7 @@ def evaluate(config_file):
results = []
summarization_provider = SummarizationProvider(
config, tester.dataset, vocab, verbose=False
)
for img_id, prediction, tokenized_ref in zip(img_ids, predictions, references):
summarized = summarization_provider.image_index[img_id].to_dense()
summarized = ids_to_sentence(vocab, summarized.argmax(1).tolist())
img_bleus = [
corpus_bleu([tokenized_ref], [prediction],
weights=weight, smoothing_function=smoothing_function)
......@@ -56,7 +49,6 @@ def evaluate(config_file):
results.append({
'img_id': img_id,
'summarized': summarized,
'prediction': ids_to_sentence(vocab, prediction),
'bleu_scores': img_bleus,
'references': references_index[img_id]
......
......@@ -10,29 +10,16 @@ DEFAULTS = {
'vocabulary_size': 50000,
'emb_dim': 512,
'decoder_dim': 512,
'dropout': 0.5,
'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
'input_img_size': 224,
'start_epoch': 0,
'embed_model': 'resnet152',
'epochs': 150,
'batch_size': 512,
'decoder_lr': 8e-5,
'grad_clip': 5.,
'alpha_c': 1.,
'data_mean': [101.77617734389132, 104.53789821808728, 94.61747324265178],
'data_std': [38.472136790400675, 35.23932449267246, 34.110112414874365],
'print_freq': 100,
'checkpoint': 'auto',
'checkpoints_path': 'checkpoints',
'results_file': 'results.json',
'summarization_model_path': '../models/summarization.tar',
'summarization_hidden_dim': 512,
'summarization_emb_dim': 256,
'summarization_max_enc_steps': 55,
'summarization_max_dec_steps': 7,
'summarization_rand_unif_init_mag': 0.02,
'summarization_trunc_norm_init_std': 1e-4
}
Config = namedtuple('Config', list(DEFAULTS.keys()))
......
......@@ -47,7 +47,8 @@ class StateFactory():
embed_dim=config.emb_dim,
decoder_dim=config.decoder_dim,
vocab_size=vocab_size,
dropout=config.dropout )
dropout=0.5
)
decoder_optimizer = torch.optim.Adam(
params=filter(lambda param: param.requires_grad, decoder.parameters()),
......
......@@ -101,10 +101,12 @@ class Tester(object):
return bleu4
def __build_test_loader(self):
data_mean = [0.485, 0.456, 0.406]
data_std = [0.229, 0.224, 0.225]
transform = transforms.Compose([
transforms.Resize(self.config.input_img_size,interpolation=3),
transforms.Resize(224,interpolation=3),
transforms.ToTensor(),
transforms.Normalize(self.config.data_mean, self.config.data_std)
transforms.Normalize(data_mean, data_std)
])
dataset = CaptioningDataset(
......
......@@ -54,10 +54,9 @@ class Trainer():
self.state.encoder_optimizer.zero_grad()
loss.backward()
if self.config.grad_clip is not None:
clip_gradient(self.state.decoder_optimizer, self.config.grad_clip)
if self.state.encoder_optimizer is not None:
clip_gradient(self.state.encoder_optimizer, self.config.grad_clip)
clip_gradient(self.state.decoder_optimizer, 5.0)
if self.state.encoder_optimizer is not None:
clip_gradient(self.state.encoder_optimizer, 5.0)
self.state.decoder_optimizer.step()
if self.state.encoder_optimizer is not None:
......@@ -73,11 +72,13 @@ class Trainer():
del final_probs
def __build_train_loader(self):
data_mean = [0.485, 0.456, 0.406]
data_std = [0.229, 0.224, 0.225]
transform = transforms.Compose([
transforms.Resize(self.config.input_img_size,interpolation=3),
transforms.Resize(224,interpolation=3),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(self.config.data_mean, self.config.data_std)
transforms.Normalize(data_mean, data_std)
])
dataset = CaptioningDataset(
......
......@@ -82,11 +82,13 @@ class Validator():
return bleu4
def __build_val_loader(self):
data_mean = [0.485, 0.456, 0.406]
data_std = [0.229, 0.224, 0.225]
transform = transforms.Compose([
transforms.Resize(self.config.input_img_size,interpolation=3),
transforms.Resize(224,interpolation=3),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(self.config.data_mean, self.config.data_std)
transforms.Normalize(data_mean, data_std)
])
dataset = CaptioningDataset(
......
......@@ -3,13 +3,16 @@ import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.nn.functional import relu, softmax
HIDDEN_DIM = 512
EMBED_DIM = 256
def init_lstm_wt(lstm, config):
for name, _ in lstm.named_parameters():
if 'weight' in name:
weight = getattr(lstm, name)
weight.data.uniform_(
-config.summarization_rand_unif_init_mag,
config.summarization_rand_unif_init_mag
-0.02,
0.02
)
elif 'bias' in name:
bias = getattr(lstm, name)
......@@ -19,12 +22,12 @@ def init_lstm_wt(lstm, config):
bias.data[start:end].fill_(1.)
def init_linear_wt(linear, config):
linear.weight.data.normal_(std=config.summarization_trunc_norm_init_std)
linear.weight.data.normal_(std=1e-4)
if linear.bias is not None:
linear.bias.data.normal_(std=config.summarization_trunc_norm_init_std)
linear.bias.data.normal_(std=1e-4)
def init_wt_normal(wt, config):
wt.data.normal_(std=config.summarization_trunc_norm_init_std)
wt.data.normal_(std=1e-4)
class Encoder(nn.Module):
......@@ -34,19 +37,19 @@ class Encoder(nn.Module):
self.config = config
self.lstm = nn.LSTM(
config.summarization_emb_dim, config.summarization_hidden_dim,
EMBED_DIM, HIDDEN_DIM,
num_layers=1, batch_first=True, bidirectional=True
)
init_lstm_wt(self.lstm, config)
self.reduce_h = nn.Linear(
config.summarization_hidden_dim * 2,
config.summarization_hidden_dim
HIDDEN_DIM * 2,
HIDDEN_DIM
)
init_linear_wt(self.reduce_h, config)
self.reduce_c = nn.Linear(
config.summarization_hidden_dim * 2,
config.summarization_hidden_dim
HIDDEN_DIM * 2,
HIDDEN_DIM
)
init_linear_wt(self.reduce_c, config)
......@@ -70,15 +73,15 @@ class EncoderAttention(nn.Module):
self.config = config
self.W_h = nn.Linear(
config.summarization_hidden_dim * 2,
config.summarization_hidden_dim * 2, bias=False
HIDDEN_DIM * 2,
HIDDEN_DIM * 2, bias=False
)
self.W_s = nn.Linear(
config.summarization_hidden_dim * 2,
config.summarization_hidden_dim * 2
HIDDEN_DIM * 2,
HIDDEN_DIM * 2
)
self.v = nn.Linear(
config.summarization_hidden_dim * 2,
HIDDEN_DIM * 2,
1, bias=False
)
......@@ -117,14 +120,14 @@ class DecoderAttention(nn.Module):
self.config = config
self.W_prev = nn.Linear(
config.summarization_hidden_dim,
config.summarization_hidden_dim, bias=False
HIDDEN_DIM,
HIDDEN_DIM, bias=False
)
self.W_s = nn.Linear(
config.summarization_hidden_dim,
config.summarization_hidden_dim
HIDDEN_DIM,
HIDDEN_DIM
)
self.v = nn.Linear(config.summarization_hidden_dim, 1, bias=False)
self.v = nn.Linear(HIDDEN_DIM, 1, bias=False)
def forward(self, s_t, prev_s):
if prev_s is None:
......@@ -151,26 +154,26 @@ class Decoder(nn.Module):
self.enc_attention = EncoderAttention(config)
self.dec_attention = DecoderAttention(config)
self.x_context = nn.Linear(
config.summarization_hidden_dim * 2 + config.summarization_emb_dim,
config.summarization_emb_dim
HIDDEN_DIM * 2 + EMBED_DIM,
EMBED_DIM
)
self.lstm = nn.LSTMCell(
config.summarization_emb_dim,
config.summarization_hidden_dim
EMBED_DIM,
HIDDEN_DIM
)
init_lstm_wt(self.lstm, config)
self.p_gen_linear = nn.Linear(
config.summarization_hidden_dim * 5 + config.summarization_emb_dim,
HIDDEN_DIM * 5 + EMBED_DIM,
1
)
self.V = nn.Linear(
config.summarization_hidden_dim * 4,
config.summarization_hidden_dim
HIDDEN_DIM * 4,
HIDDEN_DIM
)
self.V1 = nn.Linear(config.summarization_hidden_dim, vocab_size)
self.V1 = nn.Linear(HIDDEN_DIM, vocab_size)
init_linear_wt(self.V1, config)
def forward(self, x_t, s_t, enc_out, enc_padding_mask, ct_e, extra_zeros, enc_batch_extend_vocab, sum_temporal_srcs, prev_s):
......@@ -205,7 +208,7 @@ class Model(nn.Module):
super().__init__()
self.encoder = Encoder(config)
self.decoder = Decoder(config, vocab_size)
self.embeds = nn.Embedding(vocab_size, config.summarization_emb_dim)
self.embeds = nn.Embedding(vocab_size, EMBED_DIM)
init_wt_normal(self.embeds.weight, config)
self.encoder = self.encoder.to(config.device)
......
......@@ -4,9 +4,10 @@ import torch
from torch.distributions import Categorical
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn import Sequential
from summarization.models import Model
from summarization.models import Model, HIDDEN_DIM
from helpers import AverageMeter
def get_enc_data(batch, config):
batch_size = len(batch.enc_lens)
enc_batch = torch.from_numpy(batch.enc_batch).long()
......@@ -14,7 +15,7 @@ def get_enc_data(batch, config):
enc_lens = batch.enc_lens
ct_e = torch.zeros(batch_size, 2 * config.summarization_hidden_dim)
ct_e = torch.zeros(batch_size, 2 * HIDDEN_DIM)
enc_batch = enc_batch.to(config.device)
enc_padding_mask = enc_padding_mask.to(config.device)
......@@ -63,7 +64,8 @@ class Evaluate():
mask = torch.Tensor(len(enc_out)).fill_(1).long().to(self.config.device)
decoder_macs = 0.
for _ in range(self.config.summarization_max_dec_steps):
summarization_max_dec_steps = 7
for _ in range(summarization_max_dec_steps):
x_t = self.model.embeds(x_t)
final_dist, s_t, ct_e, sum_temporal_srcs, prev_s = self.model.decoder(
......@@ -95,10 +97,10 @@ class Evaluate():
class Example(object):
def __init__(self, config, article, vocab):
self.vocab = vocab
summarization_max_enc_steps = 55
article_words = article.split()
if len(article_words) > config.summarization_max_enc_steps:
article_words = article_words[:config.summarization_max_enc_steps]
if len(article_words) > summarization_max_enc_steps:
article_words = article_words[:summarization_max_enc_steps]
self.enc_len = len(article_words)
......
......@@ -2,7 +2,7 @@ import os
import json
import argparse
from helpers import (
ConfigFactory, Vocabulary, StateFactory, Trainer, BeamValidator, Tester,
ConfigFactory, Vocabulary, StateFactory, Trainer, BeamValidator,
Checkpoint, adjust_learning_rate
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment