harshinde commited on
Commit
abfc282
·
verified ·
1 Parent(s): a9caf08

Upload src/streamlit_app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +349 -37
src/streamlit_app.py CHANGED
@@ -6,6 +6,302 @@ import matplotlib.pyplot as plt
6
  import yaml
7
  import os
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Import models
10
  from src.mobilenetv2_model import LandslideModel as MobileNetV2Model
11
  from src.vgg16_model import LandslideModel as VGG16Model
@@ -21,6 +317,23 @@ from src.se_resnext50_32x4d_model import LandslideModel as SEResNeXt50_32X4DMode
21
  from segformer_model import LandslideModel as SegFormerB2Model
22
  from inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Load the configuration file
25
  config = """
26
  model_config:
@@ -55,21 +368,14 @@ logging_config:
55
 
56
  config = yaml.safe_load(config)
57
 
58
- # Model descriptions
59
- model_descriptions = {
60
- "MobileNetV2": {"path": "mobilenetv2.pth", "type": "mobilenet_v2", "description": "MobileNetV2 is a lightweight deep learning model for image classification and segmentation."},
61
- "VGG16": {"path": "vgg16.pth", "type": "vgg16", "description": "VGG16 is a popular deep learning model known for its simplicity and depth."},
62
- "ResNet34": {"path": "resnet34.pth", "type": "resnet34", "description": "ResNet34 is a deep residual network that helps in training very deep networks."},
63
- "EfficientNetB0": {"path": "effucientnetb0.pth", "type": "efficientnet_b0", "description": "EfficientNetB0 is part of the EfficientNet family, known for its efficiency and performance."},
64
- "MiT-B1": {"path": "mitb1.pth", "type": "mit_b1", "description": "MiT-B1 is a transformer-based model designed for segmentation tasks."},
65
- "InceptionV4": {"path": "inceptionv4.pth", "type": "inceptionv4", "description": "InceptionV4 is a convolutional neural network known for its inception modules."},
66
- "DeepLabV3+": {"path": "deeplabv3.pth", "type": "deeplabv3+", "description": "DeepLabV3+ is an advanced model for semantic image segmentation."},
67
- "DenseNet121": {"path": "densenet121.pth", "type": "densenet121", "description": "DenseNet121 is a densely connected convolutional network for image classification and segmentation."},
68
- "ResNeXt50_32X4D": {"path": "resnext50-32x4d.pth", "type": "resnext50_32x4d", "description": "ResNeXt50_32X4D is a highly modularized network aimed at improving accuracy."},
69
- "SEResNet50": {"path": "se_resnet50.pth", "type": "se_resnet50", "description": "SEResNet50 is a ResNet model with squeeze-and-excitation blocks for better feature recalibration."},
70
- "SEResNeXt50_32X4D": {"path": "se_resnext50_32x4d.pth", "type": "se_resnext50_32x4d", "description": "SEResNeXt50_32X4D combines ResNeXt and SE blocks for improved performance."},
71
- "SegFormerB2": {"path": "segformer.pth", "type": "segformer_b2", "description": "SegFormerB2 is a transformer-based model for semantic segmentation."},
72
- "InceptionResNetV2": {"path": "inceptionresnetv2.pth", "type": "inceptionresnetv2", "description": "InceptionResNetV2 is a hybrid model combining Inception and ResNet architectures."},
73
  }
74
 
75
  # Streamlit app
@@ -88,18 +394,13 @@ st.markdown("""
88
  st.sidebar.title("Model Selection")
89
  model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"])
90
  if model_option == "Select a single model":
91
- model_type = st.sidebar.selectbox("Select Model", list(model_descriptions.keys()))
92
- config['model_config']['model_type'] = model_descriptions[model_type]['type']
93
- if model_type == "DeepLabV3+":
94
- model_class = DeepLabV3PlusModel
95
- else:
96
- model_class = locals()[model_type.replace("-", "") + "Model"]
97
- model_path = model_descriptions[model_type]['path']
98
-
99
  # Display model details in the sidebar
100
- st.sidebar.markdown(f"**Model Type:** {model_descriptions[model_type]['type']}")
101
- st.sidebar.markdown(f"**Model Path:** {model_descriptions[model_type]['path']}")
102
- st.sidebar.markdown(f"**Description:** {model_descriptions[model_type]['description']}")
103
 
104
  # Main content
105
  st.header("Upload Data")
@@ -193,19 +494,30 @@ if uploaded_files:
193
 
