isaaccorley commited on
Commit
342c3d8
·
verified ·
1 Parent(s): 44bed20

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +65 -3
README.md CHANGED
@@ -1,3 +1,65 @@
1
- ---
2
- license: cc-by-3.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-3.0
3
+ actor: semantic_segmentation
4
+ patch_size: 256
5
+ clip_size: 32
6
+ max_batch_size: 128
7
+ device: cuda
8
+ features: [
9
+ "s2med_harvest:B04",
10
+ "s2med_harvest:B03",
11
+ "s2med_harvest:B02",
12
+ "s2med_harvest:B08",
13
+ "s2med_planting:B04",
14
+ "s2med_planting:B03",
15
+ "s2med_planting:B02",
16
+ "s2med_planting:B08"
17
+ ]
18
+ labels: [
19
+ non_field_background,
20
+ field,
21
+ field_boundaries
22
+ ]
23
+ merge_mode: weighted_average
24
+ ---
25
+
26
+
27
+ Exported using the following code:
28
+
29
+ ```python
30
+ from pathlib import Path
31
+ import torch
32
+ import torchvision.transforms.v2 as T
33
+ from stac_model.torch.export import save
34
+ import segmentation_models_pytorch as smp
35
+
36
+
37
+ path = "FTW-Release-Full-3-class-unet-efficientnetb5-weight0.75-3xlonger.ckpt"
38
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
39
+ hparams = ckpt["hyper_parameters"]
40
+ state_dict = {k.replace("model.", ""): v for k, v in ckpt["state_dict"].items()}
41
+ del state_dict["criterion.weight"]
42
+ model = smp.Unet(
43
+ encoder_name=hparams["backbone"],
44
+ encoder_weights=None,
45
+ in_channels=hparams["in_channels"],
46
+ classes=hparams["num_classes"],
47
+ )
48
+ model.load_state_dict(state_dict, strict=True)
49
+
50
+ transforms = torch.nn.Sequential(
51
+ torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
52
+ T.Normalize(mean=[0.0], std=[3000.0])
53
+ )
54
+
55
+ save(
56
+ output_file=Path("model.pt2"),
57
+ input_shape=[-1, hparams["in_channels"], -1, -1],
58
+ model=model,
59
+ transforms=transforms,
60
+ metadata=None,
61
+ device="cpu",
62
+ dtype=torch.float32,
63
+ aoti_compile_and_package=False
64
+ )
65
+ ```