snakenet.rb: You can choose between using float or integer weights. floats seem to give better results and the speed didn't improve a lot.
This commit is contained in:
parent
b644006036
commit
a902addf94
@ -4,6 +4,7 @@ require 'pp'
|
||||
require 'thread/pool'
|
||||
|
||||
GAMES_PER_ROUND = 50
|
||||
FLOAT = true
|
||||
|
||||
class Game
|
||||
WIDTH = 16
|
||||
@ -177,10 +178,10 @@ class AI
|
||||
@rounds = 1
|
||||
@id = rand(0xFFFFFF)
|
||||
if w==nil
|
||||
@weights = Array.new(network_size()) { rand() * 2.0 - 1.0 }
|
||||
@weights = Array.new(network_size()) { FLOAT ? rand() * 2.0 - 1.0 : rand(256) - 128 }
|
||||
puts "Initialized with random values: #{@weights}" if @debug
|
||||
else
|
||||
if w[0].is_a? Integer
|
||||
if w[0].is_a?(Integer) && FLOAT
|
||||
@weights = w.map{|s| s.to_s(16).rjust(8, "0").split("").each_slice(2).to_a.map(&:join).map{|s| s.to_i(16).chr}.join.unpack("g")}.flatten
|
||||
else
|
||||
@weights = w
|
||||
@ -220,7 +221,7 @@ class AI
|
||||
(1...(NETWORK_LAYOUT.count)).each do |i|
|
||||
c_in = NETWORK_LAYOUT[i-1]
|
||||
c_out = NETWORK_LAYOUT[i]
|
||||
outputs = Array.new(c_out){0.0}
|
||||
outputs = Array.new(c_out){FLOAT ? 0.0 : 0}
|
||||
(0...c_out).each do |o|
|
||||
(0...c_in).each do |i|
|
||||
outputs[o] += inputs[i] * @weights[x]
|
||||
@ -253,33 +254,53 @@ class AI
|
||||
# w[i2] = temp
|
||||
if action==0 #change single value
|
||||
i = rand(network_size())
|
||||
diff = rand() * 0.2 - 0.1
|
||||
diff = FLOAT ? rand() * 0.2 - 0.1 : rand(256) - 128
|
||||
w2 = w.dup
|
||||
w[i] += diff
|
||||
w[i] = 1.0 if w[i]>1.0
|
||||
w[i] = -1.0 if w[i]<-1.0
|
||||
if FLOAT
|
||||
w[i] = 1.0 if w[i]>1.0
|
||||
w[i] = -1.0 if w[i]<-1.0
|
||||
else
|
||||
w[i] = 127 if w[i]>127
|
||||
w[i] = -128 if w[i]<-128
|
||||
end
|
||||
w2[i] -= diff
|
||||
w2[i] = 1.0 if w2[i]>1.0
|
||||
w2[i] = -1.0 if w2[i]<-1.0
|
||||
if FLOAT
|
||||
w2[i] = 1.0 if w2[i]>1.0
|
||||
w2[i] = -1.0 if w2[i]<-1.0
|
||||
else
|
||||
w2[i] = 127 if w2[i]>127
|
||||
w2[i] = -128 if w2[i]<-128
|
||||
end
|
||||
return [AI.new(w), AI.new(w2)]
|
||||
elsif action==1 #invert single value
|
||||
i = rand(network_size())
|
||||
w[i] *= -1.0
|
||||
w[i] *= FLOAT ? -1.0 : -1
|
||||
elsif action==2
|
||||
(0...network_size()).each do |i|
|
||||
w[i] = rand() * 2 - 1.0 if rand(5)==0
|
||||
w[i] = (FLOAT ? rand() * 2 - 1.0 : rand(256) - 128) if rand(5)==0
|
||||
end
|
||||
else #change multiple values
|
||||
w2 = w.dup
|
||||
(0...network_size()).each do |i|
|
||||
if (rand(5)==0)
|
||||
diff = rand() * 0.2 - 0.1
|
||||
diff = FLOAT ? rand() * 0.2 - 0.1 : rand(256) - 128
|
||||
w[i] += diff
|
||||
w[i] = 1.0 if w[i]>1.0
|
||||
w[i] = -1.0 if w[i]<-1.0
|
||||
if FLOAT
|
||||
w[i] = 1.0 if w[i]>1.0
|
||||
w[i] = -1.0 if w[i]<-1.0
|
||||
else
|
||||
w[i] = 127 if w[i]>127
|
||||
w[i] = -128 if w[i]<-128
|
||||
end
|
||||
w2[i] -= diff
|
||||
w2[i] = 1.0 if w2[i]>1.0
|
||||
w2[i] = -1.0 if w2[i]<-1.0
|
||||
if FLOAT
|
||||
w2[i] = 1.0 if w2[i]>1.0
|
||||
w2[i] = -1.0 if w2[i]<-1.0
|
||||
else
|
||||
w2[i] = 127 if w2[i]>127
|
||||
w2[i] = -128 if w2[i]<-128
|
||||
end
|
||||
end
|
||||
end
|
||||
return [AI.new(w), AI.new(w2)]
|
||||
@ -303,13 +324,17 @@ class AI
|
||||
w = @weights.dup
|
||||
w2 = ai.weights
|
||||
(0...network_size()).each do |i|
|
||||
w[i] = (w[i] + w2[i]) / 2.0
|
||||
w[i] = (w[i] + w2[i]) / (FLOAT ? 2.0 : 2)
|
||||
end
|
||||
return AI.new(w)
|
||||
end
|
||||
|
||||
def dump
|
||||
puts "const uint32_t _weights[#{network_size()}] = {#{@weights.map{|x| "0x" + [x].pack('g').split("").map(&:ord).map{|i| i.to_s(16).rjust(2, '0')}.join}.join(", ")}};"
|
||||
if FLOAT
|
||||
puts "const uint32_t _weights[#{network_size()}] = {#{@weights.map{|x| "0x" + [x].pack('g').split("").map(&:ord).map{|i| i.to_s(16).rjust(2, '0')}.join}.join(", ")}};"
|
||||
else
|
||||
puts "const int8_t _weights[#{network_size()}] = {#{@weights.join(", ")}};"
|
||||
end
|
||||
#puts "Simplified: #{simplified}"
|
||||
end
|
||||
end
|
||||
|
Loading…
x
Reference in New Issue
Block a user