Skip to content

API Reference

Python API reference for grai.build.


Core Models

Core Pydantic models for grai.build.

This module defines the data structures for entities, relations, and properties that form the foundation of the declarative knowledge graph modeling.

Entity

Bases: BaseModel

Represents a node/entity in the knowledge graph.

Attributes:

Name Type Description
entity str

The entity type name (becomes the node label in Neo4j).

source Union[str, SourceConfig]

The data source - can be a string or SourceConfig object.

keys List[str]

List of property names that uniquely identify this entity.

properties List[Property]

List of properties/attributes for this entity.

description Optional[str]

Optional description of the entity.

metadata Dict[str, Any]

Optional additional metadata.

Source code in grai/core/models.py
class Entity(BaseModel):
    """
    Represents a node/entity in the knowledge graph.

    Attributes:
        entity: The entity type name (becomes the node label in Neo4j).
        source: The data source - can be a string or SourceConfig object.
        keys: List of property names that uniquely identify this entity.
        properties: List of properties/attributes for this entity.
        description: Optional description of the entity.
        metadata: Optional additional metadata.
    """

    entity: str = Field(..., min_length=1, description="Entity type name")
    source: Union[str, SourceConfig] = Field(..., description="Data source identifier or config")
    keys: List[str] = Field(..., min_length=1, description="Key properties for uniqueness")
    properties: List[Property] = Field(
        default_factory=list, description="Entity properties/attributes"
    )
    description: Optional[str] = Field(default=None, description="Entity description")
    metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")

    @field_validator("source")
    @classmethod
    def validate_source(cls, v: Union[str, SourceConfig]) -> SourceConfig:
        """Convert string source to SourceConfig for consistency."""
        if isinstance(v, str):
            return SourceConfig.from_string(v)
        return v

    @field_validator("entity")
    @classmethod
    def validate_entity_name(cls, v: str) -> str:
        """Validate that entity name is a valid identifier."""
        if not v.replace("_", "").isalnum():
            raise ValueError(f"Entity name must be alphanumeric with underscores: {v}")
        return v

    @field_validator("keys")
    @classmethod
    def validate_keys(cls, v: List[str]) -> List[str]:
        """Validate that all keys are non-empty."""
        if not all(k.strip() for k in v):
            raise ValueError("All keys must be non-empty strings")
        return v

    def get_property(self, name: str) -> Optional[Property]:
        """
        Get a property by name.

        Args:
            name: The property name to look up.

        Returns:
            The Property object if found, None otherwise.
        """
        return next((p for p in self.properties if p.name == name), None)

    def get_key_properties(self) -> List[Property]:
        """
        Get all properties that are designated as keys.

        Returns:
            List of Property objects that are keys.
        """
        return [p for p in self.properties if p.name in self.keys]

    def get_source_name(self) -> str:
        """
        Get the source name as a string.

        Returns:
            Source name string.
        """
        if isinstance(self.source, SourceConfig):
            return self.source.name
        return str(self.source)

    def get_source_config(self) -> SourceConfig:
        """
        Get the full source configuration.

        Returns:
            SourceConfig object.
        """
        if isinstance(self.source, SourceConfig):
            return self.source
        return SourceConfig.from_string(str(self.source))

get_key_properties()

Get all properties that are designated as keys.

Returns:

Type Description
List[Property]

List of Property objects that are keys.

Source code in grai/core/models.py
def get_key_properties(self) -> List[Property]:
    """
    Get all properties that are designated as keys.

    Returns:
        List of Property objects that are keys.
    """
    return [p for p in self.properties if p.name in self.keys]

get_property(name)

Get a property by name.

Parameters:

Name Type Description Default
name str

The property name to look up.

required

Returns:

Type Description
Optional[Property]

The Property object if found, None otherwise.

Source code in grai/core/models.py
def get_property(self, name: str) -> Optional[Property]:
    """
    Get a property by name.

    Args:
        name: The property name to look up.

    Returns:
        The Property object if found, None otherwise.
    """
    return next((p for p in self.properties if p.name == name), None)

get_source_config()

Get the full source configuration.

Returns:

Type Description
SourceConfig

SourceConfig object.

Source code in grai/core/models.py
def get_source_config(self) -> SourceConfig:
    """
    Get the full source configuration.

    Returns:
        SourceConfig object.
    """
    if isinstance(self.source, SourceConfig):
        return self.source
    return SourceConfig.from_string(str(self.source))

get_source_name()

Get the source name as a string.

Returns:

Type Description
str

Source name string.

Source code in grai/core/models.py
def get_source_name(self) -> str:
    """
    Get the source name as a string.

    Returns:
        Source name string.
    """
    if isinstance(self.source, SourceConfig):
        return self.source.name
    return str(self.source)

validate_entity_name(v) classmethod

Validate that entity name is a valid identifier.

Source code in grai/core/models.py
@field_validator("entity")
@classmethod
def validate_entity_name(cls, v: str) -> str:
    """Validate that entity name is a valid identifier."""
    if not v.replace("_", "").isalnum():
        raise ValueError(f"Entity name must be alphanumeric with underscores: {v}")
    return v

validate_keys(v) classmethod

Validate that all keys are non-empty.

Source code in grai/core/models.py
@field_validator("keys")
@classmethod
def validate_keys(cls, v: List[str]) -> List[str]:
    """Validate that all keys are non-empty."""
    if not all(k.strip() for k in v):
        raise ValueError("All keys must be non-empty strings")
    return v

validate_source(v) classmethod

Convert string source to SourceConfig for consistency.

Source code in grai/core/models.py
@field_validator("source")
@classmethod
def validate_source(cls, v: Union[str, SourceConfig]) -> SourceConfig:
    """Convert string source to SourceConfig for consistency."""
    if isinstance(v, str):
        return SourceConfig.from_string(v)
    return v

Project

Bases: BaseModel

Represents a complete grai.build project configuration.

Attributes:

Name Type Description
name str

The project name.

version str

The project version.

entities List[Entity]

List of entity definitions in the project.

relations List[Relation]

List of relation definitions in the project.

config Dict[str, Any]

Optional project-level configuration.

Source code in grai/core/models.py
class Project(BaseModel):
    """
    Represents a complete grai.build project configuration.

    Attributes:
        name: The project name.
        version: The project version.
        entities: List of entity definitions in the project.
        relations: List of relation definitions in the project.
        config: Optional project-level configuration.
    """

    name: str = Field(..., min_length=1, description="Project name")
    version: str = Field(default="1.0.0", description="Project version")
    entities: List[Entity] = Field(default_factory=list, description="Entity definitions")
    relations: List[Relation] = Field(default_factory=list, description="Relation definitions")
    config: Dict[str, Any] = Field(default_factory=dict, description="Project configuration")

    def get_entity(self, name: str) -> Optional[Entity]:
        """
        Get an entity by name.

        Args:
            name: The entity name to look up.

        Returns:
            The Entity object if found, None otherwise.
        """
        return next((e for e in self.entities if e.entity == name), None)

    def get_relation(self, name: str) -> Optional[Relation]:
        """
        Get a relation by name.

        Args:
            name: The relation name to look up.

        Returns:
            The Relation object if found, None otherwise.
        """
        return next((r for r in self.relations if r.relation == name), None)

get_entity(name)

Get an entity by name.

Parameters:

Name Type Description Default
name str

The entity name to look up.

required

Returns:

Type Description
Optional[Entity]

The Entity object if found, None otherwise.

Source code in grai/core/models.py
def get_entity(self, name: str) -> Optional[Entity]:
    """
    Get an entity by name.

    Args:
        name: The entity name to look up.

    Returns:
        The Entity object if found, None otherwise.
    """
    return next((e for e in self.entities if e.entity == name), None)

get_relation(name)

Get a relation by name.

Parameters:

Name Type Description Default
name str

The relation name to look up.

required

Returns:

Type Description
Optional[Relation]

The Relation object if found, None otherwise.

Source code in grai/core/models.py
def get_relation(self, name: str) -> Optional[Relation]:
    """
    Get a relation by name.

    Args:
        name: The relation name to look up.

    Returns:
        The Relation object if found, None otherwise.
    """
    return next((r for r in self.relations if r.relation == name), None)

Property

Bases: BaseModel

Represents a property (attribute) of an entity or relation.

Attributes:

Name Type Description
name str

The property name.

type PropertyType

The data type of the property.

required bool

Whether this property must have a value.

description Optional[str]

Optional description of the property.

default Optional[Any]

Optional default value for the property.

Source code in grai/core/models.py
class Property(BaseModel):
    """
    Represents a property (attribute) of an entity or relation.

    Attributes:
        name: The property name.
        type: The data type of the property.
        required: Whether this property must have a value.
        description: Optional description of the property.
        default: Optional default value for the property.
    """

    name: str = Field(..., min_length=1, description="Property name")
    type: PropertyType = Field(..., description="Property data type")
    required: bool = Field(default=False, description="Whether the property is required")
    description: Optional[str] = Field(default=None, description="Property description")
    default: Optional[Any] = Field(default=None, description="Default value")

    @field_validator("name")
    @classmethod
    def validate_name(cls, v: str) -> str:
        """Validate that property name is a valid identifier."""
        if not v.replace("_", "").isalnum():
            raise ValueError(f"Property name must be alphanumeric with underscores: {v}")
        return v

validate_name(v) classmethod

Validate that property name is a valid identifier.

Source code in grai/core/models.py
@field_validator("name")
@classmethod
def validate_name(cls, v: str) -> str:
    """Validate that property name is a valid identifier."""
    if not v.replace("_", "").isalnum():
        raise ValueError(f"Property name must be alphanumeric with underscores: {v}")
    return v

PropertyType

Bases: str, Enum

Supported property types for entity and relation attributes.

Source code in grai/core/models.py
class PropertyType(str, Enum):
    """Supported property types for entity and relation attributes."""

    STRING = "string"
    INTEGER = "integer"
    FLOAT = "float"
    BOOLEAN = "boolean"
    DATE = "date"
    DATETIME = "datetime"
    JSON = "json"

Relation

Bases: BaseModel

Represents an edge/relation in the knowledge graph.

Attributes:

Name Type Description
relation str

The relation type name (becomes the edge label in Neo4j).

from_entity str

The source entity type.

to_entity str

The target entity type.

source Union[str, SourceConfig]

The data source - can be a string or SourceConfig object.

mappings RelationMapping

How source and target entities connect via keys.

properties List[Property]

List of properties/attributes for this relation.

description Optional[str]

Optional description of the relation.

metadata Dict[str, Any]

Optional additional metadata.

Source code in grai/core/models.py
class Relation(BaseModel):
    """
    Represents an edge/relation in the knowledge graph.

    Attributes:
        relation: The relation type name (becomes the edge label in Neo4j).
        from_entity: The source entity type.
        to_entity: The target entity type.
        source: The data source - can be a string or SourceConfig object.
        mappings: How source and target entities connect via keys.
        properties: List of properties/attributes for this relation.
        description: Optional description of the relation.
        metadata: Optional additional metadata.
    """

    relation: str = Field(..., min_length=1, description="Relation type name")
    from_entity: str = Field(..., min_length=1, alias="from", description="Source entity type")
    to_entity: str = Field(..., min_length=1, alias="to", description="Target entity type")
    source: Union[str, SourceConfig] = Field(..., description="Data source identifier or config")
    mappings: RelationMapping = Field(..., description="Key mappings between entities")
    properties: List[Property] = Field(
        default_factory=list, description="Relation properties/attributes"
    )
    description: Optional[str] = Field(default=None, description="Relation description")
    metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")

    model_config = ConfigDict(populate_by_name=True)

    @field_validator("source")
    @classmethod
    def validate_source(cls, v: Union[str, SourceConfig]) -> SourceConfig:
        """Convert string source to SourceConfig for consistency."""
        if isinstance(v, str):
            return SourceConfig.from_string(v)
        return v

    @field_validator("relation")
    @classmethod
    def validate_relation_name(cls, v: str) -> str:
        """Validate that relation name is uppercase and valid."""
        if not v.isupper():
            raise ValueError(f"Relation name should be uppercase: {v}")
        if not v.replace("_", "").isalnum():
            raise ValueError(f"Relation name must be alphanumeric with underscores: {v}")
        return v

    def get_property(self, name: str) -> Optional[Property]:
        """
        Get a property by name.

        Args:
            name: The property name to look up.

        Returns:
            The Property object if found, None otherwise.
        """
        return next((p for p in self.properties if p.name == name), None)

    def get_source_name(self) -> str:
        """
        Get the source name as a string.

        Returns:
            Source name string.
        """
        if isinstance(self.source, SourceConfig):
            return self.source.name
        return str(self.source)

    def get_source_config(self) -> SourceConfig:
        """
        Get the full source configuration.

        Returns:
            SourceConfig object.
        """
        if isinstance(self.source, SourceConfig):
            return self.source
        return SourceConfig.from_string(str(self.source))

get_property(name)

Get a property by name.

Parameters:

Name Type Description Default
name str

The property name to look up.

required

Returns:

Type Description
Optional[Property]

The Property object if found, None otherwise.

Source code in grai/core/models.py
def get_property(self, name: str) -> Optional[Property]:
    """
    Get a property by name.

    Args:
        name: The property name to look up.

    Returns:
        The Property object if found, None otherwise.
    """
    return next((p for p in self.properties if p.name == name), None)

get_source_config()

Get the full source configuration.

Returns:

Type Description
SourceConfig

SourceConfig object.

Source code in grai/core/models.py
def get_source_config(self) -> SourceConfig:
    """
    Get the full source configuration.

    Returns:
        SourceConfig object.
    """
    if isinstance(self.source, SourceConfig):
        return self.source
    return SourceConfig.from_string(str(self.source))

get_source_name()

Get the source name as a string.

Returns:

Type Description
str

Source name string.

Source code in grai/core/models.py
def get_source_name(self) -> str:
    """
    Get the source name as a string.

    Returns:
        Source name string.
    """
    if isinstance(self.source, SourceConfig):
        return self.source.name
    return str(self.source)

validate_relation_name(v) classmethod

Validate that relation name is uppercase and valid.

Source code in grai/core/models.py
@field_validator("relation")
@classmethod
def validate_relation_name(cls, v: str) -> str:
    """Validate that relation name is uppercase and valid."""
    if not v.isupper():
        raise ValueError(f"Relation name should be uppercase: {v}")
    if not v.replace("_", "").isalnum():
        raise ValueError(f"Relation name must be alphanumeric with underscores: {v}")
    return v

validate_source(v) classmethod

Convert string source to SourceConfig for consistency.

Source code in grai/core/models.py
@field_validator("source")
@classmethod
def validate_source(cls, v: Union[str, SourceConfig]) -> SourceConfig:
    """Convert string source to SourceConfig for consistency."""
    if isinstance(v, str):
        return SourceConfig.from_string(v)
    return v

RelationMapping

Bases: BaseModel

Defines how entities are connected in a relation.

Attributes:

Name Type Description
from_key str

The key property name on the source entity.

to_key str

The key property name on the target entity.

Source code in grai/core/models.py
class RelationMapping(BaseModel):
    """
    Defines how entities are connected in a relation.

    Attributes:
        from_key: The key property name on the source entity.
        to_key: The key property name on the target entity.
    """

    from_key: str = Field(..., min_length=1, description="Source entity key property")
    to_key: str = Field(..., min_length=1, description="Target entity key property")

SourceConfig

Bases: BaseModel

Configuration for entity/relation data sources.

Supports both simple string format (backward compatible) and detailed config.

Attributes:

Name Type Description
name str

Source identifier (e.g., table name, file path, API endpoint).

type Optional[SourceType]

Type of data source.

connection Optional[str]

Optional connection string or identifier.

schema Optional[str]

Optional database schema name.

database Optional[str]

Optional database name.

format Optional[str]

Optional data format details.

metadata Dict[str, Any]

Optional additional source metadata.

