GitLab is experiencing storage issues, linked with the Windows servers shutdown. Several features are affected by this disruption.

Check the info page for updates

Commit f921ce8d authored by Jian Kang's avatar Jian Kang

merge pytorch tmp

parent 9b6a5c60
......@@ -36,26 +36,29 @@ If you are interested in BigEarthNet with the original CLC Level-3 class nomencl
We provide code and model weights for the following deep learning models that have been pre-trained on BigEarthNet with the new class nomenclature (BigEarthNet-19) for scene classification:
| Model Names | Pre-Trained TensorFlow Models |
| ------------- |-------------------------------------------------------------------------------------------------------------------------------------------------------------|
| K-Branch CNN | [http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/K-BranchCNN.zip](http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/K-BranchCNN.zip)|
| VGG16 | [http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/VGG16.zip](http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/VGG16.zip) |
| VGG19 | [http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/VGG19.zip](http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/VGG19.zip) |
| ResNet50 | [http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/ResNet50.zip](http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/ResNet50.zip) |
| ResNet101 | [http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/ResNet101.zip](http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/ResNet101.zip) |
| ResNet152 | [http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/ResNet152.zip](http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/ResNet152.zip) |
| Model Names | Pre-Trained TensorFlow Models | Pre-Trained PyTorch Models |
| ------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| K-Branch CNN | [K-BranchCNN.zip](http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/K-BranchCNN.zip) | Coming soon |
| VGG16 | [VGG16.zip](http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/VGG16.zip) | Coming soon |
| VGG19 | [VGG19.zip](http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/VGG19.zip) | Coming soon |
| ResNet50 | [ResNet50.zip](http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/ResNet50.zip) | [ResNet50.pth.tar](http://bigearth.net/static/pretrained-models-pytorch/BigEarthNet-19_labels/ResNet50.pth.tar) |
| ResNet101 | [ResNet101.zip](http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/ResNet101.zip) | Coming soon |
| ResNet152 | [ResNet152.zip](http://bigearth.net/static/pretrained-models/BigEarthNet-19_labels/ResNet152.zip) | Coming soon |
The TensorFlow code for these models can be found [here](https://gitlab.tu-berlin.de/rsim/bigearthnet-models-tf).
The PyTorch code for these models can be found [here](https://gitlab.tubit.tu-berlin.de/rsim/bigearthnet-models-pytorch).
# Generation of Training/Test/Validation Splits
After downloading the raw images from https://www.bigearth.net, they need to be prepared for your ML application. We provide the script `prep_splits_BigEarthNet-19.py` for this purpose. It generates consumable data files (i.e., TFRecord) for training, validation and test splits which are suitable to use with TensorFlow. Suggested splits can be found with corresponding csv files under `splits` folder. The following command line arguments for `prep_splits_BigEarthNet-19.py` can be specified:
After downloading the raw images from https://www.bigearth.net, they need to be prepared for your ML application. We provide the script `prep_splits_BigEarthNet-19.py` for this purpose. It generates consumable data files (i.e., TFRecord) for training, validation and test splits which are suitable to use with TensorFlow or PyTorch. Suggested splits can be found with corresponding csv files under `splits` folder. The following command line arguments for `prep_splits_BigEarthNet-19.py` can be specified:
* `-r` or `--root_folder`: The root folder containing the raw images you have previously downloaded.
* `-o` or `--out_folder`: The output folder where the resulting files will be created.
* `-n` or `--splits`: A list of CSV files each of which contains the patch names of corresponding split.
* `-l` or `--library`: A flag to indicate for which ML library data files will be prepared: TensorFlow or PyTorch.
* `--update_json`: A flag to indicate that this script will also change the original json files of the BigEarthNet by updating labels
To run the script, either the GDAL or the rasterio package should be installed. The TensorFlow package should also be installed. The script is tested with Python 2.7, TensorFlow 1.3 and Ubuntu 16.04.
To run the script, either the GDAL or the rasterio package should be installed. The TensorFlow package should also be installed. The script is tested with Python 2.7, TensorFlow 1.3, PyTorch 1.2 and Ubuntu 16.04.
**Note**: BigEarthNet patches with high density snow, cloud and cloud shadow are not included in the training, test and validation sets constructed by the provided scripts (see the list of patches with seasonal snow [here](http://bigearth.net/static/documents/patches_with_seasonal_snow.csv) and that of cloud and cloud shadow [here](http://bigearth.net/static/documents/patches_with_cloud_and_shadow.csv)).
......@@ -65,9 +68,19 @@ Authors
**Gencer Sümbül**
http://www.user.tu-berlin.de/gencersumbul/
**Jian Kang**
https://www.rsim.tu-berlin.de/menue/team/dring_jian_kang/
**Tristan Kreuziger**
https://www.rsim.tu-berlin.de/menue/team/tristan_kreuziger/
Maintained by
-------
**Gencer Sümbül** for TensorFlow models
**Jian Kang** for PyTorch models
# License
The BigEarthNet Archive is licensed under the **Community Data License Agreement – Permissive, Version 1.0** ([Text](https://cdla.io/permissive-1-0/)).
......
......@@ -7,8 +7,8 @@
#
# prep_splits_BigEarthNet-19.py --help can be used to learn how to use this script.
#
# Author: Gencer Sumbul, http://www.user.tu-berlin.de/gencersumbul/
# Email: gencer.suembuel@tu-berlin.de
# Author: Gencer Sumbul, http://www.user.tu-berlin.de/gencersumbul/, Jian Kang, https://www.rsim.tu-berlin.de/menue/team/dring_jian_kang/
# Email: gencer.suembuel@tu-berlin.de, jian.kang@tu-berlin.de
# Date: 16 Jan 2020
# Version: 1.0.1
# Usage: prep_splits_BigEarthNet-19.py [-h] [-r ROOT_FOLDER] [-o OUT_FOLDER] [--update_json]
......@@ -19,10 +19,9 @@ import argparse
import os
import csv
import json
from tensorflow_utils import prep_tf_record_files
from pytorch_utils import prep_lmdb_files
# Spectral band names to read related GeoTIFF files
band_names = ['B01', 'B02', 'B03', 'B04', 'B05',
'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12']
GDAL_EXISTED = False
RASTERIO_EXISTED = False
UPDATE_JSON = False
......@@ -30,114 +29,18 @@ UPDATE_JSON = False
with open('label_indices.json', 'rb') as f:
label_indices = json.load(f)
label_conversion = label_indices['label_conversion']
BigEarthNet_19_label_idx = {v: k for k, v in label_indices['BigEarthNet-19_labels'].iteritems()}
def prep_example(bands, labels, labels_multi_hot, patch_name):
return tf.train.Example(
features=tf.train.Features(
feature={
'B01': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B01']))),
'B02': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B02']))),
'B03': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B03']))),
'B04': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B04']))),
'B05': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B05']))),
'B06': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B06']))),
'B07': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B07']))),
'B08': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B08']))),
'B8A': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B8A']))),
'B09': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B09']))),
'B11': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B11']))),
'B12': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B12']))),
'BigEarthNet-19_labels': tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[i.encode('utf-8') for i in labels])),
'BigEarthNet-19_labels_multi_hot': tf.train.Feature(
int64_list=tf.train.Int64List(value=labels_multi_hot)),
'patch_name': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[patch_name.encode('utf-8')]))
}))
def create_split(root_folder, patch_names, TFRecord_writer):
progress_bar = tf.contrib.keras.utils.Progbar(target = len(patch_names))
for patch_idx, patch_name in enumerate(patch_names):
patch_folder_path = os.path.join(root_folder, patch_name)
bands = {}
for band_name in band_names:
# First finds related GeoTIFF path and reads values as an array
band_path = os.path.join(
patch_folder_path, patch_name + '_' + band_name + '.tif')
if GDAL_EXISTED:
band_ds = gdal.Open(band_path, gdal.GA_ReadOnly)
raster_band = band_ds.GetRasterBand(1)
band_data = raster_band.ReadAsArray()
bands[band_name] = np.array(band_data)
elif RASTERIO_EXISTED:
band_ds = rasterio.open(band_path)
band_data = np.array(band_ds.read(1))
bands[band_name] = np.array(band_data)
original_labels_multi_hot = np.zeros(
len(label_indices['original_labels'].keys()), dtype=int)
BigEarthNet_19_labels_multi_hot = np.zeros(len(label_conversion),dtype=int)
patch_json_path = os.path.join(
patch_folder_path, patch_name + '_labels_metadata.json')
with open(patch_json_path, 'rb') as f:
patch_json = json.load(f)
original_labels = patch_json['labels']
for label in original_labels:
original_labels_multi_hot[label_indices['original_labels'][label]] = 1
for i in range(len(label_conversion)):
BigEarthNet_19_labels_multi_hot[i] = (
np.sum(original_labels_multi_hot[label_conversion[i]]) > 0
).astype(int)
BigEarthNet_19_labels = []
for i in np.where(BigEarthNet_19_labels_multi_hot == 1)[0]:
BigEarthNet_19_labels.append(BigEarthNet_19_label_idx[i])
if UPDATE_JSON:
patch_json['BigEarthNet_19_labels'] = BigEarthNet_19_labels
with open(patch_json_path, 'wb') as f:
json.dump(patch_json, f)
example = prep_example(
bands,
BigEarthNet_19_labels,
BigEarthNet_19_labels_multi_hot,
patch_name
)
TFRecord_writer.write(example.SerializeToString())
progress_bar.update(patch_idx)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description=
'This script creates TFRecord files for the BigEarthNet-19 train, validation and test splits')
'This script creates TFRecord files for the BigEarthNet train, validation and test splits')
parser.add_argument('-r', '--root_folder', dest = 'root_folder',
help = 'root folder path contains multiple patch folders')
parser.add_argument('-o', '--out_folder', dest = 'out_folder',
help = 'folder path containing resulting TFRecord files')
help = 'folder path containing resulting TFRecord or LMDB files')
parser.add_argument('--update_json', default = False, action = "store_true", help =
'flag for adding BigEarthNet-19 labels to the json file of each patch')
parser.add_argument('-n', '--splits', dest = 'splits', help =
'csv files each of which contain list of patch names, patches with snow, clouds, and shadows already excluded', nargs = '+')
parser.add_argument('-l', '--library', type=str, dest = 'library', help="Limit search to Sentinel mission", choices=['tensorflow', 'pytorch'])
args = parser.parse_args()
......@@ -164,11 +67,7 @@ if __name__ == "__main__":
except ImportError:
print('ERROR: please install either GDAL or rasterio package to read GeoTIFF files')
exit()
try:
import tensorflow as tf
except ImportError:
print('ERROR: please install tensorflow package to create TFRecord files')
exit()
try:
import numpy as np
except ImportError:
......@@ -188,29 +87,37 @@ if __name__ == "__main__":
patch_names_list[-1].append(row[0].strip())
except:
print('ERROR: some csv files either do not exist or have been corrupted')
exit()
try:
writer_list = []
for split_name in split_names:
writer_list.append(
tf.python_io.TFRecordWriter(os.path.join(
args.out_folder, split_name + '.tfrecord'))
)
except:
print('ERROR: TFRecord writer is not able to write files')
exit()
if args.update_json:
UPDATE_JSON = True
for split_idx in range(len(patch_names_list)):
print('INFO: creating the split of', split_names[split_idx], 'is started')
create_split(
if args.library == 'tensorflow':
try:
import tensorflow as tf
except ImportError:
print('ERROR: please install tensorflow package to create TFRecord files')
exit()
prep_tf_record_files(
args.root_folder,
patch_names_list[split_idx],
writer_list[split_idx]
)
writer_list[split_idx].close()
args.out_folder,
split_names,
patch_names_list,
label_indices,
GDAL_EXISTED,
RASTERIO_EXISTED,
UPDATE_JSON
)
elif args.library == 'pytorch':
prep_lmdb_files(
args.root_folder,
args.out_folder,
patch_names_list,
label_indices,
GDAL_EXISTED,
RASTERIO_EXISTED,
UPDATE_JSON
)
import json
import csv
import os
import numpy as np
from collections import defaultdict
def cls2multiHot(cls_vec, label_indices):
label_conversion = label_indices['label_conversion']
BigEarthNet_19_label_idx = {v: k for k, v in label_indices['BigEarthNet-19_labels'].iteritems()}
BigEarthNet_19_labels = []
BigEartNet_19_labels_multiHot = np.zeros((len(label_conversion),))
original_labels_multiHot = np.zeros((len(label_indices['original_labels']),))
for cls_nm in cls_vec:
original_labels_multiHot[label_indices['original_labels'][cls_nm]] = 1
for i in range(len(label_conversion)):
BigEartNet_19_labels_multiHot[i] = (
np.sum(original_labels_multiHot[label_conversion[i]]) > 0
).astype(int)
BigEarthNet_19_labels = []
for i in np.where(BigEartNet_19_labels_multiHot == 1)[0]:
BigEarthNet_19_labels.append(BigEarthNet_19_label_idx[i])
return BigEartNet_19_labels_multiHot, BigEarthNet_19_labels
def read_scale_raster(file_path, GDAL_EXISTED, RASTERIO_EXISTED):
"""
read raster file with specified scale
:param file_path:
:param scale:
:return:
"""
if GDAL_EXISTED:
import gdal
elif RASTERIO_EXISTED:
import rasterio
if GDAL_EXISTED:
band_ds = gdal.Open(file_path, gdal.GA_ReadOnly)
raster_band = band_ds.GetRasterBand(1)
band_data = raster_band.ReadAsArray()
elif RASTERIO_EXISTED:
band_ds = rasterio.open(file_path)
band_data = np.array(band_ds.read(1))
return band_data
def parse_json_labels(f_j_path):
"""
parse meta-data json file for big earth to get image labels
:param f_j_path: json file path
:return:
"""
with open(f_j_path, 'r') as f_j:
j_f_c = json.load(f_j)
return j_f_c['labels']
def update_json_labels(f_j_path, BigEarthNet_19_labels):
with open(f_j_path, 'r') as f_j:
j_f_c = json.load(f_j)
j_f_c['BigEarthNet_19_labels'] = BigEarthNet_19_labels
with open(f_j_path, 'wb') as f:
json.dump(j_f_c, f)
class dataGenBigEarthTiff:
def __init__(self, bigEarthDir=None,
bands10=None, bands20=None, bands60=None,
patch_names_list=None, label_indices=None,
RASTERIO_EXISTED=None, GDAL_EXISTED=None,
UPDATE_JSON=None
):
self.bigEarthDir = bigEarthDir
self.bands10 = bands10
self.bands20 = bands20
self.bands60 = bands60
self.label_indices = label_indices
self.GDAL_EXISTED = GDAL_EXISTED
self.RASTERIO_EXISTED = RASTERIO_EXISTED
self.UPDATE_JSON = UPDATE_JSON
self.total_patch = patch_names_list[0] + patch_names_list[1] + patch_names_list[2]
def __len__(self):
return len(self.total_patch)
def __getitem__(self, index):
return self.__data_generation(index)
def __data_generation(self, idx):
imgNm = self.total_patch[idx]
bands10_array = []
bands20_array = []
bands60_array = []
if self.bands10 is not None:
for band in self.bands10:
bands10_array.append(read_scale_raster(os.path.join(self.bigEarthDir, imgNm, imgNm+'_B'+band+'.tif'), self.GDAL_EXISTED, self.RASTERIO_EXISTED))
if self.bands20 is not None:
for band in self.bands20:
bands20_array.append(read_scale_raster(os.path.join(self.bigEarthDir, imgNm, imgNm+'_B'+band+'.tif'), self.GDAL_EXISTED, self.RASTERIO_EXISTED))
if self.bands60 is not None:
for band in self.bands60:
bands60_array.append(read_scale_raster(os.path.join(self.bigEarthDir, imgNm, imgNm+'_B'+band+'.tif'), self.GDAL_EXISTED, self.RASTERIO_EXISTED))
bands10_array = np.asarray(bands10_array).astype(np.float32)
bands20_array = np.asarray(bands20_array).astype(np.float32)
bands60_array = np.asarray(bands60_array).astype(np.float32)
labels = parse_json_labels(os.path.join(self.bigEarthDir, imgNm, imgNm+'_labels_metadata.json'))
BigEartNet_19_labels_multiHot, BigEarthNet_19_labels = cls2multiHot(labels, self.label_indices)
if self.UPDATE_JSON:
update_json_labels(os.path.join(self.bigEarthDir, imgNm, imgNm+'_labels_metadata.json'), BigEarthNet_19_labels)
sample = {'bands10': bands10_array, 'bands20': bands20_array, 'bands60': bands60_array,
'patch_name': imgNm, 'multi_hots':BigEartNet_19_labels_multiHot}
return sample
def dumps_pyarrow(obj):
"""
Serialize an object.
Returns:
Implementation-dependent bytes-like object
"""
import pyarrow as pa
return pa.serialize(obj).to_buffer()
def prep_lmdb_files(root_folder, out_folder, patch_names_list, label_indices, GDAL_EXISTED, RASTERIO_EXISTED, UPDATE_JSON):
from torch.utils.data import DataLoader
import lmdb
dataGen = dataGenBigEarthTiff(
bigEarthDir = root_folder,
bands10 = ['02', '03', '04', '08'],
bands20 = ['05', '06', '07', '8A', '11', '12'],
bands60 = ['01','09'],
patch_names_list=patch_names_list,
label_indices=label_indices,
GDAL_EXISTED=GDAL_EXISTED,
RASTERIO_EXISTED=RASTERIO_EXISTED
)
nSamples = len(dataGen)
map_size_ = (dataGen[0]['bands10'].nbytes + dataGen[0]['bands20'].nbytes + dataGen[0]['bands60'].nbytes)*10*len(dataGen)
data_loader = DataLoader(dataGen, num_workers=4, collate_fn=lambda x: x)
db = lmdb.open(os.path.join(out_folder, 'BigEarthNet-19.lmdb'), map_size=map_size_)
txn = db.begin(write=True)
patch_names = []
for idx, data in enumerate(data_loader):
bands10, bands20, bands60, patch_name, multiHots = data[0]['bands10'], data[0]['bands20'], data[0]['bands60'], data[0]['patch_name'], data[0]['multi_hots']
# txn.put(u'{}'.format(patch_name).encode('ascii'), dumps_pyarrow((bands10, bands20, bands60, multiHots_n, multiHots_o)))
txn.put(u'{}'.format(patch_name).encode('ascii'), dumps_pyarrow((bands10, bands20, bands60, multiHots)))
patch_names.append(patch_name)
if idx % 10000 == 0:
print("[%d/%d]" % (idx, nSamples))
txn.commit()
txn = db.begin(write=True)
txn.commit()
keys = [u'{}'.format(patch_name).encode('ascii') for patch_name in patch_names]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps_pyarrow(keys))
txn.put(b'__len__', dumps_pyarrow(len(keys)))
print("Flushing database ...")
db.sync()
db.close()
import tensorflow as tf
import numpy as np
import os
import json
# Spectral band names to read related GeoTIFF files
band_names = ['B01', 'B02', 'B03', 'B04', 'B05',
'B06', 'B07', 'B08', 'B8A', 'B09', 'B11', 'B12']
def prep_example(bands, BigEarthNet_19_labels, BigEarthNet_19_labels_multi_hot, patch_name):
return tf.train.Example(
features=tf.train.Features(
feature={
'B01': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B01']))),
'B02': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B02']))),
'B03': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B03']))),
'B04': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B04']))),
'B05': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B05']))),
'B06': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B06']))),
'B07': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B07']))),
'B08': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B08']))),
'B8A': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B8A']))),
'B09': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B09']))),
'B11': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B11']))),
'B12': tf.train.Feature(
int64_list=tf.train.Int64List(value=np.ravel(bands['B12']))),
'BigEarthNet-19_labels': tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[i.encode('utf-8') for i in BigEarthNet_19_labels])),
'BigEarthNet-19_labels_multi_hot': tf.train.Feature(
int64_list=tf.train.Int64List(value=BigEarthNet_19_labels_multi_hot)),
'patch_name': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[patch_name.encode('utf-8')]))
}))
def create_split(root_folder, patch_names, TFRecord_writer, label_indices, GDAL_EXISTED, RASTERIO_EXISTED, UPDATE_JSON):
label_conversion = label_indices['label_conversion']
BigEarthNet_19_label_idx = {v: k for k, v in label_indices['BigEarthNet-19_labels'].iteritems()}
if GDAL_EXISTED:
import gdal
elif RASTERIO_EXISTED:
import rasterio
progress_bar = tf.contrib.keras.utils.Progbar(target = len(patch_names))
for patch_idx, patch_name in enumerate(patch_names):
patch_folder_path = os.path.join(root_folder, patch_name)
bands = {}
for band_name in band_names:
# First finds related GeoTIFF path and reads values as an array
band_path = os.path.join(
patch_folder_path, patch_name + '_' + band_name + '.tif')
if GDAL_EXISTED:
band_ds = gdal.Open(band_path, gdal.GA_ReadOnly)
raster_band = band_ds.GetRasterBand(1)
band_data = raster_band.ReadAsArray()
bands[band_name] = np.array(band_data)
elif RASTERIO_EXISTED:
band_ds = rasterio.open(band_path)
band_data = np.array(band_ds.read(1))
bands[band_name] = np.array(band_data)
original_labels_multi_hot = np.zeros(
len(label_indices['original_labels'].keys()), dtype=int)
BigEarthNet_19_labels_multi_hot = np.zeros(len(label_conversion),dtype=int)
patch_json_path = os.path.join(
patch_folder_path, patch_name + '_labels_metadata.json')
with open(patch_json_path, 'rb') as f:
patch_json = json.load(f)
original_labels = patch_json['labels']
for label in original_labels:
original_labels_multi_hot[label_indices['original_labels'][label]] = 1
for i in range(len(label_conversion)):
BigEarthNet_19_labels_multi_hot[i] = (
np.sum(original_labels_multi_hot[label_conversion[i]]) > 0
).astype(int)
BigEarthNet_19_labels = []
for i in np.where(BigEarthNet_19_labels_multi_hot == 1)[0]:
BigEarthNet_19_labels.append(BigEarthNet_19_label_idx[i])
if UPDATE_JSON:
patch_json['BigEarthNet_19_labels'] = BigEarthNet_19_labels
with open(patch_json_path, 'wb') as f:
json.dump(patch_json, f)
example = prep_example(
bands,
original_labels,
original_labels_multi_hot,
patch_name
)
TFRecord_writer.write(example.SerializeToString())
progress_bar.update(patch_idx)
def prep_tf_record_files(root_folder, out_folder, split_names, patch_names_list, label_indices, GDAL_EXISTED, RASTERIO_EXISTED, UPDATE_JSON):
try:
writer_list = []
for split_name in split_names:
writer_list.append(
tf.python_io.TFRecordWriter(os.path.join(
out_folder, split_name + '.tfrecord'))
)
except:
print('ERROR: TFRecord writer is not able to write files')
exit()
for split_idx in range(len(patch_names_list)):
print('INFO: creating the split of', split_names[split_idx], 'is started')
create_split(
root_folder,
patch_names_list[split_idx],
writer_list[split_idx],
label_indices,
GDAL_EXISTED,
RASTERIO_EXISTED,
UPDATE_JSON
)
writer_list[split_idx].close()
\ No newline at end of file