importdataclassesimporttypingastimportyamlfromzntrack.configimportPARAMS_FILE_PATH,FieldTypesfromzntrack.fields.baseimportfieldfromzntrack.nodeimportNodefromzntrack.utils.filesystemimportresolve_state_file_path_T=t.TypeVar("_T")def_params_getter(self:"Node",name:str):params_path=resolve_state_file_path(self.state.fs,self.state.path,PARAMS_FILE_PATH)withself.state.fs.open(params_path)asf:returnyaml.safe_load(f)[self.name][name]# Overloads for type checking@t.overloaddefparams()->t.Any:...@t.overloaddefparams(default:_T,**kwargs)->_T:...@t.overloaddefparams(*,default_factory:t.Callable[[],_T],**kwargs)->_T:...
[docs]defparams(default=dataclasses.MISSING,*,default_factory=dataclasses.MISSING,**kwargs)->t.Any:"""ZnTrack parameter field. A field to define a parameter for a ZnTrack node. Parameters ---------- default : dict|int|str|float|list|None, optional Set a default parameter value. default_factory : callable, optional A callable that returns the default value. Should be used instead of `default` if the default value is mutable. Examples -------- >>> import zntrack >>> class MyNode(zntrack.Node): ... param: int = zntrack.params(default=42) ... ... def run(self) -> None: ... ... >>> a = MyNode() >>> a.param 42 >>> b = MyNode(param=43) >>> b.param 43 """# TODO: check types, do not allow e.g. connections# or anything that can not be serializedreturnfield(default=default,default_factory=default_factory,field_type=FieldTypes.PARAMS,load_fn=_params_getter,suffix=None,cache=kwargs.pop("cache",True),independent=kwargs.pop("independent",False),**kwargs,)