194
  else:
195
  # Process the image with each model
196
- for model_name, model_info in model_descriptions.items():
197
- st.write(f"Using model: {model_name}")
198
- if model_name == "DeepLabV3+":
199
- model_class = DeepLabV3PlusModel
200
- else:
201
- model_class = locals()[model_name.replace("-", "") + "Model"]
202
- model_path = model_info['path']
203
  config['model_config']['model_type'] = model_info['type']
 
 
 
 
204
 
205
- # Load the model
206
- model = model_class(config)
207
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
208
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  # Make prediction
211
  with torch.no_grad():
 
6
  import yaml
7
  import os
8
 
9
+ # Import models
10
+ from src.mobilenetv2_model import LandslideModel as MobileNetV2Model
11
+ from src.vgg16_model import LandslideModel as VGG16Model
12
+ from src.resnet34_model import LandslideModel as ResNet34Model
13
+ from src.efficientnetb0_model import LandslideModel as EfficientNetB0Model
14
+ from src.mitb1_model import LandslideModel as MiTB1Model
15
+ from src.inceptionv4_model import LandslideModel as InceptionV4Model
16
+ from src.densenet121_model import LandslideModel as DenseNet121Model
17
+ from src.deeplabv3plus_model import LandslideModel as DeepLabV3PlusModel
18
+ from src.resnext50_32x4d_model import LandslideModel as ResNeXt50_32X4DModel
19
+ from src.se_resnet50_model import LandslideModel as SEResNet50Model
20
+ from src.se_resnext50_32x4d_model import LandslideModel as SEResNeXt50_32X4DModel
21
+ from src.segformer_model import LandslideModel as SegFormerB2Model
22
+ from src.inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model
23
+
24
+ # Define available models
25
+ AVAILABLE_MODELS = {
26
+ "mobilenetv2": {"name": "MobileNetV2", "type": "mobilenet_v2"},
27
+ "vgg16": {"name": "VGG16", "type": "vgg16"},
28
+ "resnet34": {"name": "ResNet34", "type": "resnet34"},
29
+ "efficientnetb0": {"name": "EfficientNetB0", "type": "efficientnet_b0"},
30
+ "mitb1": {"name": "MiTB1", "type": "mitb1"},
31
+ "inceptionv4": {"name": "InceptionV4", "type": "inception_v4"},
32
+ "densenet121": {"name": "DenseNet121", "type": "densenet121"},
33
+ "deeplabv3plus": {"name": "DeepLabV3Plus", "type": "deeplabv3plus"},
34
+ "resnext50": {"name": "ResNeXt50", "type": "resnext50_32x4d"},
35
+ "seresnet50": {"name": "SEResNet50", "type": "se_resnet50"},
36
+ "seresnext50": {"name": "SEResNeXt50", "type": "se_resnext50_32x4d"},
37
+ "segformerb2": {"name": "SegFormerB2", "type": "segformer_b2"},
38
+ "inceptionresnetv2": {"name": "InceptionResNetV2", "type": "inception_resnet_v2"}
39
+ }
40
+
41
+ # Model descriptions with their respective types and descriptions
42
+ MODEL_DESCRIPTIONS = {
43
+ model_key: {
44
+ "type": model_info["type"],
45
+ "description": f"{model_info['name']} - A model for landslide detection and segmentation.",
46
+ "name": model_info["name"]
47
+ }
48
+ for model_key, model_info in AVAILABLE_MODELS.items()
49
+ }
50
+
51
+ # Load the configuration file
52
+ config = """
53
+ model_config:
54
+ model_type: "mobilenet_v2"
55
+ in_channels: 14
56
+ num_classes: 1
57
+ encoder_weights: "imagenet"
58
+ wce_weight: 0.5
59
+
60
+ dataset_config:
61
+ num_classes: 1
62
+ num_channels: 14
63
+ channels: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]
64
+ normalize: False
65
+
66
+ train_config:
67
+ dataset_path: ""
68
+ checkpoint_path: "checkpoints"
69
+ seed: 42
70
+ train_val_split: 0.8
71
+ batch_size: 16
72
+ num_epochs: 100
73
+ lr: 0.001
74
+ device: "cuda:0"
75
+ save_config: True
76
+ experiment_name: "mobilenet_v2"
77
+ """
78
+
79
+ config = yaml.safe_load(config)
80
+
81
+ # Streamlit app
82
+ st.set_page_config(page_title="Landslide Detection", layout="wide")
83
+
84
+ st.title("Landslide Detection")
85
+ st.markdown("""
86
+ ## Instructions
87
+ 1. Select a model from the sidebar or choose to run all models.
88
+ 2. Upload one or more `.h5` files.
89
+ 3. The app will process the files and display the input image, prediction, and overlay.
90
+ 4. You can download the prediction results.
91
+ """)
92
+
93
+ # Sidebar for model selection
94
+ st.sidebar.title("Model Selection")
95
+ model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"])
96
+
97
+ if model_option == "Select a single model":
98
+ selected_model_key = st.sidebar.selectbox("Select Model", list(MODEL_DESCRIPTIONS.keys()))
99
+ selected_model_info = MODEL_DESCRIPTIONS[selected_model_key]
100
+ config['model_config']['model_type'] = selected_model_info['type']
101
+
102
+ # Display model details in the sidebar
103
+ st.sidebar.markdown("### Model Details")
104
+ st.sidebar.markdown(f"**Model Name:** {selected_model_info['name']}")
105
+ st.sidebar.markdown(f"**Model Type:** {selected_model_info['type']}")
106
+ st.sidebar.markdown(f"**Description:** {selected_model_info['description']}")
107
+
108
+ # Main content
109
+ st.header("Upload Data")
110
+
111
+ # Initialize session state for error tracking if not exists
112
+ if 'upload_errors' not in st.session_state:
113
+ st.session_state.upload_errors = []
114
+
115
+ uploaded_files = st.file_uploader(
116
+ "Choose .h5 files...",
117
+ type="h5",
118
+ accept_multiple_files=True,
119
+ help="Upload your .h5 files here. Maximum file size is 200MB."
120
+ )
121
+
122
+ if uploaded_files:
123
+ for uploaded_file in uploaded_files:
124
+ st.write(f"Processing file: {uploaded_file.name}")
125
+ st.write(f"File size: {uploaded_file.size} bytes")
126
+
127
+ with st.spinner('Classifying...'):
128
+ try:
129
+ # Read the file directly using BytesIO
130
+ import io
131
+ bytes_data = uploaded_file.getvalue()
132
+ bytes_io = io.BytesIO(bytes_data)
133
+
134
+ with h5py.File(bytes_io, 'r') as hdf:
135
+ if 'img' not in hdf:
136
+ st.error(f"Error: 'img' dataset not found in {uploaded_file.name}")
137
+ continue
138
+
139
+ data = np.array(hdf.get('img'))
140
+ data[np.isnan(data)] = 0.000001
141
+ channels = config["dataset_config"]["channels"]
142
+ image = np.zeros((128, 128, len(channels)))
143
+
144
+ for i, band in enumerate(channels):
145
+ image[:, :, i] = data[band-1]
146
+
147
+ selected_channels = [image[:, :, i] for i in range(3)]
148
+ image = np.transpose(image, (2, 0, 1))
149
+
150
+ if model_option == "Select a single model":
151
+ # Get the model class from AVAILABLE_MODELS
152
+ model_class_name = AVAILABLE_MODELS[selected_model_key]['name'].replace('-', '') + 'Model'
153
+ model_class = locals()[model_class_name]
154
+
155
+ # Initialize model downloader
156
+ from model_downloader import ModelDownloader
157
+ downloader = ModelDownloader()
158
+
159
+ try:
160
+ # Download/get model path
161
+ model_path = downloader.download_model(selected_model_key)
162
+ st.info(f"Using model from: {model_path}")
163
+
164
+ # Load the model
165
+ model = model_class(config)
166
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
167
+ model.eval()
168
+
169
+ # Make prediction
170
+ with torch.no_grad():
171
+ prediction = model(torch.from_numpy(image).unsqueeze(0).float())
172
+ prediction = torch.sigmoid(prediction).numpy()
173
+
174
+ st.header(f"Prediction Results - {selected_model_info['name']}")
175
+
176
+ # Create columns for input image, prediction, and overlay
177
+ col1, col2, col3 = st.columns(3)
178
+
179
+ # Display input image
180
+ with col1:
181
+ st.write("Input Image")
182
+ plt.figure(figsize=(8, 8))
183
+ plt.imshow(selected_channels[0], cmap='viridis')
184
+ plt.colorbar()
185
+ plt.axis('off')
186
+ st.pyplot(plt)
187
+
188
+ # Display prediction
189
+ with col2:
190
+ st.write("Prediction")
191
+ plt.figure(figsize=(8, 8))
192
+ plt.imshow(prediction.squeeze(), cmap='viridis')
193
+ plt.colorbar()
194
+ plt.axis('off')
195
+ st.pyplot(plt)
196
+
197
+ # Display overlay
198
+ with col3:
199
+ st.write("Overlay")
200
+ plt.figure(figsize=(8, 8))
201
+ plt.imshow(selected_channels[0], cmap='viridis')
202
+ plt.imshow(prediction.squeeze(), cmap='viridis', alpha=0.5)
203
+ plt.colorbar()
204
+ plt.axis('off')
205
+ st.pyplot(plt)
206
+
207
+ # Download button for prediction
208
+ st.write(f"Download the prediction as a .npy file for {selected_model_info['name']}:")
209
+ npy_data = prediction.squeeze()
210
+ st.download_button(
211
+ label=f"Download Prediction - {selected_model_info['name']}",
212
+ data=npy_data.tobytes(),
213
+ file_name=f"{uploaded_file.name.split('.')[0]}_{selected_model_key}_prediction.npy",
214
+ mime="application/octet-stream"
215
+ )
216
+
217
+ except Exception as e:
218
+ st.error(f"Error with model {selected_model_info['name']}: {str(e)}")
219
+ else:
220
+ # Process the image with each model
221
+ for model_key, model_info in MODEL_DESCRIPTIONS.items():
222
+ st.write(f"Using model: {model_info['name']}")
223
+ config['model_config']['model_type'] = model_info['type']
224
+
225
+ # Get the model class from AVAILABLE_MODELS
226
+ model_class_name = AVAILABLE_MODELS[model_key]['name'].replace('-', '') + 'Model'
227
+ model_class = locals()[model_class_name]
228
+
229
+ # Initialize model downloader
230
+ from model_downloader import ModelDownloader
231
+ downloader = ModelDownloader()
232
+
233
+ try:
234
+ # Download/get model path
235
+ model_path = downloader.download_model(model_key)
236
+ st.info(f"Using model from: {model_path}")
237
+
238
+ # Load the model
239
+ model = model_class(config)
240
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
241
+ model.eval()
242
+
243
+ # Make prediction
244
+ with torch.no_grad():
245
+ prediction = model(torch.from_numpy(image).unsqueeze(0).float())
246
+ prediction = torch.sigmoid(prediction).numpy()
247
+
248
+ st.header(f"Prediction Results - {model_info['name']}")
249
+
250
+ # Create columns for input image, prediction, and overlay
251
+ col1, col2, col3 = st.columns(3)
252
+
253
+ # Display input image
254
+ with col1:
255
+ st.write("Input Image")
256
+ plt.figure(figsize=(8, 8))
257
+ plt.imshow(selected_channels[0], cmap='viridis')
258
+ plt.colorbar()
259
+ plt.axis('off')
260
+ st.pyplot(plt)
261
+
262
+ # Display prediction
263
+ with col2:
264
+ st.write("Prediction")
265
+ plt.figure(figsize=(8, 8))
266
+ plt.imshow(prediction.squeeze(), cmap='viridis')
267
+ plt.colorbar()
268
+ plt.axis('off')
269
+ st.pyplot(plt)
270
+
271
+ # Display overlay
272
+ with col3:
273
+ st.write("Overlay")
274
+ plt.figure(figsize=(8, 8))
275
+ plt.imshow(selected_channels[0], cmap='viridis')
276
+ plt.imshow(prediction.squeeze(), cmap='viridis', alpha=0.5)
277
+ plt.colorbar()
278
+ plt.axis('off')
279
+ st.pyplot(plt)
280
+
281
+ # Download button for prediction
282
+ st.write(f"Download the prediction as a .npy file for {model_info['name']}:")
283
+ npy_data = prediction.squeeze()
284
+ st.download_button(
285
+ label=f"Download Prediction - {model_info['name']}",
286
+ data=npy_data.tobytes(),
287
+ file_name=f"{uploaded_file.name.split('.')[0]}_{model_key}_prediction.npy",
288
+ mime="application/octet-stream"
289
+ )
290
+
291
+ except Exception as e:
292
+ st.error(f"Error with model {model_info['name']}: {str(e)}")
293
+ continue
294
+
295
+ except Exception as e:
296
+ st.error(f"Error processing file {uploaded_file.name}: {str(e)}")
297
+ continue
298
+ import h5py
299
+ import torch
300
+ import numpy as np
301
+ import matplotlib.pyplot as plt
302
+ import yaml
303
+ import os
304
+
305
  # Import models