Source code in grai/core/models.py
class SourceConfig(BaseModel):
    """
    Configuration for entity/relation data sources.

    Supports both simple string format (backward compatible) and detailed config.

    Attributes:
        name: Source identifier (e.g., table name, file path, API endpoint).
        type: Type of data source.
        connection: Optional connection string or identifier.
        schema: Optional database schema name.
        database: Optional database name.
        format: Optional data format details.
        metadata: Optional additional source metadata.
    """

    name: str = Field(..., min_length=1, description="Source identifier")
    type: Optional[SourceType] = Field(default=None, description="Source type")
    connection: Optional[str] = Field(default=None, description="Connection identifier")
    db_schema: Optional[str] = Field(default=None, description="Database schema")
    database: Optional[str] = Field(default=None, description="Database name")
    format: Optional[str] = Field(default=None, description="Data format details")
    metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")

    @classmethod
    def from_string(cls, source: str) -> "SourceConfig":
        """
        Create a SourceConfig from a simple string.

        Maintains backward compatibility with existing entity definitions.

        Args:
            source: Simple source string (e.g., "analytics.customers")

        Returns:
            SourceConfig with inferred type if possible.
        """
        # Try to infer type from string
        source_type = None
        if "." in source:
            # Likely a database table (schema.table)
            source_type = SourceType.TABLE
        elif source.endswith(".csv"):
            source_type = SourceType.CSV
        elif source.endswith(".json"):
            source_type = SourceType.JSON
        elif source.endswith(".parquet"):
            source_type = SourceType.PARQUET
        elif source.startswith("http://") or source.startswith("https://"):
            source_type = SourceType.API

        return cls(name=source, type=source_type)

from_string(source) classmethod

Create a SourceConfig from a simple string.

Maintains backward compatibility with existing entity definitions.

Parameters:

Name Type Description Default
source str

Simple source string (e.g., "analytics.customers")

required

Returns:

Type Description
SourceConfig

SourceConfig with inferred type if possible.

Source code in grai/core/models.py
@classmethod
def from_string(cls, source: str) -> "SourceConfig":
    """
    Create a SourceConfig from a simple string.

    Maintains backward compatibility with existing entity definitions.

    Args:
        source: Simple source string (e.g., "analytics.customers")

    Returns:
        SourceConfig with inferred type if possible.
    """
    # Try to infer type from string
    source_type = None
    if "." in source:
        # Likely a database table (schema.table)
        source_type = SourceType.TABLE
    elif source.endswith(".csv"):
        source_type = SourceType.CSV
    elif source.endswith(".json"):
        source_type = SourceType.JSON
    elif source.endswith(".parquet"):
        source_type = SourceType.PARQUET
    elif source.startswith("http://") or source.startswith("https://"):
        source_type = SourceType.API

    return cls(name=source, type=source_type)

SourceType

Bases: str, Enum

Supported source types for entities and relations.

Source code in grai/core/models.py
class SourceType(str, Enum):
    """Supported source types for entities and relations."""

    DATABASE = "database"
    TABLE = "table"
    CSV = "csv"
    JSON = "json"
    PARQUET = "parquet"
    API = "api"
    STREAM = "stream"
    OTHER = "other"

Entity

from grai.core.models import Entity, Property

entity = Entity(
    entity="customer",
    source="analytics.customers",
    keys=["customer_id"],
    properties=[
        Property(name="customer_id", type="string"),
        Property(name="email", type="string"),
    ]
)

Relation

from grai.core.models import Relation, RelationMappings

relation = Relation(
    relation="PURCHASED",
    from_entity="customer",
    to_entity="product",
    source="analytics.orders",
    mappings=RelationMappings(
        from_key="customer_id",
        to_key="product_id"
    )
)

Parser

Parser module for loading YAML definitions into Pydantic models.

ParserError

Bases: Exception

Base exception for parser errors.

Source code in grai/core/parser/yaml_parser.py
class ParserError(Exception):
    """Base exception for parser errors."""

    def __init__(self, message: str, file_path: Optional[Path] = None):
        """
        Initialize parser error.

        Args:
            message: Error message.
            file_path: Optional path to the file that caused the error.
        """
        self.file_path = file_path
        if file_path:
            super().__init__(f"{file_path}: {message}")
        else:
            super().__init__(message)

__init__(message, file_path=None)

Initialize parser error.

Parameters:

Name Type Description Default
message str

Error message.

required
file_path Optional[Path]

Optional path to the file that caused the error.

None
Source code in grai/core/parser/yaml_parser.py
def __init__(self, message: str, file_path: Optional[Path] = None):
    """
    Initialize parser error.

    Args:
        message: Error message.
        file_path: Optional path to the file that caused the error.
    """
    self.file_path = file_path
    if file_path:
        super().__init__(f"{file_path}: {message}")
    else:
        super().__init__(message)

ValidationParserError

Bases: ParserError

Exception raised when Pydantic validation fails.

Source code in grai/core/parser/yaml_parser.py
class ValidationParserError(ParserError):
    """Exception raised when Pydantic validation fails."""

    pass

YAMLParseError

Bases: ParserError

Exception raised when YAML parsing fails.

Source code in grai/core/parser/yaml_parser.py
class YAMLParseError(ParserError):
    """Exception raised when YAML parsing fails."""

    pass

load_entities_from_directory(directory)

Load all entity definitions from a directory.

Parameters:

Name Type Description Default
directory Union[str, Path]

Path to directory containing entity YAML files.

required

Returns:

Type Description
List[Entity]

List of Entity instances.

Raises:

Type Description
ParserError

If parsing any file fails.

Source code in grai/core/parser/yaml_parser.py
def load_entities_from_directory(directory: Union[str, Path]) -> List[Entity]:
    """
    Load all entity definitions from a directory.

    Args:
        directory: Path to directory containing entity YAML files.

    Returns:
        List of Entity instances.

    Raises:
        ParserError: If parsing any file fails.
    """
    path = Path(directory)
    if not path.exists():
        raise ParserError(f"Directory not found: {path}")

    yaml_files = discover_yaml_files(path)
    entities = []
    errors = []

    for file_path in yaml_files:
        try:
            entity = parse_entity_file(file_path)
            entities.append(entity)
        except ParserError as e:
            errors.append(str(e))

    if errors:
        error_msg = "\n".join(errors)
        raise ParserError(f"Failed to load entities:\n{error_msg}")

    return entities

load_project(project_root, entities_dir='entities', relations_dir='relations', manifest_file='grai.yml')

Load a complete grai.build project from a directory structure.

Expected structure

project_root/ ├── grai.yml ├── entities/ │ ├── entity1.yml │ └── entity2.yml └── relations/ └── relation1.yml

Parameters:

Name Type Description Default
project_root Union[str, Path]

Root directory of the project.

required
entities_dir str

Subdirectory containing entity definitions (default: "entities").

'entities'
relations_dir str

Subdirectory containing relation definitions (default: "relations").

'relations'
manifest_file str

Name of the project manifest file (default: "grai.yml").

'grai.yml'

Returns:

Type Description
Project

Project instance with all entities and relations loaded.

Raises:

Type Description
ParserError

If loading fails.

Source code in grai/core/parser/yaml_parser.py
def load_project(
    project_root: Union[str, Path],
    entities_dir: str = "entities",
    relations_dir: str = "relations",
    manifest_file: str = "grai.yml",
) -> Project:
    """
    Load a complete grai.build project from a directory structure.

    Expected structure:
        project_root/
        ├── grai.yml
        ├── entities/
        │   ├── entity1.yml
        │   └── entity2.yml
        └── relations/
            └── relation1.yml

    Args:
        project_root: Root directory of the project.
        entities_dir: Subdirectory containing entity definitions (default: "entities").
        relations_dir: Subdirectory containing relation definitions (default: "relations").
        manifest_file: Name of the project manifest file (default: "grai.yml").

    Returns:
        Project instance with all entities and relations loaded.

    Raises:
        ParserError: If loading fails.
    """
    root = Path(project_root)

    if not root.exists():
        raise ParserError(f"Project root not found: {root}")

    # Load manifest
    manifest_path = root / manifest_file
    try:
        manifest = load_project_manifest(manifest_path)
    except ParserError as e:
        raise ParserError(f"Failed to load project manifest: {e}")

    # Load entities
    entities_path = root / entities_dir
    entities = []
    if entities_path.exists():
        try:
            entities = load_entities_from_directory(entities_path)
        except ParserError as e:
            raise ParserError(f"Failed to load entities: {e}")

    # Load relations
    relations_path = root / relations_dir
    relations = []
    if relations_path.exists():
        try:
            relations = load_relations_from_directory(relations_path)
        except ParserError as e:
            raise ParserError(f"Failed to load relations: {e}")

    # Create project
    try:
        project = Project(
            name=manifest.get("name", "unnamed-project"),
            version=manifest.get("version", "1.0.0"),
            entities=entities,
            relations=relations,
            config=manifest.get("config", {}),
        )
        return project
    except ValidationError as e:
        raise ValidationParserError(f"Invalid project configuration: {e}")

load_project_manifest(file_path='grai.yml')

Load the project manifest (grai.yml).

Parameters:

Name Type Description Default
file_path Union[str, Path]

Path to the grai.yml file (default: "grai.yml").

'grai.yml'

Returns:

Type Description
Dict[str, Any]

Dictionary containing project configuration.

Raises:

Type Description
ParserError

If the file cannot be loaded.

Source code in grai/core/parser/yaml_parser.py
def load_project_manifest(file_path: Union[str, Path] = "grai.yml") -> Dict[str, Any]:
    """
    Load the project manifest (grai.yml).

    Args:
        file_path: Path to the grai.yml file (default: "grai.yml").

    Returns:
        Dictionary containing project configuration.

    Raises:
        ParserError: If the file cannot be loaded.
    """
    path = Path(file_path)
    return load_yaml_file(path)

load_relations_from_directory(directory)

Load all relation definitions from a directory.

Parameters:

Name Type Description Default
directory Union[str, Path]

Path to directory containing relation YAML files.

required

Returns:

Type Description
List[Relation]

List of Relation instances.

Raises:

Type Description
ParserError

If parsing any file fails.

Source code in grai/core/parser/yaml_parser.py
def load_relations_from_directory(directory: Union[str, Path]) -> List[Relation]:
    """
    Load all relation definitions from a directory.

    Args:
        directory: Path to directory containing relation YAML files.

    Returns:
        List of Relation instances.

    Raises:
        ParserError: If parsing any file fails.
    """
    path = Path(directory)
    if not path.exists():
        raise ParserError(f"Directory not found: {path}")

    yaml_files = discover_yaml_files(path)
    relations = []
    errors = []

    for file_path in yaml_files:
        try:
            relation = parse_relation_file(file_path)
            relations.append(relation)
        except ParserError as e:
            errors.append(str(e))

    if errors:
        error_msg = "\n".join(errors)
        raise ParserError(f"Failed to load relations:\n{error_msg}")

    return relations

parse_entity_file(file_path)

Parse an entity definition from a YAML file.

Parameters:

Name Type Description Default
file_path Union[str, Path]

Path to the entity YAML file.

required

Returns:

Type Description
Entity

Entity instance.

Raises:

Type Description
ParserError

If parsing fails.

Source code in grai/core/parser/yaml_parser.py
def parse_entity_file(file_path: Union[str, Path]) -> Entity:
    """
    Parse an entity definition from a YAML file.

    Args:
        file_path: Path to the entity YAML file.

    Returns:
        Entity instance.

    Raises:
        ParserError: If parsing fails.
    """
    path = Path(file_path)
    data = load_yaml_file(path)
    return parse_entity(data, path)

parse_relation_file(file_path)

Parse a relation definition from a YAML file.

Parameters:

Name Type Description Default
file_path Union[str, Path]

Path to the relation YAML file.

required

Returns:

Type Description
Relation

Relation instance.

Raises:

Type Description
ParserError

If parsing fails.

Source code in grai/core/parser/yaml_parser.py
def parse_relation_file(file_path: Union[str, Path]) -> Relation:
    """
    Parse a relation definition from a YAML file.

    Args:
        file_path: Path to the relation YAML file.

    Returns:
        Relation instance.

    Raises:
        ParserError: If parsing fails.
    """
    path = Path(file_path)
    data = load_yaml_file(path)
    return parse_relation(data, path)

YAML Parser

from grai.core.parser.yaml_parser import YAMLParser

parser = YAMLParser()

# Parse entity file
entity = parser.parse_entity_file("entities/customer.yml")

# Parse relation file
relation = parser.parse_relation_file("relations/purchased.yml")

# Parse entire project
project = parser.parse_project(".")

Validator

Validator module for checking project consistency and correctness.

EntityReferenceError

Bases: ValidationError

Exception raised when an entity reference is invalid.

Source code in grai/core/validator/validator.py
class EntityReferenceError(ValidationError):
    """Exception raised when an entity reference is invalid."""

    pass

KeyMappingError

Bases: ValidationError

Exception raised when a key mapping is invalid.

Source code in grai/core/validator/validator.py
class KeyMappingError(ValidationError):
    """Exception raised when a key mapping is invalid."""

    pass

ValidationError

Bases: Exception

Base exception for validation errors.

Source code in grai/core/validator/validator.py
class ValidationError(Exception):
    """Base exception for validation errors."""

    def __init__(self, message: str, context: Optional[str] = None):
        """
        Initialize validation error.

        Args:
            message: Error message.
            context: Optional context (e.g., entity or relation name).
        """
        self.context = context
        if context:
            super().__init__(f"{context}: {message}")
        else:
            super().__init__(message)

__init__(message, context=None)

Initialize validation error.

Parameters:

Name Type Description Default
message str

Error message.

required
context Optional[str]

Optional context (e.g., entity or relation name).

None
Source code in grai/core/validator/validator.py
def __init__(self, message: str, context: Optional[str] = None):
    """
    Initialize validation error.

    Args:
        message: Error message.
        context: Optional context (e.g., entity or relation name).
    """
    self.context = context
    if context:
        super().__init__(f"{context}: {message}")
    else:
        super().__init__(message)

ValidationResult

Result of a validation operation.

Attributes:

Name Type Description
valid bool

Whether validation passed.

errors List[str]

List of validation errors.

warnings List[str]

List of validation warnings.

Source code in grai/core/validator/validator.py
class ValidationResult:
    """
    Result of a validation operation.

    Attributes:
        valid: Whether validation passed.
        errors: List of validation errors.
        warnings: List of validation warnings.
    """

    def __init__(self):
        """Initialize validation result."""
        self.valid: bool = True
        self.errors: List[str] = []
        self.warnings: List[str] = []

    def add_error(self, message: str, context: Optional[str] = None) -> None:
        """
        Add an error to the result.

        Args:
            message: Error message.
            context: Optional context.
        """
        self.valid = False
        if context:
            self.errors.append(f"{context}: {message}")
        else:
            self.errors.append(message)

    def add_warning(self, message: str, context: Optional[str] = None) -> None:
        """
        Add a warning to the result.

        Args:
            message: Warning message.
            context: Optional context.
        """
        if context:
            self.warnings.append(f"{context}: {message}")
        else:
            self.warnings.append(message)

    def __bool__(self) -> bool:
        """Return whether validation passed."""
        return self.valid

    def __str__(self) -> str:
        """Return string representation of validation result."""
        lines = []
        if self.errors:
            lines.append(f"Errors ({len(self.errors)}):")
            for error in self.errors:
                lines.append(f"  • {error}")
        if self.warnings:
            lines.append(f"Warnings ({len(self.warnings)}):")
            for warning in self.warnings:
                lines.append(f"  • {warning}")
        if not self.errors and not self.warnings:
            lines.append("✅ Validation passed with no errors or warnings")
        return "\n".join(lines)

__bool__()

Return whether validation passed.

Source code in grai/core/validator/validator.py
def __bool__(self) -> bool:
    """Return whether validation passed."""
    return self.valid

__init__()

Initialize validation result.

Source code in grai/core/validator/validator.py
def __init__(self):
    """Initialize validation result."""
    self.valid: bool = True
    self.errors: List[str] = []
    self.warnings: List[str] = []

__str__()

Return string representation of validation result.

Source code in grai/core/validator/validator.py
def __str__(self) -> str:
    """Return string representation of validation result."""
    lines = []
    if self.errors:
        lines.append(f"Errors ({len(self.errors)}):")
        for error in self.errors:
            lines.append(f"  • {error}")
    if self.warnings:
        lines.append(f"Warnings ({len(self.warnings)}):")
        for warning in self.warnings:
            lines.append(f"  • {warning}")
    if not self.errors and not self.warnings:
        lines.append("✅ Validation passed with no errors or warnings")
    return "\n".join(lines)

add_error(message, context=None)

Add an error to the result.

Parameters:

Name Type Description Default
message str

Error message.

required
context Optional[str]

Optional context.

