Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
AUTOMATIC1111
GitHub Repository: AUTOMATIC1111/stable-diffusion-webui
Path: blob/master/extensions-builtin/Lora/network_norm.py
2447 views
1
import network
2
3
4
class ModuleTypeNorm(network.ModuleType):
5
def create_module(self, net: network.Network, weights: network.NetworkWeights):
6
if all(x in weights.w for x in ["w_norm", "b_norm"]):
7
return NetworkModuleNorm(net, weights)
8
9
return None
10
11
12
class NetworkModuleNorm(network.NetworkModule):
13
def __init__(self, net: network.Network, weights: network.NetworkWeights):
14
super().__init__(net, weights)
15
16
self.w_norm = weights.w.get("w_norm")
17
self.b_norm = weights.w.get("b_norm")
18
19
def calc_updown(self, orig_weight):
20
output_shape = self.w_norm.shape
21
updown = self.w_norm.to(orig_weight.device)
22
23
if self.b_norm is not None:
24
ex_bias = self.b_norm.to(orig_weight.device)
25
else:
26
ex_bias = None
27
28
return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)
29
30