306
  from src.mobilenetv2_model import LandslideModel as MobileNetV2Model
307
  from src.vgg16_model import LandslideModel as VGG16Model
 
317
  from segformer_model import LandslideModel as SegFormerB2Model
318
  from inceptionresnetv2_model import LandslideModel as InceptionResNetV2Model
319
 
320
+ # Define available models
321
+ AVAILABLE_MODELS = {
322
+ "mobilenetv2": {"name": "MobileNetV2", "type": "mobilenet_v2"},
323
+ "vgg16": {"name": "VGG16", "type": "vgg16"},
324
+ "resnet34": {"name": "ResNet34", "type": "resnet34"},
325
+ "efficientnetb0": {"name": "EfficientNetB0", "type": "efficientnet_b0"},
326
+ "mitb1": {"name": "MiTB1", "type": "mitb1"},
327
+ "inceptionv4": {"name": "InceptionV4", "type": "inception_v4"},
328
+ "densenet121": {"name": "DenseNet121", "type": "densenet121"},
329
+ "deeplabv3plus": {"name": "DeepLabV3Plus", "type": "deeplabv3plus"},
330
+ "resnext50": {"name": "ResNeXt50", "type": "resnext50_32x4d"},
331
+ "seresnet50": {"name": "SEResNet50", "type": "se_resnet50"},
332
+ "seresnext50": {"name": "SEResNeXt50", "type": "se_resnext50_32x4d"},
333
+ "segformerb2": {"name": "SegFormerB2", "type": "segformer_b2"},
334
+ "inceptionresnetv2": {"name": "InceptionResNetV2", "type": "inception_resnet_v2"}
335
+ }
336
+
337
  # Load the configuration file
