Instructions to use xusenlin/uie-base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use xusenlin/uie-base with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="xusenlin/uie-base", trust_remote_code=True)# Load model directly from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("xusenlin/uie-base", trust_remote_code=True) model = AutoModel.from_pretrained("xusenlin/uie-base", trust_remote_code=True) - Notebooks
- Google Colab
- Kaggle
| import logging | |
| import math | |
| import os | |
| import queue | |
| import re | |
| from multiprocessing import Queue | |
| from typing import ( | |
| List, | |
| Tuple, | |
| Union, | |
| Dict, | |
| Any, | |
| Set, | |
| TYPE_CHECKING, | |
| Optional, | |
| Literal, | |
| ) | |
| import numpy as np | |
| import torch | |
| import torch.multiprocessing as mp | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from transformers import is_torch_npu_available | |
| if TYPE_CHECKING: | |
| from transformers import PreTrainedTokenizer | |
| os.environ["PYTHONWARNINGS"] = "ignore" | |
| logger = logging.getLogger("FASTIE") | |
| def get_id_and_prob(spans, offset_map): | |
| prompt_length = 0 | |
| for i in range(1, len(offset_map)): | |
| if offset_map[i] != [0, 0]: | |
| prompt_length += 1 | |
| else: | |
| break | |
| for i in range(1, prompt_length + 1): | |
| offset_map[i][0] -= (prompt_length + 1) | |
| offset_map[i][1] -= (prompt_length + 1) | |
| sentence_id = [] | |
| prob = [] | |
| for start, end in spans: | |
| prob.append(float(start[1] * end[1])) | |
| sentence_id.append( | |
| (offset_map[start[0]][0], offset_map[end[0]][1])) | |
| return sentence_id, prob | |
| def get_span( | |
| start_ids: Union[List[int], List[Tuple[int, float]]], | |
| end_ids: Union[List[int], List[Tuple[int, float]]], | |
| with_prob: bool = False | |
| ) -> Set[Tuple[int, int]]: | |
| """ | |
| Get span set from position start and end list. | |
| Args: | |
| start_ids (List[int]/List[tuple]): The start index list. | |
| end_ids (List[int]/List[tuple]): The end index list. | |
| with_prob (bool): If True, each element for start_ids and end_ids is a tuple aslike: (index, probability). | |
| Returns: | |
| set: The span set without overlapping, every id can only be used once. | |
| """ | |
| if with_prob: | |
| start_ids = sorted(start_ids, key=lambda x: x[0]) | |
| end_ids = sorted(end_ids, key=lambda x: x[0]) | |
| else: | |
| start_ids = sorted(start_ids) | |
| end_ids = sorted(end_ids) | |
| start_pointer = 0 | |
| end_pointer = 0 | |
| len_start = len(start_ids) | |
| len_end = len(end_ids) | |
| couple_dict = {} | |
| # 将每一个span的首/尾token的id进行配对(就近匹配,默认没有overlap的情况) | |
| while start_pointer < len_start and end_pointer < len_end: | |
| if with_prob: | |
| start_id = start_ids[start_pointer][0] | |
| end_id = end_ids[end_pointer][0] | |
| else: | |
| start_id = start_ids[start_pointer] | |
| end_id = end_ids[end_pointer] | |
| if start_id == end_id: | |
| couple_dict[end_ids[end_pointer]] = start_ids[start_pointer] | |
| start_pointer += 1 | |
| end_pointer += 1 | |
| continue | |
| if start_id < end_id: | |
| couple_dict[end_ids[end_pointer]] = start_ids[start_pointer] | |
| start_pointer += 1 | |
| continue | |
| if start_id > end_id: | |
| end_pointer += 1 | |
| continue | |
| result = [(couple_dict[end], end) for end in couple_dict] | |
| result = set(result) | |
| return result | |
| def get_bool_ids_greater_than( | |
| probs: List[List[float]], limit: float = 0.5, return_prob: bool = False | |
| ) -> List[List[int]]: | |
| """ | |
| Get idx of the last dimension in probability arrays, which is greater than a limitation. | |
| Args: | |
| probs (List[List[float]]): The input probability arrays. | |
| limit (float): The limitation for probability. | |
| return_prob (bool): Whether to return the probability | |
| Returns: | |
| List[List[int]]: The index of the last dimension meet the conditions. | |
| """ | |
| probs = np.array(probs) | |
| dim_len = len(probs.shape) | |
| if dim_len > 1: | |
| result = [] | |
| for p in probs: | |
| result.append(get_bool_ids_greater_than(p, limit, return_prob)) | |
| return result | |
| else: | |
| result = [] | |
| for i, p in enumerate(probs): | |
| if p > limit: | |
| if return_prob: | |
| result.append((i, p)) | |
| else: | |
| result.append(i) | |
| return result | |
| def dbc2sbc(s) -> str: | |
| rs = "" | |
| for char in s: | |
| code = ord(char) | |
| if code == 0x3000: | |
| code = 0x0020 | |
| else: | |
| code -= 0xfee0 | |
| if not (0x0021 <= code <= 0x7e): | |
| rs += char | |
| continue | |
| rs += chr(code) | |
| return rs | |
| def cut_chinese_sent(para: str) -> List[str]: | |
| """ | |
| Cut the Chinese sentences more precisely, reference to | |
| "https://blog.csdn.net/blmoistawinde/article/details/82379256". | |
| """ | |
| para = re.sub(r'([。!?\?])([^”’])', r'\1\n\2', para) | |
| para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para) | |
| para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para) | |
| para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para) | |
| para = para.rstrip() | |
| return para.split("\n") | |
| class UIEDecoder(nn.Module): | |
| keys_to_ignore_on_gpu = ["offset_mapping", "texts"] | |
| def predict( | |
| self, | |
| tokenizer: "PreTrainedTokenizer", | |
| texts: Union[List[str], str], | |
| schema: Optional[Any] = None, | |
| batch_size: int = 64, | |
| max_length: int = 512, | |
| split_sentence: bool = False, | |
| position_prob: float = 0.5, | |
| language: Optional[str] = "zh", | |
| show_progress_bar: bool = None, | |
| device: Optional[str] = None, | |
| ) -> List[Any]: | |
| self.eval() | |
| self.is_english = False if language.lower() in ["zh", "zh-cn", "chinese"] else True | |
| if schema is not None: | |
| self.set_schema(schema) | |
| if show_progress_bar is None: | |
| show_progress_bar = ( | |
| logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG | |
| ) | |
| # Cast an individual text to a list with length 1 | |
| if isinstance(texts, str) or not hasattr(texts, "__len__"): | |
| texts = [texts] | |
| if device is None: | |
| device = next(self.parameters()).device | |
| self.to(device) | |
| return self._multi_stage_predict( | |
| tokenizer, texts, batch_size, max_length, split_sentence, position_prob, show_progress_bar | |
| ) | |
| def set_schema(self, schema): | |
| if isinstance(schema, (dict, str)): | |
| schema = [schema] | |
| self._schema_tree = self._build_tree(schema) | |
| def _multi_stage_predict( | |
| self, | |
| tokenizer: "PreTrainedTokenizer", | |
| texts: List[str], | |
| batch_size: int = 64, | |
| max_length: int = 512, | |
| split_sentence: bool = False, | |
| position_prob: float = 0.5, | |
| show_progress_bar: bool = False, | |
| ) -> List[Any]: | |
| """ Traversal the schema tree and do multi-stage prediction. """ | |
| results = [{} for _ in range(len(texts))] | |
| if len(texts) < 1 or self._schema_tree is None: | |
| return results | |
| schema_list = self._schema_tree.children[:] | |
| while len(schema_list) > 0: | |
| node = schema_list.pop(0) | |
| examples = [] | |
| input_map = {} | |
| cnt = 0 | |
| idx = 0 | |
| if not node.prefix: | |
| for data in texts: | |
| examples.append({"text": data, "prompt": dbc2sbc(node.name)}) | |
| input_map[cnt] = [idx] | |
| idx += 1 | |
| cnt += 1 | |
| else: | |
| for pre, data in zip(node.prefix, texts): | |
| if len(pre) == 0: | |
| input_map[cnt] = [] | |
| else: | |
| for p in pre: | |
| if self.is_english: | |
| if re.search(r'\[.*?\]$', node.name): | |
| prompt_prefix = node.name[:node.name.find("[", 1)].strip() | |
| cls_options = re.search(r'\[.*?\]$', node.name).group() | |
| # Sentiment classification of xxx [positive, negative] | |
| prompt = prompt_prefix + p + " " + cls_options | |
| else: | |
| prompt = node.name + p | |
| else: | |
| prompt = p + node.name | |
| examples.append( | |
| { | |
| "text": data, | |
| "prompt": dbc2sbc(prompt) | |
| } | |
| ) | |
| input_map[cnt] = [i + idx for i in range(len(pre))] | |
| idx += len(pre) | |
| cnt += 1 | |
| result_list = self._single_stage_predict( | |
| tokenizer, examples, batch_size, max_length, split_sentence, position_prob, show_progress_bar | |
| ) if examples else [] | |
| if not node.parent_relations: | |
| relations = [[] for _ in range(len(texts))] | |
| for k, v in input_map.items(): | |
| for idx in v: | |
| if len(result_list[idx]) == 0: | |
| continue | |
| if node.name not in results[k].keys(): | |
| results[k][node.name] = result_list[idx] | |
| else: | |
| results[k][node.name].extend(result_list[idx]) | |
| if node.name in results[k].keys(): | |
| relations[k].extend(results[k][node.name]) | |
| else: | |
| relations = node.parent_relations | |
| for k, v in input_map.items(): | |
| for i in range(len(v)): | |
| if len(result_list[v[i]]) == 0: | |
| continue | |
| if "relations" not in relations[k][i].keys(): | |
| relations[k][i]["relations"] = {node.name: result_list[v[i]]} | |
| elif node.name not in relations[k][i]["relations"].keys(): | |
| relations[k][i]["relations"][node.name] = result_list[v[i]] | |
| else: | |
| relations[k][i]["relations"][node.name].extend(result_list[v[i]]) | |
| new_relations = [[] for _ in range(len(texts))] | |
| for i in range(len(relations)): | |
| for j in range(len(relations[i])): | |
| if "relations" in relations[i][j].keys() and node.name in relations[i][j]["relations"].keys(): | |
| for k in range(len(relations[i][j]["relations"][node.name])): | |
| new_relations[i].append(relations[i][j]["relations"][node.name][k]) | |
| relations = new_relations | |
| prefix = [[] for _ in range(len(texts))] | |
| for k, v in input_map.items(): | |
| for idx in v: | |
| for i in range(len(result_list[idx])): | |
| if self.is_english: | |
| prefix[k].append(" of " + result_list[idx][i]["text"]) | |
| else: | |
| prefix[k].append(result_list[idx][i]["text"] + "的") | |
| for child in node.children: | |
| child.prefix = prefix | |
| child.parent_relations = relations | |
| schema_list.append(child) | |
| return results | |
| def _convert_ids_to_results(self, examples, sentence_ids, probs): | |
| """ Convert ids to raw text in a single stage. """ | |
| results = [] | |
| for example, sentence_id, prob in zip(examples, sentence_ids, probs): | |
| if len(sentence_id) == 0: | |
| results.append([]) | |
| continue | |
| result_list = [] | |
| text = example["text"] | |
| prompt = example["prompt"] | |
| for i in range(len(sentence_id)): | |
| start, end = sentence_id[i] | |
| if start < 0 and end >= 0: | |
| continue | |
| if end < 0: | |
| start += len(prompt) + 1 | |
| end += len(prompt) + 1 | |
| result = {"text": prompt[start: end], "probability": float(prob[i])} | |
| else: | |
| result = {"text": text[start: end], "start": start, "end": end, "probability": float(prob[i])} | |
| result_list.append(result) | |
| results.append(result_list) | |
| return results | |
| def _auto_splitter(self, input_texts, max_text_len, split_sentence=False): | |
| """ | |
| Split the raw texts automatically for model inference. | |
| Args: | |
| input_texts (List[str]): input raw texts. | |
| max_text_len (int): cutting length. | |
| split_sentence (bool): If True, sentence-level split will be performed. | |
| return: | |
| short_input_texts (List[str]): the short input texts for model inference. | |
| input_mapping (dict): mapping between raw text and short input texts. | |
| """ | |
| input_mapping = {} | |
| short_input_texts = [] | |
| cnt_short = 0 | |
| for cnt_org, text in enumerate(input_texts): | |
| sens = cut_chinese_sent(text) if split_sentence else [text] | |
| for sen in sens: | |
| lens = len(sen) | |
| if lens <= max_text_len: | |
| short_input_texts.append(sen) | |
| if cnt_org in input_mapping: | |
| input_mapping[cnt_org].append(cnt_short) | |
| else: | |
| input_mapping[cnt_org] = [cnt_short] | |
| cnt_short += 1 | |
| else: | |
| temp_text_list = [sen[i: i + max_text_len] for i in range(0, lens, max_text_len)] | |
| short_input_texts.extend(temp_text_list) | |
| short_idx = cnt_short | |
| cnt_short += math.ceil(lens / max_text_len) | |
| temp_text_id = [short_idx + i for i in range(cnt_short - short_idx)] | |
| if cnt_org in input_mapping: | |
| input_mapping[cnt_org].extend(temp_text_id) | |
| else: | |
| input_mapping[cnt_org] = temp_text_id | |
| return short_input_texts, input_mapping | |
| def _single_stage_predict( | |
| self, | |
| tokenizer: "PreTrainedTokenizer", | |
| inputs: List[dict], | |
| batch_size: int = 64, | |
| max_length: int = 512, | |
| split_sentence: bool = False, | |
| position_prob: float = 0.5, | |
| show_progress_bar: bool = False, | |
| ) -> List[Any]: | |
| input_texts = [] | |
| prompts = [] | |
| for i in range(len(inputs)): | |
| input_texts.append(inputs[i]["text"]) | |
| prompts.append(inputs[i]["prompt"]) | |
| # max predict length should exclude the length of prompt and summary tokens | |
| max_predict_len = max_length - len(max(prompts)) - 3 | |
| short_input_texts, input_mapping = self._auto_splitter( | |
| input_texts, max_predict_len, split_sentence=split_sentence | |
| ) | |
| short_texts_prompts = [] | |
| for k, v in input_mapping.items(): | |
| short_texts_prompts.extend([prompts[k] for _ in range(len(v))]) | |
| short_inputs = [ | |
| { | |
| "text": short_input_texts[i], | |
| "prompt": short_texts_prompts[i] | |
| } | |
| for i in range(len(short_input_texts)) | |
| ] | |
| encoded_inputs = tokenizer( | |
| text=short_texts_prompts, | |
| text_pair=short_input_texts, | |
| stride=2, | |
| truncation=True, | |
| max_length=512, | |
| padding="max_length", | |
| add_special_tokens=True, | |
| return_offsets_mapping=True, | |
| return_tensors="np", | |
| ) | |
| offset_maps = encoded_inputs["offset_mapping"] | |
| start_prob_concat, end_prob_concat = [], [] | |
| batch_iterator = tqdm(range(0, len(short_input_texts), batch_size), desc="Batches", disable=not show_progress_bar) | |
| for batch_start in batch_iterator: | |
| batch = { | |
| key: | |
| np.array(value[batch_start: batch_start + batch_size], dtype="int64") | |
| for key, value in encoded_inputs.items() if key not in self.keys_to_ignore_on_gpu | |
| } | |
| for k, v in batch.items(): | |
| batch[k] = torch.tensor(v, device=self.device) | |
| outputs = self(**batch) | |
| start_prob, end_prob = outputs[0], outputs[1] | |
| if self.device != torch.device("cpu"): | |
| start_prob, end_prob = start_prob.cpu(), end_prob.cpu() | |
| start_prob_concat.append(start_prob.detach().numpy()) | |
| end_prob_concat.append(end_prob.detach().numpy()) | |
| start_prob_concat = np.concatenate(start_prob_concat) | |
| end_prob_concat = np.concatenate(end_prob_concat) | |
| start_ids_list = get_bool_ids_greater_than(start_prob_concat, limit=position_prob, return_prob=True) | |
| end_ids_list = get_bool_ids_greater_than(end_prob_concat, limit=position_prob, return_prob=True) | |
| input_ids = encoded_inputs["input_ids"].tolist() | |
| sentence_ids, probs = [], [] | |
| for start_ids, end_ids, ids, offset_map in zip(start_ids_list, end_ids_list, input_ids, offset_maps): | |
| span_list = get_span(start_ids, end_ids, with_prob=True) | |
| sentence_id, prob = get_id_and_prob(span_list, offset_map.tolist()) | |
| sentence_ids.append(sentence_id) | |
| probs.append(prob) | |
| results = self._convert_ids_to_results(short_inputs, sentence_ids, probs) | |
| results = self._auto_joiner(results, short_input_texts, input_mapping) | |
| return results | |
| def _auto_joiner(self, short_results, short_inputs, input_mapping): | |
| concat_results = [] | |
| is_cls_task = False | |
| for short_result in short_results: | |
| if not short_result: | |
| continue | |
| elif 'start' not in short_result[0].keys() and 'end' not in short_result[0].keys(): | |
| is_cls_task = True | |
| break | |
| else: | |
| break | |
| for k, vs in input_mapping.items(): | |
| single_results = [] | |
| if is_cls_task: | |
| cls_options = {} | |
| for v in vs: | |
| if len(short_results[v]) == 0: | |
| continue | |
| if short_results[v][0]['text'] in cls_options: | |
| cls_options[short_results[v][0]["text"]][0] += 1 | |
| cls_options[short_results[v][0]["text"]][1] += short_results[v][0]["probability"] | |
| else: | |
| cls_options[short_results[v][0]["text"]] = [1, short_results[v][0]["probability"]] | |
| if cls_options: | |
| cls_res, cls_info = max(cls_options.items(), key=lambda x: x[1]) | |
| concat_results.append( | |
| [ | |
| {"text": cls_res, "probability": cls_info[1] / cls_info[0]} | |
| ] | |
| ) | |
| else: | |
| concat_results.append([]) | |
| else: | |
| offset = 0 | |
| for v in vs: | |
| if v == 0: | |
| single_results = short_results[v] | |
| offset += len(short_inputs[v]) | |
| else: | |
| for i in range(len(short_results[v])): | |
| if "start" not in short_results[v][i] or 'end' not in short_results[v][i]: | |
| continue | |
| short_results[v][i]["start"] += offset | |
| short_results[v][i]["end"] += offset | |
| offset += len(short_inputs[v]) | |
| single_results.extend(short_results[v]) | |
| concat_results.append(single_results) | |
| return concat_results | |
| def _build_tree(cls, schema, name="root"): | |
| """ | |
| Build the schema tree. | |
| """ | |
| schema_tree = SchemaTree(name) | |
| for s in schema: | |
| if isinstance(s, str): | |
| schema_tree.add_child(SchemaTree(s)) | |
| elif isinstance(s, dict): | |
| for k, v in s.items(): | |
| if isinstance(v, str): | |
| child = [v] | |
| elif isinstance(v, list): | |
| child = v | |
| else: | |
| raise TypeError( | |
| f"Invalid schema, value for each key:value pairs should be list or string" | |
| f"but {type(v)} received") | |
| schema_tree.add_child(cls._build_tree(child, name=k)) | |
| else: | |
| raise TypeError(f"Invalid schema, element should be string or dict, but {type(s)} received") | |
| return schema_tree | |
| def start_multi_process_pool(self, target_devices: List[str] = None) -> Dict[ | |
| Literal["input", "output", "processes"], Any]: | |
| """启动多进程池,用多个独立进程进行预测 | |
| 如果要在多个GPU或CPU上进行预测,建议使用此方法,建议每个GPU只启动一个进程 | |
| Args: | |
| target_devices (List[str], optional): PyTorch target devices, e.g. ["cuda:0", "cuda:1", ...], | |
| ["npu:0", "npu:1", ...], or ["cpu", "cpu", "cpu", "cpu"]. If target_devices is None and CUDA/NPU | |
| is available, then all available CUDA/NPU devices will be used. If target_devices is None and | |
| CUDA/NPU is not available, then 4 CPU devices will be used. | |
| Returns: | |
| Dict[str, Any]: A dictionary with the target processes, an input queue, and an output queue. | |
| """ | |
| if target_devices is None: | |
| if torch.cuda.is_available(): | |
| target_devices = ["cuda:{}".format(i) for i in range(torch.cuda.device_count())] | |
| elif is_torch_npu_available(): | |
| target_devices = ["npu:{}".format(i) for i in range(torch.npu.device_count())] | |
| else: | |
| logger.info("CUDA/NPU is not available. Starting 4 CPU workers") | |
| target_devices = ["cpu"] * 4 | |
| logger.info("Start multi-process pool on devices: {}".format(", ".join(map(str, target_devices)))) | |
| self.to("cpu") | |
| self.share_memory() | |
| ctx = mp.get_context("spawn") | |
| input_queue = ctx.Queue() | |
| output_queue = ctx.Queue() | |
| processes = [] | |
| for device_id in target_devices: | |
| p = ctx.Process( | |
| target=UIEDecoder._predict_multi_process_worker, | |
| args=(device_id, self, input_queue, output_queue), | |
| daemon=True, | |
| ) | |
| p.start() | |
| processes.append(p) | |
| return {"input": input_queue, "output": output_queue, "processes": processes} | |
| def stop_multi_process_pool(pool: Dict[Literal["input", "output", "processes"], Any]) -> None: | |
| """ | |
| Stops all processes started with start_multi_process_pool. | |
| Args: | |
| pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list. | |
| Returns: | |
| None | |
| """ | |
| for p in pool["processes"]: | |
| p.terminate() | |
| for p in pool["processes"]: | |
| p.join() | |
| p.close() | |
| pool["input"].close() | |
| pool["output"].close() | |
| def predict_multi_process( | |
| self, | |
| tokenizer: "PreTrainedTokenizer", | |
| texts: List[str], | |
| pool: Dict[Literal["input", "output", "processes"], Any], | |
| batch_size: int = 64, | |
| max_length: int = 512, | |
| split_sentence: bool = False, | |
| language: Optional[str] = "zh", | |
| position_prob: float = 0.5, | |
| chunk_size: Optional[int] = None, | |
| ) -> List[Any]: | |
| if chunk_size is None: | |
| chunk_size = min(math.ceil(len(texts) / len(pool["processes"]) / 10), 5000) | |
| logger.debug(f"Chunk data into {math.ceil(len(texts) / chunk_size)} packages of size {chunk_size}") | |
| input_queue = pool["input"] | |
| last_chunk_id = 0 | |
| chunk = [] | |
| for text in texts: | |
| chunk.append(text) | |
| if len(chunk) >= chunk_size: | |
| input_queue.put( | |
| [last_chunk_id, tokenizer, batch_size, chunk, max_length, split_sentence, language, position_prob] | |
| ) | |
| last_chunk_id += 1 | |
| chunk = [] | |
| if len(chunk) > 0: | |
| input_queue.put( | |
| [last_chunk_id, tokenizer, batch_size, chunk, max_length, split_sentence, language, position_prob] | |
| ) | |
| last_chunk_id += 1 | |
| output_queue = pool["output"] | |
| results_list = sorted([output_queue.get() for _ in range(last_chunk_id)], key=lambda x: x[0]) | |
| return sum([result[1] for result in results_list], []) | |
| def _predict_multi_process_worker( | |
| target_device: str, model: "UIEDecoder", input_queue: Queue, results_queue: Queue | |
| ) -> None: | |
| """ | |
| Internal working process to predict in multi-process setup | |
| """ | |
| while True: | |
| try: | |
| chunk_id, tokenizer, batch_size, chunk, max_length, split_sentence, language, position_prob = ( | |
| input_queue.get() | |
| ) | |
| results = model.predict( | |
| tokenizer, | |
| chunk, | |
| batch_size=batch_size, | |
| max_length=max_length, | |
| split_sentence=split_sentence, | |
| language=language, | |
| show_progress_bar=False, | |
| device=target_device, | |
| ) | |
| results_queue.put([chunk_id, results]) | |
| except queue.Empty: | |
| break | |
| class SchemaTree(object): | |
| """ | |
| Implementation of SchemaTree | |
| """ | |
| def __init__(self, name='root', children=None): | |
| self.name = name | |
| self.children = [] | |
| self.prefix = None | |
| self.parent_relations = None | |
| if children is not None: | |
| for child in children: | |
| self.add_child(child) | |
| def __repr__(self): | |
| return self.name | |
| def add_child(self, node): | |
| assert isinstance( | |
| node, SchemaTree | |
| ), "The children of a node should be an instance of SchemaTree." | |
| self.children.append(node) | |