Coretex
task_config.py
1 # Copyright (C) 2023 Coretex LLC
2 
3 # This file is part of Coretex.ai
4 
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License as
7 # published by the Free Software Foundation, either version 3 of the
8 # License, or (at your option) any later version.
9 
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU Affero General Public License for more details.
14 
15 # You should have received a copy of the GNU Affero General Public License
16 # along with this program. If not, see <https://www.gnu.org/licenses/>.
17 
18 from typing import Optional, List, Dict, Any
19 from typing_extensions import Self
20 from pathlib import Path
21 
22 import yaml
23 
24 from ...entities import BaseParameter, parameter_factory
25 from ...codable import Codable, KeyDescriptor
26 
27 
28 TASK_CONFIG_PATH = Path(".", "task.yaml")
29 
30 
31 class ParamGroup(Codable):
32 
33  name: str
34  params: Optional[List[BaseParameter]]
35 
36  @classmethod
37  def _decodeValue(cls, key: str, value: Any) -> Any:
38  if key == "params":
39  return [parameter_factory.create(obj) for obj in value]
40 
41  return super()._decodeValue(key, value)
42 
43 
44 class TaskConfig(Codable):
45 
46  paramGroups: Optional[List[ParamGroup]]
47 
48  @classmethod
49  def _keyDescriptors(cls) -> Dict[str, KeyDescriptor]:
50  descriptors = super()._keyDescriptors()
51  descriptors["paramGroups"] = KeyDescriptor("param_groups", ParamGroup, list)
52 
53  return descriptors
54 
55  @classmethod
56  def decode(cls, params: dict) -> Self:
57  if params.get("param_groups") is None:
58  params["param_groups"] = []
59 
60  return super().decode(params)
61 
62 
63 def readTaskConfig() -> List[BaseParameter]:
64  parameters: List[BaseParameter] = []
65 
66  if not TASK_CONFIG_PATH.exists():
67  return []
68 
69  with TASK_CONFIG_PATH.open("r") as configFile:
70  config = TaskConfig.decode(yaml.safe_load(configFile))
71 
72  if config.paramGroups is not None:
73  for group in config.paramGroups:
74  if group.params is not None:
75  parameters.extend(group.params)
76 
77  return parameters