From 08ef171be3516b7ff6cccbf0718cd0404de34b04 Mon Sep 17 00:00:00 2001
From: "jose.luis.holgado" <jose.l.holgadoalvarez@campus.tu-berlin.de>
Date: Mon, 3 Aug 2020 14:30:15 +0200
Subject: [PATCH] patch

---
 data/alignedrs_dataset.py | 29 +++++++++++++----------------
 1 file changed, 13 insertions(+), 16 deletions(-)

diff --git a/data/alignedrs_dataset.py b/data/alignedrs_dataset.py
index 0262fca..c356d86 100644
--- a/data/alignedrs_dataset.py
+++ b/data/alignedrs_dataset.py
@@ -61,35 +61,32 @@ class AlignedRSDataset(BaseDataset):
         Parent_dir = self.dir_AB
         AB_path = self.AB_paths[index]
         AB_index = AB_path.split("_")[1]
-        # print(AB_path, Parent_dir)
+
+
+        if self.opt.direction == "BtoA":
+            tensor_a = "B"
+            tensor_b = "A"
+
+        else:
+            tensor_a = "A"
+            tensor_b = "B"
+
         imgs_path = os.path.join(Parent_dir, AB_path)
-        # print("flag0000000000000000000000000000000")
-        # AB_id = id_extractor(AB_path)
+
         dataset = GetDataset(imgs_path, self.opt.load_size)
         # split AB image into A and B
-        A = dataset.load_image('A')
+        A = dataset.load_image(tensor_a)
         if self.model_mode:
             B_list = []
             for i in range(4):
                 B_list.append(cv.GaussianBlur(A[i], (11, 11), 0))
             B = np.array(B_list)
         else:
-            B = dataset.load_image('B')
-        # coord = dataset.load_image('COORD')
+            B = dataset.load_image(tensor_b)
 
-        # A_resize = cv.resize(A, dsize=(self.opt.load_size, self.opt.load_size), interpolation=cv.INTER_CUBIC)
-        # B_resize = cv.resize(B, dsize=(self.opt.load_size, self.opt.load_size), interpolation=cv.INTER_CUBIC)
 
         Y_Mask = dataset.load_image('MASK')
 
-        # g = extraction.reconstruct_from_patches_2d(y[:,:,0], (262, 262))
-        # f, axarr = plt.subplots(1, 3)
-        # ground_truth
-
-        # axarr[0].imshow(A.transpose([1,2,0]))
-        # axarr[1].imshow(B.transpose([1,2,0]))
-        # axarr[2].imshow(Y_Mask)
-        # plt.show()
 
         return {'A': A.astype(np.float32),
                 'B': B.astype(np.float32),
-- 
GitLab