338
  config = """
339
  model_config:
 
368
 
369
  config = yaml.safe_load(config)
370
 
371
+ # Model descriptions with their respective types and descriptions
372
+ MODEL_DESCRIPTIONS = {
373
+ model_key: {
374
+ "type": model_info["type"],
375
+ "description": f"{model_info['name']} - A model for landslide detection and segmentation.",
376
+ "name": model_info["name"]
377
+ }
378
+ for model_key, model_info in AVAILABLE_MODELS.items()
 
 
 
 
 
 
 
379
  }
380
 
381
  # Streamlit app
 
394
  st.sidebar.title("Model Selection")
395
  model_option = st.sidebar.radio("Choose an option", ["Select a single model", "Run all models"])
396
  if model_option == "Select a single model":
397
+ selected_model = st.sidebar.selectbox("Select Model", list(MODEL_DESCRIPTIONS.keys()))
398
+ config['model_config']['model_type'] = MODEL_DESCRIPTIONS[selected_model]['type']
399
+
 
 
 
 
 
400
  # Display model details in the sidebar
401
+ st.sidebar.markdown(f"**Model Name:** {MODEL_DESCRIPTIONS[selected_model]['name']}")
402
+ st.sidebar.markdown(f"**Model Type:** {MODEL_DESCRIPTIONS[selected_model]['type']}")
403
+ st.sidebar.markdown(f"**Description:** {MODEL_DESCRIPTIONS[selected_model]['description']}")
404
 
405
  # Main content
406
  st.header("Upload Data")
 
494
 
495
  else:
496
  # Process the image with each model
497
+ for model_key, model_info in MODEL_DESCRIPTIONS.items():
498
+ st.write(f"Using model: {model_info['name']}")
 
 
 
 
 
499
  config['model_config']['model_type'] = model_info['type']
500
+
501
+ # Get the model class from AVAILABLE_MODELS
502
+ model_class_name = AVAILABLE_MODELS[model_key]['name'].replace('-', '') + 'Model'
503
+ model_class = locals()[model_class_name]
504
 
505
+ # Initialize model downloader
506
+ from model_downloader import ModelDownloader
507
+ downloader = ModelDownloader()
508
+
509
+ try:
510
+ # Download/get model path
511
+ model_path = downloader.download_model(model_name.lower())
512
+ st.info(f"Using model from: {model_path}")
513
+
514
+ # Load the model
515
+ model = model_class(config)
516
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)
517
+ model.eval()
518
+ except Exception as e:
519
+ st.error(f"Error loading model {model_name}: {str(e)}")
520
+ continue
521
 
522
  # Make prediction
523
  with torch.no_grad():