Skip to content

Commit

Permalink
fix: prevent call function with callee parent enviroment
Browse files Browse the repository at this point in the history
  • Loading branch information
pmqueiroz committed Sep 23, 2024
1 parent f8deebf commit fe24887
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 27 deletions.
19 changes: 19 additions & 0 deletions exception/runtime.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package exception

import (
"fmt"
)

type RuntimeError struct {
message string
}

func (e *RuntimeError) Error() string {
return fmt.Sprintf("RuntimeError: %s", e.message)
}

func NewRuntimeError(message string) error {
return &RuntimeError{
message: message,
}
}
9 changes: 9 additions & 0 deletions interpreter/environment.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
package interpreter

import (
"fmt"
"os"

"github.com/pmqueiroz/umbra/exception"
)

type Environment struct {
values map[string]interface{}
parent *Environment
Expand Down Expand Up @@ -33,6 +40,8 @@ func (env *Environment) Set(name string, value interface{}) bool {

func (env *Environment) Create(name string, value interface{}) bool {
if _, exists := env.Get(name); exists {
fmt.Println(exception.NewRuntimeError(fmt.Sprintf("variable %s already exists", name)))
os.Exit(1)
return false
}
env.values[name] = value
Expand Down
8 changes: 4 additions & 4 deletions interpreter/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,19 +82,19 @@ func Evaluate(expression ast.Expression, env *Environment) (interface{}, error)
return nil, err
}

if function, ok := callee.(ast.FunctionStatement); ok {
funcEnv := NewEnvironment(env)
if function, ok := callee.(FunctionDeclaration); ok {
funcEnv := NewEnvironment(function.Environment)

for i, arg := range expr.Arguments {
argValue, err := Evaluate(arg, env)
if err != nil {
return nil, err
}
funcEnv.Create(function.Params[i].Name.Raw.Value, argValue)
funcEnv.Create(function.Itself.Params[i].Name.Raw.Value, argValue)
}

var result interface{}
for _, stmt := range function.Body {
for _, stmt := range function.Itself.Body {
if err := Interpret(stmt, funcEnv); err != nil {
if returnValue, ok := err.(Return); ok {
result = returnValue.value
Expand Down
51 changes: 28 additions & 23 deletions interpreter/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ func (r Break) Error() string {
return "for loop break"
}

type FunctionDeclaration struct {
Itself *ast.FunctionStatement
Environment *Environment
}

func extractVarName(stmt ast.Statement) string {
switch s := stmt.(type) {
case ast.VarStatement:
Expand All @@ -30,10 +35,10 @@ func extractVarName(stmt ast.Statement) string {
}
}

func Interpret(stmt ast.Statement, env *Environment) error {
switch s := stmt.(type) {
func Interpret(statement ast.Statement, env *Environment) error {
switch stmt := statement.(type) {
case ast.PrintStatement:
value, err := Evaluate(s.Expression, env)
value, err := Evaluate(stmt.Expression, env)
if err != nil {
return err
}
Expand All @@ -50,60 +55,60 @@ func Interpret(stmt ast.Statement, env *Environment) error {
case ast.VarStatement:
var value interface{}
var err error
if s.Initializer != nil {
value, err = Evaluate(s.Initializer, env)
if stmt.Initializer != nil {
value, err = Evaluate(stmt.Initializer, env)
if err != nil {
return err
}
}
env.Create(s.Name.Raw.Value, value)
env.Create(stmt.Name.Raw.Value, value)
return nil
case ast.BlockStatement:
newEnv := NewEnvironment(env)
for _, stmt := range s.Statements {
for _, stmt := range stmt.Statements {
if err := Interpret(stmt, newEnv); err != nil {
return err
}
}
return nil
case ast.ModuleStatement:
for _, stmt := range s.Declarations {
for _, stmt := range stmt.Declarations {
if err := Interpret(stmt, env); err != nil {
return err
}
}
return nil
case ast.IfStatement:
condition, err := Evaluate(s.Condition, env)
condition, err := Evaluate(stmt.Condition, env)
if err != nil {
return err
}

if condition.(bool) {
return Interpret(s.ThenBranch, env)
} else if s.ElseBranch != nil {
return Interpret(s.ElseBranch, env)
return Interpret(stmt.ThenBranch, env)
} else if stmt.ElseBranch != nil {
return Interpret(stmt.ElseBranch, env)
}
return nil
case ast.ReturnStatement:
value, err := Evaluate(s.Value, env)
value, err := Evaluate(stmt.Value, env)
if err != nil {
return err
}
return Return{value: value}
case ast.FunctionStatement:
env.Create(s.Name.Raw.Value, s)
env.Create(stmt.Name.Raw.Value, FunctionDeclaration{Itself: &stmt, Environment: env})
return nil
case ast.ExpressionStatement:
_, err := Evaluate(s.Expression, env)
_, err := Evaluate(stmt.Expression, env)
return err
case ast.InitializedForStatement:
forEnv := NewEnvironment(env)
if err := Interpret(s.Start, forEnv); err != nil {
if err := Interpret(stmt.Start, forEnv); err != nil {
return err
}

initializedVarName := extractVarName(s.Start)
initializedVarName := extractVarName(stmt.Start)

for {
loopEnv := NewEnvironment(forEnv)
Expand All @@ -112,7 +117,7 @@ func Interpret(stmt ast.Statement, env *Environment) error {
return fmt.Errorf("control variable not found in environment: %s", initializedVarName)
}

stop, err := Evaluate(s.Stop, loopEnv)
stop, err := Evaluate(stmt.Stop, loopEnv)
if err != nil {
return err
}
Expand All @@ -128,14 +133,14 @@ func Interpret(stmt ast.Statement, env *Environment) error {
break
}

if err := Interpret(s.Body, loopEnv); err != nil {
if err := Interpret(stmt.Body, loopEnv); err != nil {
if _, ok := err.(Break); ok {
break
}
return err
}

stepValue, err := Evaluate(s.Step, loopEnv)
stepValue, err := Evaluate(stmt.Step, loopEnv)
if err != nil {
return err
}
Expand All @@ -153,7 +158,7 @@ func Interpret(stmt ast.Statement, env *Environment) error {
for {
loopEnv := NewEnvironment(env)

condition, err := Evaluate(s.Condition, loopEnv)
condition, err := Evaluate(stmt.Condition, loopEnv)
if err != nil {
return err
}
Expand All @@ -167,7 +172,7 @@ func Interpret(stmt ast.Statement, env *Environment) error {
break
}

if err := Interpret(s.Body, loopEnv); err != nil {
if err := Interpret(stmt.Body, loopEnv); err != nil {
if _, ok := err.(Break); ok {
break
}
Expand All @@ -178,6 +183,6 @@ func Interpret(stmt ast.Statement, env *Environment) error {
case ast.BreakStatement:
return Break{}
default:
return fmt.Errorf("unknown declaration: %T", stmt)
return fmt.Errorf("unknown declaration: %T", statement)
}
}

0 comments on commit fe24887

Please sign in to comment.