From 828c5ccb585209f4f7a6f1de84a69d53f76b9081 Mon Sep 17 00:00:00 2001 From: Daniel Weipert Date: Mon, 23 Sep 2024 11:13:23 +0200 Subject: ifs and imports --- henshin | 169 ++++++++++++++++++++++++++++++++++++++++++++++---- test/test.test | 21 ++++++- test/test_import.test | 4 ++ 3 files changed, 181 insertions(+), 13 deletions(-) create mode 100644 test/test_import.test diff --git a/henshin b/henshin index 44ba911..9b1c4fd 100755 --- a/henshin +++ b/henshin @@ -8,6 +8,7 @@ import ply.lex as lex import ply.yacc as yacc from enum import Enum import collections.abc +import os @@ -71,8 +72,7 @@ reserved = { "and": "AND", "or": "OR", - - "for": "LOOP", + "not": "NOT", } tokens = [ @@ -85,6 +85,13 @@ tokens = [ "OPERATOR_PIPE", "OPERATOR_PIPE_REPLACEMENT", + "COMPARE_EQUAL", + "COMPARE_NOT_EQUAL", + "COMPARE_GREATER", + "COMPARE_LESSER", + "COMPARE_GREATER_EQUAL", + "COMPARE_LESSER_EQUAL", + "PARENTHESIS_LEFT", "PARENTHESIS_RIGHT", @@ -100,7 +107,6 @@ tokens = [ "COMMA", "COLON", - "SEMICOLON", "IDENTIFIER", @@ -120,6 +126,13 @@ t_ASSIGN = r'=' t_OPERATOR_PIPE = r'=>' t_OPERATOR_PIPE_REPLACEMENT = r'\$' +t_COMPARE_EQUAL = r'==' +t_COMPARE_NOT_EQUAL = r'!=' +t_COMPARE_GREATER = r'>' +t_COMPARE_LESSER = r'<' +t_COMPARE_GREATER_EQUAL = r'>=' +t_COMPARE_LESSER_EQUAL = r'<=' + t_PARENTHESIS_LEFT = r'\(' t_PARENTHESIS_RIGHT = r'\)' @@ -135,7 +148,6 @@ t_NAMESPACE_ACCESSOR = r'\.' t_COMMA = r',' t_COLON = r':' -t_SEMICOLON = r';' @@ -208,6 +220,14 @@ class AstNodeExpressionType(Enum): ARRAY = 'array' MAP = 'map' +class AstNodeComparatorType(Enum): + EQUAL = "==" + NOT_EQUAL = "!=" + GREATER = ">" + LESSER = "<" + GREATER_EQUAL = ">=" + LESSER_EQUAL = "<=" + class AstNode: pass class AstNodeVariableDeclarationStatement(AstNode): @@ -296,6 +316,27 @@ class AstNodeNamespaceAccess(AstNode): self.left = left self.right = right +class AstNodeIf(AstNode): + def __init__(self, arms): + self.arms = arms + +class AstNodeIfArm(AstNode): + def __init__(self, comparison, body): + self.comparison = comparison + self.body = body + +class AstNodeCompare(AstNode): + def __init__(self, left, comparator, right): + self.left = left + self.comparator = comparator + self.right = right + +class AstNodeCompareExpressions(AstNode): + def __init__(self, operator, left, right): + self.operator = operator + self.left = left + self.right = right + precedence = ( ('left', 'OPERATOR_PLUS', 'OPERATOR_MINUS'), @@ -346,7 +387,8 @@ def p_program_statement(p): '''program_statement : variable_declaration_statement | variable_reassignment_statement | function_call - | pipe''' + | pipe + | if_statement''' p[0] = p[1] @@ -368,7 +410,8 @@ def p_statement(p): | function_call | type_declaration | pipe - | return_statement''' + | return_statement + | if_statement''' p[0] = p[1] @@ -632,6 +675,48 @@ def p_namespace_access(p): p[0] = AstNodeNamespaceAccess(p[1], p[3]) +def p_if_statement(p): + '''if_statement : if_arm ELSE if_statement + | if_arm''' + + arms = [p[1]] + if len(p) == 4: + arms.extend(p[3].arms) + + p[0] = AstNodeIf(arms) + +def p_if_arm(p): + '''if_arm : IF PARENTHESIS_LEFT compare_expressions PARENTHESIS_RIGHT BRACE_LEFT statements BRACE_RIGHT''' + + p[0] = AstNodeIfArm(p[3], p[6]) + +def p_compare_expression(p): + '''compare_expression : expression comparator expression''' + + p[0] = AstNodeCompare(p[1], p[2], p[3]) + +def p_compare_expressions(p): + '''compare_expressions : compare_expression AND compare_expressions + | compare_expression OR compare_expressions + | compare_expression''' + + if len(p) == 4: + p[0] = AstNodeCompareExpressions(p[2], p[1], p[3]) + else: + p[0] = p[1] + +def p_comparator(p): + '''comparator : COMPARE_EQUAL + | COMPARE_NOT_EQUAL + | COMPARE_GREATER + | COMPARE_LESSER + | COMPARE_GREATER_EQUAL + | COMPARE_LESSER_EQUAL''' + + # p[0] = AstNodeComparator(AstNodeComparatorType) + p[0] = p[1] + + def p_expression(p): '''expression : identifier | number @@ -643,6 +728,7 @@ def p_expression(p): | type_declaration | variable_type | namespace_access + | compare_expressions | expression OPERATOR_PLUS expression | expression OPERATOR_MINUS expression | expression OPERATOR_MULTIPLY expression @@ -671,9 +757,17 @@ parser = yacc.yacc(debug=True, debuglog=log) result = parser.parse(input) +def function_import(file_path): + return interpret( + parser.parse( + open(os.path.dirname(args.filename) + "/" + file_path).read() + ), + variables.copy() + ) variables = { "print": print, + "import": function_import, } types = {} @@ -789,7 +883,9 @@ def evaluate_statement(statement, context): # in-built if callable(function): - function(*[evaluate_expression(parameter, context) for parameter in statement.parameters]) + return_value = function(*[evaluate_expression(parameter, context) for parameter in statement.parameters]) + if return_value: + scoped_variables = return_value # else else: @@ -833,6 +929,16 @@ def evaluate_statement(statement, context): else: scoped_variables = right + # if + elif type(statement) is AstNodeIf: + for arm in statement.arms: + comparison = evaluate_expression(arm.comparison, context) + + if comparison: + for arm_statement in arm.body: + scoped_variables |= evaluate_statement(arm_statement, context) + break + return scoped_variables @@ -884,10 +990,14 @@ def evaluate_expression(expression, scoped_variables): else: function = scoped_variables[expression.name] - # function call + # custom function call if type(function) is AstNodeFunctionDeclaration: result = evaluate_statement(expression, scoped_variables) + # in-built function call + elif callable(function): + result = evaluate_statement(expression, scoped_variables) + # type instantiation elif type(function) is dict: instance_variables = evaluate_expression(function, scoped_variables) @@ -904,12 +1014,14 @@ def evaluate_expression(expression, scoped_variables): if result == "$": result = scoped_variables["$"] + # namespace access elif type(expression) is AstNodeNamespaceAccess: if "NAMESPACE_ACCESS" in scoped_variables: result = get_namespace_path(expression, scoped_variables) else: result = evaluate_statement(expression, scoped_variables) + # type declaration elif type(expression) is AstNodeTypeDeclaration: instance_variables = {} for node in expression.body: @@ -917,16 +1029,51 @@ def evaluate_expression(expression, scoped_variables): result = instance_variables + # comparison + elif type(expression) is AstNodeCompare: + left = evaluate_expression(expression.left, scoped_variables) + right = evaluate_expression(expression.right, scoped_variables) + + if expression.comparator == "==": + result = left == right + elif expression.comparator == "!=": + result = left != right + elif expression.comparator == ">": + result = left > right + elif expression.comparator == "<": + result = left < right + elif expression.comparator == ">=": + result = left >= right + elif expression.comparator == "<=": + result = left <= right + else: + print("ERROR: Couldn't evaluate comparison expression") + elif type(expression) is AstNodeCompareExpressions: + left = evaluate_expression(expression.left, scoped_variables) + right = evaluate_expression(expression.right, scoped_variables) + + if expression.operator == "and": + result = left and right + elif expression.operator == "or": + result = left or right + return result -if result: - for node in result: + +def interpret(parsed_result, context): + for node in parsed_result: if args.debug: print(node) - variables |= evaluate_statement(node, variables) + context |= evaluate_statement(node, context) + + return context + + +if result: + interpret(result, variables.copy()) diff --git a/test/test.test b/test/test.test index 58be29b..9fea5b2 100644 --- a/test/test.test +++ b/test/test.test @@ -28,8 +28,8 @@ const map: [string][string or integer] = [ }, ] -main("pipe1") => print($) -main("pipe2") => main($) => print($) +main("pipe1!") => print($) +main("pipe2!") => main($) => print($) const test_type: type = { const test_field: string = "test" @@ -59,3 +59,20 @@ object.test_field = "hey" object.test_function() => print($) object.nested.another_field = 5 object.nested.another_function(add = 2) => print($) + +if (henshin == 2) { + print("nice") +} +print(henshin == 3 or henshin > 1) +henshin = 4 +if (henshin == 3) { + print("shouldn't be") +} else if (henshin == 2 or henshin != 3) { + print("else will") +} else if (henshin != 3 and henshin != 2) { + print("else won't") +} + +const scoped: [string][string or function] = import("test_import.test") +print(scoped) +scoped.exported_function() => print($) diff --git a/test/test_import.test b/test/test_import.test new file mode 100644 index 0000000..99b3a08 --- /dev/null +++ b/test/test_import.test @@ -0,0 +1,4 @@ +const exported: string = "I am exported" +const exported_function: function = (): string { + return "I am also exported!" +} -- cgit v1.2.3