None
Source code in grai/core/validator/validator.py
def add_error(self, message: str, context: Optional[str] = None) -> None:
    """
    Add an error to the result.

    Args:
        message: Error message.
        context: Optional context.
    """
    self.valid = False
    if context:
        self.errors.append(f"{context}: {message}")
    else:
        self.errors.append(message)

add_warning(message, context=None)

Add a warning to the result.

Parameters:

Name Type Description Default
message str

Warning message.

required
context Optional[str]

Optional context.

None
Source code in grai/core/validator/validator.py
def add_warning(self, message: str, context: Optional[str] = None) -> None:
    """
    Add a warning to the result.

    Args:
        message: Warning message.
        context: Optional context.
    """
    if context:
        self.warnings.append(f"{context}: {message}")
    else:
        self.warnings.append(message)

validate_entity(entity)

Validate a single entity.

Parameters:

Name Type Description Default
entity Entity

The entity to validate.

required

Returns:

Type Description
ValidationResult

ValidationResult with any errors or warnings.

Source code in grai/core/validator/validator.py
def validate_entity(entity: Entity) -> ValidationResult:
    """
    Validate a single entity.

    Args:
        entity: The entity to validate.

    Returns:
        ValidationResult with any errors or warnings.
    """
    result = ValidationResult()

    # Check for empty keys
    if not entity.keys:
        result.add_error("Entity must have at least one key", context=f"Entity {entity.entity}")

    # Check for duplicate property names
    property_names = [p.name for p in entity.properties]
    duplicates = set([name for name in property_names if property_names.count(name) > 1])
    if duplicates:
        result.add_error(
            f"Duplicate property names: {', '.join(duplicates)}",
            context=f"Entity {entity.entity}",
        )

    # Check that keys have properties
    property_name_set = set(property_names)
    for key in entity.keys:
        if key not in property_name_set:
            result.add_warning(
                f"Key '{key}' does not have a corresponding property definition",
                context=f"Entity {entity.entity}",
            )

    # Check source
    source_name = entity.get_source_name()
    if not source_name or not source_name.strip():
        result.add_error("Entity has empty or missing source", context=f"Entity {entity.entity}")

    return result

validate_entity_references(relations, entity_index, result=None)

Validate that all entity references in relations exist.

Parameters:

Name Type Description Default
relations List[Relation]

List of relations to validate.

required
entity_index Dict[str, Entity]

Index of entities by name.

required
result Optional[ValidationResult]

Optional existing ValidationResult to add to.

None

Returns:

Type Description
ValidationResult

ValidationResult with any errors found.

Source code in grai/core/validator/validator.py
def validate_entity_references(
    relations: List[Relation],
    entity_index: Dict[str, Entity],
    result: Optional[ValidationResult] = None,
) -> ValidationResult:
    """
    Validate that all entity references in relations exist.

    Args:
        relations: List of relations to validate.
        entity_index: Index of entities by name.
        result: Optional existing ValidationResult to add to.

    Returns:
        ValidationResult with any errors found.
    """
    if result is None:
        result = ValidationResult()

    for relation in relations:
        # Check from_entity exists
        if relation.from_entity not in entity_index:
            result.add_error(
                f"References non-existent entity '{relation.from_entity}'",
                context=f"Relation {relation.relation}",
            )

        # Check to_entity exists
        if relation.to_entity not in entity_index:
            result.add_error(
                f"References non-existent entity '{relation.to_entity}'",
                context=f"Relation {relation.relation}",
            )

    return result

validate_key_mappings(relations, entity_index, result=None)

Validate that key mappings in relations reference valid entity keys.

Parameters:

Name Type Description Default
relations List[Relation]

List of relations to validate.

required
entity_index Dict[str, Entity]

Index of entities by name.

required
result Optional[ValidationResult]

Optional existing ValidationResult to add to.

None

Returns:

Type Description
ValidationResult

ValidationResult with any errors found.

Source code in grai/core/validator/validator.py
def validate_key_mappings(
    relations: List[Relation],
    entity_index: Dict[str, Entity],
    result: Optional[ValidationResult] = None,
) -> ValidationResult:
    """
    Validate that key mappings in relations reference valid entity keys.

    Args:
        relations: List of relations to validate.
        entity_index: Index of entities by name.
        result: Optional existing ValidationResult to add to.

    Returns:
        ValidationResult with any errors found.
    """
    if result is None:
        result = ValidationResult()

    for relation in relations:
        # Skip if entities don't exist (caught by entity_references validation)
        if relation.from_entity not in entity_index or relation.to_entity not in entity_index:
            continue

        from_entity = entity_index[relation.from_entity]
        to_entity = entity_index[relation.to_entity]

        # Check from_key exists in from_entity
        if relation.mappings.from_key not in from_entity.keys:
            result.add_error(
                f"Key '{relation.mappings.from_key}' not found in entity '{relation.from_entity}' keys: {from_entity.keys}",
                context=f"Relation {relation.relation}",
            )

        # Check to_key exists in to_entity
        if relation.mappings.to_key not in to_entity.keys:
            result.add_error(
                f"Key '{relation.mappings.to_key}' not found in entity '{relation.to_entity}' keys: {to_entity.keys}",
                context=f"Relation {relation.relation}",
            )

    return result

validate_project(project, strict=True)

Validate an entire project for consistency and correctness.

Parameters:

Name Type Description Default
project Project

The project to validate.

required
strict bool

If True, warnings will be treated as errors.

True

Returns:

Type Description
ValidationResult

ValidationResult with all errors and warnings.

Raises:

Type Description
ValidationError

If strict=True and validation fails.

Source code in grai/core/validator/validator.py
def validate_project(
    project: Project,
    strict: bool = True,
) -> ValidationResult:
    """
    Validate an entire project for consistency and correctness.

    Args:
        project: The project to validate.
        strict: If True, warnings will be treated as errors.

    Returns:
        ValidationResult with all errors and warnings.

    Raises:
        ValidationError: If strict=True and validation fails.
    """
    result = ValidationResult()

    # Build entity index for quick lookups
    entity_index = build_entity_index(project.entities)

    # Run all validations
    validate_unique_names(project.entities, project.relations, result)
    validate_sources(project.entities, project.relations, result)
    validate_entity_references(project.relations, entity_index, result)
    validate_key_mappings(project.relations, entity_index, result)
    validate_property_definitions(project.entities, project.relations, result)

    # Note: We don't check for circular dependencies because they are
    # normal and expected in graph structures (e.g., bidirectional relationships,
    # social networks, organizational hierarchies, etc.)

    # In strict mode, treat warnings as errors
    if strict and result.warnings:
        for warning in result.warnings:
            result.add_error(f"[Strict mode] {warning}")
        result.warnings.clear()

    return result

validate_relation(relation, entity_index=None)

Validate a single relation.

Parameters:

Name Type Description Default
relation Relation

The relation to validate.

required
entity_index Optional[Dict[str, Entity]]

Optional index of entities for reference checking.

None

Returns:

Type Description
ValidationResult

ValidationResult with any errors or warnings.

Source code in grai/core/validator/validator.py
def validate_relation(
    relation: Relation,
    entity_index: Optional[Dict[str, Entity]] = None,
) -> ValidationResult:
    """
    Validate a single relation.

    Args:
        relation: The relation to validate.
        entity_index: Optional index of entities for reference checking.

    Returns:
        ValidationResult with any errors or warnings.
    """
    result = ValidationResult()

    # Check for duplicate property names
    property_names = [p.name for p in relation.properties]
    duplicates = set([name for name in property_names if property_names.count(name) > 1])
    if duplicates:
        result.add_error(
            f"Duplicate property names: {', '.join(duplicates)}",
            context=f"Relation {relation.relation}",
        )

    # Check source
    source_name = relation.get_source_name()
    if not source_name or not source_name.strip():
        result.add_error(
            "Relation has empty or missing source",
            context=f"Relation {relation.relation}",
        )

    # If entity_index provided, check references
    if entity_index is not None:
        if relation.from_entity not in entity_index:
            result.add_error(
                f"References non-existent entity '{relation.from_entity}'",
                context=f"Relation {relation.relation}",
            )

        if relation.to_entity not in entity_index:
            result.add_error(
                f"References non-existent entity '{relation.to_entity}'",
                context=f"Relation {relation.relation}",
            )

        # Check key mappings if entities exist
        if relation.from_entity in entity_index and relation.to_entity in entity_index:
            from_entity = entity_index[relation.from_entity]
            to_entity = entity_index[relation.to_entity]

            if relation.mappings.from_key not in from_entity.keys:
                result.add_error(
                    f"Key '{relation.mappings.from_key}' not found in entity '{relation.from_entity}' keys: {from_entity.keys}",
                    context=f"Relation {relation.relation}",
                )

            if relation.mappings.to_key not in to_entity.keys:
                result.add_error(
                    f"Key '{relation.mappings.to_key}' not found in entity '{relation.to_entity}' keys: {to_entity.keys}",
                    context=f"Relation {relation.relation}",
                )

    return result

Schema Validator

from grai.core.validator.validator import Validator
from grai.core.models import Project

validator = Validator()
project = Project(...)

# Validate project
result = validator.validate(project)

if result.is_valid:
    print("✅ Validation passed")
else:
    for error in result.errors:
        print(f"❌ {error}")

Compiler

Compiler module for generating database queries from models.

CompilerError

Bases: Exception

Base exception for compiler errors.

Source code in grai/core/compiler/cypher_compiler.py
class CompilerError(Exception):
    """Base exception for compiler errors."""

    pass

compile_and_write(project, output_dir='target/neo4j', filename='compiled.cypher', include_header=True, include_constraints=True)

Compile a project and write the Cypher script to a file.

Parameters:

Name Type Description Default
project Project

Project to compile.

required
output_dir Union[str, Path]

Directory to write the output file.

'target/neo4j'
filename str

Name of the output file.

'compiled.cypher'
include_header bool

If True, include script header.

True
include_constraints bool

If True, include constraint statements.

True

Returns:

Type Description
Path

Path to the written file.

Raises:

Type Description
CompilerError

If compilation or writing fails.

Source code in grai/core/compiler/cypher_compiler.py
def compile_and_write(
    project: Project,
    output_dir: Union[str, Path] = "target/neo4j",
    filename: str = "compiled.cypher",
    include_header: bool = True,
    include_constraints: bool = True,
) -> Path:
    """
    Compile a project and write the Cypher script to a file.

    Args:
        project: Project to compile.
        output_dir: Directory to write the output file.
        filename: Name of the output file.
        include_header: If True, include script header.
        include_constraints: If True, include constraint statements.

    Returns:
        Path to the written file.

    Raises:
        CompilerError: If compilation or writing fails.
    """
    # Compile the project
    cypher = compile_project(
        project,
        include_header=include_header,
        include_constraints=include_constraints,
    )

    # Write to file
    output_path = Path(output_dir) / filename
    return write_cypher_file(cypher, output_path)

compile_entity(entity)

Compile an entity into a Cypher MERGE statement.

Parameters:

Name Type Description Default
entity Entity

Entity model to compile.

required

Returns:

Type Description
str

Cypher MERGE statement for creating/updating nodes.

Example
// Create customer nodes
MERGE (n:customer {customer_id: row.customer_id})
SET n.name = row.name,
    n.email = row.email;
Source code in grai/core/compiler/cypher_compiler.py
def compile_entity(entity: Entity) -> str:
    """
    Compile an entity into a Cypher MERGE statement.

    Args:
        entity: Entity model to compile.

    Returns:
        Cypher MERGE statement for creating/updating nodes.

    Example:
        ```cypher
        // Create customer nodes
        MERGE (n:customer {customer_id: row.customer_id})
        SET n.name = row.name,
            n.email = row.email;
        ```
    """
    # Build the MERGE clause with key properties
    key_conditions = []
    for key in entity.keys:
        placeholder = get_cypher_property_placeholder(key)
        key_conditions.append(f"{key}: {placeholder}")

    merge_clause = f"MERGE (n:{entity.entity} {{{', '.join(key_conditions)}}})"

    # Build the SET clause for non-key properties
    non_key_properties = [p for p in entity.properties if p.name not in entity.keys]

    if non_key_properties:
        set_clause = compile_property_set(non_key_properties)
        cypher = f"{merge_clause}\n{set_clause};"
    else:
        cypher = f"{merge_clause};"

    # Add comment header
    header = f"// Create {entity.entity} nodes"
    return f"{header}\n{cypher}"

compile_project(project, include_header=True, include_constraints=True)

Compile a complete project into a Cypher script.

Parameters:

Name Type Description Default
project Project

Project model to compile.

required
include_header bool

If True, include script header with project info.

True
include_constraints bool

If True, include constraint creation statements.

True

Returns:

Type Description
str

Complete Cypher script as a string.

Source code in grai/core/compiler/cypher_compiler.py
def compile_project(
    project: Project,
    include_header: bool = True,
    include_constraints: bool = True,
) -> str:
    """
    Compile a complete project into a Cypher script.

    Args:
        project: Project model to compile.
        include_header: If True, include script header with project info.
        include_constraints: If True, include constraint creation statements.

    Returns:
        Complete Cypher script as a string.
    """
    lines = []

    # Add header
    if include_header:
        lines.append(f"// Generated Cypher script for project: {project.name}")
        lines.append(f"// Version: {project.version}")
        lines.append("// Generated by grai.build")
        lines.append("")

    # Add constraints (unique constraints on entity keys)
    if include_constraints and project.entities:
        lines.append(
            "// ============================================================================="
        )
        lines.append("// CONSTRAINTS")
        lines.append(
            "// ============================================================================="
        )
        lines.append("")

        for entity in project.entities:
            for key in entity.keys:
                constraint_name = f"constraint_{entity.entity}_{key}"
                constraint = (
                    f"CREATE CONSTRAINT {constraint_name} IF NOT EXISTS "
                    f"FOR (n:{entity.entity}) REQUIRE n.{key} IS UNIQUE;"
                )
                lines.append(constraint)

        lines.append("")

    # Add entities
    if project.entities:
        lines.append(
            "// ============================================================================="
        )
        lines.append("// ENTITIES (NODES)")
        lines.append(
            "// ============================================================================="
        )
        lines.append("")

        for entity in project.entities:
            lines.append(compile_entity(entity))
            lines.append("")

    # Add relations
    if project.relations:
        lines.append(
            "// ============================================================================="
        )
        lines.append("// RELATIONS (EDGES)")
        lines.append(
            "// ============================================================================="
        )
        lines.append("")

        for relation in project.relations:
            lines.append(compile_relation(relation))
            lines.append("")

    return "\n".join(lines).rstrip() + "\n"

compile_relation(relation)

Compile a relation into Cypher MATCH...MERGE statements.

Parameters:

Name Type Description Default
relation Relation

Relation model to compile.

required

Returns:

Type Description
str

Cypher statements for creating relationships.

Example
// Create PURCHASED relationships
MATCH (from:customer {customer_id: row.customer_id})
MATCH (to:product {product_id: row.product_id})
MERGE (from)-[r:PURCHASED]->(to)
SET r.order_id = row.order_id,
    r.order_date = row.order_date;
Source code in grai/core/compiler/cypher_compiler.py
def compile_relation(relation: Relation) -> str:
    """
    Compile a relation into Cypher MATCH...MERGE statements.

    Args:
        relation: Relation model to compile.

    Returns:
        Cypher statements for creating relationships.

    Example:
        ```cypher
        // Create PURCHASED relationships
        MATCH (from:customer {customer_id: row.customer_id})
        MATCH (to:product {product_id: row.product_id})
        MERGE (from)-[r:PURCHASED]->(to)
        SET r.order_id = row.order_id,
            r.order_date = row.order_date;
        ```
    """
    # Build MATCH clause for source node
    from_key = relation.mappings.from_key
    from_placeholder = get_cypher_property_placeholder(from_key)
    match_from = f"MATCH (from:{relation.from_entity} {{{from_key}: {from_placeholder}}})"

    # Build MATCH clause for target node
    to_key = relation.mappings.to_key
    to_placeholder = get_cypher_property_placeholder(to_key)
    match_to = f"MATCH (to:{relation.to_entity} {{{to_key}: {to_placeholder}}})"

    # Build MERGE clause for relationship
    merge_rel = f"MERGE (from)-[r:{relation.relation}]->(to)"

    # Build SET clause for relationship properties
    if relation.properties:
        set_clause = compile_property_set(relation.properties, node_var="r")
        cypher = f"{match_from}\n{match_to}\n{merge_rel}\n{set_clause};"
    else:
        cypher = f"{match_from}\n{match_to}\n{merge_rel};"

    # Add comment header
    header = f"// Create {relation.relation} relationships"
    return f"{header}\n{cypher}"

