hvoss-techfak commited on
Commit
e0eea98
·
1 Parent(s): 99553eb
Files changed (2) hide show
  1. app.py +30 -61
  2. auto_forge.py +0 -1089
app.py CHANGED
@@ -1,14 +1,13 @@
1
  import json
2
  import string
3
- import traceback
4
  import uuid
5
  import os
6
  import logging
7
  import zipfile
 
8
  import wandb
9
  from contextlib import redirect_stdout, redirect_stderr
10
- import auto_forge
11
-
12
 
13
  USE_WANDB = "WANDB_API_KEY" in os.environ
14
  if USE_WANDB:
@@ -99,7 +98,7 @@ def get_script_args_info(exclude_args=None):
99
  {
100
  "name": "--iterations",
101
  "type": "number",
102
- "default": 6000,
103
  "help": "Number of optimization iterations",
104
  },
105
  {
@@ -160,7 +159,7 @@ def get_script_args_info(exclude_args=None):
160
  {
161
  "name": "--pruning_max_swaps",
162
  "type": "number",
163
- "default": 50,
164
  "precision": 0,
165
  "help": "Max number of swaps allowed after pruning",
166
  },
@@ -183,7 +182,7 @@ def get_script_args_info(exclude_args=None):
183
  {
184
  "name": "--learning_rate_warmup_fraction",
185
  "type": "slider",
186
- "default": 0.01,
187
  "min": 0.0,
188
  "max": 1.0,
189
  "step": 0.01,
@@ -215,7 +214,7 @@ def get_script_args_info(exclude_args=None):
215
  {
216
  "name": "--num_init_rounds",
217
  "type": "number",
218
- "default": 32,
219
  "precision": 0,
220
  "help": "Number of rounds to choose the starting height map from.",
221
  },
@@ -296,23 +295,16 @@ else:
296
  def run_autoforge_process(cmd, log_path):
297
  from joblib import parallel_backend
298
  cli_args = cmd[1:]
 
299
 
300
  exit_code = 0
301
- # Ensure local project dir is first on sys.path so `import auto_forge` imports the file in this repo
302
- script_dir = os.path.dirname(os.path.abspath(__file__))
303
- if script_dir not in sys.path:
304
- sys.path.insert(0, script_dir)
305
-
306
- with open(log_path, "w", buffering=1, encoding="utf-8") as log_f, redirect_stdout(log_f), redirect_stderr(log_f), parallel_backend("threading", n_jobs=4):
307
  try:
308
- # Force a fresh import of the local module by removing any cached module
309
- if "auto_forge" in sys.modules:
310
- del sys.modules["auto_forge"]
311
- auto_forge = __import__("auto_forge")
312
  sys.argv = ["autoforge"] + cli_args
313
- auto_forge.main()
314
  except SystemExit as e:
315
- exit_code = e.code if isinstance(e.code, int) or e.code is None else 0
316
  except Exception as e:
317
  log_f.write(f"\nERROR: {e}\n")
318
  exit_code = -1
@@ -680,6 +672,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
680
  visible=False,
681
  )
682
 
 
683
  def execute_autoforge_script(
684
  current_filaments_df_state_val, input_image, *accordion_param_values
685
  ):
@@ -774,48 +767,24 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
774
  import threading
775
 
776
  class Worker(threading.Thread):
777
- def __init__(self, cmd, log_path):
778
- super().__init__(daemon=True)
779
- self.cmd, self.log_path = cmd, log_path
780
- self.returncode = None
781
- self.exc = None
782
-
783
- def run(self):
784
- """Import and run the local `auto_forge.py` module in-process.
785
-
786
- We load the script from the project dir as a fresh module using
787
- importlib.util.spec_from_file_location to ensure decorators like
788
- @spaces.GPU are executed at import time. Stdout/stderr are redirected
789
- to the run log to preserve the live console stream.
790
- """
791
- try:
792
- # Ensure the project directory is on sys.path so a plain `import auto_forge` finds the local file
793
- script_dir = os.path.dirname(os.path.abspath(__file__))
794
- if script_dir not in sys.path:
795
- sys.path.insert(0, script_dir)
796
-
797
- with open(self.log_path, "a", encoding="utf-8") as lf, redirect_stdout(lf), redirect_stderr(lf):
798
- try:
799
- # Provide argv for the module's CLI parsing and call main()
800
- sys.argv = ["autoforge"] + (self.cmd[1:] if len(self.cmd) > 1 else [])
801
- auto_forge.main()
802
- self.returncode = 0
803
- except Exception as e:
804
- lf.write(f"\nERROR while importing/running auto_forge: {exc_text(e)}\n")
805
- traceback.print_exc()
806
- self.exc = e
807
- if isinstance(e, SystemExit):
808
- self.returncode = e.code if isinstance(e.code, int) or e.code is None else 1
809
- else:
810
- self.returncode = -1
811
- except Exception as outer_e:
812
- self.exc = outer_e
813
- try:
814
- with open(self.log_path, "a", encoding="utf-8") as lf:
815
- lf.write(f"\nERROR loading autoforge.auto_forge: {exc_text(outer_e)}\n")
816
- except Exception:
817
- pass
818
- self.returncode = -1
819
 
820
  try:
821
  worker = Worker(command, log_file)
 
1
  import json
2
  import string
 
3
  import uuid
4
  import os
5
  import logging
6
  import zipfile
7
+ import importlib
8
  import wandb
9
  from contextlib import redirect_stdout, redirect_stderr
10
+ import spaces
 
11
 
12
  USE_WANDB = "WANDB_API_KEY" in os.environ
13
  if USE_WANDB:
 
98
  {
99
  "name": "--iterations",
100
  "type": "number",
101
+ "default": 4000,
102
  "help": "Number of optimization iterations",
103
  },
104
  {
 
159
  {
160
  "name": "--pruning_max_swaps",
161
  "type": "number",
162
+ "default": 20,
163
  "precision": 0,
164
  "help": "Max number of swaps allowed after pruning",
165
  },
 
182
  {
183
  "name": "--learning_rate_warmup_fraction",
184
  "type": "slider",
185
+ "default": 0.2,
186
  "min": 0.0,
187
  "max": 1.0,
188
  "step": 0.01,
 
214
  {
215
  "name": "--num_init_rounds",
216
  "type": "number",
217
+ "default": 8,
218
  "precision": 0,
219
  "help": "Number of rounds to choose the starting height map from.",
220
  },
 
295
  def run_autoforge_process(cmd, log_path):
296
  from joblib import parallel_backend
297
  cli_args = cmd[1:]
298
+ autoforge_main = importlib.import_module("autoforge.__main__")
299
 
300
  exit_code = 0
301
+ with open(log_path, "w", buffering=1, encoding="utf-8") as log_f, \
302
+ redirect_stdout(log_f), redirect_stderr(log_f), parallel_backend("threading", n_jobs=-1):
 
 
 
 
303
  try:
 
 
 
 
304
  sys.argv = ["autoforge"] + cli_args
305
+ autoforge_main.main()
306
  except SystemExit as e:
307
+ exit_code = e.code
308
  except Exception as e:
309
  log_f.write(f"\nERROR: {e}\n")
310
  exit_code = -1
 
672
  visible=False,
673
  )
674
 
675
+ @spaces.GPU(duration=150)
676
  def execute_autoforge_script(
677
  current_filaments_df_state_val, input_image, *accordion_param_values
678
  ):
 
767
  import threading
768
 
769
  class Worker(threading.Thread):
770
+ def __init__(self, cmd, log_path):
771
+ super().__init__(daemon=True)
772
+ self.cmd, self.log_path = cmd, log_path
773
+ self.returncode = None
774
+ self.exc = None
775
+
776
+ def run(self):
777
+ try:
778
+ self.returncode = run_autoforge_process(self.cmd, self.log_path)
779
+ except Exception as e:
780
+ self.exc = e
781
+ with open(self.log_path, "a", encoding="utf-8") as lf:
782
+ lf.write(
783
+ "\nERROR: {}. This usually means there was no GPU or the process took too long.\n".format(
784
+ exc_text(e)
785
+ )
786
+ )
787
+ self.returncode = -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
788
 
789
  try:
790
  worker = Worker(command, log_file)
auto_forge.py DELETED
@@ -1,1089 +0,0 @@
1
- """auto_forge.py
2
-
3
- High-level orchestration module for the AutoForge optimization pipeline.
4
-
5
- Responsibilities:
6
- - Parse CLI / config file arguments.
7
- - Load image and material properties.
8
- - (Optionally) auto-select a background filament color based on dominant image color.
9
- - Initialize a height map using one of several strategies (k-means clustering or depth estimation).
10
- - Build and run the filament optimization loop (differentiable + periodic discretization checks).
11
- - Optionally prune the solution to respect practical printer constraints (materials, swaps, layers).
12
- - Export final artifacts: preview PNG, STL(s), swap instructions, project file, metadata.
13
-
14
- The implementation intentionally keeps side-effects (disk writes / prints) order-stable to
15
- preserve prior behavior. Helper functions are factored out for readability; no functional
16
- behavior should have changed relative to the previous monolithic version.
17
- """
18
- import argparse
19
- import sys
20
- import os
21
- import traceback
22
- from typing import Optional, Tuple, List
23
-
24
- import configargparse
25
- import cv2
26
- try:
27
- import spaces
28
- except Exception:
29
- # Provide a minimal shim so @spaces.GPU can be used even when 'spaces' isn't installed.
30
- def _spaces_noop_decorator(fn=None):
31
- # Support usage as @spaces.GPU or @spaces.GPU()
32
- if fn is None:
33
- def _inner(f):
34
- return f
35
- return _inner
36
- return fn
37
-
38
- class _DummySpaces:
39
- GPU = staticmethod(_spaces_noop_decorator)
40
-
41
- spaces = _DummySpaces()
42
- import torch
43
- import numpy as np
44
- from tqdm import tqdm
45
-
46
- from autoforge.Helper import PruningHelper
47
- from autoforge.Helper.FilamentHelper import hex_to_rgb, load_materials
48
- from autoforge.Helper.Heightmaps.ChristofidesHeightMap import (
49
- run_init_threads,
50
- )
51
-
52
- from autoforge.Helper.ImageHelper import resize_image, imread
53
- from autoforge.Helper.OtherHelper import set_seed, perform_basic_check, get_device
54
- from autoforge.Helper.OutputHelper import (
55
- generate_stl,
56
- generate_swap_instructions,
57
- generate_project_file,
58
- generate_flatforge_stls,
59
- )
60
- from autoforge.Modules.Optimizer import FilamentOptimizer
61
-
62
- # check if we can use torch.set_float32_matmul_precision('high')
63
- if torch.__version__ >= "2.0.0":
64
- try:
65
- torch.set_float32_matmul_precision("high")
66
- except Exception as e:
67
- print("Warning: Could not set float32 matmul precision to high. Error:", e)
68
- pass
69
-
70
-
71
- def parse_args() -> argparse.Namespace:
72
- """Create and parse command-line & config-file arguments.
73
-
74
- Returns:
75
- argparse.Namespace: Populated arguments structure. Some parameters may be adjusted later
76
- (e.g., num_init_cluster_layers when -1 to infer from max_layers).
77
- """
78
- parser = configargparse.ArgParser()
79
- parser.add_argument("--config", is_config_file=True, help="Path to config file")
80
-
81
- parser.add_argument(
82
- "--input_image", type=str, required=True, help="Path to input image"
83
- )
84
- parser.add_argument(
85
- "--csv_file",
86
- type=str,
87
- default="",
88
- help="Path to CSV file with material data",
89
- )
90
- parser.add_argument(
91
- "--json_file",
92
- type=str,
93
- default="",
94
- help="Path to json file with material data",
95
- )
96
- parser.add_argument(
97
- "--output_folder", type=str, default="output", help="Folder to write outputs"
98
- )
99
-
100
- parser.add_argument(
101
- "--iterations", type=int, default=6000, help="Number of optimization iterations"
102
- )
103
-
104
- parser.add_argument(
105
- "--warmup_fraction",
106
- type=float,
107
- default=1.0,
108
- help="Fraction of iterations for keeping the tau at the initial value",
109
- )
110
-
111
- parser.add_argument(
112
- "--learning_rate_warmup_fraction",
113
- type=float,
114
- default=0.01,
115
- help="Fraction of iterations that the learning rate is increasing (warmup)",
116
- )
117
-
118
- parser.add_argument(
119
- "--init_tau",
120
- type=float,
121
- default=1.0,
122
- help="Initial tau value for Gumbel-Softmax",
123
- )
124
-
125
- parser.add_argument(
126
- "--final_tau",
127
- type=float,
128
- default=0.01,
129
- help="Final tau value for Gumbel-Softmax",
130
- )
131
-
132
- parser.add_argument(
133
- "--learning_rate",
134
- type=float,
135
- default=0.015,
136
- help="Learning rate for optimization",
137
- )
138
-
139
- parser.add_argument(
140
- "--layer_height", type=float, default=0.04, help="Layer thickness in mm"
141
- )
142
-
143
- parser.add_argument(
144
- "--max_layers", type=int, default=75, help="Maximum number of layers"
145
- )
146
-
147
- parser.add_argument(
148
- "--min_layers",
149
- type=int,
150
- default=0,
151
- help="Minimum number of layers. Used for pruning.",
152
- )
153
-
154
- parser.add_argument(
155
- "--background_height",
156
- type=float,
157
- default=0.24,
158
- help="Height of the background in mm",
159
- )
160
-
161
- parser.add_argument(
162
- "--background_color", type=str, default="#000000", help="Background color"
163
- )
164
-
165
- parser.add_argument(
166
- "--auto_background_color",
167
- default=True,
168
- help="Automatically set background color to the closest filament color matching the dominant image color. Overrides --background_color.",
169
- )
170
-
171
- parser.add_argument(
172
- "--visualize",
173
- type=bool,
174
- default=True,
175
- help="Enable visualization during optimization",
176
- action=argparse.BooleanOptionalAction,
177
- )
178
-
179
- # Instead of an output_size parameter, we use stl_output_size and nozzle_diameter.
180
- parser.add_argument(
181
- "--stl_output_size",
182
- type=int,
183
- default=150,
184
- help="Size of the longest dimension of the output STL file in mm",
185
- )
186
-
187
- parser.add_argument(
188
- "--processing_reduction_factor",
189
- type=int,
190
- default=2,
191
- help="Reduction factor for reducing the processing size compared to the output size (default: 2 - half resolution)",
192
- )
193
-
194
- parser.add_argument(
195
- "--nozzle_diameter",
196
- type=float,
197
- default=0.4,
198
- help="Diameter of the printer nozzle in mm (details smaller than half this value will be ignored)",
199
- )
200
-
201
- parser.add_argument(
202
- "--early_stopping",
203
- type=int,
204
- default=2000,
205
- help="Number of steps without improvement before stopping",
206
- )
207
-
208
- parser.add_argument(
209
- "--perform_pruning",
210
- type=bool,
211
- default=True,
212
- help="Perform pruning after optimization",
213
- action=argparse.BooleanOptionalAction,
214
- )
215
-
216
- parser.add_argument(
217
- "--fast_pruning",
218
- type=bool,
219
- default=True,
220
- help="Use fast pruning method",
221
- action=argparse.BooleanOptionalAction,
222
- )
223
- parser.add_argument(
224
- "--fast_pruning_percent",
225
- type=float,
226
- default=0.25,
227
- help="Percentage of increment search for fast pruning",
228
- )
229
-
230
- parser.add_argument(
231
- "--pruning_max_colors",
232
- type=int,
233
- default=100,
234
- help="Max number of colors allowed after pruning",
235
- )
236
- parser.add_argument(
237
- "--pruning_max_swaps",
238
- type=int,
239
- default=100,
240
- help="Max number of swaps allowed after pruning",
241
- )
242
-
243
- parser.add_argument(
244
- "--pruning_max_layer",
245
- type=int,
246
- default=75,
247
- help="Max number of layers allowed after pruning",
248
- )
249
-
250
- parser.add_argument(
251
- "--random_seed",
252
- type=int,
253
- default=0,
254
- help="Specify the random seed, or use 0 for automatic generation",
255
- )
256
-
257
- parser.add_argument(
258
- "--mps",
259
- action="store_true",
260
- help="Use the Metal Performance Shaders (MPS) backend, if available.",
261
- )
262
-
263
- parser.add_argument(
264
- "--run_name", type=str, help="Name of the run used for TensorBoard logging"
265
- )
266
-
267
- parser.add_argument(
268
- "--tensorboard", action="store_true", help="Enable TensorBoard logging"
269
- )
270
-
271
- parser.add_argument(
272
- "--num_init_rounds",
273
- type=int,
274
- default=16,
275
- help="Number of rounds to choose the starting height map from.",
276
- )
277
-
278
- parser.add_argument(
279
- "--num_init_cluster_layers",
280
- type=int,
281
- default=-1,
282
- help="Number of layers to cluster the image into.",
283
- )
284
-
285
- parser.add_argument(
286
- "--disable_visualization_for_gradio",
287
- type=int,
288
- default=0,
289
- help="Simple switch to disable the matplotlib render window for gradio rendering.",
290
- )
291
-
292
- parser.add_argument(
293
- "--best_of",
294
- type=int,
295
- default=1,
296
- help="Run the program multiple times and output the best result.",
297
- )
298
-
299
- parser.add_argument(
300
- "--discrete_check",
301
- type=int,
302
- default=100,
303
- help="Modulo how often to check for new discrete results.",
304
- )
305
-
306
- parser.add_argument(
307
- "--flatforge",
308
- type=bool,
309
- default=False,
310
- help="Enable FlatForge mode to generate separate STL files for each color",
311
- action=argparse.BooleanOptionalAction,
312
- )
313
-
314
- parser.add_argument(
315
- "--cap_layers",
316
- type=int,
317
- default=0,
318
- help="Number of complete clear/transparent layers to add on top in FlatForge mode",
319
- )
320
-
321
- # New: choose heightmap initializer
322
- parser.add_argument(
323
- "--init_heightmap_method",
324
- type=str,
325
- choices=["kmeans", "depth"],
326
- default="kmeans",
327
- help="Initializer for the height map: 'kmeans' (fast, default) or 'depth' (requires transformers).",
328
- )
329
- # New priority mask argument (optional)
330
- parser.add_argument(
331
- "--priority_mask",
332
- type=str,
333
- default="",
334
- help="Optional path to a priority mask image (same dimensions as input image). Non-empty: apply weighted loss (0.1 outside, 1.0 at max inside).",
335
- )
336
-
337
- args = parser.parse_args()
338
- return args
339
-
340
-
341
- def _compute_dominant_image_color(
342
- img_rgb: np.ndarray, alpha: Optional[np.ndarray]
343
- ) -> Optional[Tuple[str, np.ndarray]]:
344
- """Compute an approximate dominant color of the input image.
345
-
346
- Strategy:
347
- - Optionally downscale very large images for efficiency.
348
- - Ignore (mostly) transparent pixels if alpha channel is provided.
349
- - Use frequency counts (np.unique) over exact RGB triplets.
350
-
351
- Args:
352
- img_rgb: Image array in RGB order (H,W,3) uint8.
353
- alpha: Optional alpha mask (H,W,1) or (H,W) uint8; pixels <128 are ignored.
354
-
355
- Returns:
356
- (hex_color, normalized_rgb) where hex_color is a '#RRGGBB' string and normalized_rgb
357
- is float32 in [0,1]^3. Returns None if no valid pixels remain.
358
- """
359
- try:
360
- # Downscale if needed (max side 300 px)
361
- h, w = img_rgb.shape[:2]
362
- max_side = max(h, w)
363
- target_side = 300
364
- alpha_small: Optional[np.ndarray] = None
365
- if max_side > target_side:
366
- scale = target_side / max_side
367
- new_w = max(1, int(w * scale))
368
- new_h = max(1, int(h * scale))
369
- img_small = cv2.resize(
370
- img_rgb, (new_w, new_h), interpolation=cv2.INTER_AREA
371
- )
372
- if alpha is not None:
373
- alpha_small = cv2.resize(
374
- alpha, (new_w, new_h), interpolation=cv2.INTER_NEAREST
375
- )
376
- else:
377
- img_small = img_rgb
378
- alpha_small = alpha
379
- # Build mask for valid pixels (ignore transparent)
380
- if alpha_small is not None:
381
- valid_mask = (
382
- alpha_small[..., 0] if alpha_small.ndim == 3 else alpha_small
383
- ) >= 128
384
- else:
385
- valid_mask = np.ones(img_small.shape[:2], dtype=bool)
386
- if valid_mask.sum() == 0:
387
- return None
388
- pixels = img_small[valid_mask]
389
- # Use np.unique to find most frequent RGB triplet
390
- unique_colors, counts = np.unique(
391
- pixels.reshape(-1, 3), axis=0, return_counts=True
392
- )
393
- idx = int(np.argmax(counts))
394
- dom_rgb_uint8 = unique_colors[idx]
395
- dom_rgb_norm = dom_rgb_uint8.astype(np.float32) / 255.0
396
- hex_color = "#" + "".join(f"{c:02X}" for c in dom_rgb_uint8)
397
- return hex_color, dom_rgb_norm
398
- except Exception:
399
- traceback.print_exc()
400
- return None
401
-
402
-
403
- def _auto_select_background_color(
404
- args,
405
- img_rgb: np.ndarray,
406
- alpha: Optional[np.ndarray],
407
- material_colors_np: np.ndarray,
408
- material_names: List[str],
409
- colors_list: List[str],
410
- ) -> None:
411
- """Optionally override the user-provided background color with a closest material color.
412
-
413
- When --auto_background_color is set:
414
- - Determine dominant image color (ignoring transparency).
415
- - Find closest filament (Euclidean in normalized RGB).
416
- - Persist metadata to 'auto_background_color.txt'.
417
-
418
- Side effects: Mutates args.background_color and attaches background_material_* fields.
419
-
420
- Args:
421
- args: Global argument namespace (mutated).
422
- img_rgb: Full-resolution RGB image (uint8).
423
- alpha: Optional alpha channel for transparency filtering.
424
- material_colors_np: (N,3) array of filament RGB colors in [0,1].
425
- material_names: List of filament names.
426
- colors_list: List of filament hex color strings (#RRGGBB).
427
- """
428
- if not args.auto_background_color:
429
- return
430
- res = _compute_dominant_image_color(img_rgb, alpha)
431
- if res is not None:
432
- dominant_hex, dominant_rgb = res
433
- diffs = material_colors_np - dominant_rgb[None, :]
434
- dists = np.linalg.norm(diffs, axis=1)
435
- closest_idx = int(np.argmin(dists))
436
- chosen_hex = colors_list[closest_idx]
437
- print(
438
- f"Auto background color: dominant image color {dominant_hex} -> closest filament {chosen_hex} (index {closest_idx})."
439
- )
440
- args.background_color = chosen_hex
441
- args.background_material_index = closest_idx
442
- try:
443
- args.background_material_name = material_names[closest_idx]
444
- except Exception:
445
- args.background_material_name = None
446
- try:
447
- with open(
448
- os.path.join(args.output_folder, "auto_background_color.txt"), "w"
449
- ) as f:
450
- f.write(f"dominant_image_color={dominant_hex}\n")
451
- f.write(f"chosen_filament_color={chosen_hex}\n")
452
- f.write(f"closest_filament_index={closest_idx}\n")
453
- if getattr(args, "background_material_name", None):
454
- f.write(
455
- f"closest_filament_name={args.background_material_name}\n"
456
- )
457
- except Exception:
458
- traceback.print_exc()
459
- else:
460
- print(
461
- "Warning: Auto background color computation failed; using provided --background_color."
462
- )
463
-
464
-
465
- def _prepare_background_and_materials(
466
- args, device: torch.device, material_colors_np: np.ndarray, material_TDs_np: np.ndarray
467
- ) -> Tuple[Tuple[int, int, int], torch.Tensor, torch.Tensor, torch.Tensor]:
468
- """Create torch tensors for materials & background color.
469
-
470
- Args:
471
- args: Global arguments (uses background_color hex string).
472
- device: Torch device for tensor placement.
473
- material_colors_np: (N,3) float32 array in [0,1].
474
- material_TDs_np: (N,*) array of material transmission / diffusion parameters.
475
-
476
- Returns:
477
- (bgr_tuple_uint8, background_tensor, material_colors_tensor, material_TDs_tensor)
478
- """
479
- bgr_tuple = hex_to_rgb(args.background_color)
480
- background = torch.tensor(bgr_tuple, dtype=torch.float32, device=device)
481
- material_colors = torch.tensor(
482
- material_colors_np, dtype=torch.float32, device=device
483
- )
484
- material_TDs = torch.tensor(material_TDs_np, dtype=torch.float32, device=device)
485
- return bgr_tuple, background, material_colors, material_TDs
486
-
487
-
488
- def _compute_pixel_sizes(args) -> Tuple[int, int]:
489
- """Derive pixel dimensions for solving vs. output STL size.
490
-
491
- We oversample relative to nozzle_diameter to capture detail, then optionally downscale
492
- for the differentiable optimization pass.
493
-
494
- Returns:
495
- (computed_output_size, computed_processing_size)
496
- """
497
- computed_output_size = int(round(args.stl_output_size * 2 / args.nozzle_diameter))
498
- computed_processing_size = int(
499
- round(computed_output_size / args.processing_reduction_factor)
500
- )
501
- print(f"Computed solving pixel size: {computed_output_size}")
502
- return computed_output_size, computed_processing_size
503
-
504
-
505
- def _load_priority_mask(
506
- args, output_img_np: np.ndarray, device: torch.device
507
- ) -> Optional[torch.Tensor]:
508
- """Load and resize a priority / focus mask if provided.
509
-
510
- The mask scales heights during initialization and can later weight loss terms.
511
-
512
- Behavior:
513
- - Reads image; converts RGBA/RGB to grayscale.
514
- - Resizes to full-resolution output size.
515
- - Persists a diagnostic PNG after normalization.
516
-
517
- Returns:
518
- focus_map_full: Float32 tensor (H,W) in [0,1] or None if no mask provided.
519
- """
520
- focus_map_full = None
521
- if args.priority_mask != "":
522
- pm = imread(args.priority_mask, cv2.IMREAD_UNCHANGED)
523
- if pm.ndim == 3:
524
- if pm.shape[2] == 4:
525
- pm = pm[:, :, :3]
526
- pm = cv2.cvtColor(pm, cv2.COLOR_BGR2GRAY)
527
- tgt_h, tgt_w = output_img_np.shape[:2]
528
- pm_resized = cv2.resize(pm, (tgt_w, tgt_h), interpolation=cv2.INTER_LINEAR)
529
- pm_float = pm_resized.astype(np.float32) / 255.0
530
- focus_map_full = torch.tensor(pm_float, dtype=torch.float32, device=device)
531
- cv2.imwrite(
532
- os.path.join(args.output_folder, "priority_mask_resized.png"),
533
- (pm_float * 255).astype(np.uint8),
534
- )
535
- return focus_map_full
536
-
537
-
538
- def _initialize_heightmap(
539
- args,
540
- output_img_np: np.ndarray,
541
- bgr_tuple: Tuple[int, int, int],
542
- material_colors_np: np.ndarray,
543
- random_seed: int,
544
- ) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]:
545
- """Initialize the height map logits & labels using selected method.
546
-
547
- Methods:
548
- depth : Uses an external depth estimation model (requires transformers).
549
- kmeans : Clusters pixel colors into layer assignments (default).
550
-
551
- Returns:
552
- pixel_height_logits_init: (H,W) float32 numpy array of raw logits.
553
- global_logits_init : (L,*) global logits array or None (depth variant may not use it).
554
- pixel_height_labels : (H,W) int array of discrete initial layer indices.
555
- """
556
- print("Initalizing height map. This can take a moment...")
557
- if args.init_heightmap_method == "depth":
558
- try:
559
- from autoforge.Helper.Heightmaps.DepthEstimateHeightMap import (
560
- init_height_map_depth_color_adjusted,
561
- )
562
- except Exception:
563
- print(
564
- "Error: depth initializer requested but could not be imported. Install 'transformers' and try again.",
565
- file=sys.stderr,
566
- )
567
- raise
568
- pixel_height_logits_init, pixel_height_labels = (
569
- init_height_map_depth_color_adjusted(
570
- output_img_np,
571
- args.max_layers,
572
- random_seed=random_seed,
573
- focus_map=None,
574
- )
575
- )
576
- global_logits_init = None
577
- else:
578
- pixel_height_logits_init, global_logits_init, pixel_height_labels = (
579
- run_init_threads(
580
- output_img_np,
581
- args.max_layers,
582
- args.layer_height,
583
- bgr_tuple,
584
- random_seed=random_seed,
585
- num_threads=4,
586
- num_runs=args.num_init_rounds,
587
- init_method="kmeans",
588
- cluster_layers=args.num_init_cluster_layers,
589
- material_colors=material_colors_np,
590
- focus_map=None,
591
- )
592
- )
593
-
594
- return pixel_height_logits_init, global_logits_init, pixel_height_labels
595
-
596
-
597
- def _prepare_processing_targets(
598
- output_img_np: np.ndarray,
599
- computed_processing_size: int,
600
- device: torch.device,
601
- focus_map_full: Optional[torch.Tensor],
602
- ) -> Tuple[np.ndarray, torch.Tensor, Optional[torch.Tensor]]:
603
- """Create downscaled optimization target & focus map for faster iterations.
604
-
605
- Args:
606
- output_img_np: Full-resolution RGB image (float or uint8 expected).
607
- computed_processing_size: Target square size for processing (maintains aspect via resize helper).
608
- device: Torch device.
609
- focus_map_full: Optional full-resolution focus map tensor.
610
-
611
- Returns:
612
- processing_img_np : Downscaled numpy image (H_p,W_p,3).
613
- processing_target : Torch tensor version (float32) on device.
614
- focus_map_proc : Optional downscaled focus map tensor (H_p,W_p).
615
- """
616
- processing_img_np = resize_image(output_img_np, computed_processing_size)
617
- processing_target = torch.tensor(
618
- processing_img_np, dtype=torch.float32, device=device
619
- )
620
-
621
- focus_map_proc = None
622
- if focus_map_full is not None:
623
- fm_proc_np = cv2.resize(
624
- focus_map_full.cpu().numpy().astype(np.float32),
625
- (processing_target.shape[1], processing_target.shape[0]),
626
- interpolation=cv2.INTER_LINEAR,
627
- )
628
- focus_map_proc = torch.tensor(fm_proc_np, dtype=torch.float32, device=device)
629
-
630
- return processing_img_np, processing_target, focus_map_proc
631
-
632
-
633
- def _build_optimizer(
634
- args,
635
- processing_target: torch.Tensor,
636
- processing_pixel_height_logits_init: np.ndarray,
637
- processing_pixel_height_labels: np.ndarray,
638
- global_logits_init,
639
- material_colors: torch.Tensor,
640
- material_TDs: torch.Tensor,
641
- background: torch.Tensor,
642
- device: torch.device,
643
- perception_loss_module,
644
- focus_map_proc: Optional[torch.Tensor],
645
- ) -> FilamentOptimizer:
646
- """Instantiate the FilamentOptimizer with initial tensors and configuration.
647
-
648
- Args mirror the optimizer's constructor; this function simply centralizes assembly.
649
-
650
- Returns:
651
- FilamentOptimizer: Ready-to-run optimizer instance.
652
- """
653
- optimizer = FilamentOptimizer(
654
- args=args,
655
- target=processing_target,
656
- pixel_height_logits_init=processing_pixel_height_logits_init,
657
- pixel_height_labels=processing_pixel_height_labels,
658
- global_logits_init=global_logits_init,
659
- material_colors=material_colors,
660
- material_TDs=material_TDs,
661
- background=background,
662
- device=device,
663
- perception_loss_module=perception_loss_module,
664
- focus_map=focus_map_proc,
665
- )
666
- return optimizer
667
-
668
- @spaces.GPU
669
- def _run_optimization_loop(optimizer: FilamentOptimizer, args, device: torch.device) -> None:
670
- """Execute the main gradient-based optimization iterations.
671
-
672
- Features:
673
- - Automatic mixed precision (bfloat16 unless MPS).
674
- - Periodic visualization & tensorboard logging (every 100 iterations).
675
- - Discrete solution snapshots controlled via --discrete_check.
676
- - Early stopping after a patience window (--early_stopping).
677
-
678
- Args:
679
- optimizer: Configured FilamentOptimizer instance.
680
- args: Global argument namespace.
681
- device: Torch device for autocast context.
682
- """
683
- print("Starting optimization...")
684
- tbar = tqdm(range(args.iterations))
685
- dtype = torch.bfloat16 if not args.mps else torch.float32
686
- with torch.autocast(device.type, dtype=dtype):
687
- for i in tbar:
688
- loss_val = optimizer.step(record_best=i % args.discrete_check == 0)
689
-
690
- optimizer.visualize(interval=100)
691
- optimizer.log_to_tensorboard(interval=100)
692
-
693
- if (i + 1) % 100 == 0:
694
- tbar.set_description(
695
- f"Iteration {i + 1}, Loss = {loss_val:.4f}, best validation Loss = {optimizer.best_discrete_loss:.4f}, learning_rate= {optimizer.current_learning_rate:.6f}"
696
- )
697
- if (
698
- optimizer.best_step is not None
699
- and optimizer.num_steps_done - optimizer.best_step > args.early_stopping
700
- ):
701
- print(
702
- "Early stopping after",
703
- args.early_stopping,
704
- "steps without improvement.",
705
- )
706
- break
707
-
708
-
709
-
710
- def _post_optimize_and_export(
711
- args,
712
- optimizer: FilamentOptimizer,
713
- pixel_height_logits_init: np.ndarray,
714
- pixel_height_labels: np.ndarray,
715
- output_target: torch.Tensor,
716
- alpha: Optional[np.ndarray],
717
- material_colors_np: np.ndarray,
718
- material_TDs_np: np.ndarray,
719
- material_names: List[str],
720
- bgr_tuple: Tuple[int, int, int],
721
- device: torch.device,
722
- focus_map_full: Optional[torch.Tensor],
723
- focus_map_proc: Optional[torch.Tensor],
724
- ) -> float:
725
- """Finalize solution, optionally prune, and write all output artifacts.
726
-
727
- Steps:
728
- - Restore full-resolution logits to optimizer and (optionally) height residual.
729
- - Replace focus map with full-res version if used.
730
- - Perform pruning (respecting color slots for background & clear in FlatForge mode).
731
- - Compute final loss estimate and persist to file.
732
- - Export preview PNG, STL(s), swap instructions & project file.
733
-
734
- Returns:
735
- float: The final reported loss (post-pruning).
736
- """
737
- post_opt_step = 0
738
-
739
- optimizer.log_to_tensorboard(
740
- interval=1, namespace="post_opt", step=(post_opt_step := post_opt_step + 1)
741
- )
742
-
743
- optimizer.pixel_height_logits = torch.from_numpy(pixel_height_logits_init)
744
- optimizer.best_params["pixel_height_logits"] = torch.from_numpy(
745
- pixel_height_logits_init
746
- ).to(device)
747
- optimizer.target = output_target
748
- optimizer.pixel_height_labels = torch.tensor(
749
- pixel_height_labels, dtype=torch.int32, device=device
750
- )
751
- if focus_map_proc is not None and focus_map_full is not None:
752
- optimizer.focus_map = focus_map_full
753
-
754
- dtype = torch.bfloat16 if not args.mps else torch.float32
755
- with torch.no_grad():
756
- with torch.autocast(device.type, dtype=dtype):
757
- if args.perform_pruning:
758
- # Adjust pruning_max_colors to account for background and clear filament
759
- # pruning_max_colors = total filaments needed
760
- # Need to reserve slots: 1 for background (always), 1 for clear (FlatForge only)
761
- max_colors_for_pruning = args.pruning_max_colors
762
-
763
- if args.flatforge:
764
- # FlatForge: pruning_max_colors = colored + clear + background
765
- # Reserve 2 slots (1 clear + 1 background)
766
- max_colors_for_pruning = max(1, args.pruning_max_colors - 2)
767
- else:
768
- # Traditional: pruning_max_colors = colored + background
769
- # Reserve 1 slot for background
770
- max_colors_for_pruning = max(1, args.pruning_max_colors - 1)
771
-
772
- post_opt_step = run_pruning(args, max_colors_for_pruning, optimizer, post_opt_step)
773
-
774
- disc_global, disc_height_image = optimizer.get_discretized_solution(
775
- best=True
776
- )
777
-
778
- final_loss = PruningHelper.get_initial_loss(
779
- optimizer.best_params["global_logits"].shape[0], optimizer
780
- )
781
- with open(os.path.join(args.output_folder, "final_loss.txt"), "w") as f:
782
- f.write(f"{final_loss}")
783
-
784
- print("Done. Saving outputs...")
785
- comp_disc = optimizer.get_best_discretized_image()
786
- args.max_layers = optimizer.max_layers
787
-
788
- optimizer.log_to_tensorboard(
789
- interval=1,
790
- namespace="post_opt",
791
- step=(post_opt_step := post_opt_step + 1),
792
- )
793
-
794
- comp_disc_np = comp_disc.cpu().numpy().astype(np.uint8)
795
- comp_disc_np = cv2.cvtColor(comp_disc_np, cv2.COLOR_RGB2BGR)
796
- cv2.imwrite(
797
- os.path.join(args.output_folder, "final_model.png"), comp_disc_np
798
- )
799
-
800
- # Generate STL files
801
- if args.flatforge:
802
- # FlatForge mode: Generate separate STL files for each color
803
- print("FlatForge mode enabled. Generating separate STL files...")
804
- generate_flatforge_stls(
805
- disc_global.cpu().numpy(),
806
- disc_height_image.cpu().numpy(),
807
- material_colors_np,
808
- material_names,
809
- material_TDs_np,
810
- args.layer_height,
811
- args.background_height,
812
- args.background_color,
813
- args.stl_output_size,
814
- args.output_folder,
815
- cap_layers=args.cap_layers,
816
- alpha_mask=alpha,
817
- )
818
- else:
819
- # Traditional mode: Generate single STL file
820
- stl_filename = os.path.join(args.output_folder, "final_model.stl")
821
- height_map_mm = (
822
- disc_height_image.cpu().numpy().astype(np.float32)
823
- ) * args.layer_height
824
- generate_stl(
825
- height_map_mm,
826
- stl_filename,
827
- args.background_height,
828
- maximum_x_y_size=args.stl_output_size,
829
- alpha_mask=alpha,
830
- )
831
-
832
- if not args.flatforge:
833
- background_layers = int(args.background_height // args.layer_height)
834
- swap_instructions = generate_swap_instructions(
835
- disc_global.cpu().numpy(),
836
- disc_height_image.cpu().numpy(),
837
- args.layer_height,
838
- background_layers,
839
- args.background_height,
840
- material_names,
841
- getattr(args, "background_material_name", None),
842
- )
843
- with open(
844
- os.path.join(args.output_folder, "swap_instructions.txt"), "w"
845
- ) as f:
846
- for line in swap_instructions:
847
- f.write(line + "\n")
848
-
849
- project_filename = os.path.join(args.output_folder, "project_file.hfp")
850
- generate_project_file(
851
- project_filename,
852
- args,
853
- disc_global.cpu().numpy(),
854
- disc_height_image.cpu().numpy(),
855
- output_target.shape[1],
856
- output_target.shape[0],
857
- os.path.join(args.output_folder, "final_model.stl"),
858
- args.csv_file,
859
- )
860
-
861
- print("All done. Outputs in:", args.output_folder)
862
- print("Happy Printing!")
863
- return final_loss
864
-
865
- @spaces.GPU
866
- def run_pruning(args, max_colors_for_pruning: int, optimizer: FilamentOptimizer, post_opt_step: int) -> int:
867
- optimizer.prune(
868
- max_colors_allowed=max_colors_for_pruning,
869
- max_swaps_allowed=args.pruning_max_swaps,
870
- min_layers_allowed=args.min_layers,
871
- max_layers_allowed=args.pruning_max_layer,
872
- search_seed=True,
873
- fast_pruning=args.fast_pruning,
874
- fast_pruning_percent=args.fast_pruning_percent,
875
- )
876
- optimizer.log_to_tensorboard(
877
- interval=1,
878
- namespace="post_opt",
879
- step=(post_opt_step := post_opt_step + 1),
880
- )
881
- return post_opt_step
882
-
883
-
884
- def start(args) -> float:
885
- """Entry point for a single optimization run.
886
-
887
- Orchestrates the entire pipeline:
888
- - Validation & device selection.
889
- - Material & image loading (+ optional auto background selection).
890
- - Resolution computation & resizing.
891
- - Heightmap initialization.
892
- - Optimizer construction & iterative optimization loop.
893
- - Post-processing, pruning, and output generation.
894
-
895
- Args:
896
- args: Parsed argument namespace.
897
-
898
- Returns:
899
- float: Final loss value for this run (after pruning/export).
900
- """
901
- if args.num_init_cluster_layers == -1:
902
- args.num_init_cluster_layers = args.max_layers
903
-
904
- # check if csv or json is given
905
- if args.csv_file == "" and args.json_file == "":
906
- print("Error: No CSV or JSON file given. Please provide one of them.")
907
- sys.exit(1)
908
-
909
- device = torch.device("cpu")
910
-
911
- os.makedirs(args.output_folder, exist_ok=True)
912
-
913
- perform_basic_check(args)
914
-
915
- random_seed = set_seed(args)
916
-
917
- # Load materials (we keep colors_list for potential auto background)
918
- material_colors_np, material_TDs_np, material_names, colors_list = load_materials(
919
- args
920
- )
921
-
922
- # Read input image early (needed for auto background color)
923
- img = imread(args.input_image, cv2.IMREAD_UNCHANGED)
924
- alpha = None
925
- if img.shape[2] == 4:
926
- alpha = img[:, :, 3]
927
- alpha = alpha[..., None]
928
- img = img[:, :, :3]
929
-
930
- # Convert image from BGR to RGB for color analysis
931
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
932
-
933
- # Auto background color selection (optional)
934
- _auto_select_background_color(
935
- args, img_rgb, alpha, material_colors_np, material_names, colors_list
936
- )
937
-
938
- # Prepare background color tensor and material tensors
939
- bgr_tuple, background, material_colors, material_TDs = _prepare_background_and_materials(
940
- args, device, material_colors_np, material_TDs_np
941
- )
942
-
943
- # Compute sizes
944
- computed_output_size, computed_processing_size = _compute_pixel_sizes(args)
945
-
946
- # Resize alpha if present (match final resolution) after computing size
947
- if alpha is not None:
948
- alpha = resize_image(alpha, computed_output_size)
949
-
950
- # For the final resolution
951
- output_img_np = resize_image(img_rgb, computed_output_size)
952
- output_target = torch.tensor(output_img_np, dtype=torch.float32, device=device)
953
-
954
- # Priority mask handling (full-res)
955
- focus_map_full = _load_priority_mask(args, output_img_np, device)
956
-
957
- # Initialize heightmap
958
- pixel_height_logits_init, global_logits_init, pixel_height_labels = _initialize_heightmap(
959
- args,
960
- output_img_np,
961
- bgr_tuple,
962
- material_colors_np,
963
- random_seed,
964
- )
965
-
966
- # Prepare processing targets and focus map (processing-res)
967
- processing_img_np, processing_target, focus_map_proc = _prepare_processing_targets(
968
- output_img_np, computed_processing_size, device, focus_map_full
969
- )
970
-
971
- # Downscale initial logits/labels to processing resolution
972
- processing_pixel_height_logits_init = cv2.resize(
973
- src=pixel_height_logits_init,
974
- interpolation=cv2.INTER_NEAREST,
975
- dsize=(processing_target.shape[1], processing_target.shape[0]),
976
- )
977
- processing_pixel_height_labels = cv2.resize(
978
- src=pixel_height_labels,
979
- interpolation=cv2.INTER_NEAREST,
980
- dsize=(processing_target.shape[1], processing_target.shape[0]),
981
- )
982
-
983
- # Apply alpha mask to full-res logits (keep original order/behavior)
984
- if alpha is not None:
985
- pixel_height_logits_init[alpha < 128] = -13.815512
986
-
987
- perception_loss_module = None
988
-
989
- # Build optimizer
990
- optimizer = _build_optimizer(
991
- args,
992
- processing_target,
993
- processing_pixel_height_logits_init,
994
- processing_pixel_height_labels,
995
- global_logits_init,
996
- material_colors,
997
- material_TDs,
998
- background,
999
- device,
1000
- perception_loss_module,
1001
- focus_map_proc,
1002
- )
1003
-
1004
- # Run optimization loop
1005
- _run_optimization_loop(optimizer, args, torch.device("cuda"))
1006
-
1007
- # Post-process, prune, and export outputs
1008
- final_loss = _post_optimize_and_export(
1009
- args,
1010
- optimizer,
1011
- pixel_height_logits_init,
1012
- pixel_height_labels,
1013
- output_target,
1014
- alpha,
1015
- material_colors_np,
1016
- material_TDs_np,
1017
- material_names,
1018
- bgr_tuple,
1019
- torch.device("cuda"),
1020
- focus_map_full,
1021
- focus_map_proc,
1022
- )
1023
-
1024
- return final_loss
1025
-
1026
-
1027
- def main() -> None:
1028
- """Support multi-run execution via --best_of; persist best run artifacts.
1029
-
1030
- If --best_of == 1, simply invokes a single start(). Otherwise:
1031
- - Creates temporary run subfolders.
1032
- - Tracks losses, reports statistics (best / median / std).
1033
- - Moves files from best run folder into the final output folder.
1034
-
1035
- Note: Memory is periodically reclaimed (gc + CUDA cache clears + closing matplotlib figures).
1036
- """
1037
- args = parse_args()
1038
- final_output_folder = args.output_folder
1039
- run_best_loss = 1000000000
1040
- if args.best_of == 1:
1041
- start(args)
1042
- else:
1043
- temp_output_folder = os.path.join(args.output_folder, "temp")
1044
- ret = []
1045
- for i in range(args.best_of):
1046
- try:
1047
- print(f"Run {i + 1}/{args.best_of}")
1048
- run_folder = os.path.join(temp_output_folder, f"run_{i + 1}")
1049
- args.output_folder = run_folder
1050
- os.makedirs(args.output_folder, exist_ok=True)
1051
- run_loss = start(args)
1052
- print(f"Run {i + 1} finished with loss: {run_loss}")
1053
- if run_loss < run_best_loss:
1054
- run_best_loss = run_loss
1055
- print(f"New best loss found: {run_best_loss} in run {i + 1}")
1056
- ret.append((run_folder, run_loss))
1057
- torch.cuda.empty_cache()
1058
- import gc
1059
-
1060
- gc.collect()
1061
- torch.cuda.empty_cache()
1062
- import matplotlib.pyplot as plt
1063
-
1064
- plt.close("all")
1065
- except Exception:
1066
- traceback.print_exc()
1067
- best_run = min(ret, key=lambda x: x[1])
1068
- best_run_folder = best_run[0]
1069
- best_loss = best_run[1]
1070
-
1071
- losses = [x[1] for x in ret]
1072
- median_loss = np.median(losses)
1073
- std_loss = np.std(losses)
1074
- print(f"Best run folder: {best_run_folder}")
1075
- print(f"Best run loss: {best_loss}")
1076
- print(f"Median loss: {median_loss}")
1077
- print(f"Standard deviation of losses: {std_loss}")
1078
-
1079
- if not os.path.exists(final_output_folder):
1080
- os.makedirs(final_output_folder)
1081
- for file in os.listdir(best_run_folder):
1082
- src_file = os.path.join(best_run_folder, file)
1083
- dst_file = os.path.join(final_output_folder, file)
1084
- if os.path.isfile(src_file):
1085
- os.rename(src_file, dst_file)
1086
-
1087
-
1088
- if __name__ == "__main__":
1089
- main()