Path: blob/master/extensions-builtin/Lora/network_ia3.py
2447 views
import network123class ModuleTypeIa3(network.ModuleType):4def create_module(self, net: network.Network, weights: network.NetworkWeights):5if all(x in weights.w for x in ["weight"]):6return NetworkModuleIa3(net, weights)78return None91011class NetworkModuleIa3(network.NetworkModule):12def __init__(self, net: network.Network, weights: network.NetworkWeights):13super().__init__(net, weights)1415self.w = weights.w["weight"]16self.on_input = weights.w["on_input"].item()1718def calc_updown(self, orig_weight):19w = self.w.to(orig_weight.device)2021output_shape = [w.size(0), orig_weight.size(1)]22if self.on_input:23output_shape.reverse()24else:25w = w.reshape(-1, 1)2627updown = orig_weight * w2829return self.finalize_updown(updown, orig_weight, output_shape)303132