compile_schema_only(project)

Compile only the schema (constraints and indexes) without data loading.

Parameters:

Name Type Description Default
project Project

Project to compile schema for.

required

Returns:

Type Description
str

Cypher script with only schema definitions.

Source code in grai/core/compiler/cypher_compiler.py
def compile_schema_only(project: Project) -> str:
    """
    Compile only the schema (constraints and indexes) without data loading.

    Args:
        project: Project to compile schema for.

    Returns:
        Cypher script with only schema definitions.
    """
    lines = [
        f"// Schema definition for project: {project.name}",
        f"// Version: {project.version}",
        "",
        "// =============================================================================",
        "// CONSTRAINTS",
        "// =============================================================================",
        "",
    ]

    for entity in project.entities:
        for key in entity.keys:
            constraint_name = f"constraint_{entity.entity}_{key}"
            constraint = (
                f"CREATE CONSTRAINT {constraint_name} IF NOT EXISTS "
                f"FOR (n:{entity.entity}) REQUIRE n.{key} IS UNIQUE;"
            )
            lines.append(constraint)

    lines.append("")
    lines.append("// =============================================================================")
    lines.append("// INDEXES")
    lines.append("// =============================================================================")
    lines.append("")

    # Create indexes on non-key properties that might be used in queries
    for entity in project.entities:
        for prop in entity.properties:
            if prop.name not in entity.keys:
                index_name = f"index_{entity.entity}_{prop.name}"
                index = (
                    f"CREATE INDEX {index_name} IF NOT EXISTS "
                    f"FOR (n:{entity.entity}) ON (n.{prop.name});"
                )
                lines.append(index)

    return "\n".join(lines) + "\n"

generate_load_csv_statements(project, data_dir='data')

Generate LOAD CSV statements for entities and relations.

Parameters:

Name Type Description Default
project Project

Project to generate load statements for.

required
data_dir str

Directory containing CSV files.

'data'

Returns:

Type Description
Dict[str, str]

Dictionary mapping entity/relation names to LOAD CSV statements.

Source code in grai/core/compiler/cypher_compiler.py
def generate_load_csv_statements(
    project: Project,
    data_dir: str = "data",
) -> Dict[str, str]:
    """
    Generate LOAD CSV statements for entities and relations.

    Args:
        project: Project to generate load statements for.
        data_dir: Directory containing CSV files.

    Returns:
        Dictionary mapping entity/relation names to LOAD CSV statements.
    """
    statements = {}

    # Generate entity load statements
    for entity in project.entities:
        source_name = entity.get_source_name()
        csv_file = f"{data_dir}/{source_name.replace('.', '_')}.csv"

        # Build LOAD CSV statement
        merge_keys = {key: f"row.{key}" for key in entity.keys}
        key_clause = ", ".join([f"{k}: {v}" for k, v in merge_keys.items()])

        lines = [
            f"// Load {entity.entity} from CSV",
            f"LOAD CSV WITH HEADERS FROM 'file:///{csv_file}' AS row",
            f"MERGE (n:{entity.entity} {{{key_clause}}})",
        ]

        # Add SET clause for other properties
        non_key_props = [p for p in entity.properties if p.name not in entity.keys]
        if non_key_props:
            set_clauses = [f"n.{p.name} = row.{p.name}" for p in non_key_props]
            lines.append("SET " + ",\n    ".join(set_clauses))

        lines.append(";")
        statements[entity.entity] = "\n".join(lines)

    # Generate relation load statements
    for relation in project.relations:
        source_name = relation.get_source_name()
        csv_file = f"{data_dir}/{source_name.replace('.', '_')}.csv"

        lines = [
            f"// Load {relation.relation} from CSV",
            f"LOAD CSV WITH HEADERS FROM 'file:///{csv_file}' AS row",
            f"MATCH (from:{relation.from_entity} {{{relation.mappings.from_key}: row.{relation.mappings.from_key}}})",
            f"MATCH (to:{relation.to_entity} {{{relation.mappings.to_key}: row.{relation.mappings.to_key}}})",
            f"MERGE (from)-[r:{relation.relation}]->(to)",
        ]

        # Add SET clause for relationship properties
        if relation.properties:
            set_clauses = [f"r.{p.name} = row.{p.name}" for p in relation.properties]
            lines.append("SET " + ",\n    ".join(set_clauses))

        lines.append(";")
        statements[relation.relation] = "\n".join(lines)

    return statements

write_cypher_file(cypher, output_path, create_dirs=True)

Write Cypher script to a file.

Parameters:

Name Type Description Default
cypher str

Cypher script content.

required
output_path Union[str, Path]

Path to write the file.

required
create_dirs bool

If True, create parent directories if they don't exist.

True

Returns:

Type Description
Path

Path to the written file.

Raises:

Type Description
CompilerError

If file cannot be written.

Source code in grai/core/compiler/cypher_compiler.py
def write_cypher_file(
    cypher: str,
    output_path: Union[str, Path],
    create_dirs: bool = True,
) -> Path:
    """
    Write Cypher script to a file.

    Args:
        cypher: Cypher script content.
        output_path: Path to write the file.
        create_dirs: If True, create parent directories if they don't exist.

    Returns:
        Path to the written file.

    Raises:
        CompilerError: If file cannot be written.
    """
    path = Path(output_path)

    try:
        if create_dirs:
            path.parent.mkdir(parents=True, exist_ok=True)

        with open(path, "w", encoding="utf-8") as f:
            f.write(cypher)

        return path

    except Exception as e:
        raise CompilerError(f"Failed to write Cypher file to {path}: {e}")

Cypher Compiler

from grai.core.compiler.cypher_compiler import CypherCompiler
from grai.core.models import Entity, Relation

compiler = CypherCompiler()

# Compile entity
entity_cypher = compiler.compile_entity(entity)

# Compile relation
relation_cypher = compiler.compile_relation(relation)

# Compile entire project
project_cypher = compiler.compile_project(project)

Loader

Loader module for executing Cypher against Neo4j and loading data from warehouses.

Neo4jConnection dataclass

Neo4j connection configuration.

Attributes:

Name Type Description
uri str

