Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,65 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: cc-by-3.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: cc-by-3.0
|
| 3 |
+
actor: semantic_segmentation
|
| 4 |
+
patch_size: 256
|
| 5 |
+
clip_size: 32
|
| 6 |
+
max_batch_size: 128
|
| 7 |
+
device: cuda
|
| 8 |
+
features: [
|
| 9 |
+
"s2med_harvest:B04",
|
| 10 |
+
"s2med_harvest:B03",
|
| 11 |
+
"s2med_harvest:B02",
|
| 12 |
+
"s2med_harvest:B08",
|
| 13 |
+
"s2med_planting:B04",
|
| 14 |
+
"s2med_planting:B03",
|
| 15 |
+
"s2med_planting:B02",
|
| 16 |
+
"s2med_planting:B08"
|
| 17 |
+
]
|
| 18 |
+
labels: [
|
| 19 |
+
non_field_background,
|
| 20 |
+
field,
|
| 21 |
+
field_boundaries
|
| 22 |
+
]
|
| 23 |
+
merge_mode: weighted_average
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
Exported using the following code:
|
| 28 |
+
|
| 29 |
+
```python
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
import torch
|
| 32 |
+
import torchvision.transforms.v2 as T
|
| 33 |
+
from stac_model.torch.export import save
|
| 34 |
+
import segmentation_models_pytorch as smp
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
path = "FTW-Release-Full-3-class-unet-efficientnetb5-weight0.75-3xlonger.ckpt"
|
| 38 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
| 39 |
+
hparams = ckpt["hyper_parameters"]
|
| 40 |
+
state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}
|
| 41 |
+
del state_dict["criterion.weight"]
|
| 42 |
+
model = smp.Unet(
|
| 43 |
+
encoder_name=hparams["backbone"],
|
| 44 |
+
encoder_weights=None,
|
| 45 |
+
in_channels=hparams["in_channels"],
|
| 46 |
+
classes=hparams["num_classes"],
|
| 47 |
+
)
|
| 48 |
+
model.load_state_dict(state_dict, strict=True)
|
| 49 |
+
|
| 50 |
+
transforms = torch.nn.Sequential(
|
| 51 |
+
torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
|
| 52 |
+
T.Normalize(mean=[0.0], std=[3000.0])
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
save(
|
| 56 |
+
output_file=Path("model.pt2"),
|
| 57 |
+
input_shape=[-1, hparams["in_channels"], -1, -1],
|
| 58 |
+
model=model,
|
| 59 |
+
transforms=transforms,
|
| 60 |
+
metadata=None,
|
| 61 |
+
device="cpu",
|
| 62 |
+
dtype=torch.float32,
|
| 63 |
+
aoti_compile_and_package=False
|
| 64 |
+
)
|
| 65 |
+
```
|