Neo4j connection URI (e.g., bolt://localhost:7687)

user str

Username for authentication

password str

Password for authentication

database str

Database name (default: neo4j)

encrypted bool

Whether to use encrypted connection

max_retry_time int

Maximum time to retry connection (seconds)

Source code in grai/core/loader/neo4j_loader.py
@dataclass
class Neo4jConnection:
    """
    Neo4j connection configuration.

    Attributes:
        uri: Neo4j connection URI (e.g., bolt://localhost:7687)
        user: Username for authentication
        password: Password for authentication
        database: Database name (default: neo4j)
        encrypted: Whether to use encrypted connection
        max_retry_time: Maximum time to retry connection (seconds)
    """

    uri: str
    user: str
    password: str
    database: str = "neo4j"
    encrypted: bool = False
    max_retry_time: int = 30

check_apoc_available(driver, database='neo4j')

Check if APOC library is installed and available in Neo4j.

Parameters:

Name Type Description Default
driver Driver

Neo4j driver instance.

required
database str

Database name to check.

'neo4j'

Returns:

Type Description
bool

True if APOC is available, False otherwise.

Example
driver = connect_neo4j(...)
if check_apoc_available(driver):
    print("APOC is available - can use advanced features!")
else:
    print("APOC not installed - using standard Cypher")
Source code in grai/core/loader/neo4j_loader.py
def check_apoc_available(driver: Driver, database: str = "neo4j") -> bool:
    """
    Check if APOC library is installed and available in Neo4j.

    Args:
        driver: Neo4j driver instance.
        database: Database name to check.

    Returns:
        True if APOC is available, False otherwise.

    Example:
        ```python
        driver = connect_neo4j(...)
        if check_apoc_available(driver):
            print("APOC is available - can use advanced features!")
        else:
            print("APOC not installed - using standard Cypher")
        ```
    """
    check_neo4j_available()

    try:
        with driver.session(database=database) as session:
            # Try to call a basic APOC function
            result = session.run("RETURN apoc.version() AS version")
            record = result.single()

            # The record should have a 'version' key if APOC is available
            # Check if record exists and has the version field
            if record:
                version_value = record.get("version")
                # APOC returns a string version like "5.12.0"
                return version_value is not None and len(str(version_value)) > 0

    except Exception as e:
        print(e)
        # APOC not available or error occurred
        pass

    return False

check_indexes_exist(driver, label, properties, database='neo4j')

Check if indexes exist for specified label and properties.

Parameters:

Name Type Description Default
driver Driver

Neo4j driver instance.

required
label str

Node label to check.

required
properties List[str]

List of property names to check for indexes.

required
database str

Database name to check.

'neo4j'

Returns:

Type Description
Dict[str, bool]

Dictionary mapping property names to whether an index exists.

Example
driver = connect_neo4j(...)
results = check_indexes_exist(driver, "Customer", ["customer_id", "email"])
# Returns: {"customer_id": True, "email": False}
Source code in grai/core/loader/neo4j_loader.py
def check_indexes_exist(
    driver: Driver, label: str, properties: List[str], database: str = "neo4j"
) -> Dict[str, bool]:
    """
    Check if indexes exist for specified label and properties.

    Args:
        driver: Neo4j driver instance.
        label: Node label to check.
        properties: List of property names to check for indexes.
        database: Database name to check.

    Returns:
        Dictionary mapping property names to whether an index exists.

    Example:
        ```python
        driver = connect_neo4j(...)
        results = check_indexes_exist(driver, "Customer", ["customer_id", "email"])
        # Returns: {"customer_id": True, "email": False}
        ```
    """
    check_neo4j_available()

    index_status = {prop: False for prop in properties}

    try:
        with driver.session(database=database) as session:
            # Query to show all indexes
            result = session.run("SHOW INDEXES")
            indexes = [record for record in result]

            # Check each property
            for prop in properties:
                for index in indexes:
                    # Check if this index is for our label and property
                    # Index structure varies by Neo4j version, check multiple fields
                    labels_or_entities = index.get("labelsOrTypes", []) or index.get("entityType")
                    props = index.get("properties", [])

                    # Handle both list and string formats
                    if isinstance(labels_or_entities, str):
                        labels_or_entities = [labels_or_entities]
                    if isinstance(props, str):
                        props = [props]

                    # Check if label matches and property is in the index
                    if label in labels_or_entities and prop in props:
                        index_status[prop] = True
                        break

    except Exception:
        # If we can't check indexes (old Neo4j version, permissions, etc.),
        # assume they don't exist to be safe
        pass

    return index_status

close_connection(driver)

Close Neo4j driver connection.

Parameters:

Name Type Description Default
driver Driver

Neo4j driver instance to close.

required
Example
driver = connect_neo4j(...)
# ... use driver ...
close_connection(driver)
Source code in grai/core/loader/neo4j_loader.py
def close_connection(driver: Driver) -> None:
    """
    Close Neo4j driver connection.

    Args:
        driver: Neo4j driver instance to close.

    Example:
        ```python
        driver = connect_neo4j(...)
        # ... use driver ...
        close_connection(driver)
        ```
    """
    if driver:
        driver.close()

connect_neo4j(connection_or_uri=None, user=None, password=None, database='neo4j', encrypted=False, max_retry_time=30, uri=None)

Connect to Neo4j database.

Parameters:

Name Type Description Default
connection_or_uri Optional[Union[Neo4jConnection, str]]

Either a Neo4jConnection object or a URI string

None
user Optional[str]

Username for authentication (required if connection_or_uri is a string)

None
password Optional[str]

Password for authentication (required if connection_or_uri is a string)

None
database str

Database name (default: neo4j)

'neo4j'
encrypted bool

Whether to use encrypted connection

False
max_retry_time int

Maximum time to retry connection (seconds)

30
uri Optional[str]

(Legacy) URI string - use connection_or_uri instead

None

Returns:

Type Description
Driver

Neo4j driver instance.

Raises:

Type Description
ImportError

If neo4j driver is not installed.

ServiceUnavailable

If cannot connect to Neo4j.

AuthError

If authentication fails.

Example
# Option 1: Using Neo4jConnection object
conn = Neo4jConnection(uri="bolt://localhost:7687", user="neo4j", password="password")
driver = connect_neo4j(conn)

# Option 2: Using individual parameters
driver = connect_neo4j(
    uri="bolt://localhost:7687",
    user="neo4j",
    password="password"
)

# Option 3: Positional URI
driver = connect_neo4j(
    "bolt://localhost:7687",
    user="neo4j",
    password="password"
)
Source code in grai/core/loader/neo4j_loader.py
def connect_neo4j(
    connection_or_uri: Optional[Union[Neo4jConnection, str]] = None,
    user: Optional[str] = None,
    password: Optional[str] = None,
    database: str = "neo4j",
    encrypted: bool = False,
    max_retry_time: int = 30,
    # Support legacy keyword argument
    uri: Optional[str] = None,
) -> Driver:
    """
    Connect to Neo4j database.

    Args:
        connection_or_uri: Either a Neo4jConnection object or a URI string
        user: Username for authentication (required if connection_or_uri is a string)
        password: Password for authentication (required if connection_or_uri is a string)
        database: Database name (default: neo4j)
        encrypted: Whether to use encrypted connection
        max_retry_time: Maximum time to retry connection (seconds)
        uri: (Legacy) URI string - use connection_or_uri instead

    Returns:
        Neo4j driver instance.

    Raises:
        ImportError: If neo4j driver is not installed.
        ServiceUnavailable: If cannot connect to Neo4j.
        AuthError: If authentication fails.

    Example:
        ```python
        # Option 1: Using Neo4jConnection object
        conn = Neo4jConnection(uri="bolt://localhost:7687", user="neo4j", password="password")
        driver = connect_neo4j(conn)

        # Option 2: Using individual parameters
        driver = connect_neo4j(
            uri="bolt://localhost:7687",
            user="neo4j",
            password="password"
        )

        # Option 3: Positional URI
        driver = connect_neo4j(
            "bolt://localhost:7687",
            user="neo4j",
            password="password"
        )
        ```
    """
    check_neo4j_available()

    # Handle legacy uri keyword argument
    if uri is not None and connection_or_uri is None:
        connection_or_uri = uri

    # connection_or_uri is now required
    if connection_or_uri is None:
        raise ValueError("Must provide either connection_or_uri or uri parameter")

    # Handle both Neo4jConnection object and individual parameters
    if isinstance(connection_or_uri, Neo4jConnection):
        uri = connection_or_uri.uri
        user = connection_or_uri.user
        password = connection_or_uri.password
        database = connection_or_uri.database  # noqa: F841 - database used by caller in sessions
        encrypted = connection_or_uri.encrypted
        max_retry_time = connection_or_uri.max_retry_time
    else:
        uri = connection_or_uri
        if user is None or password is None:
            raise ValueError("user and password are required when using URI string")

    try:
        driver = GraphDatabase.driver(
            uri,
            auth=(user, password),
            encrypted=encrypted,
            max_connection_lifetime=3600,
            max_connection_pool_size=50,
            connection_acquisition_timeout=max_retry_time,
        )

        # Verify connectivity
        driver.verify_connectivity()

        return driver

    except AuthError as e:
        raise AuthError(f"Authentication failed: {e}")
    except ServiceUnavailable as e:
        raise ServiceUnavailable(f"Cannot connect to Neo4j at {uri}: {e}")
    except Exception as e:
        raise RuntimeError(f"Error connecting to Neo4j: {e}")

execute_apoc_periodic_iterate(driver, batch_data, cypher_statement, batch_size=10000, parallel=False, database='neo4j')

Execute APOC periodic.iterate for efficient bulk loading.

This uses APOC's apoc.periodic.iterate which automatically handles: - Sub-batch commits (reduces memory usage) - Parallel processing (optional) - Better performance for large datasets

Parameters:

Name Type Description Default
driver Driver

Neo4j driver instance.

required
batch_data List[Dict[str, Any]]

List of dictionaries to load.

required
cypher_statement str

Cypher statement to execute for each row. Use 'row' to reference each item.

required
batch_size int

Number of rows to process per sub-batch.

10000
parallel bool

Whether to use parallel processing.

False
database str

Database name.

'neo4j'

Returns:

Type Description
ExecutionResult

ExecutionResult with execution details.

Example
data = [
    {"customer_id": "C001", "name": "Alice"},
    {"customer_id": "C002", "name": "Bob"}
]

cypher = '''
MERGE (c:Customer {customer_id: row.customer_id})
SET c.name = row.name
'''

result = execute_apoc_periodic_iterate(
    driver, data, cypher, batch_size=10000
)
print(f"Loaded {result.records_affected} rows")
Source code in grai/core/loader/neo4j_loader.py
def execute_apoc_periodic_iterate(
    driver: Driver,
    batch_data: List[Dict[str, Any]],
    cypher_statement: str,
    batch_size: int = 10000,
    parallel: bool = False,
    database: str = "neo4j",
) -> ExecutionResult:
    """
    Execute APOC periodic.iterate for efficient bulk loading.

    This uses APOC's apoc.periodic.iterate which automatically handles:
    - Sub-batch commits (reduces memory usage)
    - Parallel processing (optional)
    - Better performance for large datasets

    Args:
        driver: Neo4j driver instance.
        batch_data: List of dictionaries to load.
        cypher_statement: Cypher statement to execute for each row.
                         Use 'row' to reference each item.
        batch_size: Number of rows to process per sub-batch.
        parallel: Whether to use parallel processing.
        database: Database name.

    Returns:
        ExecutionResult with execution details.

    Example:
        ```python
        data = [
            {"customer_id": "C001", "name": "Alice"},
            {"customer_id": "C002", "name": "Bob"}
        ]

        cypher = '''
        MERGE (c:Customer {customer_id: row.customer_id})
        SET c.name = row.name
        '''

        result = execute_apoc_periodic_iterate(
            driver, data, cypher, batch_size=10000
        )
        print(f"Loaded {result.records_affected} rows")
        ```
    """
    check_neo4j_available()

    if not batch_data:
        return ExecutionResult(
            success=True,
            statements_executed=0,
            records_affected=0,
            errors=[],
            execution_time=0.0,
        )

    result = ExecutionResult(
        success=False,
        statements_executed=0,
        records_affected=0,
        errors=[],
        execution_time=0.0,
    )

    start_time = time.time()

    try:
        with driver.session(database=database) as session:
            # Use APOC periodic.iterate
            apoc_query = """
            CALL apoc.periodic.iterate(
                'UNWIND $batch AS row RETURN row',
                $statement,
                {batchSize: $batchSize, parallel: $parallel, params: {batch: $batch}}
            )
            YIELD batches, total, errorMessages
            RETURN batches, total, errorMessages
            """

            apoc_result = session.run(
                apoc_query,
                batch=batch_data,
                statement=cypher_statement,
                batchSize=batch_size,
                parallel=parallel,
            )

            record = apoc_result.single()
            if record:
                result.statements_executed = record.get("batches", 0)
                result.records_affected = record.get("total", 0)

                error_msgs = record.get("errorMessages", [])
                if error_msgs:
                    result.errors = error_msgs
                    result.success = False
                else:
                    result.success = True

                # Track node/relationship changes
                counters = apoc_result.consume().counters
                result.nodes_created = counters.nodes_created
                result.relationships_created = counters.relationships_created
                result.properties_set = counters.properties_set

    except Exception as e:
        result.errors.append(str(e))
        result.success = False

    result.execution_time = time.time() - start_time

    return result

execute_cypher(driver, cypher, parameters=None, database='neo4j')

Execute Cypher statement(s) against Neo4j.

Parameters:

Name Type Description Default
driver Driver

Neo4j driver instance.

required
cypher str

Cypher statement(s) to execute.

required
parameters Optional[Dict[str, Any]]

Optional parameters for the query.

None
database str

Database name to execute against.

'neo4j'

Returns:

Type Description
ExecutionResult

ExecutionResult with execution details.

Example
driver = connect_neo4j(...)
result = execute_cypher(
    driver,
    "CREATE (n:Person {name: $name}) RETURN n",
    parameters={"name": "Alice"}
)
print(f"Success: {result.success}")
print(f"Statements executed: {result.statements_executed}")
Source code in grai/core/loader/neo4j_loader.py
def execute_cypher(
    driver: Driver,
    cypher: str,
    parameters: Optional[Dict[str, Any]] = None,
    database: str = "neo4j",
) -> ExecutionResult:
    """
    Execute Cypher statement(s) against Neo4j.

    Args:
        driver: Neo4j driver instance.
        cypher: Cypher statement(s) to execute.
        parameters: Optional parameters for the query.
        database: Database name to execute against.

    Returns:
        ExecutionResult with execution details.

    Example:
        ```python
        driver = connect_neo4j(...)
        result = execute_cypher(
            driver,
            "CREATE (n:Person {name: $name}) RETURN n",
            parameters={"name": "Alice"}
        )
        print(f"Success: {result.success}")
        print(f"Statements executed: {result.statements_executed}")
        ```
    """
    check_neo4j_available()

    start_time = time.time()
    result = ExecutionResult(success=False)

    try:
        # Split into individual statements
        statements = split_cypher_statements(cypher)

        with driver.session(database=database) as session:
            for statement in statements:
                try:
                    # Execute statement
                    query_result = session.run(statement, parameters or {})

                    # Consume results to ensure execution
                    summary = query_result.consume()

                    # Track counters
                    counters = summary.counters
                    result.nodes_created += counters.nodes_created
                    result.nodes_deleted += counters.nodes_deleted
                    result.relationships_created += counters.relationships_created
                    result.relationships_deleted += counters.relationships_deleted
                    result.properties_set += counters.properties_set
                    result.records_affected += (
                        counters.nodes_created
                        + counters.nodes_deleted
                        + counters.relationships_created
                        + counters.relationships_deleted
                        + counters.properties_set
                    )

                    result.statements_executed += 1

                except Neo4jError as e:
                    result.errors.append(f"Error executing statement: {e}")
                    result.success = False
                    return result

            # All statements executed successfully
            result.success = True

    except Exception as e:
        result.errors.append(f"Execution error: {e}")
        result.success = False

    finally:
        result.execution_time = time.time() - start_time

    return result

execute_cypher_file(driver, file_path, database='neo4j', batch_size=None)

Execute Cypher statements from a file.

Parameters:

Name Type Description Default
driver Driver

Neo4j driver instance.

required
file_path Union[str, Path]

Path to Cypher file.

required
database str

Database name to execute against.

'neo4j'
batch_size Optional[int]

Optional batch size for large files.

None

Returns:

Type Description
ExecutionResult

ExecutionResult with execution details.

Raises:

Type Description
FileNotFoundError

If file does not exist.

Example
driver = connect_neo4j(...)
result = execute_cypher_file(
    driver,
    "target/neo4j/compiled.cypher"
)
print(f"Executed {result.statements_executed} statements")
print(f"Affected {result.records_affected} records")
Source code in grai/core/loader/neo4j_loader.py
def execute_cypher_file(
    driver: Driver,
    file_path: Union[str, Path],
    database: str = "neo4j",
    batch_size: Optional[int] = None,
) -> ExecutionResult:
    """
    Execute Cypher statements from a file.

    Args:
        driver: Neo4j driver instance.
        file_path: Path to Cypher file.
        database: Database name to execute against.
        batch_size: Optional batch size for large files.

    Returns:
        ExecutionResult with execution details.

    Raises:
        FileNotFoundError: If file does not exist.

    Example:
        ```python
        driver = connect_neo4j(...)
        result = execute_cypher_file(
            driver,
            "target/neo4j/compiled.cypher"
        )
        print(f"Executed {result.statements_executed} statements")
        print(f"Affected {result.records_affected} records")
        ```
    """
    check_neo4j_available()

    file_path = Path(file_path)

    if not file_path.exists():
        raise FileNotFoundError(f"Cypher file not found: {file_path}")

    # Read file
    cypher = file_path.read_text()

    # Execute
    return execute_cypher(driver, cypher, database=database)

get_apoc_version(driver, database='neo4j')

Get the installed APOC version.

Parameters:

Name Type Description Default
driver Driver

Neo4j driver instance.

required
database str

Database name.

'neo4j'

Returns:

Type Description
Optional[str]

APOC version string or None if not available.

Example
version = get_apoc_version(driver)
print(f"APOC version: {version}")
Source code in grai/core/loader/neo4j_loader.py
def get_apoc_version(driver: Driver, database: str = "neo4j") -> Optional[str]:
    """
    Get the installed APOC version.

    Args:
        driver: Neo4j driver instance.
        database: Database name.

    Returns:
        APOC version string or None if not available.

    Example:
        ```python
        version = get_apoc_version(driver)
        print(f"APOC version: {version}")
        ```
    """
    check_neo4j_available()

    try:
        with driver.session(database=database) as session:
            result = session.run("RETURN apoc.version() AS version")
            record = result.single()
            if record:
                return record.get("version")
    except Exception:
        pass

    return None

get_database_info(driver, database='neo4j')

Get information about the Neo4j database.

Parameters:

Name Type Description Default
driver Driver

Neo4j driver instance.

required
database str

Database name.

'neo4j'

Returns:

Type Description
Dict[str, Any]

Dictionary with database information.

Example
driver = connect_neo4j(...)
info = get_database_info(driver)
print(f"Node count: {info['node_count']}")
print(f"Relationship count: {info['relationship_count']}")
Source code in grai/core/loader/neo4j_loader.py
def get_database_info(driver: Driver, database: str = "neo4j") -> Dict[str, Any]:
    """
    Get information about the Neo4j database.

    Args:
        driver: Neo4j driver instance.
        database: Database name.

    Returns:
        Dictionary with database information.

    Example:
        ```python
        driver = connect_neo4j(...)
        info = get_database_info(driver)
        print(f"Node count: {info['node_count']}")
        print(f"Relationship count: {info['relationship_count']}")
        ```
    """
    check_neo4j_available()

    info = {
        "node_count": 0,
        "relationship_count": 0,
        "labels": [],
        "relationship_types": [],
        "constraints": [],
        "indexes": [],
    }

    try:
        with driver.session(database=database) as session:
            # Get node count
            result = session.run("MATCH (n) RETURN count(n) AS count")
            info["node_count"] = result.single()["count"]

            # Get relationship count
            result = session.run("MATCH ()-[r]->() RETURN count(r) AS count")
            info["relationship_count"] = result.single()["count"]

            # Get labels
            result = session.run("CALL db.labels()")
            info["labels"] = [record["label"] for record in result]

            # Get relationship types
            result = session.run("CALL db.relationshipTypes()")
            info["relationship_types"] = [record["relationshipType"] for record in result]

            # Get constraints
            result = session.run("SHOW CONSTRAINTS")
            info["constraints"] = [dict(record) for record in result]

            # Get indexes
            result = session.run("SHOW INDEXES")
            info["indexes"] = [dict(record) for record in result]

    except Exception as e:
        info["error"] = str(e)

    return info

verify_connection(driver, database='neo4j')

Verify that connection to Neo4j is working.

Parameters:

Name Type Description Default
driver Driver

Neo4j driver instance.

required
database str

Database name to test.

'neo4j'

Returns:

Type Description
bool

True if connection is working, False otherwise.

Example
driver = connect_neo4j(...)
if verify_connection(driver):
    print("Connected!")
Source code in grai/core/loader/neo4j_loader.py
def verify_connection(driver: Driver, database: str = "neo4j") -> bool:
    """
    Verify that connection to Neo4j is working.

    Args:
        driver: Neo4j driver instance.
        database: Database name to test.

    Returns:
        True if connection is working, False otherwise.

    Example:
        ```python
        driver = connect_neo4j(...)
        if verify_connection(driver):
            print("Connected!")
        ```
    """
    check_neo4j_available()

    try:
        with driver.session(database=database) as session:
            result = session.run("RETURN 1 AS test")
            record = result.single()
            return record["test"] == 1
    except Exception:
        return False

verify_indexes_and_warn(driver, label, key_properties, database='neo4j', verbose=True)

Verify that indexes exist for key properties and print warnings if missing.

Parameters:

Name Type Description Default
driver Driver

Neo4j driver instance.

required
label str

Node label to check.

required
key_properties List[str]

List of key property names that should have indexes.

required
database str

Database name to check.

'neo4j'
verbose bool

Whether to print warnings.

True

Returns:

Type Description
bool

True if all indexes exist, False otherwise.

Example
driver = connect_neo4j(...)
all_good = verify_indexes_and_warn(driver, "Customer", ["customer_id"])
if not all_good:
    print("Warning: Missing indexes!")
Source code in grai/core/loader/neo4j_loader.py
def verify_indexes_and_warn(
    driver: Driver,
    label: str,
    key_properties: List[str],
    database: str = "neo4j",
    verbose: bool = True,
) -> bool:
    """
    Verify that indexes exist for key properties and print warnings if missing.

    Args:
        driver: Neo4j driver instance.
        label: Node label to check.
        key_properties: List of key property names that should have indexes.
        database: Database name to check.
        verbose: Whether to print warnings.

    Returns:
        True if all indexes exist, False otherwise.

    Example:
        ```python
        driver = connect_neo4j(...)
        all_good = verify_indexes_and_warn(driver, "Customer", ["customer_id"])
        if not all_good:
            print("Warning: Missing indexes!")
        ```
    """
    index_status = check_indexes_exist(driver, label, key_properties, database)
    missing_indexes = [prop for prop, exists in index_status.items() if not exists]

    if missing_indexes and verbose:
        print("\n⚠️  WARNING: Missing indexes detected!")
        print("   Loading data without indexes will be significantly slower.")
        print(f"   Label: {label}")
        print(f"   Missing indexes on: {', '.join(missing_indexes)}")
        print("\n   Recommended action:")
        for prop in missing_indexes:
            print(f"   CREATE INDEX IF NOT EXISTS FOR (n:{label}) ON (n.{prop});")
        print()

    return len(missing_indexes) == 0

Neo4j Loader

from grai.core.loader.neo4j_loader import (
    connect_neo4j,
    execute_cypher,
    close_connection,
)

# Connect
driver = connect_neo4j(
    uri="bolt://localhost:7687",
    user="neo4j",
    password="password",
    database="neo4j"
)

# Execute Cypher
cypher = "CREATE CONSTRAINT ..."
result = execute_cypher(driver, cypher)

if result.success:
    print(f"✅ Executed {result.statements_executed} statements")
    print(f"   Nodes created: {result.nodes_created}")
    print(f"   Properties set: {result.properties_set}")
else:
    for error in result.errors:
        print(f"❌ {error}")

# Close
close_connection(driver)

BigQuery Loader

from grai.core.loader.bigquery_loader import (
    BigQueryExtractor,
    load_entity_from_bigquery,
)

# Extract data
extractor = BigQueryExtractor(
    project_id="my-project",
    credentials_path="service-account.json"
)

# Load entity
result = load_entity_from_bigquery(
    entity=entity,
    bigquery_connection=extractor,
    neo4j_connection=driver,
    batch_size=1000,
    limit=None,
    verbose=True
)

print(f"✅ Loaded {result.rows_processed} rows")
print(f"   Duration: {result.duration_seconds}s")

Profiles

Profile management for grai.build connections.

Inspired by dbt's profiles.yml, this module handles connection configurations for data warehouses (BigQuery, PostgreSQL, Snowflake, etc.) and graph databases (Neo4j).

BigQueryProfile

Bases: BaseModel

BigQuery connection profile.

Source code in grai/core/profiles.py
class BigQueryProfile(BaseModel):
    """BigQuery connection profile."""

    type: str = Field(default="bigquery", frozen=True)
    method: str = Field(
        default="oauth",
        description="Authentication method: 'oauth', 'service-account', or 'service-account-json'",
    )
    project: Optional[str] = Field(
        None, description="BigQuery project ID (defaults to gcloud default)"
    )
    dataset: Optional[str] = Field(None, description="Default dataset name")
    location: Optional[str] = Field("US", description="BigQuery location (e.g., 'US', 'EU')")
    keyfile: Optional[str] = Field(None, description="Path to service account JSON keyfile")
    keyfile_json: Optional[Dict[str, Any]] = Field(
        None, description="Service account JSON credentials as dict"
    )
    timeout_seconds: int = Field(300, description="Query timeout in seconds")
    maximum_bytes_billed: Optional[int] = Field(None, description="Maximum bytes billed per query")

    @field_validator("method")
    @classmethod
    def validate_method(cls, v: str) -> str:
        """Validate authentication method."""
        valid_methods = ["oauth", "service-account", "service-account-json"]
        if v not in valid_methods:
            raise ValueError(f"Invalid method '{v}'. Must be one of: {', '.join(valid_methods)}")
        return v

validate_method(v) classmethod

Validate authentication method.

Source code in grai/core/profiles.py
@field_validator("method")
@classmethod
def validate_method(cls, v: str) -> str:
    """Validate authentication method."""
    valid_methods = ["oauth", "service-account", "service-account-json"]
    if v not in valid_methods:
        raise ValueError(f"Invalid method '{v}'. Must be one of: {', '.join(valid_methods)}")
    return v

Neo4jProfile

Bases: BaseModel

Neo4j connection profile.

Source code in grai/core/profiles.py
class Neo4jProfile(BaseModel):
    """Neo4j connection profile."""

    type: str = Field(default="neo4j", frozen=True)
    uri: str = Field(default="bolt://localhost:7687", description="Neo4j connection URI")
    user: str = Field(default="neo4j", description="Neo4j username")
    password: Optional[str] = Field(None, description="Neo4j password")
    database: Optional[str] = Field(None, description="Neo4j database name")
    encrypted: bool = Field(True, description="Use encrypted connection")
    trust: str = Field(
        "TRUST_SYSTEM_CA_SIGNED_CERTIFICATES",
        description="Certificate trust level",
    )

PostgresProfile

Bases: BaseModel

PostgreSQL connection profile.

Source code in grai/core/profiles.py
class PostgresProfile(BaseModel):
    """PostgreSQL connection profile."""

    type: str = Field(default="postgres", frozen=True)
    host: str = Field(..., description="PostgreSQL server hostname or IP")
    port: int = Field(5432, description="PostgreSQL server port")
    database: str = Field(..., description="Database name")
    user: str = Field(..., description="Username for authentication")
    password: Optional[str] = Field(None, description="Password for authentication")
    schema: str = Field("public", description="Default schema")
    sslmode: str = Field(
        "prefer",
        description="SSL mode: 'disable', 'allow', 'prefer', 'require', 'verify-ca', 'verify-full'",
    )

    @field_validator("sslmode")
    @classmethod
    def validate_sslmode(cls, v: str) -> str:
        """Validate SSL mode."""
        valid_modes = ["disable", "allow", "prefer", "require", "verify-ca", "verify-full"]
        if v not in valid_modes:
            raise ValueError(f"Invalid sslmode '{v}'. Must be one of: {', '.join(valid_modes)}")
        return v

validate_sslmode(v) classmethod

Validate SSL mode.

Source code in grai/core/profiles.py
@field_validator("sslmode")
@classmethod
def validate_sslmode(cls, v: str) -> str:
    """Validate SSL mode."""
    valid_modes = ["disable", "allow", "prefer", "require", "verify-ca", "verify-full"]
    if v not in valid_modes:
        raise ValueError(f"Invalid sslmode '{v}'. Must be one of: {', '.join(valid_modes)}")
    return v

Profile

Bases: BaseModel

Complete profile configuration.

Source code in grai/core/profiles.py
class Profile(BaseModel):
    """Complete profile configuration."""

    config: TargetConfig

SnowflakeProfile

Bases: BaseModel

Snowflake connection profile.

Source code in grai/core/profiles.py
class SnowflakeProfile(BaseModel):
    """Snowflake connection profile."""

    type: str = Field(default="snowflake", frozen=True)
    account: str = Field(..., description="Snowflake account identifier")
    user: str = Field(..., description="Snowflake username")
    password: Optional[str] = Field(None, description="Snowflake password")
    role: Optional[str] = Field(None, description="Snowflake role")
    database: Optional[str] = Field(None, description="Default database")
    warehouse: Optional[str] = Field(None, description="Snowflake warehouse")
    schema_name: Optional[str] = Field(None, description="Default schema", alias="schema")
    authenticator: Optional[str] = Field(
        None, description="Authentication method (e.g., 'externalbrowser')"
    )

    model_config = {"populate_by_name": True}

TargetConfig

Bases: BaseModel

Configuration for a specific target (environment).

Source code in grai/core/profiles.py
class TargetConfig(BaseModel):
    """Configuration for a specific target (environment)."""

    outputs: Dict[str, Any] = Field(..., description="Output configurations (warehouse + graph)")
    target: str = Field(..., description="Default output to use")

create_default_profiles_file()

Create a default profiles.yml file.

Returns:

Type Description
Path

Path to created profiles file

Source code in grai/core/profiles.py
def create_default_profiles_file() -> Path:
    """
    Create a default profiles.yml file.

    Returns:
        Path to created profiles file
    """
    profiles_dir = get_profiles_dir()
    profiles_dir.mkdir(parents=True, exist_ok=True)

    profile_path = profiles_dir / "profiles.yml"

    default_content = """# grai.build profiles configuration
# Similar to dbt profiles.yml, this file manages connections to data warehouses and graph databases
#
# Environment variables can be referenced using: {{ env_var('VAR_NAME') }}
# Set GRAI_TARGET environment variable to override the default target

default:
  target: dev
  outputs:
    dev:
      # Data warehouse configuration
      warehouse:
        type: bigquery
        method: oauth  # or 'service-account'
        project: "{{ env_var('GCP_PROJECT') }}"
        dataset: analytics
        location: US
        # keyfile: /path/to/service-account.json  # for service-account method
        timeout_seconds: 300

      # Graph database configuration
      graph:
        type: neo4j
        uri: bolt://localhost:7687
        user: neo4j
        password: "{{ env_var('NEO4J_PASSWORD') }}"
        database: neo4j
        encrypted: true

    prod:
      warehouse:
        type: bigquery
        method: service-account
        project: my-prod-project
        dataset: analytics_prod
        location: US
        keyfile: "{{ env_var('GCP_KEYFILE_PATH') }}"
        timeout_seconds: 600

      graph:
        type: neo4j
        uri: "{{ env_var('NEO4J_URI') }}"
        user: neo4j
        password: "{{ env_var('NEO4J_PASSWORD') }}"
        database: neo4j
        encrypted: true

# Example with Snowflake
# snowflake_project:
#   target: dev
#   outputs:
#     dev:
#       warehouse:
#         type: snowflake
#         account: abc12345.us-east-1
#         user: "{{ env_var('SNOWFLAKE_USER') }}"
#         password: "{{ env_var('SNOWFLAKE_PASSWORD') }}"
#         role: ANALYST
#         database: ANALYTICS
#         warehouse: COMPUTE_WH
#         schema: PUBLIC
#
#       graph:
#         type: neo4j
#         uri: bolt://localhost:7687
#         user: neo4j
#         password: "{{ env_var('NEO4J_PASSWORD') }}"
"""

    with open(profile_path, "w") as f:
        f.write(default_content)

    return profile_path

get_connection_info(profile_name='default', target_name=None)

Get warehouse and graph connection info from profiles.

Parameters:

Name Type Description Default
profile_name str

Name of the profile to use (defaults to 'default')

'default'
target_name Optional[str]

Name of the target (defaults to profile's default or GRAI_TARGET)

None

Returns:

Type Description
tuple[Any, Neo4jProfile]

Tuple of (warehouse_profile, graph_profile)

Raises:

Type Description
FileNotFoundError

If profiles.yml doesn't exist

KeyError

If profile or target doesn't exist

Source code in grai/core/profiles.py
def get_connection_info(
    profile_name: str = "default", target_name: Optional[str] = None
) -> tuple[Any, Neo4jProfile]:
    """
    Get warehouse and graph connection info from profiles.

    Args:
        profile_name: Name of the profile to use (defaults to 'default')
        target_name: Name of the target (defaults to profile's default or GRAI_TARGET)

    Returns:
        Tuple of (warehouse_profile, graph_profile)

    Raises:
        FileNotFoundError: If profiles.yml doesn't exist
        KeyError: If profile or target doesn't exist
    """
    target_config = get_target_config(profile_name, target_name)

    # Resolve environment variables
    target_config = resolve_env_vars(target_config)

    # Parse warehouse config
    warehouse_config = target_config.get("warehouse")
    if not warehouse_config:
        raise ValueError(
            f"No warehouse configuration found in target '{target_name}' "
            f"of profile '{profile_name}'"
        )
    warehouse_profile = parse_warehouse_profile(warehouse_config)

    # Parse graph config
    graph_config = target_config.get("graph")
    if not graph_config:
        raise ValueError(
            f"No graph configuration found in target '{target_name}' "
            f"of profile '{profile_name}'"
        )
    graph_profile = parse_graph_profile(graph_config)

    return warehouse_profile, graph_profile

get_profile(profile_name)

Get a specific profile by name.

Parameters:

Name Type Description Default
profile_name str

Name of the profile to retrieve

required

Returns:

Type Description
Dict[str, Any]

Profile configuration dictionary

Raises:

Type Description
KeyError

If profile doesn't exist

Source code in grai/core/profiles.py
def get_profile(profile_name: str) -> Dict[str, Any]:
    """
    Get a specific profile by name.

    Args:
        profile_name: Name of the profile to retrieve

    Returns:
        Profile configuration dictionary

    Raises:
        KeyError: If profile doesn't exist
    """
    profiles = load_profiles()

    if profile_name not in profiles:
        available = ", ".join(profiles.keys())
        raise KeyError(f"Profile '{profile_name}' not found. Available profiles: {available}")

    return profiles[profile_name]

get_profile_path()

Get the path to profiles.yml.

Returns:

Type Description
Path

Path to profiles.yml file

Source code in grai/core/profiles.py
def get_profile_path() -> Path:
    """
    Get the path to profiles.yml.

    Returns:
        Path to profiles.yml file
    """
    return get_profiles_dir() / "profiles.yml"

get_profiles_dir()

Get the profiles directory path.

Checks in order: 1. GRAI_PROFILES_DIR environment variable 2. ~/.grai/ directory

Returns:

Type Description
Path

Path to profiles directory

Source code in grai/core/profiles.py
def get_profiles_dir() -> Path:
    """
    Get the profiles directory path.

    Checks in order:
    1. GRAI_PROFILES_DIR environment variable
    2. ~/.grai/ directory

    Returns:
        Path to profiles directory
    """
    profiles_dir_env = os.getenv("GRAI_PROFILES_DIR")
    if profiles_dir_env:
        return Path(profiles_dir_env)

    return Path.home() / ".grai"

get_target_config(profile_name, target_name=None)

Get target configuration from a profile.

Parameters:

Name Type Description Default
profile_name str

Name of the profile

required
target_name Optional[str]

Name of the target (defaults to profile's default target)

None

Returns:

Type Description
Dict[str, Any]

Target configuration dictionary with 'warehouse' and 'graph' outputs

Raises:

Type Description
KeyError

If profile or target doesn't exist

Source code in grai/core/profiles.py
def get_target_config(profile_name: str, target_name: Optional[str] = None) -> Dict[str, Any]:
    """
    Get target configuration from a profile.

    Args:
        profile_name: Name of the profile
        target_name: Name of the target (defaults to profile's default target)

    Returns:
        Target configuration dictionary with 'warehouse' and 'graph' outputs

    Raises:
        KeyError: If profile or target doesn't exist
    """
    profile = get_profile(profile_name)

    # Get target name (from arg, env var, or profile default)
    if target_name is None:
        target_name = os.getenv("GRAI_TARGET")

    if target_name is None:
        target_name = profile.get("target")

    if target_name is None:
        raise ValueError(
            f"No target specified for profile '{profile_name}'. "
            f"Set target in profiles.yml or use GRAI_TARGET env var."
        )

    # Get outputs
    outputs = profile.get("outputs", {})
    if target_name not in outputs:
        available = ", ".join(outputs.keys())
        raise KeyError(
            f"Target '{target_name}' not found in profile '{profile_name}'. "
            f"Available targets: {available}"
        )

    return outputs[target_name]

load_profiles()

Load profiles from profiles.yml.

Returns:

Type Description
Dict[str, Any]

Dictionary of profile configurations

Raises:

Type Description
FileNotFoundError

If profiles.yml doesn't exist

YAMLError

If profiles.yml is invalid

Source code in grai/core/profiles.py
def load_profiles() -> Dict[str, Any]:
    """
    Load profiles from profiles.yml.

    Returns:
        Dictionary of profile configurations

    Raises:
        FileNotFoundError: If profiles.yml doesn't exist
        yaml.YAMLError: If profiles.yml is invalid
    """
    profile_path = get_profile_path()

    if not profile_path.exists():
        raise FileNotFoundError(
            f"Profile file not found at {profile_path}. "
            f"Run 'grai init' to create one, or set GRAI_PROFILES_DIR."
        )

    with open(profile_path) as f:
        profiles = yaml.safe_load(f)

    if not profiles:
        raise ValueError(f"Profile file {profile_path} is empty")

    return profiles

parse_graph_profile(config)

Parse graph database configuration into profile model.

Parameters:

Name Type Description Default
config Dict[str, Any]

Graph database configuration dictionary

required

Returns:

Type Description
Neo4jProfile

Neo4jProfile model

Raises:

Type Description
ValueError

If graph type is unsupported

Source code in grai/core/profiles.py
def parse_graph_profile(config: Dict[str, Any]) -> Neo4jProfile:
    """
    Parse graph database configuration into profile model.

    Args:
        config: Graph database configuration dictionary

    Returns:
        Neo4jProfile model

    Raises:
        ValueError: If graph type is unsupported
    """
    graph_type = config.get("type", "neo4j")

    if graph_type == "neo4j":
        return Neo4jProfile(**config)
    else:
        raise ValueError(
            f"Unsupported graph type: {graph_type}. Currently only 'neo4j' is supported."
        )

parse_warehouse_profile(config)

Parse warehouse configuration into appropriate profile model.

Parameters:

Name Type Description Default
config Dict[str, Any]

Warehouse configuration dictionary

required

Returns:

Type Description
Any

Profile model (BigQueryProfile, SnowflakeProfile, PostgresProfile, etc.)

Raises:

Type Description
ValueError

If warehouse type is unsupported

Source code in grai/core/profiles.py
def parse_warehouse_profile(config: Dict[str, Any]) -> Any:
    """
    Parse warehouse configuration into appropriate profile model.

    Args:
        config: Warehouse configuration dictionary

    Returns:
        Profile model (BigQueryProfile, SnowflakeProfile, PostgresProfile, etc.)

    Raises:
        ValueError: If warehouse type is unsupported
    """
    warehouse_type = config.get("type")

    if warehouse_type == "bigquery":
        return BigQueryProfile(**config)
    elif warehouse_type == "snowflake":
        return SnowflakeProfile(**config)
    elif warehouse_type == "postgres":
        return PostgresProfile(**config)
    else:
        raise ValueError(
            f"Unsupported warehouse type: {warehouse_type}. "
            f"Supported types: bigquery, snowflake, postgres"
        )

resolve_env_vars(config)

Resolve environment variable references in configuration.

Replaces strings like "{{ env_var('MY_VAR') }}" with environment variable values.

Parameters:

Name Type Description Default
config Dict[str, Any]

Configuration dictionary

required

Returns:

Type Description
Dict[str, Any]

Configuration with environment variables resolved

Source code in grai/core/profiles.py
def resolve_env_vars(config: Dict[str, Any]) -> Dict[str, Any]:
    """
    Resolve environment variable references in configuration.

    Replaces strings like "{{ env_var('MY_VAR') }}" with environment variable values.

    Args:
        config: Configuration dictionary

    Returns:
        Configuration with environment variables resolved
    """
    import re

    def replace_env_var(match: re.Match) -> str:
        var_name = match.group(1)
        value = os.getenv(var_name)
        if value is None:
            raise ValueError(f"Environment variable '{var_name}' is not set")
        return value

    result = {}
    env_var_pattern = re.compile(r"{{\s*env_var\(['\"]([^'\"]+)['\"]\)\s*}}")

    for key, value in config.items():
        if isinstance(value, str):
            result[key] = env_var_pattern.sub(replace_env_var, value)
        elif isinstance(value, dict):
            result[key] = resolve_env_vars(value)
        else:
            result[key] = value

    return result

Profile Configuration

from grai.core.profiles import (
    BigQueryProfile,
    Neo4jProfile,
    TargetConfig,
)

# BigQuery profile
bq_profile = BigQueryProfile(
    project_id="my-project",
    dataset="my_dataset",
    credentials_path="/path/to/credentials.json"
)

# Neo4j profile
neo4j_profile = Neo4jProfile(
    uri="bolt://localhost:7687",
    user="neo4j",
    password="password",
    database="neo4j"
)

# Target config
target = TargetConfig(
    bigquery=bq_profile,
    neo4j=neo4j_profile
)

Lineage

Lineage tracking module for knowledge graph analysis.

Exports lineage tracking functions for analyzing entity relationships, dependencies, and impact analysis.

LineageEdge dataclass

Represents an edge in the lineage graph.

Attributes:

Name Type Description
from_node str

Source node ID

to_node str

Target node ID

relation_type str

Type of relationship (e.g., "depends_on", "produces")

metadata Dict

Additional metadata about the edge

Source code in grai/core/lineage/lineage_tracker.py
@dataclass
class LineageEdge:
    """
    Represents an edge in the lineage graph.

    Attributes:
        from_node: Source node ID
        to_node: Target node ID
        relation_type: Type of relationship (e.g., "depends_on", "produces")
        metadata: Additional metadata about the edge
    """

    from_node: str
    to_node: str
    relation_type: str
    metadata: Dict = field(default_factory=dict)

LineageGraph dataclass

Represents the complete lineage graph.

Attributes:

Name Type Description
nodes Dict[str, LineageNode]

Dictionary mapping node IDs to LineageNode objects

edges List[LineageEdge]

List of LineageEdge objects

entity_map Dict[str, str]

Mapping of entity names to node IDs

relation_map Dict[str, str]

Mapping of relation names to node IDs

source_map Dict[str, str]

Mapping of source names to node IDs

Source code in grai/core/lineage/lineage_tracker.py
@dataclass
class LineageGraph:
    """
    Represents the complete lineage graph.

    Attributes:
        nodes: Dictionary mapping node IDs to LineageNode objects
        edges: List of LineageEdge objects
        entity_map: Mapping of entity names to node IDs
        relation_map: Mapping of relation names to node IDs
        source_map: Mapping of source names to node IDs
    """

    nodes: Dict[str, LineageNode] = field(default_factory=dict)
    edges: List[LineageEdge] = field(default_factory=list)
    entity_map: Dict[str, str] = field(default_factory=dict)
    relation_map: Dict[str, str] = field(default_factory=dict)
    source_map: Dict[str, str] = field(default_factory=dict)

    def add_node(self, node: LineageNode) -> None:
        """Add a node to the graph."""
        self.nodes[node.id] = node

        if node.type == NodeType.ENTITY:
            self.entity_map[node.name] = node.id
        elif node.type == NodeType.RELATION:
            self.relation_map[node.name] = node.id
        elif node.type == NodeType.SOURCE:
            self.source_map[node.name] = node.id

    def add_edge(self, edge: LineageEdge) -> None:
        """Add an edge to the graph."""
        self.edges.append(edge)

    def get_node(self, node_id: str) -> Optional[LineageNode]:
        """Get node by ID."""
        return self.nodes.get(node_id)

    def get_edges_from(self, node_id: str) -> List[LineageEdge]:
        """Get all edges originating from a node."""
        return [edge for edge in self.edges if edge.from_node == node_id]

    def get_edges_to(self, node_id: str) -> List[LineageEdge]:
        """Get all edges pointing to a node."""
        return [edge for edge in self.edges if edge.to_node == node_id]

add_edge(edge)

Add an edge to the graph.

Source code in grai/core/lineage/lineage_tracker.py
def add_edge(self, edge: LineageEdge) -> None:
    """Add an edge to the graph."""
    self.edges.append(edge)

add_node(node)

Add a node to the graph.

Source code in grai/core/lineage/lineage_tracker.py
def add_node(self, node: LineageNode) -> None:
    """Add a node to the graph."""
    self.nodes[node.id] = node

    if node.type == NodeType.ENTITY:
        self.entity_map[node.name] = node.id
    elif node.type == NodeType.RELATION:
        self.relation_map[node.name] = node.id
    elif node.type == NodeType.SOURCE:
        self.source_map[node.name] = node.id

get_edges_from(node_id)

Get all edges originating from a node.

Source code in grai/core/lineage/lineage_tracker.py
def get_edges_from(self, node_id: str) -> List[LineageEdge]:
    """Get all edges originating from a node."""
    return [edge for edge in self.edges if edge.from_node == node_id]

get_edges_to(node_id)

Get all edges pointing to a node.

Source code in grai/core/lineage/lineage_tracker.py
def get_edges_to(self, node_id: str) -> List[LineageEdge]:
    """Get all edges pointing to a node."""
    return [edge for edge in self.edges if edge.to_node == node_id]

get_node(node_id)

Get node by ID.

Source code in grai/core/lineage/lineage_tracker.py
def get_node(self, node_id: str) -> Optional[LineageNode]:
    """Get node by ID."""
    return self.nodes.get(node_id)

LineageNode dataclass

Represents a node in the lineage graph.

Attributes:

Name Type Description
id str

Unique identifier for the node

name str

Node name (entity name, relation name, or source)

type NodeType

Type of node (entity, relation, or source)

metadata Dict

Additional metadata about the node

Source code in grai/core/lineage/lineage_tracker.py
@dataclass
class LineageNode:
    """
    Represents a node in the lineage graph.

    Attributes:
        id: Unique identifier for the node
        name: Node name (entity name, relation name, or source)
        type: Type of node (entity, relation, or source)
        metadata: Additional metadata about the node
    """

    id: str
    name: str
    type: NodeType
    metadata: Dict = field(default_factory=dict)

    def __hash__(self):
        return hash(self.id)

    def __eq__(self, other):
        return isinstance(other, LineageNode) and self.id == other.id

NodeType

Bases: Enum

Type of lineage node.

Source code in grai/core/lineage/lineage_tracker.py
class NodeType(Enum):
    """Type of lineage node."""

    ENTITY = "entity"
    RELATION = "relation"
    SOURCE = "source"

build_lineage_graph(project)

Build a complete lineage graph from a project.

Parameters:

Name Type Description Default
project Project

Project to analyze

required

Returns:

Type Description
LineageGraph

LineageGraph with all entities, relations, and sources

Source code in grai/core/lineage/lineage_tracker.py
def build_lineage_graph(project: Project) -> LineageGraph:
    """
    Build a complete lineage graph from a project.

    Args:
        project: Project to analyze

    Returns:
        LineageGraph with all entities, relations, and sources
    """
    graph = LineageGraph()

    # Add entity nodes
    for entity in project.entities:
        source_config = entity.get_source_config()
        source_name = source_config.name

        node_id = f"entity:{entity.entity}"
        node = LineageNode(
            id=node_id,
            name=entity.entity,
            type=NodeType.ENTITY,
            metadata={
                "source": source_name,
                "source_type": source_config.type.value if source_config.type else None,
                "keys": entity.keys,
                "property_count": len(entity.properties),
                "description": getattr(entity, "description", None),
            },
        )
        graph.add_node(node)

        # Add source node if not exists
        source_id = f"source:{source_name}"
        if source_id not in graph.nodes:
            source_node = LineageNode(
                id=source_id,
                name=source_name,
                type=NodeType.SOURCE,
                metadata={
                    "type": "data_source",
                    "source_type": source_config.type.value if source_config.type else None,
                },
            )
            graph.add_node(source_node)

        # Add edge from source to entity
        graph.add_edge(
            LineageEdge(
                from_node=source_id,
                to_node=node_id,
                relation_type="produces",
                metadata={"keys": entity.keys},
            )
        )

    # Add relation nodes and edges
    for relation in project.relations:
        source_config = relation.get_source_config()
        source_name = source_config.name

        node_id = f"relation:{relation.relation}"
        node = LineageNode(
            id=node_id,
            name=relation.relation,
            type=NodeType.RELATION,
            metadata={
                "source": source_name,
                "source_type": source_config.type.value if source_config.type else None,
                "from_entity": relation.from_entity,
                "to_entity": relation.to_entity,
                "property_count": len(relation.properties),
                "description": getattr(relation, "description", None),
            },
        )
        graph.add_node(node)

        # Add source node if not exists
        source_id = f"source:{source_name}"
        if source_id not in graph.nodes:
            source_node = LineageNode(
                id=source_id,
                name=source_name,
                type=NodeType.SOURCE,
                metadata={
                    "type": "data_source",
                    "source_type": source_config.type.value if source_config.type else None,
                },
            )
            graph.add_node(source_node)

        # Add edge from source to relation
        graph.add_edge(
            LineageEdge(from_node=source_id, to_node=node_id, relation_type="produces", metadata={})
        )

        # Add edges from entities to relation
        from_entity_id = f"entity:{relation.from_entity}"
        to_entity_id = f"entity:{relation.to_entity}"

        graph.add_edge(
            LineageEdge(
                from_node=from_entity_id,
                to_node=node_id,
                relation_type="participates_in",
                metadata={"role": "from", "key": relation.mappings.from_key},
            )
        )

        graph.add_edge(
            LineageEdge(
                from_node=node_id,
                to_node=to_entity_id,
                relation_type="connects_to",
                metadata={"role": "to", "key": relation.mappings.to_key},
            )
        )

    return graph

calculate_impact_analysis(graph, entity_name)

Calculate the impact of changes to an entity.

Parameters:

Name Type Description Default
graph LineageGraph

Lineage graph

required
entity_name str

Name of the entity to analyze

required

Returns:

Type Description
Dict

Dictionary with impact analysis

Source code in grai/core/lineage/lineage_tracker.py
def calculate_impact_analysis(graph: LineageGraph, entity_name: str) -> Dict:
    """
    Calculate the impact of changes to an entity.

    Args:
        graph: Lineage graph
        entity_name: Name of the entity to analyze

    Returns:
        Dictionary with impact analysis
    """
    node_id = graph.entity_map.get(entity_name)
    if not node_id:
        return {"error": f"Entity '{entity_name}' not found"}

    # Find all affected entities and relations
    downstream_entities = find_downstream_entities(graph, entity_name)

    # Find affected relations
    affected_relations = set()
    for edge in graph.get_edges_from(node_id):
        to_node = graph.get_node(edge.to_node)
        if to_node and to_node.type == NodeType.RELATION:
            affected_relations.add(to_node.name)

    # Calculate impact score (simple: count of affected nodes)
    impact_score = len(downstream_entities) + len(affected_relations)

    return {
        "entity": entity_name,
        "impact_score": impact_score,
        "affected_entities": sorted(downstream_entities),
        "affected_relations": sorted(affected_relations),
        "impact_level": _calculate_impact_level(impact_score),
    }

export_lineage_to_dict(graph)

Export lineage graph to dictionary format.

Parameters:

Name Type Description Default
graph LineageGraph

Lineage graph

required

Returns:

Type Description
Dict

Dictionary representation of the graph

Source code in grai/core/lineage/lineage_tracker.py
def export_lineage_to_dict(graph: LineageGraph) -> Dict:
    """
    Export lineage graph to dictionary format.

    Args:
        graph: Lineage graph

    Returns:
        Dictionary representation of the graph
    """
    return {
        "nodes": [
            {
                "id": node.id,
                "name": node.name,
                "type": node.type.value,
                "metadata": node.metadata,
            }
            for node in graph.nodes.values()
        ],
        "edges": [
            {
                "from": edge.from_node,
                "to": edge.to_node,
                "type": edge.relation_type,
                "metadata": edge.metadata,
            }
            for edge in graph.edges
        ],
        "statistics": get_lineage_statistics(graph),
    }

find_downstream_entities(graph, entity_name, max_depth=10)

Find all downstream entities (recursive).

Parameters:

Name Type Description Default
graph LineageGraph

Lineage graph

required
entity_name str

Name of the entity

required
max_depth int

Maximum depth to traverse

10

Returns:

Type Description
Set[str]

Set of downstream entity names

Source code in grai/core/lineage/lineage_tracker.py
def find_downstream_entities(
    graph: LineageGraph, entity_name: str, max_depth: int = 10
) -> Set[str]:
    """
    Find all downstream entities (recursive).

    Args:
        graph: Lineage graph
        entity_name: Name of the entity
        max_depth: Maximum depth to traverse

    Returns:
        Set of downstream entity names
    """
    node_id = graph.entity_map.get(entity_name)
    if not node_id:
        return set()

    visited = set()
    downstream = set()

    def traverse(current_id: str, depth: int):
        if depth >= max_depth or current_id in visited:
            return

        visited.add(current_id)
        edges = graph.get_edges_from(current_id)

        for edge in edges:
            to_node = graph.get_node(edge.to_node)
            if to_node and to_node.type == NodeType.ENTITY:
                downstream.add(to_node.name)
                traverse(edge.to_node, depth + 1)
            elif to_node and to_node.type == NodeType.RELATION:
                # Traverse through relation to find entities
                traverse(edge.to_node, depth + 1)

    traverse(node_id, 0)
    return downstream

find_entity_path(graph, from_entity, to_entity)

Find shortest path between two entities.

Parameters:

Name Type Description Default
graph LineageGraph

Lineage graph

required
from_entity str

Starting entity name

required
to_entity str

Target entity name

required

Returns:

Type Description
Optional[List[str]]

List of node names representing the path, or None if no path exists

Source code in grai/core/lineage/lineage_tracker.py
def find_entity_path(graph: LineageGraph, from_entity: str, to_entity: str) -> Optional[List[str]]:
    """
    Find shortest path between two entities.

    Args:
        graph: Lineage graph
        from_entity: Starting entity name
        to_entity: Target entity name

    Returns:
        List of node names representing the path, or None if no path exists
    """
    from_id = graph.entity_map.get(from_entity)
    to_id = graph.entity_map.get(to_entity)

    if not from_id or not to_id:
        return None

    # BFS to find shortest path
    queue = [(from_id, [from_entity])]
    visited = {from_id}

    while queue:
        current_id, path = queue.pop(0)

        if current_id == to_id:
            return path

        # Check outgoing edges
        for edge in graph.get_edges_from(current_id):
            if edge.to_node not in visited:
                visited.add(edge.to_node)
                node = graph.get_node(edge.to_node)
                queue.append((edge.to_node, path + [node.name]))

    return None

find_upstream_entities(graph, entity_name, max_depth=10)

Find all upstream entities (recursive).

Parameters:

Name Type Description Default
graph LineageGraph

Lineage graph

required
entity_name str

Name of the entity

required
max_depth int

Maximum depth to traverse

10

Returns:

Type Description
Set[str]

Set of upstream entity names

Source code in grai/core/lineage/lineage_tracker.py
def find_upstream_entities(graph: LineageGraph, entity_name: str, max_depth: int = 10) -> Set[str]:
    """
    Find all upstream entities (recursive).

    Args:
        graph: Lineage graph
        entity_name: Name of the entity
        max_depth: Maximum depth to traverse

    Returns:
        Set of upstream entity names
    """
    node_id = graph.entity_map.get(entity_name)
    if not node_id:
        return set()

    visited = set()
    upstream = set()

    def traverse(current_id: str, depth: int):
        if depth >= max_depth or current_id in visited:
            return

        visited.add(current_id)
        edges = graph.get_edges_to(current_id)

        for edge in edges:
            from_node = graph.get_node(edge.from_node)
            if from_node and from_node.type == NodeType.ENTITY:
                upstream.add(from_node.name)
                traverse(edge.from_node, depth + 1)
            elif from_node and from_node.type == NodeType.RELATION:
                # Traverse through relation to find entities
                traverse(edge.from_node, depth + 1)

    traverse(node_id, 0)
    return upstream

get_entity_lineage(graph, entity_name)

Get complete lineage information for an entity.

Parameters:

Name Type Description Default
graph LineageGraph

Lineage graph

required
entity_name str

Name of the entity

required

Returns:

Type Description
Dict

Dictionary with lineage information

Source code in grai/core/lineage/lineage_tracker.py
def get_entity_lineage(graph: LineageGraph, entity_name: str) -> Dict:
    """
    Get complete lineage information for an entity.

    Args:
        graph: Lineage graph
        entity_name: Name of the entity

    Returns:
        Dictionary with lineage information
    """
    node_id = graph.entity_map.get(entity_name)
    if not node_id:
        return {"error": f"Entity '{entity_name}' not found"}

    node = graph.get_node(node_id)

    # Get upstream (sources)
    upstream_edges = graph.get_edges_to(node_id)
    upstream = [
        {
            "node": graph.get_node(edge.from_node).name,
            "type": graph.get_node(edge.from_node).type.value,
            "relation": edge.relation_type,
        }
        for edge in upstream_edges
    ]

    # Get downstream (relations)
    downstream_edges = graph.get_edges_from(node_id)
    downstream = [
        {
            "node": graph.get_node(edge.to_node).name,
            "type": graph.get_node(edge.to_node).type.value,
            "relation": edge.relation_type,
        }
        for edge in downstream_edges
    ]

    return {
        "entity": entity_name,
        "source": node.metadata.get("source"),
        "upstream": upstream,
        "downstream": downstream,
        "metadata": node.metadata,
    }

get_lineage_statistics(graph)

Get statistics about the lineage graph.

Parameters:

Name Type Description Default
graph LineageGraph

Lineage graph

required

Returns:

Type Description
Dict

Dictionary with statistics

Source code in grai/core/lineage/lineage_tracker.py
def get_lineage_statistics(graph: LineageGraph) -> Dict:
    """
    Get statistics about the lineage graph.

    Args:
        graph: Lineage graph

    Returns:
        Dictionary with statistics
    """
    entity_count = len([n for n in graph.nodes.values() if n.type == NodeType.ENTITY])
    relation_count = len([n for n in graph.nodes.values() if n.type == NodeType.RELATION])
    source_count = len([n for n in graph.nodes.values() if n.type == NodeType.SOURCE])

    # Calculate connectivity
    max_downstream = 0
    most_connected_entity = None

    for entity_name in graph.entity_map.keys():
        downstream = find_downstream_entities(graph, entity_name)
        if len(downstream) > max_downstream:
            max_downstream = len(downstream)
            most_connected_entity = entity_name

    return {
        "total_nodes": len(graph.nodes),
        "total_edges": len(graph.edges),
        "entity_count": entity_count,
        "relation_count": relation_count,
        "source_count": source_count,
        "max_downstream_connections": max_downstream,
        "most_connected_entity": most_connected_entity,
    }

get_relation_lineage(graph, relation_name)

Get complete lineage information for a relation.

Parameters:

Name Type Description Default
graph LineageGraph

Lineage graph

required
relation_name str

Name of the relation

required

Returns:

Type Description
Dict

Dictionary with lineage information

Source code in grai/core/lineage/lineage_tracker.py
def get_relation_lineage(graph: LineageGraph, relation_name: str) -> Dict:
    """
    Get complete lineage information for a relation.

    Args:
        graph: Lineage graph
        relation_name: Name of the relation

    Returns:
        Dictionary with lineage information
    """
    node_id = graph.relation_map.get(relation_name)
    if not node_id:
        return {"error": f"Relation '{relation_name}' not found"}

    node = graph.get_node(node_id)

    # Get upstream (sources and entities)
    upstream_edges = graph.get_edges_to(node_id)
    upstream = [
        {
            "node": graph.get_node(edge.from_node).name,
            "type": graph.get_node(edge.from_node).type.value,
            "relation": edge.relation_type,
        }
        for edge in upstream_edges
    ]

    # Get downstream (entities)
    downstream_edges = graph.get_edges_from(node_id)
    downstream = [
        {
            "node": graph.get_node(edge.to_node).name,
            "type": graph.get_node(edge.to_node).type.value,
            "relation": edge.relation_type,
        }
        for edge in downstream_edges
    ]

    return {
        "relation": relation_name,
        "source": node.metadata.get("source"),
        "from_entity": node.metadata.get("from_entity"),
        "to_entity": node.metadata.get("to_entity"),
        "upstream": upstream,
        "downstream": downstream,
        "metadata": node.metadata,
    }

visualize_lineage_graphviz(graph, focus_entity=None)

Generate Graphviz DOT representation of lineage.

Parameters:

Name Type Description Default
graph LineageGraph

Lineage graph

required
focus_entity Optional[str]

Optional entity to focus on (shows only related nodes)

None

Returns:

Type Description
str

Graphviz DOT diagram as string

Source code in grai/core/lineage/lineage_tracker.py
def visualize_lineage_graphviz(graph: LineageGraph, focus_entity: Optional[str] = None) -> str:
    """
    Generate Graphviz DOT representation of lineage.

    Args:
        graph: Lineage graph
        focus_entity: Optional entity to focus on (shows only related nodes)

    Returns:
        Graphviz DOT diagram as string
    """
    lines = ["digraph lineage {"]
    lines.append("    rankdir=LR;")
    lines.append("    node [shape=box, style=rounded];")

    # Filter nodes if focus entity specified
    if focus_entity:
        node_id = graph.entity_map.get(focus_entity)
        if node_id:
            # Get related nodes
            related_ids = {node_id}
            for edge in graph.edges:
                if edge.from_node == node_id:
                    related_ids.add(edge.to_node)
                if edge.to_node == node_id:
                    related_ids.add(edge.from_node)

            nodes_to_show = {nid: graph.nodes[nid] for nid in related_ids if nid in graph.nodes}
            edges_to_show = [
                e for e in graph.edges if e.from_node in related_ids and e.to_node in related_ids
            ]
        else:
            nodes_to_show = graph.nodes
            edges_to_show = graph.edges
    else:
        nodes_to_show = graph.nodes
        edges_to_show = graph.edges

    # Add node definitions with styling
    for node in nodes_to_show.values():
        node_id = node.id.replace(":", "_")
        if node.type == NodeType.ENTITY:
            lines.append(
                f'    {node_id} [label="{node.name}", fillcolor="#e1f5ff", style="filled,rounded"];'
            )
        elif node.type == NodeType.RELATION:
            lines.append(
                f'    {node_id} [label="{node.name}", shape=diamond, fillcolor="#fff9c4", style="filled"];'
            )
        elif node.type == NodeType.SOURCE:
            lines.append(
                f'    {node_id} [label="{node.name}", shape=cylinder, fillcolor="#f3e5f5", style="filled"];'
            )

    # Add edges
    for edge in edges_to_show:
        from_id = edge.from_node.replace(":", "_")
        to_id = edge.to_node.replace(":", "_")
        lines.append(f'    {from_id} -> {to_id} [label="{edge.relation_type}"];')

    lines.append("}")
    return "\n".join(lines)

visualize_lineage_mermaid(graph, focus_entity=None)

Generate Mermaid diagram representation of lineage.

Parameters:

Name Type Description Default
graph LineageGraph

Lineage graph

required
focus_entity Optional[str]

Optional entity to focus on (shows only related nodes)

None

Returns:

Type Description
str

Mermaid diagram as string

Source code in grai/core/lineage/lineage_tracker.py
def visualize_lineage_mermaid(graph: LineageGraph, focus_entity: Optional[str] = None) -> str:
    """
    Generate Mermaid diagram representation of lineage.

    Args:
        graph: Lineage graph
        focus_entity: Optional entity to focus on (shows only related nodes)

    Returns:
        Mermaid diagram as string
    """
    lines = ["graph LR"]

    # Filter nodes if focus entity specified
    if focus_entity:
        node_id = graph.entity_map.get(focus_entity)
        if node_id:
            # Get related nodes
            related_ids = {node_id}
            for edge in graph.edges:
                if edge.from_node == node_id:
                    related_ids.add(edge.to_node)
                if edge.to_node == node_id:
                    related_ids.add(edge.from_node)

            nodes_to_show = {nid: graph.nodes[nid] for nid in related_ids if nid in graph.nodes}
            edges_to_show = [
                e for e in graph.edges if e.from_node in related_ids and e.to_node in related_ids
            ]
        else:
            nodes_to_show = graph.nodes
            edges_to_show = graph.edges
    else:
        nodes_to_show = graph.nodes
        edges_to_show = graph.edges

    # Add node definitions with styling
    for node in nodes_to_show.values():
        node.name.replace(" ", "_")
        if node.type == NodeType.ENTITY:
            lines.append(f'    {node.id.replace(":", "_")}["{node.name}"]')
            lines.append(f'    style {node.id.replace(":", "_")} fill:#e1f5ff,stroke:#0288d1')
        elif node.type == NodeType.RELATION:
            lines.append(f'    {node.id.replace(":", "_")}{{"{node.name}"}}')
            lines.append(f'    style {node.id.replace(":", "_")} fill:#fff9c4,stroke:#f57f17')
        elif node.type == NodeType.SOURCE:
            lines.append(f'    {node.id.replace(":", "_")}[("{node.name}")]')
            lines.append(f'    style {node.id.replace(":", "_")} fill:#f3e5f5,stroke:#7b1fa2')

    # Add edges
    for edge in edges_to_show:
        from_id = edge.from_node.replace(":", "_")
        to_id = edge.to_node.replace(":", "_")
        lines.append(f"    {from_id} -->|{edge.relation_type}| {to_id}")

    return "\n".join(lines)

Lineage Tracker

from grai.core.lineage.lineage_tracker import LineageTracker

tracker = LineageTracker()

# Build lineage graph
lineage = tracker.build_lineage(project)

# Export as Mermaid
mermaid = tracker.export_mermaid(lineage)

# Export as DOT
dot = tracker.export_dot(lineage)

# Export as JSON
json_data = tracker.export_json(lineage)

Visualizer

Interactive visualization module for knowledge graphs.

Provides HTML-based interactive visualizations using D3.js and other web technologies.

generate_cytoscape_visualization(project, output_path, title=None, width=1200, height=800)

Generate interactive Cytoscape.js visualization of the knowledge graph.

Creates an HTML file with an interactive graph using Cytoscape.js.

Parameters:

Name Type Description Default
project Project

The Project to visualize

required
output_path Path

Path to save the HTML file

required
title Optional[str]

Optional title for the visualization (defaults to project name)

None
width int

Width of the visualization canvas in pixels

1200
height int

Height of the visualization canvas in pixels

800
Example

from grai.core.parser.yaml_parser import load_project project = load_project(Path(".")) generate_cytoscape_visualization(project, Path("graph.html"))

Source code in grai/core/visualizer/__init__.py
def generate_cytoscape_visualization(
    project: Project,
    output_path: Path,
    title: Optional[str] = None,
    width: int = 1200,
    height: int = 800,
) -> None:
    """
    Generate interactive Cytoscape.js visualization of the knowledge graph.

    Creates an HTML file with an interactive graph using Cytoscape.js.

    Args:
        project: The Project to visualize
        output_path: Path to save the HTML file
        title: Optional title for the visualization (defaults to project name)
        width: Width of the visualization canvas in pixels
        height: Height of the visualization canvas in pixels

    Example:
        >>> from grai.core.parser.yaml_parser import load_project
        >>> project = load_project(Path("."))
        >>> generate_cytoscape_visualization(project, Path("graph.html"))
    """
    # Build lineage graph
    graph = build_lineage_graph(project)
    graph_data = export_lineage_to_dict(graph)
    stats = get_lineage_statistics(graph)

    # Use project name as default title
    if title is None:
        title = project.name

    # Generate HTML with embedded Cytoscape.js visualization
    html_content = _generate_cytoscape_html(
        title=title,
        graph_data=graph_data,
        stats=stats,
        width=width,
        height=height,
    )

    # Write to file
    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_text(html_content, encoding="utf-8")

generate_d3_visualization(project, output_path, title=None, width=1200, height=800)

Generate interactive D3.js visualization of the knowledge graph.

Creates an HTML file with an interactive force-directed graph using D3.js.

Parameters:

Name Type Description Default
project Project

The Project to visualize

required
output_path Path

Path to save the HTML file

required
title Optional[str]

Optional title for the visualization (defaults to project name)

None
width int

Width of the visualization canvas in pixels

1200
height int

Height of the visualization canvas in pixels

800
Example

from grai.core.parser.yaml_parser import load_project project = load_project(Path(".")) generate_d3_visualization(project, Path("graph.html"))

Source code in grai/core/visualizer/__init__.py
def generate_d3_visualization(
    project: Project,
    output_path: Path,
    title: Optional[str] = None,
    width: int = 1200,
    height: int = 800,
) -> None:
    """
    Generate interactive D3.js visualization of the knowledge graph.

    Creates an HTML file with an interactive force-directed graph using D3.js.

    Args:
        project: The Project to visualize
        output_path: Path to save the HTML file
        title: Optional title for the visualization (defaults to project name)
        width: Width of the visualization canvas in pixels
        height: Height of the visualization canvas in pixels

    Example:
        >>> from grai.core.parser.yaml_parser import load_project
        >>> project = load_project(Path("."))
        >>> generate_d3_visualization(project, Path("graph.html"))
    """
    # Build lineage graph
    graph = build_lineage_graph(project)
    graph_data = export_lineage_to_dict(graph)
    stats = get_lineage_statistics(graph)

    # Use project name as default title
    if title is None:
        title = project.name

    # Generate HTML with embedded D3.js visualization
    html_content = _generate_d3_html(
        title=title,
        graph_data=graph_data,
        stats=stats,
        width=width,
        height=height,
    )

    # Write to file
    output_path.parent.mkdir(parents=True, exist_ok=True)
    output_path.write_text(html_content, encoding="utf-8")

Graph Visualizer

from grai.core.visualizer.visualizer import Visualizer

visualizer = Visualizer()

# Generate D3.js visualization
html = visualizer.generate_d3(project, output="graph.html")

# Generate Cytoscape.js visualization
html = visualizer.generate_cytoscape(project, output="graph.html")

# Generate custom visualization
html = visualizer.generate_custom(
    project,
    template="custom_template.html",
    output="graph.html"
)

CLI

Main CLI

from grai.cli.main import main_cli

# Programmatically invoke CLI
if __name__ == "__main__":
    main_cli()

Type Definitions

Common Types

from typing import Optional, List, Dict, Any

# Entity types
EntityName = str
PropertyName = str
PropertyType = str

# Source types
SourceReference = Optional[str]

# Cypher types
CypherStatement = str
CypherStatements = List[CypherStatement]

# Result types
ValidationResult = Dict[str, Any]
CompilationResult = Dict[str, Any]
ExecutionResult = Dict[str, Any]

Exceptions

Custom Exceptions

from grai.core.exceptions import (
    GraiError,
    ValidationError,
    CompilationError,
    ConnectionError,
    ExecutionError,
)

try:
    result = validator.validate(project)
except ValidationError as e:
    print(f"Validation failed: {e}")
except GraiError as e:
    print(f"General error: {e}")

Utilities

Common Utilities

from grai.core.utils import (
    load_yaml,
    write_yaml,
    ensure_dir,
    hash_file,
)

# Load YAML
data = load_yaml("grai.yml")

# Write YAML
write_yaml(data, "output.yml")

# Ensure directory exists
ensure_dir("target/neo4j")

# Hash file for caching
hash_value = hash_file("entities/customer.yml")

Usage Examples

Complete Workflow

from pathlib import Path
from grai.core.parser.yaml_parser import YAMLParser
from grai.core.validator.validator import Validator
from grai.core.compiler.cypher_compiler import CypherCompiler
from grai.core.loader.neo4j_loader import connect_neo4j, execute_cypher

# 1. Parse project
parser = YAMLParser()
project = parser.parse_project(Path.cwd())

# 2. Validate
validator = Validator()
result = validator.validate(project)

if not result.is_valid:
    for error in result.errors:
        print(f"❌ {error}")
    exit(1)

# 3. Compile
compiler = CypherCompiler()
cypher = compiler.compile_project(project)

# 4. Execute
driver = connect_neo4j(
    uri="bolt://localhost:7687",
    user="neo4j",
    password="password"
)

result = execute_cypher(driver, cypher)
print(f"✅ Executed successfully")
print(f"   Nodes created: {result.nodes_created}")
print(f"   Relationships created: {result.relationships_created}")

Data Loading Workflow

from grai.core.loader.bigquery_loader import (
    BigQueryExtractor,
    load_entity_from_bigquery,
)

# Setup connections
extractor = BigQueryExtractor(
    project_id="my-project",
    credentials_path="credentials.json"
)

driver = connect_neo4j(
    uri="bolt://localhost:7687",
    user="neo4j",
    password="password"
)

# Load each entity
for entity in project.entities:
    print(f"Loading {entity.entity}...")

    result = load_entity_from_bigquery(
        entity=entity,
        bigquery_connection=extractor,
        neo4j_connection=driver,
        batch_size=1000,
        verbose=True
    )

    if result.success:
        print(f"✅ Loaded {result.rows_processed} rows")
    else:
        print(f"❌ Failed: {result.errors}")

See Also